# Import

In [1]:
import torch
import torchvision
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as transforms
from torch.utils.data.dataloader import DataLoader
from torch.utils.data import random_split

import numpy as np
import matplotlib.pyplot as plt
import os
from tqdm import tqdm
from typing import Tuple, NoReturn

# Dataloader

In [2]:
#Transforms from https://pytorch.org/tutorials/beginner/blitz/cifar10_tutorial.html

transform = transforms.Compose(
    [transforms.ToTensor(),
     transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

train_dataset = torchvision.datasets.CIFAR10(root='./data', train=True, transform=transform, download=True)
test_dataset = torchvision.datasets.CIFAR10(root='./data', train=False, transform=transform, download=True)
train_dataset,validation_dataset = random_split(train_dataset,[round(0.9 * len(train_dataset)), round(0.1 * len(train_dataset))])

Files already downloaded and verified
Files already downloaded and verified


In [3]:
BATCH_SIZE = 128
SHUFFLE = True
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, num_workers=4, shuffle=SHUFFLE)
test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, num_workers=4, shuffle=SHUFFLE)
val_loader = DataLoader(validation_dataset, batch_size=BATCH_SIZE, num_workers=4, shuffle=SHUFFLE)


## Visualizing Data

# Model

In [4]:
class SCNN(nn.Module):
    def __init__(self, activation: str, batch_norm: str, dropout: bool) -> None:
        super().__init__()

        #BASE MODEL
        
        #Conv Block 1
        #(Nx3x32x32) -> (Nx16x32x32)
        conv1 = nn.Conv2d(in_channels=3, out_channels=32,
                          kernel_size=5, padding=2)

        #Conv Block 2
        #(Nx32x32x32) -> (Nx64x28x28)
        conv2 = nn.Conv2d(in_channels=32, out_channels=64,
                          kernel_size=5, stride=1)
        #(Nx64x28x28) -> (Nx64x14x14)
        pool2 = nn.MaxPool2d(kernel_size=2, stride=2)

        #Conv Block 3
        #(Nx64x14x14) -> (Nx128x12x12)
        conv3 = nn.Conv2d(in_channels=64, out_channels=128,
                          kernel_size=3, padding=0, stride=1)

        #(Nx128x12x12) -> (Nx128x6x6)
        pool3 = nn.MaxPool2d(kernel_size=2, stride=2)

        
        in_dim = 128*6*6
        fc1 = nn.Linear(in_dim,1024)
        fc2 = nn.Linear(1024,10)

        activation_layer = None
        dropout_layer = None
        if activation == 'relu':
            activation_layer = nn.ReLU()
        elif activation == 'lrelu':
            activation_layer = nn.LeakyReLU()
        else:
            if activation == 'elu':
                activation_layer = nn.ELU()



        self.conv_net = nn.Sequential(conv1,
                                     activation_layer,
                                     conv2,
                                     activation_layer,
                                     pool2,
                                     conv3,
                                     activation_layer,
                                     pool3,
                                     nn.Flatten(start_dim=1),
                                     fc1,
                                     activation_layer,
                                     fc2
                                     )

        if batch_norm == 'bnorm':
            bnorm1 = nn.BatchNorm2d(32)
            bnorm2 = nn.BatchNorm2d(64)
            bnorm3 = nn.BatchNorm2d(128)

            self.conv_net = nn.Sequential(conv1,
                                          activation_layer,
                                          bnorm1,
                                          conv2,
                                          activation_layer,
                                          pool2,
                                          bnorm2,
                                          conv3,
                                          activation_layer,
                                          pool3,
                                          bnorm3,
                                          nn.Flatten(start_dim=1),
                                          fc1,
                                          activation_layer,
                                          fc2
                                          )
        elif batch_norm == 'gnorm':

            gnorm1 = nn.GroupNorm(8,32)
            gnorm2 = nn.GroupNorm(8,64)
            gnorm3 = nn.GroupNorm(8,128)

            self.conv_net = nn.Sequential(conv1,
                                          activation_layer,
                                          gnorm1,
                                          conv2,
                                          activation_layer,
                                          pool2,
                                          gnorm2,
                                          conv3,
                                          activation_layer,
                                          pool3,
                                          gnorm3,
                                          nn.Flatten(start_dim=1),
                                          fc1,
                                          activation_layer,
                                          fc2
                                          )

        else:
            pass




        #############################################################################################
        #Add dropout in similar way, so one code for all
        ###############################################################################################

    def forward(self, x):
        
        
        y = self.conv_net(x)

        return y

