Evaluate the performance of different types of optimizer on a LeNet-5 network using MNIST data. At least you need to evaluate SGD, AdaGrad, RMSprop. 

In [1]:
from __future__ import print_function
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision 
import torchvision.transforms as transforms
import time

In [2]:
import numpy as np
from datetime import datetime 

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader

from torchvision import datasets, transforms

import matplotlib.pyplot as plt

# check device
#DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'

In [3]:
# Preparing for Data
print('==> Preparing data..')

"""
# Training Data augmentation
transform_train = transforms.Compose([
    transforms.RandomCrop(32, padding=4),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])
# Testing Data preparation
transform_test = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])

#classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')

"""

==> Preparing data..


"\n# Training Data augmentation\ntransform_train = transforms.Compose([\n    transforms.RandomCrop(32, padding=4),\n    transforms.RandomHorizontalFlip(),\n    transforms.ToTensor(),\n    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),\n])\n# Testing Data preparation\ntransform_test = transforms.Compose([\n    transforms.ToTensor(),\n    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),\n])\n\n#classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')\n\n"

In [4]:
# Preparing for Data
print('==> Preparing data..')

# define transforms
transforms = transforms.Compose([transforms.Resize((32, 32)),
                                 transforms.ToTensor()])


==> Preparing data..


In [5]:
#Defining the convolutional neural network
class LeNet(nn.Module):
    
    def __init__(self):
        super(LeNet, self).__init__()
        
        self.conv1 = nn.Conv2d(1, 6, 5)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.fc1 = nn.Linear(16 * 5 * 5, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)
        
        
    def forward(self, x):
        out = self.pool(F.relu(self.conv1(x)))
        out = self.pool(F.relu(self.conv2(out)))
        out = torch.flatten(out, 1) # flatten all dimensions except batch
        out = F.relu(self.fc1(out))
        out = F.relu(self.fc2(out))
        out = self.fc3(out)
 
        return out

In [6]:
model1 = LeNet()

In [7]:
def count_parameters(model):
    return sum(p.numel() for p in model1.parameters() if p.requires_grad)


print(f'The model has {count_parameters(model1):,} trainable parameters')

The model has 61,706 trainable parameters


In [8]:
from torchvision import models
from torchsummary import summary

In [9]:
print(model1)

summary(model1)

LeNet(
  (conv1): Conv2d(1, 6, kernel_size=(5, 5), stride=(1, 1))
  (pool): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (conv2): Conv2d(6, 16, kernel_size=(5, 5), stride=(1, 1))
  (fc1): Linear(in_features=400, out_features=120, bias=True)
  (fc2): Linear(in_features=120, out_features=84, bias=True)
  (fc3): Linear(in_features=84, out_features=10, bias=True)
)
Layer (type:depth-idx)                   Param #
├─Conv2d: 1-1                            156
├─MaxPool2d: 1-2                         --
├─Conv2d: 1-3                            2,416
├─Linear: 1-4                            48,120
├─Linear: 1-5                            10,164
├─Linear: 1-6                            850
Total params: 61,706
Trainable params: 61,706
Non-trainable params: 0


Layer (type:depth-idx)                   Param #
├─Conv2d: 1-1                            156
├─MaxPool2d: 1-2                         --
├─Conv2d: 1-3                            2,416
├─Linear: 1-4                            48,120
├─Linear: 1-5                            10,164
├─Linear: 1-6                            850
Total params: 61,706
Trainable params: 61,706
Non-trainable params: 0

In [10]:
########################################################################
# 3. Define a Loss function and optimizer
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
# Let's use a Classification Cross-Entropy loss and SGD with momentum.

#import torch.optim as optim

criterion = nn.CrossEntropyLoss()
#optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9)


