In [1]:
import torch
import torchvision
import torchvision.transforms as transforms
import torch.optim as optim
import torch.nn as nn
import torchvision.datasets as datasets
from torchvision.datasets import ImageFolder
import torch.nn.functional as F
from torch.utils.data import DataLoader
from PIL import Image
import matplotlib.pyplot as plt

Training the CNN

In [2]:
n_epochs = 5
batch_size_train = 64
batch_size_test = 1000
learning_rate = 0.01
momentum = 0.9
log_interval = 10

#Normalizing the data given the global mean (0.1307) and standard deviation (0.3081)
transform = transforms.Compose(
    [transforms.ToTensor(),
     transforms.Normalize((0.1307,), (0.3081,))])

In [3]:
trainset = datasets.MNIST("/files/", train=True, download=True, transform=transform)

train_loader = DataLoader(trainset, batch_size=batch_size_train, shuffle=True)

testset = datasets.MNIST("/files/", train=False, download=True, transform=transform)

test_loader = DataLoader(testset, batch_size=batch_size_test, shuffle=True)

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to /files/MNIST/raw/train-images-idx3-ubyte.gz


  0%|          | 0/9912422 [00:00<?, ?it/s]

Extracting /files/MNIST/raw/train-images-idx3-ubyte.gz to /files/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to /files/MNIST/raw/train-labels-idx1-ubyte.gz


  0%|          | 0/28881 [00:00<?, ?it/s]

Extracting /files/MNIST/raw/train-labels-idx1-ubyte.gz to /files/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to /files/MNIST/raw/t10k-images-idx3-ubyte.gz


  0%|          | 0/1648877 [00:00<?, ?it/s]

Extracting /files/MNIST/raw/t10k-images-idx3-ubyte.gz to /files/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to /files/MNIST/raw/t10k-labels-idx1-ubyte.gz


  0%|          | 0/4542 [00:00<?, ?it/s]

Extracting /files/MNIST/raw/t10k-labels-idx1-ubyte.gz to /files/MNIST/raw



  return torch.from_numpy(parsed.astype(m[2], copy=False)).view(*s)


In [4]:
class Net(nn.Module):
  def __init__(self):
    super(Net, self).__init__()
    self.conv1 = nn.Conv2d(1, 10, 5)
    self.conv2 = nn.Conv2d(10, 20, 5)
    self.conv2_drop = nn.Dropout2d()
    self.fc1 = nn.Linear(320, 50)
    self.fc2 = nn.Linear(50, 10)

  def forward(self, x):
    x = F.relu(F.max_pool2d(self.conv1(x), 2))
    x = F.relu(F.max_pool2d(self.conv2_drop(self.conv2(x)), 2))
    x = x.view(-1, 320)
    x = F.relu(self.fc1(x))
    x = F.dropout(x, training=self.training)
    x = self.fc2(x)
    return F.log_softmax(x)

mnist_cnn = Net()

optimizer = optim.SGD(mnist_cnn.parameters(), lr=learning_rate, momentum=momentum)

train_losses = []
train_counter = []
test_losses = []
test_counter = [i * len(train_loader.dataset) for i in range (n_epochs + 1)]

In [5]:
def train(epoch): 
  mnist_cnn.train()
  for batch_idx, (data, target) in enumerate(train_loader):
    optimizer.zero_grad()
    output = mnist_cnn(data)

    loss_func = nn.CrossEntropyLoss()
    loss = loss_func(output, target)
    loss.backward()
    optimizer.step()

    #displaying the results
    if batch_idx % log_interval == 2:
      print("Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}".format(epoch, batch_idx * len(data), len(train_loader.dataset), 100. * batch_idx / len(train_loader), loss.item()))
      train_losses.append(loss.item())
      train_counter.append((batch_idx * 64) + ((epoch - 1) * len(train_loader.dataset)))

In [7]:
def test():
  mnist_cnn.eval()
  test_loss = 0
  correct = 0

  with torch.no_grad():
    for data, target in test_loader:
      output = mnist_cnn(data)
      test_loss += F.nll_loss(output, target, size_average=False).item()
      pred = output.data.max(1, keepdim=True)[1]
      correct += pred.eq(target.data.view_as(pred)).sum()

  test_loss /= len(test_loader.dataset)
  test_losses.append(test_loss)

  print("\nTest set: Avg. loss: {:.4f}, Accuracy : {}/{} ({:.0f}%)\n".format(test_loss, correct, len(test_loader.dataset), 100.*correct/len(test_loader.dataset)))

In [8]:
for epoch in range(1, n_epochs + 1):
  train(epoch)
  test()








Test set: Avg. loss: 0.1061, Accuracy : 9656/10000 (97%)


Test set: Avg. loss: 0.0711, Accuracy : 9799/10000 (98%)


Test set: Avg. loss: 0.0612, Accuracy : 9821/10000 (98%)


Test set: Avg. loss: 0.0527, Accuracy : 9840/10000 (98%)


Test set: Avg. loss: 0.0498, Accuracy : 9850/10000 (98%)



#### Generating Image Labels

In [9]:
import shutil
import os

for img in os.listdir("../input/wgan-generated-mnist/Generated_MNIST"):
    image = Image.open(f"../input/wgan-generated-mnist/Generated_MNIST/{img}").convert("L")
    
    #Normalizing the images
    transform = transforms.ToTensor()
    tensor = transform(image)
    
    normalize = transforms.Normalize((0.1307,), (0.3081,))
    img_tensor = normalize(tensor)
    
    resize =  transforms.Resize(28)
    image_tensor = resize(img_tensor)
    
    label = mnist_cnn(image_tensor[None,...])
    label = torch.argmax(label)
    
    if(os.path.isdir(f"./{label}") == False):
        os.mkdir(f"./{label}")
    
    shutil.copy(f"../input/wgan-generated-mnist/Generated_MNIST/{img}", f"./{label}/{img}")