def count_model_params(model):
    """ Counting the number of learnable parameters in a nn.Module """
    num_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    return num_params

In [5]:
cnn = SCNN(activation='relu',batch_norm='none',dropout=False)
params = count_model_params(cnn)
print(cnn)
print(f"Model has {params} learnable parameters")


SCNN(
  (conv_net): Sequential(
    (0): Conv2d(3, 32, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
    (1): ReLU()
    (2): Conv2d(32, 64, kernel_size=(5, 5), stride=(1, 1))
    (3): ReLU()
    (4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (5): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1))
    (6): ReLU()
    (7): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (8): Flatten(start_dim=1, end_dim=-1)
    (9): Linear(in_features=4608, out_features=1024, bias=True)
    (10): ReLU()
    (11): Linear(in_features=1024, out_features=10, bias=True)
  )
)
Model has 4857418 learnable parameters


# Training

## Todo

-> Visualization of images and plots for eval and train loss <br />
-> Early stopping criterion <br />
-> Accuracy above 85% by modifying parameters <br />
-> Visualizing activations of different layers <br />

## Parameters for training

In [6]:
LR = 3e-4
EPOCHS = 50
EVAL_FREQ = 10
SAVE_FREQ = 10

In [7]:
stats = {
    "epoch": [],
    "train_loss": [],
    "valid_loss": [],
    "accuracy": []
}
init_epoch = 0

In [8]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

cnn = cnn.to(device)

In [9]:
criterion = nn.CrossEntropyLoss().to(device)
optimizer = torch.optim.Adam(params=cnn.parameters(), lr=LR)

In [10]:
@torch.no_grad()
def eval_model(model):
    correct = 0
    total = 0
    loss_list = []

    for images, labels in val_loader:
        images = images.to(device)
        labels = labels.to(device)

        # Forward pass only to get logits/output
        outputs = model(images)

        loss = criterion(outputs, labels)
        loss_list.append(loss.item())

        # Get predictions from the maximum value
        preds = torch.argmax(outputs, dim=1)
        correct += len(torch.where(preds == labels)[0])
        total += len(labels)

    # Total correct predictions and loss
    accuracy = correct / total * 100
    loss = np.mean(loss_list)
    return accuracy, loss


def save_model(model, optimizer, epoch, stats, name):
    """ Saving model checkpoint """

    if(not os.path.exists("models")):
        os.makedirs("models")
    savepath = f"models/{name}_checkpoint_epoch_{epoch}.pth"

    torch.save({
        'epoch': epoch,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'stats': stats
    }, savepath)
    return


def load_model(model, optimizer, savepath):
    """ Loading pretrained checkpoint """

    checkpoint = torch.load(savepath)
    model.load_state_dict(checkpoint['model_state_dict'])
    optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    epoch = checkpoint["epoch"]
    stats = checkpoint["stats"]

    return model, optimizer, epoch, stats


In [11]:
loss_hist = []

for epoch in range(EPOCHS):
    progress_bar = tqdm(enumerate(train_loader), total=len(train_loader))
    for i,(images,labels) in progress_bar:

        loss_list = []

        x = images.to(device)
        y_train = labels.to(device)

        # Clear gradients w.r.t. parameters
        optimizer.zero_grad()

        # Forward pass to get output/logits
        outputs = cnn(x)

        # Calculate Loss: softmax --> cross entropy loss
        loss = criterion(outputs, y_train)

        # Getting gradients w.r.t. parameters
        loss.backward()
        loss_list.append(loss.item())

        # Updating parameters
        optimizer.step()
        progress_bar.set_description(f"Epoch {epoch+1} Iter {i+1}: loss {loss.item():.5f}. ")
    
    loss_hist.append(np.mean(loss_list))
    stats["epoch"].append(epoch)
    stats["train_loss"].append(loss_hist[-1])
    

    if epoch % EVAL_FREQ == 0:
        accuracy, valid_loss = eval_model(cnn)
        print(f"Accuracy at epoch {epoch}: {round(accuracy, 2)}%")
    else:
        accuracy, valid_loss = -1, -1
    
    stats["accuracy"].append(accuracy)
    stats["valid_loss"].append(valid_loss)

    if epoch % SAVE_FREQ == 0:
        save_model(model=cnn, optimizer=optimizer, epoch=epoch, stats=stats,name="simple_cnn")


