
## Import libraries

In [4]:
from __future__ import print_function
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
from collections import OrderedDict
%matplotlib inline
import matplotlib.pyplot as plt

from torch.optim.lr_scheduler import StepLR, ExponentialLR, OneCycleLR, LambdaLR


In [5]:
class Net(nn.Module):
    def __init__(self,norm_type,dropout_value = 0.05):
        super(Net, self).__init__()
        self.conv1 = self.conv2d(1, 8, 3,norm_type,dropout_value,2)
        self.conv2 = self.conv2d(8, 16, 3,norm_type,dropout_value,4) 
        
        #Transition Block
        self.trans1 = nn.Sequential(
            
            nn.MaxPool2d(2, 2), #  Input 24x24 output 12x12 RF : 6x6
            nn.Conv2d(in_channels=16, out_channels=8, kernel_size=(1, 1), padding=0, bias=False)  # Input 12x12 output 12x12 RF : 6x6
        )
        
        self.conv3 = self.conv2d(8, 16, 3,norm_type,dropout_value,4) 
        self.conv4 = self.conv2d(16, 24, 3,norm_type,dropout_value,4)

        self.avgpool2d = nn.AvgPool2d(kernel_size=6)

        self.conv5 = self.conv2d(24, 32, 1,norm_type,dropout_value,4) 
        self.conv6 =  self.conv2d(32, 16, 1,norm_type,dropout_value,4) 
        self.conv7 = nn.Conv2d(in_channels=16, out_channels=10, kernel_size=(1, 1), padding=0, bias=False) 

    def conv2d(self, in_channels, out_channels, kernel_size, norm_type, dropout,num_of_groups):
        if norm_type == "BN":
         conv = nn.Sequential(OrderedDict([
                ('conv2d',nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size,padding=0, bias=False)),
                ('Relu',nn.ReLU()),
                ('BatchNorm',nn.BatchNorm2d(out_channels)),
                ('Dropout',nn.Dropout(dropout))
         ]))
        elif norm_type == "LN":
            conv = nn.Sequential(OrderedDict([
                ('conv2d',nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, padding=0, bias=False)),
                ('Relu',nn.ReLU()),
                ## When number of groups is 1, its layernorm
                ('GroupNorm',nn.GroupNorm(1,out_channels)),
                ('Dropout',nn.Dropout(dropout))
            ]))
        elif norm_type == "GN":
            conv = nn.Sequential(OrderedDict([
                ('conv2d',nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, padding=0, bias=False)),
                ('Relu',nn.ReLU()),
                ('GroupNorm',nn.GroupNorm(num_of_groups,out_channels)),
                ('Dropout',nn.Dropout(dropout))
            ]))
        else:
            conv = nn.Sequential(
                nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, padding=0, bias=False),
                nn.ReLU(),
                nn.Dropout(dropout)
            )
      
        return conv

    def forward(self, x):
        x = self.conv1(x)
        x = self.conv2(x)
        x = self.trans1(x)
        x = self.conv3(x)
        x = self.conv4(x)
        x = self.avgpool2d(x)
        x = self.conv5(x)     
        x = self.conv6(x)
        x = self.conv7(x)

        x = x.view(-1, 10)
        return F.log_softmax(x, dim=-1)

In [8]:
!pip install torchsummary
from torchsummary import summary
use_cuda = torch.cuda.is_available()
device = torch.device("cuda" if use_cuda else "cpu")
print(device)


cuda


## Specify Data Transformations

In [9]:
# Train Phase transformations
train_transforms = transforms.Compose([
                                       transforms.RandomRotation((-7.0, 7.0), fill=(1,)),
                                       transforms.RandomAffine(degrees=7, shear=10, translate=(0.1, 0.1), scale=(0.8, 1.2)),
                                       transforms.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.40, hue=0.1),
                                       transforms.ToTensor(),
                                       transforms.Normalize((0.1307,), (0.3081,))
                                       ])

# Test Phase transformations
test_transforms = transforms.Compose([
                                       transforms.ToTensor(),
                                       transforms.Normalize((0.1307,), (0.3081,))
                                       ])

## Download Dataset & Specify Transformations

