Imports

In [None]:
%reset

import os
import time

import numpy as np
import pandas as pd

import torch

import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torch.utils.data.dataset import Subset

from torchvision import datasets
from torchvision import transforms
from torch import tensor 

import matplotlib.pyplot as plt


# **Model settings**

In [5]:
# Hyperparameters
RANDOM_SEED = 1
LEARNING_RATE_list = [0.001,0.01, 0.1, 0.5, 1]
MOMENTUM_list =[0.8,0.85,0.9,0.92,0.95]
BATCH_SIZE = 125
NUM_EPOCHS = 10

# Architecture
NUM_FEATURES = 28*28
NUM_CLASSES = 10

# Other
GRAYSCALE = True

# MNIST Dataset

In [None]:
# Note transforms.ToTensor() scales input images
# to 0-1 range
train_dataset = datasets.MNIST(root='data', 
                               train=True, 
                               transform=transforms.ToTensor(),
                               download=True)

test_dataset = datasets.MNIST(root='data', 
                              train=False, 
                              transform=transforms.ToTensor())

#Training_set with the first 1500 images of MNIST training_set
subset_indices1=list(range(0,1500))
train_set=Subset(train_dataset,subset_indices1)

#Validation_set with the first 250 images of MNIST test_set
subset_indices2=list(range(0,250))
validation_set = Subset(test_dataset,subset_indices2)

#Test_set with the 251-499 images of MNIST test_set
subset_indices3=list(range(250,500))
test_set=Subset(test_dataset,subset_indices3)

train_loader= DataLoader(dataset=train_set, 
                          batch_size=BATCH_SIZE, 
                          shuffle=True)

validation_loader= DataLoader(dataset=validation_set, 
                         batch_size=BATCH_SIZE, 
                         shuffle=False)

test_loader= DataLoader(dataset=test_set, 
                         batch_size=BATCH_SIZE, 
                         shuffle=False)

#Model

In [7]:
def conv3x3(in_planes, out_planes, stride=1):
    """3x3 convolution with padding"""
    return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
                     padding=1, bias=False)


class BasicBlock(nn.Module):
    expansion = 1

    def __init__(self, inplanes, planes, stride=1, downsample=None):
        super(BasicBlock, self).__init__()
        self.conv1 = conv3x3(inplanes, planes, stride)
        self.bn1 = nn.BatchNorm2d(planes)
        self.relu = nn.ReLU(inplace=True)
        self.conv2 = conv3x3(planes, planes)
        self.bn2 = nn.BatchNorm2d(planes)
        self.downsample = downsample
        self.stride = stride

    def forward(self, x):
        residual = x

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)

        if self.downsample is not None:
            residual = self.downsample(x)

        out += residual
        out = self.relu(out)

        return out




class ResNet(nn.Module):

    def __init__(self, block, layers, num_classes, grayscale):
        self.inplanes = 64
        if grayscale:
            in_dim = 1
        else:
            in_dim = 3
        super(ResNet, self).__init__()
        self.conv1 = nn.Conv2d(in_dim, 64, kernel_size=7, stride=2, padding=3,
                               bias=False)
        self.bn1 = nn.BatchNorm2d(64)
        self.relu = nn.ReLU(inplace=True)
        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
        self.layer1 = self._make_layer(block, 64, layers[0])
        self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
        self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
        self.layer4 = self._make_layer(block, 512, layers[3], stride=2)
        self.avgpool = nn.AvgPool2d(7, stride=1)
        self.fc = nn.Linear(512 * block.expansion, num_classes)

        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
                m.weight.data.normal_(0, (2. / n)**.5)
            elif isinstance(m, nn.BatchNorm2d):
                m.weight.data.fill_(1)
                m.bias.data.zero_()

    def _make_layer(self, block, planes, blocks, stride=1):
        downsample = None
        if stride != 1 or self.inplanes != planes * block.expansion:
            downsample = nn.Sequential(
                nn.Conv2d(self.inplanes, planes * block.expansion,
                          kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(planes * block.expansion),
            )

        layers = []
        layers.append(block(self.inplanes, planes, stride, downsample))
        self.inplanes = planes * block.expansion
        for i in range(1, blocks):
            layers.append(block(self.inplanes, planes))

        return nn.Sequential(*layers)

    def forward(self, x):
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.maxpool(x)

        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)
        # because MNIST is already 1x1 here:
        # disable avg pooling
        #x = self.avgpool(x)
        
        x = x.view(x.size(0), -1)
        logits = self.fc(x)
        probas = F.softmax(logits, dim=1)
        return logits, probas



def resnet18(num_classes):
    """Constructs a ResNet-18 model."""
    model = ResNet(block=BasicBlock, 
                   layers=[2, 2, 2, 2],
                   num_classes=NUM_CLASSES,
                   grayscale=GRAYSCALE)
    return model

In [9]:
#function to compute accuracy
def compute_accuracy(model, data_loader):
  
    correct_pred, num_examples = 0, 0
    
    for i, (features, targets) in enumerate(data_loader):
            
        logits, probas = model(features)
        _, predicted_labels = torch.max(probas, 1)
        num_examples += targets.size(0)
        correct_pred += (predicted_labels == targets).sum()
    return correct_pred.float()/num_examples * 100

#Training and Validation

In [None]:
torch.manual_seed(RANDOM_SEED)

#SGD

