## Basic Imports

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F

import torch.optim as optim
from torch.optim.lr_scheduler import StepLR

import torchvision
from torchvision import datasets, transforms
from torchvision.transforms import v2
import random
import numpy as np

## SEED setup for Reproducibility and deciding device

In [2]:
SEED = 1

# Apple metal or Nvidia CUDA
use_cuda = torch.cuda.is_available()
use_mps = torch.mps.is_available()

# seed for repeatablility:

# for all devices
torch.manual_seed(SEED)

# for specific acc
if use_cuda:
    print(f"Use CUDA?:{use_cuda}")
    torch.cuda.manual_seed(SEED)
    torch.cuda.manual_seed_all(SEED)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
elif use_mps:
    print(f"Use MPS?: {use_mps}")
    torch.mps.manual_seed(SEED)
else:
    print("Using CPU")

random.seed(SEED)
np.random.seed(SEED)


if use_cuda:
    device = torch.device("cuda")
elif use_mps:
    device = torch.device("mps")
else:
    device = torch.device("cpu")

Use MPS?: True


## Transforms, datasets and dataloaders

In [3]:
# Train Phase transformations
train_transforms = transforms.Compose([
                                       transforms.ToTensor(),
                                       transforms.Normalize((0.1307,), (0.3081,)), # The mean and std have to be sequences (e.g., tuples), therefore you should add a comma after the values. 
                                       ])

# Test Phase transformations
test_transforms = transforms.Compose([
                                       transforms.ToTensor(),
                                       transforms.Normalize((0.1307,), (0.3081,))
                                       ])

train = datasets.MNIST('./data', train=True, download=True, transform=train_transforms)
test = datasets.MNIST('./data', train=False, download=True, transform=test_transforms)

use_cuda = False
use_mps = False

dataloader_args = dict(shuffle=True,
                       batch_size=128,
                       num_workers=0,
                       pin_memory=True) \
                                        if use_cuda or use_mps else \
                    dict(shuffle=True,
                         batch_size=128)


train_loader = torch.utils.data.DataLoader(train,
                                           **dataloader_args)
test_loader = torch.utils.data.DataLoader(test,
                                          **dataloader_args)


## Network

### Network 1

```Target```:
- Get the set-up right for basic NN training
    - Set Transforms
    - Set Data Loader
    - Set Basic Working Code
    - Set Basic Training  & Test Loop

  AND

- Reduce parmas < 8k
      - Use GAP instead of FCN
- Use batchnorm for efficiency


```Results```:
- Parameters: 6922
- Best Training Accuracy: 98.83
- Best Test Accuracy: 98.67

```Analysis```: Model is overfitting and needs regularisation



In [4]:

class MyNetwork1(nn.Module):
    def __init__(self):
        super().__init__()

        #INPUT BLOCK
        self.convblock1 = nn.Sequential(
            nn.Conv2d(1, 8, 3), #inch=1, outch=8, size=26, rf=3, j=1
            nn.BatchNorm2d(8),
            nn.ReLU()
        )
    
        #CONV BLOCK 1
        self.convblock2 = nn.Sequential(
            nn.Conv2d(8,8, 3), #inch=8, outch=8, size=24, rf=5, j=1
            nn.BatchNorm2d(8),
            nn.ReLU()
        )
        self.convblock3 = nn.Sequential(
            nn.Conv2d(8,8, 3), #inch=8, inout=8, size=22, rf=7, j=1
            nn.BatchNorm2d(8),
            nn.ReLU()
        )
        self.convblock4 = nn.Sequential(
            nn.Conv2d(8,16, 3), #inch=8, inout=16, size=20, rf=9, j=1
            nn.BatchNorm2d(16),
            nn.ReLU()
        )
        # TRANSITION BLOCK
        self.convblock5 = nn.Sequential(
            nn.Conv2d(16,8,3), #inch32, outch=8, size=18, rf=11, j=1
            nn.BatchNorm2d(8),
            nn.ReLU()
        )
    
        #CONV BLOCK 2
        self.convblock6 = nn.Sequential(
            nn.Conv2d(8,8,3), #inch=8, outch=8, size=16, rf=13, j=1
            nn.BatchNorm2d(8),
            nn.ReLU()
        )
        self.convblock7 = nn.Sequential(
            nn.Conv2d(8,16,3), #inch=8, outch=16, size=12, rf=17, j=1
            nn.BatchNorm2d(16),
            nn.ReLU()
        )

        # OUTPUT BLOCK
        self.convblock8 = nn.Sequential(
            nn.Conv2d(16,10,3), #inch=16, outch=10, size=10, rf=19, j=1
            # nn.ReLU()
        )
        #LINEAR LAYER
        self.gap = nn.AdaptiveAvgPool2d((1,1))

    def forward(self, x):
        # BLOCK 1
        x = self.convblock1(x)
        x = self.convblock2(x)
        x = self.convblock3(x)
        x = self.convblock4(x)
        # TRANSITION
        x = self.convblock5(x)
        # BLOCK 2
        x = self.convblock6(x)
        x = self.convblock7(x)
        # OUTPUT
        x = self.convblock8(x)
        # GAP
        x = self.gap(x)
        # here the shape of x will be (batch_size,10,14,14), before passing let's reshape it to (batch_size, 10*14*14)
        x = x.view(x.size(0),-1)
        
        return F.log_softmax(x, dim=1)