In [10]:
train = datasets.MNIST('./data', train=True, download=True, transform=train_transforms)
test = datasets.MNIST('./data', train=False, download=True, transform=test_transforms)

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to ./data/MNIST/raw/train-images-idx3-ubyte.gz


  0%|          | 0/9912422 [00:00<?, ?it/s]

Extracting ./data/MNIST/raw/train-images-idx3-ubyte.gz to ./data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to ./data/MNIST/raw/train-labels-idx1-ubyte.gz


  0%|          | 0/28881 [00:00<?, ?it/s]

Extracting ./data/MNIST/raw/train-labels-idx1-ubyte.gz to ./data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to ./data/MNIST/raw/t10k-images-idx3-ubyte.gz


  0%|          | 0/1648877 [00:00<?, ?it/s]

Extracting ./data/MNIST/raw/t10k-images-idx3-ubyte.gz to ./data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to ./data/MNIST/raw/t10k-labels-idx1-ubyte.gz


  0%|          | 0/4542 [00:00<?, ?it/s]

Extracting ./data/MNIST/raw/t10k-labels-idx1-ubyte.gz to ./data/MNIST/raw



  return torch.from_numpy(parsed.astype(m[2], copy=False)).view(*s)


## Device & Dataloader Specifications

In [11]:
SEED = 1

# CUDA?
cuda = torch.cuda.is_available()
print("CUDA Available?", cuda)

# For reproducibility
torch.manual_seed(SEED)

if cuda:
    torch.cuda.manual_seed(SEED)

# dataloader arguments - something you'll fetch these from cmdprmt
dataloader_args = dict(shuffle=True, batch_size=128, num_workers=4, pin_memory=True) if cuda else dict(shuffle=True, batch_size=64)

# train dataloader
train_loader = torch.utils.data.DataLoader(train, **dataloader_args)

# test dataloader
test_loader = torch.utils.data.DataLoader(test, **dataloader_args)

CUDA Available? True


  cpuset_checked))


## Define Train & Test Functions

In [14]:
from tqdm import tqdm

train_losses = []
test_losses = []
train_acc = []
test_acc = []


## Training
def train(model, device, train_loader, optimizer, epoch, L1=L1):
  model.train()
  pbar = tqdm(train_loader)
  correct = 0
  processed = 0
  for batch_idx, (data, target) in enumerate(pbar):
    # get samples
    data, target = data.to(device), target.to(device)

    # Init
    optimizer.zero_grad()

    # Predict
    y_pred = model(data)

    # Calculate loss
    loss = F.nll_loss(y_pred, target)
    train_losses.append(loss)
    # if using L1 regularization
    l1 = 1 if L1 else 0

    if l1:
      for p in model.parameters():
        l1 += torch.norm(p)
    loss += lambda_1 * l1
    # Backpropagation
    loss.backward()
    optimizer.step()

    # Update pbar-tqdm
    
    pred = y_pred.argmax(dim=1, keepdim=True)  # get the index of the max log-probability
    correct += pred.eq(target.view_as(pred)).sum().item()
    processed += len(data)

    pbar.set_description(desc= f'Loss={loss.item()} Batch_id={batch_idx} Accuracy={100*correct/processed:0.2f}')
    train_acc.append(100*correct/processed)

## Testing
def test(model, device, test_loader):
    model.eval()
    test_loss = 0
    correct = 0
    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)
            output = model(data)
            test_loss += F.nll_loss(output, target, reduction='sum').item()  # sum up batch loss
            pred = output.argmax(dim=1, keepdim=True)  # get the index of the max log-probability
            correct += pred.eq(target.view_as(pred)).sum().item()

    test_loss /= len(test_loader.dataset)
    test_losses.append(test_loss)

    print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.2f}%)\n'.format(
        test_loss, correct, len(test_loader.dataset),
        100. * correct / len(test_loader.dataset)))
    
    test_acc.append(100. * correct / len(test_loader.dataset))

## Misclassified images
def wrong_predictions(test_loader,model,device):
  wrong_images=[]
  wrong_label=[]
  correct_label=[]
  with torch.no_grad():
    for data, target in test_loader:
      data, target = data.to(device), target.to(device)
      output = model(data)        
      pred = output.argmax(dim=1, keepdim=True)  # get the index of the max log-probability

      wrong_pred = (pred.eq(target.view_as(pred)) == False)
      wrong_images.append(data[wrong_pred])
      wrong_label.append(pred[wrong_pred])
      correct_label.append(target.view_as(pred)[wrong_pred])  
   
    return list(zip(torch.cat(wrong_images),torch.cat(wrong_label),torch.cat(correct_label)))

