In [None]:
import os
import numpy as np
import glob
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import Dataset, DataLoader
from torchvision.datasets import MNIST
from PIL import Image
from torch.utils.data import Subset
from torch.optim import Optimizer
from sklearn.model_selection import train_test_split
from torchvision.models import vgg16
from tqdm import tqdm
import warnings
warnings.filterwarnings('ignore')

In [None]:
SEED = 42 # For reproducibility

# Image transformer
img_transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize([0.5], [0.5])
])
# Batch size for training and testing
TRAIN_BATCHSIZE = 128
TEST_BATCHSIZE = 64

# Download dataset
train_dataset = MNIST('./data', transform=img_transform, download=True, train=True)
test_dataset = MNIST('./data', transform=img_transform, download=True, train=False)

# Dataset length
num_train = len(train_dataset)
num_test = len(test_dataset)
print(f"Num. training samples: {num_train}")
print(f"Num. test samples:     {num_test}")

# Fraction of the original train set that we want to use as validation set
val_frac = 0.2
# Number of samples of the validation set
num_val = int(num_train * val_frac) 
num_train = num_train - num_val

print(f"{num_train} samples used as train dataset")
print(f"{num_val}  samples used as val dataset")

# Split train_dataset into training and validation
train_dataset, val_dataset = torch.utils.data.random_split(train_dataset, [num_train, num_val], 
                                                           generator=torch.Generator().manual_seed(SEED))
# Build dataloaders
train_loader = DataLoader(train_dataset, batch_size=TRAIN_BATCHSIZE, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=TEST_BATCHSIZE, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=TEST_BATCHSIZE, shuffle=False)

In [None]:
class VGGBlock(nn.Module):
    def __init__(self, in_channels, out_channels, batch_norm=False):
        super().__init__()
        conv2_params = {
            'kernel_size': (3, 3),
            'stride'     : (1, 1),
            'padding'   : 1
        }
        noop = lambda x : x

        self._batch_norm = batch_norm

        self.conv1 = nn.Conv2d(in_channels=in_channels,out_channels=out_channels , **conv2_params)
        self.bn1 = nn.BatchNorm2d(out_channels) if batch_norm else noop

        self.conv2 = nn.Conv2d(in_channels=out_channels,out_channels=out_channels, **conv2_params)
        self.bn2 = nn.BatchNorm2d(out_channels) if batch_norm else noop

        self.max_pooling = nn.MaxPool2d(kernel_size=(2, 2), stride=(2, 2))

    @property
    def batch_norm(self):
        return self._batch_norm

    def forward(self,x):
        x = self.conv1(x)
        x = self.bn1(x)
        x = F.relu(x)

        x = self.conv2(x)
        x = self.bn2(x)
        x = F.relu(x)

        x = self.max_pooling(x)
        return x
    
class VGG16(nn.Module):
    def __init__(self, config):
        super(VGG16, self).__init__()
        self.config = config
        self.in_channels, self.in_width, self.in_height = config.input_size

        self.block_1 = VGGBlock(self.in_channels, 64, batch_norm=config.batch_norm)
        self.block_2 = VGGBlock(64, 128,batch_norm=config.batch_norm)
        self.block_3 = VGGBlock(128, 256,batch_norm=config.batch_norm)
        self.block_4 = VGGBlock(256,512,batch_norm=config.batch_norm)

        self.classifier = nn.Sequential(
                nn.Linear(512, 256),
                nn.ReLU(True),
                nn.Dropout(p=0.65),
                nn.Linear(256, 128),
                nn.ReLU(True),
                nn.Dropout(p=0.65),
                nn.Linear(128, config.num_classes) 
            )

    @property
    def input_size(self):
        return self.in_channels,self.in_width,self.in_height

    def forward(self, x):

        x = self.block_1(x)
        x = self.block_2(x)
        x = self.block_3(x)
        x = self.block_4(x)
        x = x.view(x.size(0), -1)
        x = self.classifier(x)

        return x
    
    @staticmethod
    def _init_model_with_state_dict(state):
        model = VGG16(state['model_config'])
        model.load_state_dict(state['model_state_dict'])
        return model
    
    @classmethod
    def load(cls, path):
        r"""
        Loads a model with data fields and pretrained model parameters.
        Args:
            path (str):
                - a string with the shortcut name of a pretrained learner
                  to load from .pt file.
        Examples:
            >>> # model = VGG16.load('./tmp/resources/<model_name>.pt')
        """
        if os.path.exists(path):
            state = torch.load(path)
        else:
            raise FileNotFoundError(errno.ENOENT, os.strerror(errno.ENOENT), path)
        model = cls._init_model_with_state_dict(state)
        return model

