# 0. Imports

In [1]:
# imports

import torch
from torch import nn
from torchvision.datasets import CIFAR10
from torchvision import transforms
from torch.utils.data import random_split, DataLoader
from torchsummary import summary
from torch.nn import CrossEntropyLoss
from torch.optim import Adam
from tqdm.auto import tqdm


In [2]:
# for reproducibility
torch.manual_seed(42)

<torch._C.Generator at 0x2002fb64390>

In [3]:
# setup device
device = 'cuda' if torch.cuda.is_available() else 'cpu'
device = torch.device(device)
device

device(type='cuda')

# 1. Dataset and Dataloader

In [12]:
# get cifar10
dataset = CIFAR10(root='data', transform=transforms.ToTensor(), download=True)
len(dataset)

TypeError: CIFAR10.__init__() got an unexpected keyword argument 'shuffle'

In [5]:
# generate the train val split

# set the params
test_ratio = 0.10
batch_size = 1024

# find the sizes
val_size = int(test_ratio * len(dataset))
train_size = len(dataset) - val_size

# split the dataset
train_dataset, val_dataset = random_split(dataset, [train_size, val_size])

# create dataloaders
train_dataloader = DataLoader(train_dataset, batch_size=batch_size)
val_dataloader = DataLoader(val_dataset, batch_size=batch_size)

# 2. Create a CNN

In [6]:
class CifarClassifier(nn.Module):
    
    def __init__(self, activation: str):
        super().__init__()
        assert activation in ["relu", "sigmoid", "tanh"] , "select activation from relu, sigmoid, tanh"
        
        self.activation_layer = None
        if activation == "relu":
            self.activation_layer = nn.ReLU()
        elif activation == "sigmoid":
            self.activation_layer = nn.Sigmoid()
        elif activation == "tanh":
            self.activation_layer = nn.Tanh()
            
        self.model = nn.Sequential(
            
            nn.Conv2d(in_channels=3, out_channels=32, kernel_size=3, padding="same"),
            self.activation_layer, 
            # nn.BatchNorm2d(32),
            nn.Conv2d(in_channels=32, out_channels=32, kernel_size=3, padding="same"),
            self.activation_layer, 
            nn.BatchNorm2d(32),
            nn.MaxPool2d(2),
            nn.Dropout(0.24),
            
            nn.Conv2d(in_channels=3, out_channels=64, kernel_size=3, padding="same"),
            self.activation_layer, 
            # nn.BatchNorm2d(64),
            nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, padding="same"),
            self.activation_layer, 
#             nn.BatchNorm2d(64),
            nn.MaxPool2d(2),
            nn.Dropout(0.24),
            
            nn.Conv2d(in_channels=64, out_channels=128, kernel_size=3, padding="same"),
            self.activation_layer, 
            # nn.BatchNorm2d(128),
            nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, padding="same"),
            self.activation_layer, 
#             nn.BatchNorm2d(128),
            nn.MaxPool2d(2),
            nn.Dropout(0.24),
            
            nn.Conv2d(in_channels=128, out_channels=256, kernel_size=3, padding="same"),
            self.activation_layer, 
            # nn.BatchNorm2d(128),
            nn.Conv2d(in_channels=256, out_channels=256, kernel_size=3, padding="same"),
            self.activation_layer, 
#             nn.BatchNorm2d(256),
            nn.MaxPool2d(2),
            nn.Dropout(0.24),
            
            nn.Flatten(),
            nn.Linear(in_features=4096, out_features=256),
            self.activation_layer,
            nn.Linear(in_features=256, out_features=10)
        )
        
    def forward(self, x):
        return self.model(x)