In [11]:
def train(model, device, train_loader, optimizer, epoch):
    
    model.train()
    count = 0
    train_loss = 0
    total = 0 
    correct = 0
    
    total_step = len(train_loader)
    
    for batch_idx, (data, target) in enumerate(train_loader):
        
        data, target = data.to(device), target.to(device)
        

        #forward pass
        y_pred = model(data)

        loss = criterion(y_pred, target)

        acc = calculate_accuracy(y_pred, target)
        
        
        #Backward pass
        optimizer.zero_grad()
        
        loss.backward()

        optimizer.step()
        
        
        if batch_idx % 10 == 0:
            print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                epoch, batch_idx * len(data), len(train_loader.dataset),
                100. * batch_idx / len(train_loader), loss.item()))
            

In [12]:
def test( model, device, test_loader):
    model.eval()
    test_loss = 0
    correct = 0
    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)
            output = model(data)
            test_loss += F.nll_loss(output, target, reduction='sum').item() # sum up batch loss
            pred = output.argmax(dim=1, keepdim=True) # get the index of the max log-probability
            correct += pred.eq(target.view_as(pred)).sum().item()
    test_loss /= len(test_loader.dataset)

    print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
        test_loss, correct, len(test_loader.dataset),
        100. * correct / len(test_loader.dataset)))

In [13]:
def calculate_accuracy(y_pred, y):
    top_pred = y_pred.argmax(1, keepdim=True)
    correct = top_pred.eq(y.view_as(top_pred)).sum()
    acc = correct.float() / y.shape[0]
    return acc

In [14]:
def main():
    time0 = time.time()
    # Training settings
    batch_size = 128
    epochs = 50
    lr = 0.05
    no_cuda = True
    save_model = False
    use_cuda = not no_cuda and torch.cuda.is_available()
    torch.manual_seed(100)
    device = torch.device("cuda" if use_cuda else "cpu")
    
    trainset = torchvision.datasets.MNIST(root='mnist_data', train=True, download=True, transform=transforms)
    train_loader = torch.utils.data.DataLoader(trainset, batch_size=128, shuffle=True)
    testset = torchvision.datasets.MNIST(root='mnist_data', train=False, download=True, transform=transforms)
    test_loader = torch.utils.data.DataLoader(testset, batch_size=100, shuffle=False)

    model = LeNet().to(device)
    optimizer = optim.SGD(model.parameters(), lr=lr, momentum=0.9, weight_decay=5e-4)

    for epoch in range(1, epochs + 1):
        train( model, device, train_loader, optimizer, epoch)
        test( model, device, test_loader)

    if (save_model):
        torch.save(model.state_dict(),"cifar_lenet.pt")
    time1 = time.time() 
    print ('Traning and Testing total excution time is: %s seconds ' % (time1-time0))   
if __name__ == '__main__':
    main()

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to mnist_data\MNIST\raw\train-images-idx3-ubyte.gz


  0%|          | 0/9912422 [00:00<?, ?it/s]

Extracting mnist_data\MNIST\raw\train-images-idx3-ubyte.gz to mnist_data\MNIST\raw

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to mnist_data\MNIST\raw\train-labels-idx1-ubyte.gz


  0%|          | 0/28881 [00:00<?, ?it/s]

Extracting mnist_data\MNIST\raw\train-labels-idx1-ubyte.gz to mnist_data\MNIST\raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to mnist_data\MNIST\raw\t10k-images-idx3-ubyte.gz


  0%|          | 0/1648877 [00:00<?, ?it/s]

Extracting mnist_data\MNIST\raw\t10k-images-idx3-ubyte.gz to mnist_data\MNIST\raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to mnist_data\MNIST\raw\t10k-labels-idx1-ubyte.gz


  0%|          | 0/4542 [00:00<?, ?it/s]

Extracting mnist_data\MNIST\raw\t10k-labels-idx1-ubyte.gz to mnist_data\MNIST\raw


Test set: Average loss: -9.9974, Accuracy: 9562/10000 (96%)


Test set: Average loss: -9.7625, Accuracy: 9757/10000 (98%)


Test set: Average loss: -11.9398, Accuracy: 9812/10000 (98%)




Test set: Average loss: -12.0242, Accuracy: 9834/10000 (98%)