In [None]:

class AdamOptimizer(Optimizer):
    """
    implements ADAM Algorithm, as a preceding step.
    """
    def __init__(self, params, lr=1e-3, betas=(0.9, 0.99), eps=1e-8, weight_decay=0):
        defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay)
        super(AdamOptimizer, self).__init__(params, defaults)
        
    def step(self):
        import math
        """
        Performs a single optimization step.
        """
        loss = None
        for group in self.param_groups:

            for p in group['params']:
                grad = p.grad.data
                state = self.state[p]

                # State initialization
                if len(state) == 0:
                    state['step'] = 0
                    # Momentum (Exponential MA of gradients)
                    state['exp_avg'] = torch.zeros_like(p.data)
                    #print(p.data.size())
                    # RMS Prop componenet. (Exponential MA of squared gradients). Denominator.
                    state['exp_avg_sq'] = torch.zeros_like(p.data)
                    
                exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']

                b1, b2 = group['betas']
                state['step'] += 1
                
                # L2 penalty. Gotta add to Gradient as well.
                if group['weight_decay'] != 0:
                    grad = grad.add(group['weight_decay'], p.data)

                # Momentum
                exp_avg = torch.mul(exp_avg, b1) + (1 - b1)*grad
                # RMS
                exp_avg_sq = torch.mul(exp_avg_sq, b2) + (1-b2)*(grad*grad)
                
                denom = exp_avg_sq.sqrt() + group['eps']

                bias_correction1 = 1 / (1 - b1 ** state['step'])
                bias_correction2 = 1 / (1 - b2 ** state['step'])
                
                adapted_learning_rate = group['lr'] * bias_correction1 / math.sqrt(bias_correction2)

                p.data = p.data - adapted_learning_rate * exp_avg / denom 

        return loss

class SgdOptimizer(Optimizer):
    """Implements SGD Algorithm
    The Nesterov version can be performed by choosing input argument
    """
    def __init__(self, 
        params, 
        lr=1e-3, 
        momentum=0, 
        dampening=0,
        weight_decay=0, 
        nesterov=False):

        defaults = dict(lr=lr, momentum=momentum,
                        dampening=dampening,
                        weight_decay=weight_decay, 
                        nesterov=nesterov)

        super(SgdOptimizer, self).__init__(params, defaults)

    def step(self):
        loss = None
        for group in self.param_groups:
            weight_decay = group['weight_decay']
            momentum = group['momentum']
            dampening = group['dampening']
            nesterov = group['nesterov']

            for p in group['params']:
                if p.grad is None:
                    continue
                d_p = p.grad.data
                if weight_decay != 0:
                    d_p.add_(weight_decay, p.data)
                if momentum != 0:
                    param_state = self.state[p]
                    if 'momentum_buffer' not in param_state:
                        buf = param_state['momentum_buffer'] = torch.zeros_like(p.data)
                        buf.mul_(momentum).add_(d_p)
                    else:
                        buf = param_state['momentum_buffer']
                        buf.mul_(momentum).add_(1 - dampening, d_p)
                    if nesterov:
                        d_p = d_p.add(momentum, buf)
                    else:
                        d_p = buf

                p.data.add_(-group['lr'], d_p)

        return loss

In [None]:
def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

class TrainingArgs:
    def __init__(self):
        self.lr = 1e-3
        self.num_epochs = 50
        self.input_size = (1,28,28)
        self.num_classes = 10
        self.optimizer = 'sgd'
        self.batch_norm = True
        self.nesterov = True
        self.momentum = 0.96
        self.weight_decay = 1e-5
        
args = TrainingArgs()

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = VGG16(args).to(device)

print(">> Model's Architecture: ")
print(model)
print(f">> Total parameters: {count_parameters(model)}")

In [None]:
criterion = nn.CrossEntropyLoss()

if args.optimizer == 'adam':
    optim = AdamOptimizer(model.parameters(), lr=args.lr, weight_decay=args.weight_decay)
else:
    optim = SgdOptimizer(model.parameters(), lr=args.lr, 
                         weight_decay=args.weight_decay, 
                         momentum=args.momentum, nesterov=args.nesterov)

In [None]:
history = {'acc': {'train': [], 'val': []}, 
    'loss': {'train': [], 'val': []}}
min_val_loss = np.inf

