In [1]:
#import modules to obtain MNIST dataset 
import numpy as np 
import torch 
from torch.autograd import Variable
from torch.utils.data import TensorDataset, DataLoader
import torch.nn as nn
import torchvision
from torchvision import datasets, transforms
import torch.nn.functional as F

In [2]:
#OOP Style model implementation 
class LeNet(nn.Module):
  
  #Initialize: Define the layers in the network 
  def __init__(self):
    super(LeNet, self).__init__()
    self.conv1 = nn.Sequential(nn.Conv2d(1,6,kernel_size = 5), nn.BatchNorm2d(6), nn.ReLU(), nn.AvgPool2d(kernel_size=2, stride=2))
    self.conv2 = nn.Sequential(nn.Conv2d(6,16,kernel_size = 5), nn.BatchNorm2d(16), nn.ReLU(), nn.AvgPool2d(kernel_size=2, stride=2))
    self.fc1 = nn.Sequential(nn.Linear(400,120)) 
    self.relu = nn.ReLU()
    self.fc2 = nn.Sequential(nn.Linear(120, 84))
    self.relu1 = nn.ReLU()
    self.fc3 = nn.Sequential(nn.Linear(84, 10))
    self.softmax = nn.Softmax()

  #Using the layers initialized, build a computational graph and forward pass 
  def forward(self, x):
    out = self.conv1(x)
    out = self.conv2(out)
    out = out.reshape(out.shape[0], -1)
    out = self.fc1(out)
    out = self.relu(out)
    out = self.fc2(out)
    out = self.relu1(out)
    out = self.fc3(out)
    return out

In [3]:
#Loading the dataset and preprocessing
batch_size = 128
train_dataset = torchvision.datasets.MNIST(root = './data',
                                           train = True,
                                           transform = transforms.Compose([
                                                  transforms.Resize((32,32)),
                                                  transforms.ToTensor(),
                                                  transforms.Normalize(mean = (0.1307,), std = (0.3081,))]),
                                           download = True)
test_dataset = torchvision.datasets.MNIST(root = './data',
                                          train = False,
                                          transform = transforms.Compose([
                                                  transforms.Resize((32,32)),
                                                  transforms.ToTensor(),
                                                  transforms.Normalize(mean = (0.1325,), std = (0.3105,))]),
                                          download=True)

train_loader = torch.utils.data.DataLoader(dataset = train_dataset,
                                           batch_size = batch_size,
                                           shuffle = True)
test_loader = torch.utils.data.DataLoader(dataset = test_dataset,
                                           batch_size = batch_size,
                                           shuffle = True)

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to ./data/MNIST/raw/train-images-idx3-ubyte.gz


  0%|          | 0/9912422 [00:00<?, ?it/s]

Extracting ./data/MNIST/raw/train-images-idx3-ubyte.gz to ./data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to ./data/MNIST/raw/train-labels-idx1-ubyte.gz


  0%|          | 0/28881 [00:00<?, ?it/s]

Extracting ./data/MNIST/raw/train-labels-idx1-ubyte.gz to ./data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to ./data/MNIST/raw/t10k-images-idx3-ubyte.gz


  0%|          | 0/1648877 [00:00<?, ?it/s]

Extracting ./data/MNIST/raw/t10k-images-idx3-ubyte.gz to ./data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to ./data/MNIST/raw/t10k-labels-idx1-ubyte.gz


  0%|          | 0/4542 [00:00<?, ?it/s]

Extracting ./data/MNIST/raw/t10k-labels-idx1-ubyte.gz to ./data/MNIST/raw



In [4]:
#Mount model onto GPU 
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = LeNet().to(device)

#pick loss function and optimizer 
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr = 1e-2)

#Training loop 
for epoch in range(10):
  print("Epoch",epoch)
  for batch_idx, (x_var, y_var) in enumerate(train_loader):
    x_batch = Variable(x_var.to(device))
    #print(x_batch.size())
    y_batch = Variable(y_var.to(device))
    y_pred = model(x_batch)
    loss = criterion(y_pred, y_batch)
    optimizer.zero_grad()
    if batch_idx%200 == 0:
      print("loss: ",loss.data)
    loss.backward()
    optimizer.step()

Epoch 0
loss:  tensor(2.3046, device='cuda:0')
loss:  tensor(2.0622, device='cuda:0')
loss:  tensor(1.0920, device='cuda:0')
Epoch 1
loss:  tensor(0.7713, device='cuda:0')
loss:  tensor(0.4185, device='cuda:0')
loss:  tensor(0.2995, device='cuda:0')
Epoch 2
loss:  tensor(0.3274, device='cuda:0')
loss:  tensor(0.2901, device='cuda:0')
loss:  tensor(0.0882, device='cuda:0')
Epoch 3
loss:  tensor(0.1898, device='cuda:0')
loss:  tensor(0.1256, device='cuda:0')
loss:  tensor(0.0534, device='cuda:0')
Epoch 4
loss:  tensor(0.0881, device='cuda:0')
loss:  tensor(0.0689, device='cuda:0')
loss:  tensor(0.0865, device='cuda:0')
Epoch 5
loss:  tensor(0.1055, device='cuda:0')
loss:  tensor(0.0752, device='cuda:0')
loss:  tensor(0.0644, device='cuda:0')
Epoch 6
loss:  tensor(0.1323, device='cuda:0')
loss:  tensor(0.0850, device='cuda:0')
loss:  tensor(0.0869, device='cuda:0')
Epoch 7
loss:  tensor(0.0919, device='cuda:0')
loss:  tensor(0.1808, device='cuda:0')
loss:  tensor(0.0580, device='cuda:0')


In [6]:
#Fit model and compute accuracy in the wild (test set)
correct = 0
for batch_idx_test, (x_var_test, y_var_test) in enumerate(test_loader):
  x_batch_test = x_var_test.to(device)
  y_batch_test = y_var_test.to(device)
  y_pred_test = model(x_batch_test)

  _, preds = torch.max(y_pred_test, 1)
  correct += torch.sum(preds == y_batch_test.data)

accuracy = 100 * (correct.double() / len(test_loader.dataset))

print("Test Set Accuracy: {:.2f}".format(accuracy))

Test Set Accuracy: 98.39