Test set: Average loss: -11.8336, Accuracy: 9841/10000 (98%)


Test set: Average loss: -11.0473, Accuracy: 9857/10000 (99%)




Test set: Average loss: -12.0771, Accuracy: 9850/10000 (98%)


Test set: Average loss: -12.1467, Accuracy: 9866/10000 (99%)


Test set: Average loss: -12.3412, Accuracy: 9861/10000 (99%)


Test set: Average loss: -12.4255, Accuracy: 9877/10000 (99%)




Test set: Average loss: -12.7245, Accuracy: 9874/10000 (99%)


Test set: Average loss: -12.6471, Accuracy: 9888/10000 (99%)


Test set: Average loss: -13.6070, Accuracy: 9879/10000 (99%)




Test set: Average loss: -12.5844, Accuracy: 9829/10000 (98%)


Test set: Average loss: -13.8461, Accuracy: 9883/10000 (99%)


Test set: Average loss: -13.9519, Accuracy: 9896/10000 (99%)




Test set: Average loss: -12.3107, Accuracy: 9858/10000 (99%)


Test set: Average loss: -13.8454, Accuracy: 9894/10000 (99%)


Test set: Average loss: -13.1197, Accuracy: 9897/10000 (99%)


Test set: Average loss: -12.2206, Accuracy: 9865/10000 (99%)




Test set: Average loss: -13.4151, Accuracy: 9856/10000 (99%)


Test set: Average loss: -13.9274, Accuracy: 9897/10000 (99%)


Test set: Average loss: -13.8088, Accuracy: 9871/10000 (99%)




Test set: Average loss: -13.3304, Accuracy: 9885/10000 (99%)


Test set: Average loss: -12.7554, Accuracy: 9869/10000 (99%)


Test set: Average loss: -13.5766, Accuracy: 9865/10000 (99%)


Test set: Average loss: -13.7203, Accuracy: 9882/10000 (99%)




Test set: Average loss: -13.3476, Accuracy: 9866/10000 (99%)


Test set: Average loss: -13.1564, Accuracy: 9879/10000 (99%)


Test set: Average loss: -13.9921, Accuracy: 9886/10000 (99%)




Test set: Average loss: -13.2957, Accuracy: 9886/10000 (99%)


Test set: Average loss: -14.0938, Accuracy: 9857/10000 (99%)


Test set: Average loss: -13.4150, Accuracy: 9889/10000 (99%)




Test set: Average loss: -13.5627, Accuracy: 9861/10000 (99%)


Test set: Average loss: -13.8661, Accuracy: 9883/10000 (99%)


Test set: Average loss: -13.7926, Accuracy: 9838/10000 (98%)


Test set: Average loss: -13.6301, Accuracy: 9881/10000 (99%)




Test set: Average loss: -13.7773, Accuracy: 9908/10000 (99%)


Test set: Average loss: -14.6170, Accuracy: 9870/10000 (99%)


Test set: Average loss: -13.5872, Accuracy: 9902/10000 (99%)




Test set: Average loss: -12.7501, Accuracy: 9895/10000 (99%)


Test set: Average loss: -14.5600, Accuracy: 9883/10000 (99%)


Test set: Average loss: -12.8394, Accuracy: 9899/10000 (99%)




Test set: Average loss: -13.4196, Accuracy: 9833/10000 (98%)


Test set: Average loss: -14.6497, Accuracy: 9890/10000 (99%)


Test set: Average loss: -12.7521, Accuracy: 9888/10000 (99%)


Test set: Average loss: -13.9589, Accuracy: 9876/10000 (99%)




Test set: Average loss: -14.4487, Accuracy: 9886/10000 (99%)


Test set: Average loss: -12.9762, Accuracy: 9877/10000 (99%)


Test set: Average loss: -13.5171, Accuracy: 9876/10000 (99%)

Traning and Testing total excution time is: 1094.8859181404114 seconds 


In [None]:
#%config InlineBackend.figure_format = 'retina'