In [None]:
norm_options = ["BN", "GN", "LN"]

metrics = {}

for norm in norm_options:
  model = Net(norm).to(device)
  print(f"Norm Type - { norm }")

  # print summary for norm type
  summary(model, input_size=(1, 28, 28))

  train_losses = []
  test_losses = []
  train_acc = []
  test_acc = []

  lambda_1 = 0.01

  EPOCHS = 20

  model =  Net(norm).to(device)
  optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.9)
  scheduler = StepLR(optimizer, step_size=8, gamma=0.1)

  
  for epoch in range(EPOCHS):
      print("EPOCH:", epoch)

      # use L1 if norm is BN
      L1 = 1 if norm == "BN" else 0

      train(model, device, train_loader, optimizer, epoch, L1=L1)
      scheduler.step()
      test(model, device, test_loader)

Norm Type - BN
----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1            [-1, 8, 26, 26]              72
              ReLU-2            [-1, 8, 26, 26]               0
       BatchNorm2d-3            [-1, 8, 26, 26]              16
           Dropout-4            [-1, 8, 26, 26]               0
            Conv2d-5           [-1, 16, 24, 24]           1,152
              ReLU-6           [-1, 16, 24, 24]               0
       BatchNorm2d-7           [-1, 16, 24, 24]              32
           Dropout-8           [-1, 16, 24, 24]               0
         MaxPool2d-9           [-1, 16, 12, 12]               0
           Conv2d-10            [-1, 8, 12, 12]             128
           Conv2d-11           [-1, 16, 10, 10]           1,152
             ReLU-12           [-1, 16, 10, 10]               0
      BatchNorm2d-13           [-1, 16, 10, 10]              32
          Dropout-14    

  cpuset_checked))
Loss=0.7559993267059326 Batch_id=468 Accuracy=80.56: 100%|██████████| 469/469 [00:43<00:00, 10.83it/s]



Test set: Average loss: 0.1094, Accuracy: 9667/10000 (96.67%)

EPOCH: 1


Loss=0.6898101568222046 Batch_id=468 Accuracy=92.80: 100%|██████████| 469/469 [00:43<00:00, 10.74it/s]



Test set: Average loss: 0.0679, Accuracy: 9791/10000 (97.91%)

EPOCH: 2


Loss=0.5308824777603149 Batch_id=468 Accuracy=94.24: 100%|██████████| 469/469 [00:43<00:00, 10.79it/s]



Test set: Average loss: 0.0522, Accuracy: 9852/10000 (98.52%)

EPOCH: 3


Loss=0.5339038372039795 Batch_id=468 Accuracy=94.95: 100%|██████████| 469/469 [00:43<00:00, 10.80it/s]



Test set: Average loss: 0.0461, Accuracy: 9864/10000 (98.64%)

EPOCH: 4


Loss=0.42464759945869446 Batch_id=468 Accuracy=95.31: 100%|██████████| 469/469 [00:43<00:00, 10.81it/s]



Test set: Average loss: 0.0448, Accuracy: 9873/10000 (98.73%)

EPOCH: 5


Loss=0.43808770179748535 Batch_id=468 Accuracy=95.25: 100%|██████████| 469/469 [00:43<00:00, 10.89it/s]



Test set: Average loss: 0.0454, Accuracy: 9875/10000 (98.75%)

EPOCH: 6


Loss=0.4544174373149872 Batch_id=468 Accuracy=95.23: 100%|██████████| 469/469 [00:43<00:00, 10.88it/s]



Test set: Average loss: 0.0426, Accuracy: 9888/10000 (98.88%)

EPOCH: 7


Loss=0.5506535768508911 Batch_id=468 Accuracy=95.42: 100%|██████████| 469/469 [00:43<00:00, 10.84it/s]



Test set: Average loss: 0.0629, Accuracy: 9822/10000 (98.22%)

EPOCH: 8


Loss=0.38480666279792786 Batch_id=468 Accuracy=96.65: 100%|██████████| 469/469 [00:43<00:00, 10.87it/s]