for epoch in range(args.num_epochs):
    # Training
    train_iterator = tqdm(train_loader, leave=True)
    running_train_loss = 0.0
    running_train_acc = 0.0
    model.train()
    for i, (images, labels) in enumerate(train_iterator):
        images = images.to(device)
        labels = labels.to(device)
        # Forward pass
        outputs = model(images)
        loss = criterion(outputs, labels)
        train_acc = ((outputs.argmax(dim=1) == labels).float().mean())        
        # Backward pass
        optim.zero_grad()
        loss.backward()
        optim.step()

        # Gather training loss and acc
        running_train_loss += loss.item()
        running_train_acc += train_acc.item()
        
        train_iterator.set_description('(Train) Epoch [{}/{}]'.format(epoch, args.num_epochs))
        train_iterator.set_postfix(train_loss=loss.item(), train_acc=train_acc.item())

    epoch_train_loss = running_train_loss/len(train_loader)
    epoch_train_acc = running_train_acc/len(train_loader)

    history['acc']['train'].append(epoch_train_acc)
    history['loss']['train'].append(epoch_train_loss)
    
    # Evaluation
    val_iterator = tqdm(val_loader, leave=True)
    running_val_loss = 0.0
    running_val_acc = 0.0
    model.eval()
    for vidx, (val_images, val_labels) in enumerate(val_iterator):
        with torch.no_grad():
            val_images = val_images.to(device)
            val_labels = val_labels.to(device)

            val_outputs = model(val_images)
            val_loss = criterion(val_outputs, val_labels)
            val_acc = ((val_outputs.argmax(dim=1) == val_labels).float().mean())

            running_val_loss += val_loss.item()
            running_val_acc += val_acc.item()
            
            val_iterator.set_description('(Val) Epoch [{}/{}]'.format(epoch, args.num_epochs))
            val_iterator.set_postfix(val_loss=val_loss.item(), val_acc=val_acc.item())
            
    epoch_val_acc = running_val_acc/len(val_loader)
    epoch_val_loss = running_val_loss/len(val_loader)

    history['acc']['val'].append(epoch_val_acc)
    history['loss']['val'].append(epoch_val_loss)
    
    print(f'>> Epoch [{epoch+1}/{args.num_epochs}]:\tTrain loss = {epoch_train_loss:.5f} | Val loss = {epoch_val_loss:.5f},\
                \t Train Acc = {epoch_train_acc:.5f} | Val Acc = {epoch_val_acc:.5f}')
    if epoch_val_loss < min_val_loss:
        min_val_loss = epoch_val_loss
        print(">> Saving The Model Checkpoint")
        torch.save(
            {
                'model_config': model.config,
                'model_state_dict': model.state_dict(),
                'optim_state_dict': optim.state_dict(),
                'history': history
            }, './vgg16-sgd-nesterov-mnist.pt'
        )

In [None]:
import matplotlib.pyplot as plt
fig, axs = plt.subplots(1,2, figsize=(20,6))

axs[0].plot(np.arange(len(history['acc']['train'])), history['acc']['train'])
axs[0].plot(np.arange(len(history['acc']['val'])), history['acc']['val'])
axs[0].set_title('Model Accuracy', fontsize = 16)
axs[0].set_ylabel('Accuracy', fontsize = 14)
axs[0].set_xlabel('Epoch', fontsize = 14)
axs[0].legend(['train-acc', 'val-acc'], loc='upper left', fontsize = 14)

axs[1].plot(np.arange(len(history['loss']['train'])), history['loss']['train'])
axs[1].plot(np.arange(len(history['loss']['val'])), history['loss']['val'])
axs[1].set_title('Model Loss', fontsize = 16)
axs[1].set_ylabel('Loss', fontsize = 14)
axs[1].set_xlabel('Epoch', fontsize = 14)
axs[1].legend(['train-loss', 'val-loss'], loc='upper left', fontsize = 14)

plt.tight_layout()
plt.show()

In [None]:
best_model = VGG16.load('./vgg16-sgd-nesterov-mnist.pt').to(device)

In [None]:
best_model.eval()
test_iterator = tqdm(test_loader, leave=True)
running_test_loss, running_test_acc = 0.0, 0.0
for tidx, (test_images, test_labels) in enumerate(test_iterator):
    with torch.no_grad():
        test_images = test_images.to(device)
        test_labels = test_labels.to(device)

        test_outputs = best_model(test_images)
        test_loss = criterion(test_outputs, test_labels)
        test_acc = ((test_outputs.argmax(dim=1) == test_labels).float().mean())

        running_test_loss += test_loss.item()
        running_test_acc += test_acc.item()

        val_iterator.set_description('(Test)')
        val_iterator.set_postfix(test_loss=test_loss.item(), test_acc=test_acc.item())

total_test_acc = running_test_acc/len(test_loader)
total_test_loss = running_test_loss/len(test_loader)

print(f'>> Result:\tTest loss = {total_test_loss:.5f} \t Test Acc = {total_test_acc:.5f}')