### Model summary

In [5]:
# !pip install torchsummary
model_1 = MyNetwork1()

In [6]:
from torchsummary import summary

summary(model_1, input_size=(1, 28, 28))

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1            [-1, 8, 26, 26]              80
       BatchNorm2d-2            [-1, 8, 26, 26]              16
              ReLU-3            [-1, 8, 26, 26]               0
            Conv2d-4            [-1, 8, 24, 24]             584
       BatchNorm2d-5            [-1, 8, 24, 24]              16
              ReLU-6            [-1, 8, 24, 24]               0
            Conv2d-7            [-1, 8, 22, 22]             584
       BatchNorm2d-8            [-1, 8, 22, 22]              16
              ReLU-9            [-1, 8, 22, 22]               0
           Conv2d-10           [-1, 16, 20, 20]           1,168
      BatchNorm2d-11           [-1, 16, 20, 20]              32
             ReLU-12           [-1, 16, 20, 20]               0
           Conv2d-13            [-1, 8, 18, 18]           1,160
      BatchNorm2d-14            [-1, 8,

## Train and Test loops

In [7]:
from tqdm import tqdm

# define metrics to calc

train_losses = []
train_acc = []

test_losses = []
test_acc = []


def train(model, device, train_loader, optimizer):
    # set model to train mode
    model.train()

    # tqdm iterator
    pbar = tqdm(train_loader)

    # correct and processed vars
    correct = 0
    processed = 0

    # loop on batches of data
    for batch_idx, (data,target) in enumerate(pbar):
        #send data, targte to training device
        data, target = data.to(device), target.to(device)

        # Initialize grad to zero for the fresh batch grad accumuation
        optimizer.zero_grad()

        # pred with model
        y_pred = model(data)

        # calc loss
        batch_loss = F.nll_loss(y_pred, target)
        train_losses.append(batch_loss)

        # backprop loss to calc and acc grad w.r.t loss of batch
        batch_loss.backward()
        
        # update weights as per losses seen in this batch
        optimizer.step()

        # calculate correct pred count and acc for batch
        pred_labels = y_pred.argmax(dim=1, keepdim=True)
        correct_count_batch = pred_labels.eq(target.view_as(pred_labels)).sum().item()

        # update total correct and total processed so far
        correct+= correct_count_batch
        processed+= len(data)

        # set pbar desc
        pbar.set_description(desc=f'batch Loss = {batch_loss.item()} batch_id = {batch_idx} accuracy = {100*correct/processed:.02f}'
                            )
        #append train acc
        train_acc.append(100*correct/processed)


def test(model, device, test_loader):
    # set model to eval mode
    model.eval()

    # define var to calc correct and processed
    correct = 0
    processed = 0
    test_loss = 0 # seeing loss as the code runs has no value for test

    # set a no grad context
    with torch.no_grad():
        for data,target in test_loader:
            #send data, target to device
            data, target = data.to(device), target.to(device)
    
            # do pred
            y_pred = model(data)
    
            #calc loss for batch as summed and update total test loss
            batch_loss = F.nll_loss(y_pred, target, reduction='sum').item()
            test_loss+= batch_loss
            # collect loss
            test_losses.append(batch_loss)
    
            # count correct
            pred_labels = y_pred.argmax(dim=1, keepdim=True)
            correct_batch = pred_labels.eq(target.view_as(pred_labels)).sum().item()
    
            #update correct
            correct+= correct_batch
            processed+= len(data)

    # avg loss on test makes more sense to avg it
    test_loss/= processed
    # collect avg losses
    test_losses.append(test_loss)

    print(f'\n Test set avg loss: {test_loss:.4f} \
                Accuracy: {correct}/{processed}, {100*correct/processed:.2f}'
         )

    test_acc.append(100*correct/processed)

    return test_loss, round(100*correct/processed, 1)
        

    

## Train for n Epochs

In [8]:
# train without scheduler

# initialize model on device
model = model_1.to(device)

# initialize optimizer with model params and lr
optimizer = optim.Adam(model.parameters(), lr=0.001, weight_decay=1e-4)

# Set total epochs
EPOCHS = 14

for epoch in range(EPOCHS):
    print(f'EPOCH: {epoch}')
    train(model, device, train_loader, optimizer)
    test(model, device, test_loader)


EPOCH: 0



atch Loss = 0.1570795327425003 batch_id = 468 accuracy = 82.58: 100%|██| 469/469 [00:09<00:00, 51.44it/s]


 Test set avg loss: 0.1699                 Accuracy: 9565/10000, 95.65
EPOCH: 1



atch Loss = 0.1131078377366066 batch_id = 468 accuracy = 96.06: 100%|██| 469/469 [00:08<00:00, 55.95it/s]


 Test set avg loss: 0.1111                 Accuracy: 9677/10000, 96.77
EPOCH: 2



atch Loss = 0.12369922548532486 batch_id = 468 accuracy = 97.24: 100%|█| 469/469 [00:09<00:00, 51.44it/s]


 Test set avg loss: 0.0889                 Accuracy: 9741/10000, 97.41
EPOCH: 3



atch Loss = 0.09919320791959763 batch_id = 468 accuracy = 97.73: 100%|█| 469/469 [00:08<00:00, 54.44it/s]


 Test set avg loss: 0.0896                 Accuracy: 9715/10000, 97.15
EPOCH: 4



atch Loss = 0.06370348483324051 batch_id = 468 accuracy = 97.94: 100%|█| 469/469 [00:08<00:00, 54.63it/s]


 Test set avg loss: 0.0761                 Accuracy: 9781/10000, 97.81
EPOCH: 5



atch Loss = 0.06324979662895203 batch_id = 468 accuracy = 98.22: 100%|█| 469/469 [00:08<00:00, 56.00it/s]


 Test set avg loss: 0.0574                 Accuracy: 9818/10000, 98.18
EPOCH: 6



atch Loss = 0.02165728621184826 batch_id = 468 accuracy = 98.33: 100%|█| 469/469 [00:08<00:00, 56.69it/s]


 Test set avg loss: 0.0479                 Accuracy: 9857/10000, 98.57
EPOCH: 7



atch Loss = 0.0949222519993782 batch_id = 468 accuracy = 98.40: 100%|██| 469/469 [00:08<00:00, 55.60it/s]


 Test set avg loss: 0.0604                 Accuracy: 9815/10000, 98.15
EPOCH: 8



atch Loss = 0.03512514755129814 batch_id = 468 accuracy = 98.55: 100%|█| 469/469 [00:08<00:00, 52.87it/s]


 Test set avg loss: 0.0487                 Accuracy: 9840/10000, 98.40
EPOCH: 9



atch Loss = 0.09983039647340775 batch_id = 468 accuracy = 98.54: 100%|█| 469/469 [00:09<00:00, 50.10it/s]


 Test set avg loss: 0.0404                 Accuracy: 9870/10000, 98.70
EPOCH: 10



atch Loss = 0.02062891609966755 batch_id = 468 accuracy = 98.69: 100%|█| 469/469 [00:09<00:00, 50.83it/s]


 Test set avg loss: 0.0393                 Accuracy: 9877/10000, 98.77
EPOCH: 11



atch Loss = 0.020275773480534554 batch_id = 468 accuracy = 98.73: 100%|█| 469/469 [00:09<00:00, 50.49it/s


 Test set avg loss: 0.0394                 Accuracy: 9879/10000, 98.79
EPOCH: 12



atch Loss = 0.11575008183717728 batch_id = 468 accuracy = 98.83: 100%|█| 469/469 [00:11<00:00, 41.34it/s]


 Test set avg loss: 0.0459                 Accuracy: 9863/10000, 98.63
EPOCH: 13



atch Loss = 0.05142282322049141 batch_id = 468 accuracy = 98.83: 100%|█| 469/469 [00:11<00:00, 40.96it/s]


 Test set avg loss: 0.0448                 Accuracy: 9867/10000, 98.67