Test set: Average loss: 0.0292, Accuracy: 9926/10000 (99.26%)

EPOCH: 9


Loss=0.45790666341781616 Batch_id=468 Accuracy=97.03: 100%|██████████| 469/469 [00:43<00:00, 10.76it/s]



Test set: Average loss: 0.0282, Accuracy: 9920/10000 (99.20%)

EPOCH: 10


Loss=0.318374902009964 Batch_id=468 Accuracy=97.17: 100%|██████████| 469/469 [00:43<00:00, 10.81it/s]



Test set: Average loss: 0.0253, Accuracy: 9928/10000 (99.28%)

EPOCH: 11


Loss=0.3942092955112457 Batch_id=468 Accuracy=97.28: 100%|██████████| 469/469 [00:42<00:00, 10.92it/s]



Test set: Average loss: 0.0256, Accuracy: 9925/10000 (99.25%)

EPOCH: 12


Loss=0.4375170171260834 Batch_id=468 Accuracy=97.13: 100%|██████████| 469/469 [00:42<00:00, 10.92it/s]



Test set: Average loss: 0.0254, Accuracy: 9935/10000 (99.35%)

EPOCH: 13


Loss=0.30982303619384766 Batch_id=468 Accuracy=97.24: 100%|██████████| 469/469 [00:42<00:00, 10.95it/s]



Test set: Average loss: 0.0267, Accuracy: 9930/10000 (99.30%)

EPOCH: 14


Loss=0.3725537359714508 Batch_id=468 Accuracy=97.34: 100%|██████████| 469/469 [00:43<00:00, 10.87it/s]



Test set: Average loss: 0.0258, Accuracy: 9925/10000 (99.25%)

EPOCH: 15


Loss=0.33952513337135315 Batch_id=468 Accuracy=97.31: 100%|██████████| 469/469 [00:43<00:00, 10.91it/s]



Test set: Average loss: 0.0249, Accuracy: 9937/10000 (99.37%)

EPOCH: 16


Loss=0.3486785292625427 Batch_id=468 Accuracy=97.40: 100%|██████████| 469/469 [00:43<00:00, 10.90it/s]



Test set: Average loss: 0.0233, Accuracy: 9941/10000 (99.41%)

EPOCH: 17


Loss=0.285395085811615 Batch_id=468 Accuracy=97.53: 100%|██████████| 469/469 [00:42<00:00, 10.91it/s]



Test set: Average loss: 0.0227, Accuracy: 9943/10000 (99.43%)

EPOCH: 18


Loss=0.3175352215766907 Batch_id=468 Accuracy=97.53: 100%|██████████| 469/469 [00:43<00:00, 10.85it/s]



Test set: Average loss: 0.0236, Accuracy: 9937/10000 (99.37%)

EPOCH: 19


Loss=0.3938078284263611 Batch_id=468 Accuracy=97.50: 100%|██████████| 469/469 [00:42<00:00, 10.91it/s]



Test set: Average loss: 0.0246, Accuracy: 9934/10000 (99.34%)

Norm Type - GN
----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1            [-1, 8, 26, 26]              72
              ReLU-2            [-1, 8, 26, 26]               0
         GroupNorm-3            [-1, 8, 26, 26]              16
           Dropout-4            [-1, 8, 26, 26]               0
            Conv2d-5           [-1, 16, 24, 24]           1,152
              ReLU-6           [-1, 16, 24, 24]               0
         GroupNorm-7           [-1, 16, 24, 24]              32
           Dropout-8           [-1, 16, 24, 24]               0
         MaxPool2d-9           [-1, 16, 12, 12]               0
           Conv2d-10            [-1, 8, 12, 12]             128
           Conv2d-11           [-1, 16, 10, 10]           1,152
             ReLU-12           [-1, 16, 10, 10]               0
        GroupNorm-13    

Loss=0.5256937146186829 Batch_id=468 Accuracy=58.31: 100%|██████████| 469/469 [00:41<00:00, 11.33it/s]



Test set: Average loss: 0.3428, Accuracy: 9083/10000 (90.83%)

EPOCH: 1


Loss=0.34741151332855225 Batch_id=468 Accuracy=87.36: 100%|██████████| 469/469 [00:41<00:00, 11.30it/s]