In [7]:
model = CifarClassifier("relu").to(device)
summary(model, (3, 32, 32))

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1           [-1, 64, 32, 32]           1,792
              ReLU-2           [-1, 64, 32, 32]               0
              ReLU-3           [-1, 64, 32, 32]               0
            Conv2d-4           [-1, 64, 32, 32]          36,928
              ReLU-5           [-1, 64, 32, 32]               0
              ReLU-6           [-1, 64, 32, 32]               0
         MaxPool2d-7           [-1, 64, 16, 16]               0
           Dropout-8           [-1, 64, 16, 16]               0
            Conv2d-9          [-1, 128, 16, 16]          73,856
             ReLU-10          [-1, 128, 16, 16]               0
             ReLU-11          [-1, 128, 16, 16]               0
           Conv2d-12          [-1, 128, 16, 16]         147,584
             ReLU-13          [-1, 128, 16, 16]               0
             ReLU-14          [-1, 128,

# 3. Training Loop

In [8]:
# create loss_fn
loss_fn = CrossEntropyLoss()

In [9]:
# create optimizer

lr = 3.2 * (10**-4)
optimizer = Adam(model.parameters(), lr=lr)

In [10]:
# run the training loop
from utils import train_step

epochs = 64

for epoch in tqdm(range(epochs)):
        tres = train_step(model, train_dataloader, loss_fn, optimizer, device)
        print(f"epoch: {epoch}")
        print(f"avg_batch_loss: {tres['avg_batch_loss']}")
        print(f"time: {tres['time']}")   
        print("")

  0%|          | 0/64 [00:00<?, ?it/s]

epoch: 0
avg_batch_loss: 2.132908582687378
time: 27758.018732070923

epoch: 1
avg_batch_loss: 1.8292808532714844
time: 27241.322994232178

epoch: 2
avg_batch_loss: 1.6173372268676758
time: 27199.805736541748

epoch: 3
avg_batch_loss: 1.5109660625457764
time: 27285.918474197388

epoch: 4
avg_batch_loss: 1.4247936010360718
time: 27415.27223587036

epoch: 5
avg_batch_loss: 1.3605000972747803
time: 27372.832775115967

epoch: 6
avg_batch_loss: 1.3015085458755493
time: 27382.23886489868

epoch: 7
avg_batch_loss: 1.2535165548324585
time: 27410.678148269653

epoch: 8
avg_batch_loss: 1.205361008644104
time: 27383.683919906616

epoch: 9
avg_batch_loss: 1.1534439325332642
time: 27382.514238357544

epoch: 10
avg_batch_loss: 1.1074875593185425
time: 27406.677961349487

epoch: 11
avg_batch_loss: 1.0709277391433716
time: 27415.08412361145

epoch: 12
avg_batch_loss: 1.0279995203018188
time: 27326.87520980835

epoch: 13
avg_batch_loss: 0.9931640625
time: 27350.841999053955

epoch: 14
avg_batch_loss: 0.

In [11]:
# validation
from utils import valid_step

vres = valid_step(model, val_dataloader, device)
print(f"accuracy: {vres['accuracy']}")
print(f"confusion_matrix: \n{vres['confusion_matrix']}")

accuracy: 0.7286
confusion_matrix: 
[[0.79508197 0.01434426 0.02254098 0.01639344 0.01639344 0.00409836
  0.00409836 0.02254098 0.06967213 0.03483607]
 [0.01953125 0.89257812 0.00195312 0.00195312 0.00195312 0.00585938
  0.         0.         0.0234375  0.05273438]
 [0.07330827 0.0075188  0.58646617 0.07894737 0.07894737 0.05075188
  0.05075188 0.04511278 0.02255639 0.0056391 ]
 [0.03609342 0.01698514 0.04670913 0.53078556 0.01910828 0.21019108
  0.03397028 0.0403397  0.03609342 0.02972399]
 [0.02972399 0.         0.06794055 0.06369427 0.61146497 0.04458599
  0.05095541 0.10191083 0.01698514 0.01273885]
 [0.01361868 0.00583658 0.03307393 0.12062257 0.04085603 0.65564202
  0.04669261 0.05447471 0.02140078 0.0077821 ]
 [0.01380671 0.00394477 0.02761341 0.07100592 0.04536489 0.02366864
  0.78106509 0.00197239 0.01577909 0.01577909]
 [0.024      0.012      0.02       0.034      0.042      0.062
  0.014      0.766      0.008      0.018     ]
 [0.04563492 0.02579365 0.00396825 0.00595238 0.0