In [10]:
#import modules to obtain MNIST dataset 
import torch 
import torch.nn as nn
import torchvision
from torchvision import transforms

In [27]:
#OOP Style model implementation 
class LeNet(nn.Module):
  
  #Initialize: Define the layers in the network 
  def __init__(self):
    super(LeNet, self).__init__()
    self.conv1 = nn.Sequential(nn.Conv2d(1,6,kernel_size = 5), nn.BatchNorm2d(6), nn.ReLU(), nn.AvgPool2d(kernel_size=2, stride=2))
    self.conv2 = nn.Sequential(nn.Conv2d(6,16,kernel_size = 5), nn.BatchNorm2d(16), nn.ReLU(), nn.AvgPool2d(kernel_size=2, stride=2))
    self.fc1 = nn.Sequential(nn.Linear(400,120)) 
    self.relu = nn.ReLU()
    self.fc2 = nn.Sequential(nn.Linear(120, 84))
    self.relu1 = nn.ReLU()
    self.fc3 = nn.Sequential(nn.Linear(84, 10))
    self.softmax = nn.Softmax()

  #Using the layers initialized, build a computational graph and forward pass 
  def forward(self, x):
    out = self.conv1(x)
    out = self.conv2(out)
    out = out.reshape(out.shape[0], -1)
    out = self.fc1(out)
    out = self.relu(out)
    out = self.fc2(out)
    out = self.relu1(out)
    out = self.fc3(out)
    return out

In [31]:
#Loading the dataset and preprocessing
batch_size = 128
train_dataset = torchvision.datasets.MNIST(root = './data',
                                           train = True,
                                           transform = transforms.Compose([
                                                  transforms.Resize((32,32)),
                                                  transforms.ToTensor(),
                                                  transforms.Normalize(mean = (0.1307,), std = (0.3081,))]),
                                           download = True)
test_dataset = torchvision.datasets.MNIST(root = './data',
                                          train = False,
                                          transform = transforms.Compose([
                                                  transforms.Resize((32,32)),
                                                  transforms.ToTensor(),
                                                  transforms.Normalize(mean = (0.1325,), std = (0.3105,))]),
                                          download=True)

train_loader = torch.utils.data.DataLoader(dataset = train_dataset,
                                           batch_size = batch_size,
                                           shuffle = True)
test_loader = torch.utils.data.DataLoader(dataset = test_dataset,
                                           batch_size = batch_size,
                                           shuffle = True)

In [32]:
#Mount model onto GPU 
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = LeNet().to(device)

#pick loss function and optimizer 
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr = 1e-3)
#Training loop 
for epoch in range(10):
  print("Epoch {} / 10".format(epoch))
  print("*"*10)
  correct = 0
  for batch_idx, (x_var, y_var) in enumerate(train_loader):
    x_batch = x_var.to(device)
    #print(x_batch.size())
    y_batch = y_var.to(device)
    y_pred = model(x_batch)
    loss = criterion(y_pred, y_batch)
    optimizer.zero_grad()
    _, preds = torch.max(y_pred, 1)
    correct += torch.sum(preds == y_batch.data)
    accuracy = 100 * (correct.double() / len(train_loader.dataset))
    loss.backward()
    optimizer.step()
  print('Current Epoch loss: {:.4f}, accuracy: {:.2f} %'.format(loss, accuracy))

Epoch 0 / 10
**********
Current Epoch loss: 0.1473, accuracy: 93.59 %
Epoch 1 / 10
**********
Current Epoch loss: 0.0856, accuracy: 98.02 %
Epoch 2 / 10
**********
Current Epoch loss: 0.0101, accuracy: 98.59 %
Epoch 3 / 10
**********
Current Epoch loss: 0.0411, accuracy: 98.72 %
Epoch 4 / 10
**********
Current Epoch loss: 0.0387, accuracy: 98.88 %
Epoch 5 / 10
**********
Current Epoch loss: 0.0080, accuracy: 99.09 %
Epoch 6 / 10
**********
Current Epoch loss: 0.0354, accuracy: 99.16 %
Epoch 7 / 10
**********
Current Epoch loss: 0.0438, accuracy: 99.29 %
Epoch 8 / 10
**********
Current Epoch loss: 0.0039, accuracy: 99.31 %
Epoch 9 / 10
**********
Current Epoch loss: 0.0693, accuracy: 99.40 %


In [33]:
#Fit model and compute accuracy in the wild (test set)
correct = 0
for batch_idx_test, (x_var_test, y_var_test) in enumerate(test_loader):
  x_batch_test = x_var_test.to(device)
  y_batch_test = y_var_test.to(device)
  y_pred_test = model(x_batch_test)

  _, preds = torch.max(y_pred_test, 1)
  correct += torch.sum(preds == y_batch_test.data)

accuracy = 100 * (correct.double() / len(test_loader.dataset))

print("Test Set Accuracy: {:.2f}".format(accuracy))

Test Set Accuracy: 99.06