Test set: Average loss: 0.1191, Accuracy: 9713/10000 (97.13%)

EPOCH: 2


Loss=0.07528015226125717 Batch_id=468 Accuracy=90.82: 100%|██████████| 469/469 [00:41<00:00, 11.38it/s]



Test set: Average loss: 0.0906, Accuracy: 9768/10000 (97.68%)

EPOCH: 3


Loss=0.1558024138212204 Batch_id=468 Accuracy=93.42: 100%|██████████| 469/469 [00:41<00:00, 11.37it/s]



Test set: Average loss: 0.0701, Accuracy: 9810/10000 (98.10%)

EPOCH: 4


Loss=0.2572553753852844 Batch_id=468 Accuracy=94.49: 100%|██████████| 469/469 [00:41<00:00, 11.39it/s]



Test set: Average loss: 0.0715, Accuracy: 9803/10000 (98.03%)

EPOCH: 5


Loss=0.13928858935832977 Batch_id=468 Accuracy=94.66: 100%|██████████| 469/469 [00:41<00:00, 11.34it/s]



Test set: Average loss: 0.0505, Accuracy: 9869/10000 (98.69%)

EPOCH: 6


Loss=0.13077740371227264 Batch_id=468 Accuracy=95.54: 100%|██████████| 469/469 [00:41<00:00, 11.34it/s]



Test set: Average loss: 0.0466, Accuracy: 9880/10000 (98.80%)

EPOCH: 7


Loss=0.0770954042673111 Batch_id=468 Accuracy=94.68: 100%|██████████| 469/469 [00:41<00:00, 11.41it/s]



Test set: Average loss: 0.0457, Accuracy: 9884/10000 (98.84%)

EPOCH: 8


Loss=0.21264636516571045 Batch_id=468 Accuracy=96.64: 100%|██████████| 469/469 [00:41<00:00, 11.39it/s]



Test set: Average loss: 0.0403, Accuracy: 9887/10000 (98.87%)

EPOCH: 9


Loss=0.15296092629432678 Batch_id=468 Accuracy=96.67: 100%|██████████| 469/469 [00:41<00:00, 11.37it/s]



Test set: Average loss: 0.0392, Accuracy: 9890/10000 (98.90%)

EPOCH: 10


Loss=0.07085851579904556 Batch_id=468 Accuracy=96.80: 100%|██████████| 469/469 [00:40<00:00, 11.52it/s]



Test set: Average loss: 0.0395, Accuracy: 9890/10000 (98.90%)

EPOCH: 11


Loss=0.16519315540790558 Batch_id=468 Accuracy=96.93: 100%|██████████| 469/469 [00:40<00:00, 11.50it/s]



Test set: Average loss: 0.0359, Accuracy: 9895/10000 (98.95%)

EPOCH: 12


Loss=0.11059419065713882 Batch_id=468 Accuracy=96.97: 100%|██████████| 469/469 [00:41<00:00, 11.41it/s]



Test set: Average loss: 0.0358, Accuracy: 9898/10000 (98.98%)

EPOCH: 13


Loss=0.04329981282353401 Batch_id=468 Accuracy=97.05: 100%|██████████| 469/469 [00:41<00:00, 11.42it/s]



Test set: Average loss: 0.0359, Accuracy: 9898/10000 (98.98%)

EPOCH: 14


Loss=0.027033500373363495 Batch_id=468 Accuracy=97.07: 100%|██████████| 469/469 [00:41<00:00, 11.41it/s]



Test set: Average loss: 0.0383, Accuracy: 9889/10000 (98.89%)

EPOCH: 15


Loss=0.08284071087837219 Batch_id=468 Accuracy=97.06: 100%|██████████| 469/469 [00:41<00:00, 11.33it/s]



Test set: Average loss: 0.0357, Accuracy: 9899/10000 (98.99%)

EPOCH: 16


Loss=0.15429849922657013 Batch_id=468 Accuracy=97.14: 100%|██████████| 469/469 [00:41<00:00, 11.39it/s]



Test set: Average loss: 0.0346, Accuracy: 9900/10000 (99.00%)

EPOCH: 17


Loss=0.06691044569015503 Batch_id=468 Accuracy=97.02: 100%|██████████| 469/469 [00:40<00:00, 11.49it/s]



