In [1]:
from IPython.display import clear_output

In [2]:
# Download the required libraries (needed when running outside colab where the environment doesn't come pre-loaded with libraries)

%pip install torch
%pip install torchvision
%pip install matplotlib

clear_output()

In [3]:
import torch
import torch.nn as nn

from torchvision.datasets import CIFAR10
from torchvision.transforms.functional import to_tensor, to_pil_image, resize

from torch.utils.data import DataLoader
from torch.optim import Adam

import matplotlib.pyplot as plt

#Contents:

1. We'll make a classifier for CIFAR10 dataset in pytorch using CNN architecture

About CIFAR10:

The CIFAR-10 dataset consists of 60000 32x32 colour images in 10 classes, with 6000 images per class. There are 50000 training images and 10000 test images.

![CIFAR-10 image](https://production-media.paperswithcode.com/datasets/4fdf2b82-2bc3-4f97-ba51-400322b228b1.png)


You need to know:

1. **torch** (for impelementation)
2. a little bit of **matplotlib** (for visualization)


Good to have knowledge of:

1. torch dataset and dataloader

# Downloading the dataset

In [None]:
dataset_root = 'data/'

train_dataset = CIFAR10(root=dataset_root, train=True, download=True, transform=to_tensor)
# Todo: Create a validation dataset
...

In [None]:
print('Length of train_dataset is', len(train_dataset))
# Todo: Print the length of the validation dataset


In [6]:
batch_size = 64

train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
# Todo: Create a validation data loader


## Let's visualize the images and it's channels

In [None]:
random_img_idx = torch.randint(0, 1000, (1,)).item()

test_image = train_dataset[random_img_idx][0]  # 0 for image part in (image, label) tuple.
test_image = resize(test_image, (250, 250), antialias=None)  # better visualization
print(test_image.shape)
print('Number of channels in test_image: ', test_image.shape[0])
to_pil_image(test_image)

In [8]:
tred, tgreen, tblue = test_image
empty_channel = torch.zeros_like(tred)

tred = [tred, empty_channel, empty_channel]  #R00
tgreen = [empty_channel, tgreen, empty_channel]  #0G0
tblue = [empty_channel, empty_channel, tblue]  #00B

channels = [torch.stack(img) for img in [tred, tgreen, tblue]]

In [None]:
to_pil_image(torch.cat(channels, dim=2))

In [10]:
class Cifar10Classifier(nn.Module):

  def __init__(self):

    super(Cifar10Classifier, self).__init__()
    self.conv1 = nn.Conv2d(3, 32, kernel_size=3, padding=1)
    self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1)
    self.conv3 = nn.Conv2d(64, 128, kernel_size=3, padding=1)

    self.relu = nn.ReLU()  # Relu isn't learnable. no need to intialize different relu objects for each layer
    self.pool = nn.MaxPool2d(2, 2)  # pool isn't learnable to no need to initialize different pool layers unless we want to change window size

    self.fc1 = nn.Linear(128 * 4 * 4, 512)
    self.fc2 = nn.Linear(512, 10)

    self.softmax = nn.Softmax(dim=-1)

  def forward(self, x: torch.Tensor):

    single_input = False
    if x.ndim == 3:  # 3 dimensions mean [C, H, W] instead of [B, C, H, W] so we're dealing with a single image
      x = x.unsqueeze(dim=0)  # convert [C, H, W] to [1, C, H, W] where 1 will act as batch size

      # keep track of whether input was one (non-batch) image.
      # If so, we want to convert it back to the same format after inference for consistency purposes
      single_input = True

    # TODO: Implement the forward pass for the model

    if single_input:
      x = x.squeeze(dim=0)  # or x = x[0]

    return x


In [11]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'  # checks if machine supports cuda and if it does, we use that, otherwise cpu
model = Cifar10Classifier().to(device)

In [None]:
# TODO: Set the number of epochs
num_epochs = ...
# TODO: Set the learning rate
lr = ... 

train_losses = []
val_losses = []

# TODO: Set the optimizer and the loss function
optimizer = ...
criterion = ...

model.to(device)  # we need to send all input tensors as well as our model to this device. by default they are on cpu

print(f'Using device {device}')

In [None]:
%%time
for epoch_no in range(num_epochs):

  # TODO: Set the model to train mode and iterate through the training data
  ... 

  epoch_weighted_loss = 0

  for batch_X, batch_y in train_loader:

    batch_X = batch_X.to(device)
    batch_y = batch_y.to(device)

    # TODO: Perform the forward pass and get the predictions
    batch_y_probs = ...  # outputs [N, 10] where each [:, 10] is probabilities for class (0-9)
    # TODO: Calculate the loss using the predictions and the actual labels
    loss = ...

    # TODO: Perform the backward pass and update the weights
    ...

    epoch_weighted_loss += (len(batch_y)*loss.item())

  epoch_loss = epoch_weighted_loss/len(train_loader.dataset)
  train_losses.append(epoch_loss)


  # validation time

  # TODO: Set the model to evaluation mode and iterate through the validation data
  ...
  correctly_labelled = 0

  with torch.no_grad():

    val_epoch_weighted_loss = 0

    for val_batch_X, val_batch_y in val_loader:

      val_batch_X = val_batch_X.to(device)
      val_batch_y = val_batch_y.to(device)
      # TODO: Perform the forward pass and get the predictions
      val_batch_y_probs = ... 
      # TODO: Calculate the loss using the predictions and the actual labels
      loss = ... 
      val_epoch_weighted_loss += (len(val_batch_y)*loss.item())

      val_batch_y_pred = val_batch_y_probs.argmax(dim=1)  # convert probailities to labels by picking the label (index) with the highest prob

      correctly_labelled += (val_batch_y_pred == val_batch_y).sum().item()  # item converts tensor to float/int/list

  val_epoch_loss = val_epoch_weighted_loss/len(val_loader.dataset)
  val_losses.append(val_epoch_loss)

  print(f'Epoch: {epoch_no}, train_loss={epoch_loss}, val_loss={val_epoch_loss}. labelled {correctly_labelled}/{len(val_loader.dataset)} correctly ({correctly_labelled/len(val_loader.dataset)*100}% accuracy)')

print(f'Training complete on device {device}.')

In [None]:
plt.plot(train_losses, label='Train Loss')
plt.plot(val_losses  , label='Val Loss')

plt.ylabel('Loss (CCE)')
plt.xlabel('Epoch')

plt.legend()
plt.show()