In [None]:
model_SGD = resnet18(NUM_CLASSES)
accuracy_score_SGD=[]
hyperparameters_SGD=[]

for l in range(len(LEARNING_RATE_list)):

    del model_SGD
    model_SGD = resnet18(NUM_CLASSES)
    optimizer_SGD= torch.optim.SGD(model_SGD.parameters(), lr=LEARNING_RATE_list[l]) 

    start_time = time.time()
    for epoch in range(NUM_EPOCHS):

        model_SGD.train()
        for batch_idx, (features, targets) in enumerate(train_loader):

        
            
            ### FORWARD AND BACK PROP
            logits, probas = model_SGD(features)
            cost = F.cross_entropy(logits, targets)
            optimizer_SGD.zero_grad()
          
            cost.backward()
          
            ### UPDATE MODEL PARAMETERS
            optimizer_SGD.step()
          
            ### LOGGING
            if not batch_idx % 6:
                print ('Epoch: %03d/%03d | Batch %04d/%04d | Cost: %.4f' 
                    %(epoch+1, NUM_EPOCHS, batch_idx, 
                      len(train_loader), cost))
            
        model_SGD.eval()
        with torch.set_grad_enabled(False): # save memory during inference
            print('Epoch: %03d/%03d | Train: %.3f%%' % (
                  epoch+1, NUM_EPOCHS, 
                  compute_accuracy(model_SGD, train_loader)))
        
        print('Time elapsed: %.2f min' % ((time.time() - start_time)/60))
        
    print('Total Training Time: %.2f min' % ((time.time() - start_time)/60))

    with torch.set_grad_enabled(False): # save memory during inference
        accuracy_score_SGD.append(compute_accuracy(model_SGD, validation_loader))
        hyperparameters_SGD.append(LEARNING_RATE_list[l])
        print(accuracy_score_SGD)
        print(hyperparameters_SGD)

list_acc_SGD=torch.FloatTensor(accuracy_score_SGD).tolist()

#SGD with momemtum

In [None]:
model_SGD_momentum = resnet18(NUM_CLASSES)
accuracy_score_SGD_momentum=[]
hyperparameters_SGD_momentum=[]

for l in range(len(LEARNING_RATE_list)):
  for mom in range(len(MOMENTUM_list)):

    del model_SGD_momentum
    model_SGD_momentum = resnet18(NUM_CLASSES)
    optimizer_SGD_momentum = torch.optim.SGD(model_SGD_momentum.parameters(), lr=LEARNING_RATE_list[l], momentum=MOMENTUM_list[mom]) 

    start_time = time.time()
    for epoch in range(NUM_EPOCHS):

        model_SGD_momentum.train()
        for batch_idx, (features, targets) in enumerate(train_loader):
            
            ### FORWARD AND BACK PROP
            logits, probas = model_SGD_momentum(features)
            cost = F.cross_entropy(logits, targets)
            optimizer_SGD_momentum.zero_grad()
          
            cost.backward()
          
            ### UPDATE MODEL PARAMETERS
            optimizer_SGD_momentum.step()
          
            ### LOGGING
            if not batch_idx % 6:
                print ('Epoch: %03d/%03d | Batch %04d/%04d | Cost: %.4f' 
                    %(epoch+1, NUM_EPOCHS, batch_idx, 
                      len(train_loader), cost))
            
        model_SGD_momentum.eval()
        with torch.set_grad_enabled(False): # save memory during inference
            print('Epoch: %03d/%03d | Train: %.3f%%' % (
                  epoch+1, NUM_EPOCHS, 
                  compute_accuracy(model_SGD_momentum, train_loader)))
        
        print('Time elapsed: %.2f min' % ((time.time() - start_time)/60))
        
    print('Total Training Time: %.2f min' % ((time.time() - start_time)/60))

    with torch.set_grad_enabled(False): # save memory during inference
        accuracy_score_SGD_momentum.append(compute_accuracy(model_SGD_momentum, validation_loader))
        hyperparameters_SGD_momentum.append([LEARNING_RATE_list[l],MOMENTUM_list[mom]])
        print(accuracy_score_SGD_momentum)
        print(hyperparameters_SGD_momentum)

list_acc_SGD_momentum=torch.FloatTensor(accuracy_score_SGD_momentum).tolist()                

#Plots

SGD

In [None]:
plt.plot(LEARNING_RATE_list, list_acc_SGD, label='SGD')
plt.ylabel('Accuracy (%)')
plt.xlabel('Learning rate $\gamma$')
plt.legend(loc='lower center')
plt.grid()
plt.show()

SGD with momentum

In [None]:
plt.plot(MOMENTUM_list, list_acc_SGD_momentum[0:5], label=r'$\gamma=0.001$ ')
plt.plot(MOMENTUM_list, list_acc_SGD_momentum[5:10], label= r'$\gamma=0.01$')
plt.plot(MOMENTUM_list, list_acc_SGD_momentum[10:15], label=r'$\gamma = 0.1$')
plt.plot(MOMENTUM_list, list_acc_SGD_momentum[15:20], label=r'$\gamma = 0.5$')
plt.plot(MOMENTUM_list, list_acc_SGD_momentum[20:25], label=r'$\gamma = 1$')
plt.ylabel('Accuracy (%)')
plt.xlabel(r'Momentum $ \beta $')
plt.legend(loc='center')
plt.grid()
plt.show()