Test set: Average loss: 0.0349, Accuracy: 9894/10000 (98.94%)

EPOCH: 18


Loss=0.08110811561346054 Batch_id=468 Accuracy=97.19: 100%|██████████| 469/469 [00:41<00:00, 11.42it/s]



Test set: Average loss: 0.0351, Accuracy: 9895/10000 (98.95%)

EPOCH: 19


Loss=0.015284097753465176 Batch_id=468 Accuracy=97.27: 100%|██████████| 469/469 [00:41<00:00, 11.38it/s]



Test set: Average loss: 0.0343, Accuracy: 9900/10000 (99.00%)

Norm Type - LN
----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1            [-1, 8, 26, 26]              72
              ReLU-2            [-1, 8, 26, 26]               0
         GroupNorm-3            [-1, 8, 26, 26]              16
           Dropout-4            [-1, 8, 26, 26]               0
            Conv2d-5           [-1, 16, 24, 24]           1,152
              ReLU-6           [-1, 16, 24, 24]               0
         GroupNorm-7           [-1, 16, 24, 24]              32
           Dropout-8           [-1, 16, 24, 24]               0
         MaxPool2d-9           [-1, 16, 12, 12]               0
           Conv2d-10            [-1, 8, 12, 12]             128
           Conv2d-11           [-1, 16, 10, 10]           1,152
             ReLU-12           [-1, 16, 10, 10]               0
        GroupNorm-13    

Loss=0.4073697626590729 Batch_id=468 Accuracy=75.64: 100%|██████████| 469/469 [00:41<00:00, 11.36it/s]



Test set: Average loss: 0.0988, Accuracy: 9754/10000 (97.54%)

EPOCH: 1


Loss=0.2997186481952667 Batch_id=468 Accuracy=92.88: 100%|██████████| 469/469 [00:40<00:00, 11.48it/s]



Test set: Average loss: 0.0693, Accuracy: 9803/10000 (98.03%)

EPOCH: 2


Loss=0.09108268469572067 Batch_id=468 Accuracy=94.52: 100%|██████████| 469/469 [00:40<00:00, 11.45it/s]



Test set: Average loss: 0.0527, Accuracy: 9857/10000 (98.57%)

EPOCH: 3


Loss=0.19153840839862823 Batch_id=468 Accuracy=95.36: 100%|██████████| 469/469 [00:41<00:00, 11.43it/s]



Test set: Average loss: 0.0400, Accuracy: 9882/10000 (98.82%)

EPOCH: 4


Loss=0.027813786640763283 Batch_id=468 Accuracy=95.83: 100%|██████████| 469/469 [00:40<00:00, 11.55it/s]



Test set: Average loss: 0.0427, Accuracy: 9859/10000 (98.59%)

EPOCH: 5


Loss=0.13210612535476685 Batch_id=468 Accuracy=96.19: 100%|██████████| 469/469 [00:40<00:00, 11.46it/s]



Test set: Average loss: 0.0438, Accuracy: 9874/10000 (98.74%)

EPOCH: 6


Loss=0.18267010152339935 Batch_id=468 Accuracy=96.36: 100%|██████████| 469/469 [00:41<00:00, 11.41it/s]



Test set: Average loss: 0.0375, Accuracy: 9892/10000 (98.92%)

EPOCH: 7


Loss=0.06655258685350418 Batch_id=468 Accuracy=96.54: 100%|██████████| 469/469 [00:40<00:00, 11.46it/s]



Test set: Average loss: 0.0336, Accuracy: 9907/10000 (99.07%)

EPOCH: 8


Loss=0.09012022614479065 Batch_id=468 Accuracy=97.27: 100%|██████████| 469/469 [00:41<00:00, 11.39it/s]



Test set: Average loss: 0.0285, Accuracy: 9924/10000 (99.24%)

EPOCH: 9


Loss=0.05565478280186653 Batch_id=468 Accuracy=97.39: 100%|██████████| 469/469 [00:41<00:00, 11.23it/s]



Test set: Average loss: 0.0279, Accuracy: 9924/10000 (99.24%)

EPOCH: 10


Loss=0.05331813916563988 Batch_id=468 Accuracy=97.30: 100%|██████████| 469/469 [00:41<00:00, 11.31it/s]