Epoch 1 Iter 352: loss 1.03062. : 100%|██████████| 352/352 [00:04<00:00, 79.68it/s]


Accuracy at epoch 0: 52.22%


Epoch 2 Iter 352: loss 1.01289. : 100%|██████████| 352/352 [00:04<00:00, 87.24it/s]
Epoch 3 Iter 352: loss 0.87133. : 100%|██████████| 352/352 [00:04<00:00, 87.29it/s]
Epoch 4 Iter 352: loss 0.71679. : 100%|██████████| 352/352 [00:04<00:00, 86.54it/s]
Epoch 5 Iter 352: loss 0.66915. : 100%|██████████| 352/352 [00:04<00:00, 87.10it/s]
Epoch 6 Iter 352: loss 0.59965. : 100%|██████████| 352/352 [00:03<00:00, 91.67it/s]
Epoch 7 Iter 352: loss 0.35192. : 100%|██████████| 352/352 [00:03<00:00, 93.21it/s]
Epoch 8 Iter 352: loss 0.39618. : 100%|██████████| 352/352 [00:03<00:00, 92.64it/s]
Epoch 9 Iter 352: loss 0.24611. : 100%|██████████| 352/352 [00:03<00:00, 93.11it/s]
Epoch 10 Iter 352: loss 0.19168. : 100%|██████████| 352/352 [00:03<00:00, 89.91it/s]
Epoch 11 Iter 352: loss 0.08061. : 100%|██████████| 352/352 [00:03<00:00, 90.89it/s]


Accuracy at epoch 10: 72.96%


Epoch 12 Iter 352: loss 0.04582. : 100%|██████████| 352/352 [00:04<00:00, 83.94it/s]
Epoch 13 Iter 352: loss 0.06938. : 100%|██████████| 352/352 [00:04<00:00, 83.11it/s]
Epoch 14 Iter 352: loss 0.04552. : 100%|██████████| 352/352 [00:04<00:00, 85.54it/s]
Epoch 15 Iter 352: loss 0.06960. : 100%|██████████| 352/352 [00:04<00:00, 85.81it/s]
Epoch 16 Iter 352: loss 0.04192. : 100%|██████████| 352/352 [00:04<00:00, 86.42it/s]
Epoch 17 Iter 352: loss 0.02081. : 100%|██████████| 352/352 [00:04<00:00, 84.16it/s]
Epoch 18 Iter 352: loss 0.09414. : 100%|██████████| 352/352 [00:04<00:00, 85.47it/s]
Epoch 19 Iter 352: loss 0.02152. : 100%|██████████| 352/352 [00:04<00:00, 86.74it/s]
Epoch 20 Iter 352: loss 0.01540. : 100%|██████████| 352/352 [00:04<00:00, 86.61it/s]
Epoch 21 Iter 352: loss 0.04900. : 100%|██████████| 352/352 [00:04<00:00, 85.82it/s]


Accuracy at epoch 20: 72.3%


Epoch 22 Iter 352: loss 0.00115. : 100%|██████████| 352/352 [00:04<00:00, 85.85it/s]
Epoch 23 Iter 352: loss 0.04362. : 100%|██████████| 352/352 [00:04<00:00, 85.33it/s]
Epoch 24 Iter 352: loss 0.01702. : 100%|██████████| 352/352 [00:04<00:00, 84.91it/s]
Epoch 25 Iter 352: loss 0.01348. : 100%|██████████| 352/352 [00:04<00:00, 84.49it/s]
Epoch 26 Iter 352: loss 0.05426. : 100%|██████████| 352/352 [00:04<00:00, 84.55it/s]
Epoch 27 Iter 352: loss 0.04327. : 100%|██████████| 352/352 [00:04<00:00, 85.41it/s]
Epoch 28 Iter 352: loss 0.01068. : 100%|██████████| 352/352 [00:04<00:00, 84.48it/s]
Epoch 29 Iter 352: loss 0.05826. : 100%|██████████| 352/352 [00:04<00:00, 84.68it/s]
Epoch 30 Iter 352: loss 0.00570. : 100%|██████████| 352/352 [00:04<00:00, 84.96it/s]