Test set: Average loss: 0.0277, Accuracy: 9922/10000 (99.22%)

EPOCH: 11


Loss=0.10458730906248093 Batch_id=468 Accuracy=97.43: 100%|██████████| 469/469 [00:41<00:00, 11.32it/s]



Test set: Average loss: 0.0267, Accuracy: 9931/10000 (99.31%)

EPOCH: 12


Loss=0.0824398621916771 Batch_id=468 Accuracy=97.40: 100%|██████████| 469/469 [00:41<00:00, 11.35it/s]



Test set: Average loss: 0.0274, Accuracy: 9927/10000 (99.27%)

EPOCH: 13


Loss=0.09071478992700577 Batch_id=468 Accuracy=97.48: 100%|██████████| 469/469 [00:41<00:00, 11.37it/s]



Test set: Average loss: 0.0265, Accuracy: 9929/10000 (99.29%)

EPOCH: 14


Loss=0.14151692390441895 Batch_id=468 Accuracy=97.45: 100%|██████████| 469/469 [00:41<00:00, 11.37it/s]



Test set: Average loss: 0.0266, Accuracy: 9932/10000 (99.32%)

EPOCH: 15


Loss=0.05423042178153992 Batch_id=468 Accuracy=97.65: 100%|██████████| 469/469 [00:41<00:00, 11.36it/s]



Test set: Average loss: 0.0260, Accuracy: 9930/10000 (99.30%)

EPOCH: 16


Loss=0.07406426221132278 Batch_id=468 Accuracy=97.62: 100%|██████████| 469/469 [00:41<00:00, 11.32it/s]



Test set: Average loss: 0.0256, Accuracy: 9930/10000 (99.30%)

EPOCH: 17


Loss=0.11070188879966736 Batch_id=399 Accuracy=97.55:  85%|████████▌ | 400/469 [00:35<00:05, 13.25it/s]

## Accuracy and loss plots

In [None]:
import numpy as np
import matplotlib.pyplot as plt

plt.rcParams['figure.figsize'] = [15, 6]

for norm in norm_options:
  train_losses = metrics[norm][0]
  test_losses = metrics[norm][1]
  train_acc = metrics[norm][2]
  test_acc = metrics[norm][3]

  plt.subplot(2,2, 1)

  plt.plot(test_acc)
  plt.legend(norm_options)
  plt.title("Test Accuracy")

  plt.xlabel('Epoch')
  plt.ylabel('Accuracy')

  plt.subplot(1,2,2)
  plt.plot(test_losses)
  plt.legend(norm_options)
  plt.title('Test Loss')
  plt.xlabel('Epoch')
  plt.ylabel('Loss')

## Misclassified Images

In [None]:
classes = ('0','1','2','3','4','5','6','7','8','9')

for norm in norm_options:

  print('-'*15,end=' ')
  print(norm,end=' ')
  print('-'*15,end = ' ')

  model = Net(norm).to(device)

  model.eval()

  figure = plt.figure(figsize=(20, 20))
  num_of_images = 10
  index = 1

  misclass_img_list = []
  untrans_img=[]

  with torch.no_grad():

      for data, target in test_loader:
          data, target = data.to(
              device), target.to(device)
          output = model(data)
          pred = output.argmax(dim=1, keepdim=True)
          act = target.view_as(pred)
          # since most of the bool vec is true (good problem to have) and switch (flip) the true to false and vice versa
          bool_vec = ~pred.eq(act)

          # now extract the index number from the tensor which has 'true'
          idx = list(
              np.where(bool_vec.cpu().numpy())[0])

          if idx:  # if not a blank list
              idx_list = idx
              # print(data[idx_list[0]].shape)
              if index < num_of_images+1:
                  plt.subplot(5, 2, index)
                  plt.axis('off')
                  titl = 'act/pred : ' + \
                      str(classes[target[idx[0]].cpu().item(
                      )]) + '/' + str(classes[pred[idx[0]].cpu().item()])
                  # prints the 1st index of each batch.
              
                  img = data[idx[0]].cpu()
                  untrans_img.append(img.squeeze_(0))
                  image = plt.imshow(img)
                  misclass_img_list.append(image)
                                    


                  plt.title(titl)
                  index += 1