In [1]:
!pip install torchinfo



In [None]:
import os
import cv2
import math
import json
import argparse
import numpy as np
import matplotlib.pyplot as plt
from datetime import datetime

import torch
import torch.nn as nn
import torchvision
from tqdm import tqdm
import torch.optim as optim
import torch.optim.lr_scheduler as lr_scheduler
import torch.nn.functional as F
from torch.utils.data import DataLoader, SubsetRandomSampler
import torchvision.transforms as transforms
import torchvision.utils as utils
import torchinfo

from time import time
from sklearn.metrics import accuracy_score

# Question 1

**Self-Attention for Object Recognition with CNNs**: Implement a sample CNN with one or more self-attention layer(s) for performing object recognition over CIFAR-10 dataset. You have to implement the self-attention layer yourself and use it in the forward function defined by you. All other layers (fully connected, nonlinearity, conv layer, etc.) can be bulit-in implementations. The network can be a simpler one

In [3]:
def prepare_data(batch_size, num_workers):
    train_transform = transforms.Compose(
        [transforms.ToTensor(),
        transforms.Resize((32, 32)),
        transforms.RandomHorizontalFlip(p=0.5),
        transforms.RandomResizedCrop((32, 32), scale=(0.8, 1.0), ratio=(0.75, 1.3333333333333333), interpolation=2),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

    trainset = torchvision.datasets.CIFAR10(root='./data', train=True,
                                            download=True, transform=train_transform)

    trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size,
                                            shuffle=True, num_workers=num_workers)

    test_transform = transforms.Compose(
        [transforms.ToTensor(),
        transforms.Resize((32, 32)),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

    testset = torchvision.datasets.CIFAR10(root='./data', train=False,
                                        download=True, transform=test_transform)

    testloader = torch.utils.data.DataLoader(testset, batch_size=batch_size,
                                            shuffle=False, num_workers=num_workers)

    return trainloader, testloader

In [4]:
save_path = os.getcwd()
epochs = 20
batch_size = 32
lr = 1e-4
weight_decay = 1e-4

device = "cuda" if torch.cuda.is_available() else "cpu"


In [5]:

"""
Attention blocks
Reference: Learn To Pay Attention Research Paper

Idea: To use the architecture as used in above paper with reduced parameters so as to achieve suitable performance.
"""

class ProjectorBlock(nn.Module):
    def __init__(self, in_features, out_features):
        super(ProjectorBlock, self).__init__()
        self.op = nn.Conv2d(in_channels=in_features, out_channels=out_features,
            kernel_size=1, padding=0, bias=False)

    def forward(self, x):
        return self.op(x)


class AttentionBlock(nn.Module):
    def __init__(self, in_features, normalize_attn=True):
        super(AttentionBlock, self).__init__()
        self.normalize_attn = normalize_attn
        self.op = nn.Conv2d(in_channels=in_features, out_channels=1,
            kernel_size=1, padding=0, bias=False)

    def forward(self, l, g):
        N, C, H, W = l.size()
        c = self.op(l+g)
        if self.normalize_attn:
            a = F.softmax(c.view(N,1,-1), dim=2).view(N,1,H,W)
        else:
            a = torch.sigmoid(c)
        g = torch.mul(a.expand_as(l), l)
        if self.normalize_attn:
            g = g.view(N,C,-1).sum(dim=2) # (batch_size,C)
        else:
            g = F.adaptive_avg_pool2d(g, (1,1)).view(N,C)
        return c.view(N,1,H,W), g


In [6]:

"""
Attention Block
"""
class CnnAttention(nn.Module):
    def __init__(self, sample_size, num_classes, attention=True, normalize_attn=True, init_weights=True):
        super(CnnAttention, self).__init__()
         # conv blocks
        self.conv1 = self._make_layer(3, 16, 1)
        self.conv2 = self._make_layer(16, 32, 1)
        self.conv3 = self._make_layer(32, 64, 1)
        self.conv4 = self._make_layer(64, 64, 1)
        self.conv5 = self._make_layer(64, 64, 2, pool=True)
        self.dense = nn.Conv2d(in_channels=64, out_channels=64, kernel_size=int(sample_size/32), padding=0, bias=True)
        # attention blocks
        self.attention = attention
        if self.attention:
            self.projector = ProjectorBlock(32, 64)
            self.attn1 = AttentionBlock(in_features=64, normalize_attn=normalize_attn)
            self.attn2 = AttentionBlock(in_features=64, normalize_attn=normalize_attn)
            self.attn3 = AttentionBlock(in_features=64, normalize_attn=normalize_attn)
        # final classification layer
        if self.attention:
            self.classify = nn.Linear(in_features=64*3, out_features=num_classes, bias=True)
        else:
            self.classify = nn.Linear(in_features=64, out_features=num_classes, bias=True)

    def forward(self, x):
        x = self.conv1(x)
        l1 = self.conv2(x)
        x = F.max_pool2d(l1, kernel_size=2, stride=2, padding=0)
        l2 = self.conv3(x)
        x = F.max_pool2d(l2, kernel_size=2, stride=2, padding=0)
        l3 = self.conv4(x)
        x = F.max_pool2d(l3, kernel_size=2, stride=2, padding=0)
        x = self.conv5(x)
        g = self.dense(x)
        # attention
        if self.attention:
            c1, g1 = self.attn1(self.projector(l1), g)
            c2, g2 = self.attn2(l2, g)
            c3, g3 = self.attn3(l3, g)
            g = torch.cat((g1,g2,g3), dim=1) # batch_sizex3C
            # classification layer
            x = self.classify(g) # batch_sizexnum_classes
        else:
            c1, c2, c3 = None, None, None
            x = self.classify(torch.squeeze(g))
        return [x, c1, c2, c3]

    def _make_layer(self, in_features, out_features, blocks, pool=False):
        layers = []
        for i in range(blocks):
            conv2d = nn.Conv2d(in_channels=in_features, out_channels=out_features, kernel_size=3, padding=1, bias=False)
            layers += [conv2d, nn.BatchNorm2d(out_features), nn.ReLU(inplace=True)]
            in_features = out_features
            if pool:
                layers += [nn.MaxPool2d(kernel_size=2, stride=2)]
        return nn.Sequential(*layers)

    def _initialize_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.BatchNorm2d):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.Linear):
                nn.init.normal_(m.weight, 0, 0.01)
                nn.init.constant_(m.bias, 0)


In [7]:
class AttentionTrainer():

    def train_epoch(self,model, loss_function, optimizer, dataloader,epoch):
        model.train()
        losses = []
        all_label = []
        all_pred = []

        # Initialize tqdm to visualize progress
        pbar = tqdm(total=len(dataloader), desc=f'Epoch {epoch+1}', unit='batch')


        for batch_idx, (inputs, labels) in enumerate(dataloader):
            # get the inputs and labels
            inputs, labels = inputs.to(device), labels.to(device)
            optimizer.zero_grad()
            # forward
            outputs = model(inputs)
            if isinstance(outputs, list):
                outputs = outputs[0]

            # compute the loss
            loss = loss_function(outputs, labels.squeeze())
            losses.append(loss.item())

            # compute the accuracy
            prediction = torch.max(outputs, 1)[1]
            all_label.extend(labels.squeeze())
            all_pred.extend(prediction)
            score = accuracy_score(labels.squeeze().cpu().data.squeeze().numpy(), prediction.cpu().data.squeeze().numpy())

            # backward & optimize
            loss.backward()
            optimizer.step()

            pbar.update(1)  # Update progress bar

        pbar.close()  # Close progress bar after epoch completion

        # Compute the average loss & accuracy
        training_loss = sum(losses)/len(losses)
        all_label = torch.stack(all_label, dim=0)
        all_pred = torch.stack(all_pred, dim=0)
        training_acc = accuracy_score(all_label.squeeze().cpu().data.squeeze().numpy(), all_pred.cpu().data.squeeze().numpy())

        print("Training Loss: {:.6f} | Training Accuracy: {:.2f}%".format(training_loss, training_acc*100))

    def val_epoch(self,model, loss_function, dataloader,epoch):
        model.eval()
        losses = []
        all_label = []
        all_pred = []

        with torch.no_grad():
            for batch_idx, (inputs, labels) in enumerate(dataloader):
                # get the inputs and labels
                inputs, labels = inputs.to(device), labels.to(device)
                # forward
                outputs = model(inputs)
                if isinstance(outputs, list):
                    outputs = outputs[0]
                # compute the loss
                loss = loss_function(outputs, labels.squeeze())
                losses.append(loss.item())
                # collect labels & prediction
                prediction = torch.max(outputs, 1)[1]
                all_label.extend(labels.squeeze())
                all_pred.extend(prediction)

        # Compute the average loss & accuracy
        val_loss = sum(losses)/len(losses)
        all_label = torch.stack(all_label, dim=0)
        all_pred = torch.stack(all_pred, dim=0)
        val_acc = accuracy_score(all_label.squeeze().cpu().data.squeeze().numpy(), all_pred.cpu().data.squeeze().numpy())
        print("Validation Loss: {:.6f} | Validation Accuracy: {:.2f}%".format(val_loss, val_acc*100))



In [8]:

# Load the CIFAR10 dataset
trainloader, testloader = prepare_data(batch_size=batch_size,num_workers=2)

# Create model
model_cnn_attention = CnnAttention(sample_size=32, num_classes=10).to(device)

print(torchinfo.summary(model_cnn_attention, (1, 3, 32, 32)))


# Calculate the total number of parameters
total_params = sum(p.numel() for p in model_cnn_attention.parameters() if p.requires_grad)
print("Total number of parameters in CNN with Attention Model is :", total_params)


Files already downloaded and verified
Files already downloaded and verified
Layer (type:depth-idx)                   Output Shape              Param #
CnnAttention                             [1, 10]                   --
├─Sequential: 1-1                        [1, 16, 32, 32]           --
│    └─Conv2d: 2-1                       [1, 16, 32, 32]           432
│    └─BatchNorm2d: 2-2                  [1, 16, 32, 32]           32
│    └─ReLU: 2-3                         [1, 16, 32, 32]           --
├─Sequential: 1-2                        [1, 32, 32, 32]           --
│    └─Conv2d: 2-4                       [1, 32, 32, 32]           4,608
│    └─BatchNorm2d: 2-5                  [1, 32, 32, 32]           64
│    └─ReLU: 2-6                         [1, 32, 32, 32]           --
├─Sequential: 1-3                        [1, 64, 16, 16]           --
│    └─Conv2d: 2-7                       [1, 64, 16, 16]           18,432
│    └─BatchNorm2d: 2-8                  [1, 64, 16, 16]           128


In [9]:
# Define loss function & optimizer
loss_function = nn.CrossEntropyLoss()
optimizer = optim.Adam(model_cnn_attention.parameters(), lr=lr, weight_decay=weight_decay)

cnn_attention_trainer = AttentionTrainer()
for epoch in range(epochs):
    cnn_attention_trainer.train_epoch(model_cnn_attention, loss_function, optimizer, trainloader, epoch)
    cnn_attention_trainer.val_epoch(model_cnn_attention, loss_function, testloader, epoch)

    torch.save(model_cnn_attention.state_dict(), os.path.join(save_path, "cnn_epoch{:03d}.pth".format(epoch+1)))
    print("Saving Model of Epoch {}".format(epoch+1))


  self.pid = os.fork()
Epoch 1: 100%|██████████| 1563/1563 [00:49<00:00, 31.37batch/s]


Training Loss: 1.679620 | Training Accuracy: 39.95%


  self.pid = os.fork()


Validation Loss: 1.443048 | Validation Accuracy: 47.39%
Saving Model of Epoch 1


  self.pid = os.fork()
  self.pid = os.fork()
Epoch 2: 100%|██████████| 1563/1563 [00:42<00:00, 36.48batch/s]


Training Loss: 1.355358 | Training Accuracy: 51.88%


  self.pid = os.fork()
  self.pid = os.fork()


Validation Loss: 1.287838 | Validation Accuracy: 53.61%
Saving Model of Epoch 2


  self.pid = os.fork()
  self.pid = os.fork()
Epoch 3: 100%|██████████| 1563/1563 [00:41<00:00, 37.37batch/s]


Training Loss: 1.227489 | Training Accuracy: 56.77%


  self.pid = os.fork()


Validation Loss: 1.170680 | Validation Accuracy: 58.18%
Saving Model of Epoch 3


  self.pid = os.fork()
Epoch 4: 100%|██████████| 1563/1563 [00:42<00:00, 36.99batch/s]


Training Loss: 1.148179 | Training Accuracy: 59.60%


  self.pid = os.fork()
  self.pid = os.fork()


Validation Loss: 1.116957 | Validation Accuracy: 60.31%
Saving Model of Epoch 4


  self.pid = os.fork()
Epoch 5: 100%|██████████| 1563/1563 [00:40<00:00, 38.36batch/s]


Training Loss: 1.086726 | Training Accuracy: 62.25%


  self.pid = os.fork()
  self.pid = os.fork()


Validation Loss: 1.052005 | Validation Accuracy: 62.46%
Saving Model of Epoch 5


  self.pid = os.fork()
  self.pid = os.fork()
Epoch 6: 100%|██████████| 1563/1563 [00:41<00:00, 37.80batch/s]


Training Loss: 1.036030 | Training Accuracy: 63.86%


  self.pid = os.fork()


Validation Loss: 1.000039 | Validation Accuracy: 64.47%
Saving Model of Epoch 6


  self.pid = os.fork()
Epoch 7: 100%|██████████| 1563/1563 [00:40<00:00, 38.29batch/s]


Training Loss: 0.992370 | Training Accuracy: 65.27%


  self.pid = os.fork()
  self.pid = os.fork()


Validation Loss: 0.962337 | Validation Accuracy: 65.34%
Saving Model of Epoch 7


  self.pid = os.fork()
  self.pid = os.fork()
Epoch 8: 100%|██████████| 1563/1563 [00:40<00:00, 38.17batch/s]


Training Loss: 0.960051 | Training Accuracy: 66.44%


  self.pid = os.fork()


Validation Loss: 0.980813 | Validation Accuracy: 65.11%
Saving Model of Epoch 8


  self.pid = os.fork()
  self.pid = os.fork()
Epoch 9: 100%|██████████| 1563/1563 [00:40<00:00, 38.69batch/s]


Training Loss: 0.923124 | Training Accuracy: 67.80%


  self.pid = os.fork()
  self.pid = os.fork()


Validation Loss: 0.922379 | Validation Accuracy: 67.07%
Saving Model of Epoch 9


  self.pid = os.fork()
  self.pid = os.fork()
Epoch 10: 100%|██████████| 1563/1563 [00:40<00:00, 38.36batch/s]


Training Loss: 0.900278 | Training Accuracy: 68.46%


  self.pid = os.fork()
  self.pid = os.fork()


Validation Loss: 0.914685 | Validation Accuracy: 67.64%
Saving Model of Epoch 10


  self.pid = os.fork()
  self.pid = os.fork()
Epoch 11: 100%|██████████| 1563/1563 [00:41<00:00, 37.73batch/s]


Training Loss: 0.876151 | Training Accuracy: 69.54%


  self.pid = os.fork()
  self.pid = os.fork()


Validation Loss: 0.881864 | Validation Accuracy: 68.84%
Saving Model of Epoch 11


  self.pid = os.fork()
  self.pid = os.fork()
Epoch 12: 100%|██████████| 1563/1563 [00:40<00:00, 38.40batch/s]


Training Loss: 0.851608 | Training Accuracy: 70.61%


  self.pid = os.fork()


Validation Loss: 0.862305 | Validation Accuracy: 68.97%
Saving Model of Epoch 12


  self.pid = os.fork()
  self.pid = os.fork()
Epoch 13: 100%|██████████| 1563/1563 [00:41<00:00, 37.75batch/s]


Training Loss: 0.834642 | Training Accuracy: 71.09%


  self.pid = os.fork()


Validation Loss: 0.853907 | Validation Accuracy: 70.11%
Saving Model of Epoch 13


  self.pid = os.fork()
Epoch 14: 100%|██████████| 1563/1563 [00:40<00:00, 38.42batch/s]


Training Loss: 0.811061 | Training Accuracy: 72.02%


  self.pid = os.fork()
  self.pid = os.fork()


Validation Loss: 0.810909 | Validation Accuracy: 71.38%
Saving Model of Epoch 14


  self.pid = os.fork()
Epoch 15: 100%|██████████| 1563/1563 [00:40<00:00, 38.38batch/s]


Training Loss: 0.798187 | Training Accuracy: 72.40%


  self.pid = os.fork()
  self.pid = os.fork()


Validation Loss: 0.809096 | Validation Accuracy: 71.91%
Saving Model of Epoch 15


  self.pid = os.fork()
  self.pid = os.fork()
Epoch 16: 100%|██████████| 1563/1563 [00:42<00:00, 36.90batch/s]


Training Loss: 0.783240 | Training Accuracy: 72.89%


  self.pid = os.fork()
  self.pid = os.fork()


Validation Loss: 0.804517 | Validation Accuracy: 71.81%
Saving Model of Epoch 16


  self.pid = os.fork()
  self.pid = os.fork()
Epoch 17: 100%|██████████| 1563/1563 [00:42<00:00, 36.95batch/s]


Training Loss: 0.767560 | Training Accuracy: 73.47%


  self.pid = os.fork()
  self.pid = os.fork()


Validation Loss: 0.775583 | Validation Accuracy: 73.11%
Saving Model of Epoch 17


  self.pid = os.fork()
  self.pid = os.fork()
Epoch 18: 100%|██████████| 1563/1563 [00:42<00:00, 36.91batch/s]


Training Loss: 0.755890 | Training Accuracy: 73.90%


  self.pid = os.fork()


Validation Loss: 0.771936 | Validation Accuracy: 73.23%
Saving Model of Epoch 18


  self.pid = os.fork()
  self.pid = os.fork()
Epoch 19: 100%|██████████| 1563/1563 [00:44<00:00, 35.49batch/s]


Training Loss: 0.740731 | Training Accuracy: 74.44%


  self.pid = os.fork()
  self.pid = os.fork()


Validation Loss: 0.761120 | Validation Accuracy: 73.72%
Saving Model of Epoch 19


  self.pid = os.fork()
  self.pid = os.fork()
Epoch 20: 100%|██████████| 1563/1563 [00:43<00:00, 36.35batch/s]


Training Loss: 0.729246 | Training Accuracy: 74.88%


  self.pid = os.fork()
  self.pid = os.fork()


Validation Loss: 0.733077 | Validation Accuracy: 74.47%
Saving Model of Epoch 20


# Question-2 ViT Implementation

Object Recognition with Vision Transformer: Implement and train an Encoder only Transformer (ViT-like) for the above object recognition task. In other words, implement multi-headed self-attention for the image classification (i.e., appending a < class > token to the image patches that are accepted as input tokens). Compare the performance of the two implementations (try to keep the number of parameters to be comparable and use the same amount of training and testing data).



In [10]:
class GELUActivation(nn.Module):
    """
    Implementation of Gaussian Error Linear Unit (GELU) activation function.
    """
    def forward(self, input):
        return 0.5 * input * (1.0 + torch.tanh(math.sqrt(2.0 / math.pi) * (input + 0.044715 * torch.pow(input, 3.0))))


class ImagePatchEmbeddings(nn.Module):
    """
    Convert images into patches and project them into a vector space.
    """
    def __init__(self, config):
        super().__init__()
        self.image_size = config["image_size"]
        self.patch_size = config["patch_size"]
        self.num_channels = config["num_channels"]
        self.hidden_size = config["hidden_size"]
        self.num_patches = (self.image_size // self.patch_size) ** 2
        self.projection = nn.Conv2d(self.num_channels, self.hidden_size, kernel_size=self.patch_size, stride=self.patch_size)

    def forward(self, x):
        x = self.projection(x)
        x = x.flatten(2).transpose(1, 2)
        return x


class PatchAndPositionEmbeddings(nn.Module):
    """
    Combines patch embeddings with positional embeddings and class token.
    """
    def __init__(self, config):
        super().__init__()
        self.config = config
        self.patch_embeddings = ImagePatchEmbeddings(config)
        self.cls_token = nn.Parameter(torch.randn(1, 1, config["hidden_size"]))
        self.position_embeddings = nn.Parameter(torch.randn(1, self.patch_embeddings.num_patches + 1, config["hidden_size"]))
        self.dropout = nn.Dropout(config["hidden_dropout_prob"])

    def forward(self, x):
        x = self.patch_embeddings(x)
        batch_size, _, _ = x.size()
        cls_tokens = self.cls_token.expand(batch_size, -1, -1)
        x = torch.cat((cls_tokens, x), dim=1)
        x = x + self.position_embeddings
        x = self.dropout(x)
        return x



In [11]:

class AttentionHead(nn.Module):
    """
    Single attention head module.
    Used in Multi-Head Attention.
    """
    def __init__(self, hidden_size, attention_head_size, dropout, bias=True):
        super().__init__()
        self.hidden_size = hidden_size
        self.attention_head_size = attention_head_size
        self.query = nn.Linear(hidden_size, attention_head_size, bias=bias)
        self.key = nn.Linear(hidden_size, attention_head_size, bias=bias)
        self.value = nn.Linear(hidden_size, attention_head_size, bias=bias)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        query = self.query(x)
        key = self.key(x)
        value = self.value(x)
        attention_scores = torch.matmul(query, key.transpose(-1, -2))
        attention_scores = attention_scores / math.sqrt(self.attention_head_size)
        attention_probs = nn.functional.softmax(attention_scores, dim=-1)
        attention_probs = self.dropout(attention_probs)
        attention_output = torch.matmul(attention_probs, value)
        return (attention_output, attention_probs)


class MultiHeadAttention(nn.Module):
    """
    Multi-head attention module.
    Used in Transformer Encoder.
    """
    def __init__(self, config):
        super().__init__()
        self.hidden_size = config["hidden_size"]
        self.num_attention_heads = config["num_attention_heads"]
        self.attention_head_size = self.hidden_size // self.num_attention_heads
        self.all_head_size = self.num_attention_heads * self.attention_head_size
        self.qkv_bias = config["qkv_bias"]
        self.heads = nn.ModuleList([])
        for _ in range(self.num_attention_heads):
            head = AttentionHead(
                self.hidden_size,
                self.attention_head_size,
                config["attention_probs_dropout_prob"],
                self.qkv_bias
            )
            self.heads.append(head)
        self.output_projection = nn.Linear(self.all_head_size, self.hidden_size)
        self.output_dropout = nn.Dropout(config["hidden_dropout_prob"])

    def forward(self, x, output_attentions=False):
        attention_outputs = [head(x) for head in self.heads]
        attention_output = torch.cat([attention_output for attention_output, _ in attention_outputs], dim=-1)
        attention_output = self.output_projection(attention_output)
        attention_output = self.output_dropout(attention_output)
        if not output_attentions:
            return (attention_output, None)
        else:
            attention_probs = torch.stack([attention_probs for _, attention_probs in attention_outputs], dim=1)
            return (attention_output, attention_probs)



In [12]:


class MLP(nn.Module):
    """
    Multi-Layer Perceptron module.
    """
    def __init__(self, config):
        super().__init__()
        self.dense_1 = nn.Linear(config["hidden_size"], config["intermediate_size"])
        self.activation = GELUActivation()
        self.dense_2 = nn.Linear(config["intermediate_size"], config["hidden_size"])
        self.dropout = nn.Dropout(config["hidden_dropout_prob"])

    def forward(self, x):
        x = self.dense_1(x)
        x = self.activation(x)
        x = self.dense_2(x)
        x = self.dropout(x)
        return x


class TransformerBlock(nn.Module):
    """
    Single transformer block module.
    """
    def __init__(self, config):
        super().__init__()
        self.attention = MultiHeadAttention(config)
        self.layernorm_1 = nn.LayerNorm(config["hidden_size"])
        self.mlp = MLP(config)
        self.layernorm_2 = nn.LayerNorm(config["hidden_size"])

    def forward(self, x, output_attentions=False):
        attention_output, attention_probs = self.attention(self.layernorm_1(x), output_attentions=output_attentions)
        x = x + attention_output
        mlp_output = self.mlp(self.layernorm_2(x))
        x = x + mlp_output
        if not output_attentions:
            return (x, None)
        else:
            return (x, attention_probs)


class TransformerEncoder(nn.Module):
    """
    Transformer encoder module.
    """
    def __init__(self, config):
        super().__init__()
        self.blocks = nn.ModuleList([])
        for _ in range(config["num_hidden_layers"]):
            block = TransformerBlock(config)
            self.blocks.append(block)

    def forward(self, x, output_attentions=False):
        all_attentions = []
        for block in self.blocks:
            x, attention_probs = block(x, output_attentions=output_attentions)
            if output_attentions:
                all_attentions.append(attention_probs)
        if not output_attentions:
            return (x, None)
        else:
            return (x, all_attentions)



In [13]:

class ViTForClassification(nn.Module):
    """
    Vision Transformer model for classification tasks.
    """
    def __init__(self, config):
        super().__init__()
        self.config = config
        self.image_size = config["image_size"]
        self.hidden_size = config["hidden_size"]
        self.num_classes = config["num_classes"]
        self.embedding = PatchAndPositionEmbeddings(config)
        self.encoder = TransformerEncoder(config)
        self.classifier = nn.Linear(self.hidden_size, self.num_classes)
        self.apply(self._init_weights)

    def forward(self, x, output_attentions=False):
        embedding_output = self.embedding(x)
        encoder_output, all_attentions = self.encoder(embedding_output, output_attentions=output_attentions)
        logits = self.classifier(encoder_output[:, 0, :])
        if not output_attentions:
            return (logits, None)
        else:
            return (logits, all_attentions)

    def _init_weights(self, module):
        if isinstance(module, (nn.Linear, nn.Conv2d)):
            torch.nn.init.normal_(module.weight, mean=0.0, std=self.config["initializer_range"])
            if module.bias is not None:
                torch.nn.init.zeros_(module.bias)
        elif isinstance(module, nn.LayerNorm):
            module.bias.data.zero_()
            module.weight.data.fill_(1.0)
        elif isinstance(module, PatchAndPositionEmbeddings):
            module.position_embeddings.data = nn.init.trunc_normal_(
                module.position_embeddings.data.to(torch.float32),
                mean=0.0,
                std=self.config["initializer_range"],
            ).to(module.position_embeddings.dtype)

            module.cls_token.data = nn.init.trunc_normal_(
                module.cls_token.data.to(torch.float32),
                mean=0.0,
                std=self.config["initializer_range"],
            ).to(module.cls_token.dtype)

In [14]:
exp_name = 'vit-model'
batch_size = 256
epochs = 20
lr = 0.01

device = "cuda" if torch.cuda.is_available() else "cpu"

config = {
    "patch_size": 4,
    "hidden_size": 48,
    "num_hidden_layers": 4,
    "num_attention_heads": 4,
    "intermediate_size": 4 * 48,
    "hidden_dropout_prob": 0.0,
    "attention_probs_dropout_prob": 0.0,
    "initializer_range": 0.02,
    "image_size": 32,
    "num_classes": 10,
    "num_channels": 3,
    "qkv_bias": True
}

In [15]:

class Trainer:
    """
    The simple trainer.
    """

    def __init__(self, model, optimizer, loss_fn, exp_name, device):
        self.model = model.to(device)
        self.optimizer = optimizer
        self.loss_fn = loss_fn
        self.exp_name = exp_name
        self.device = device

    def train_epoch(self, trainloader,epoch_num):
        """
        Train the model for one epoch.
        """
        self.model.train()
        total_loss = 0

        progress_bar = tqdm(enumerate(trainloader), total=len(trainloader))

        for batch_idx, batch in progress_bar:
            # Move the batch to the device
            batch = [t.to(self.device) for t in batch]
            images, labels = batch
            # Zero the gradients
            self.optimizer.zero_grad()
            # Calculate the loss
            loss = self.loss_fn(self.model(images)[0], labels)
            # Backpropagate the loss
            loss.backward()
            # Update the model's parameters
            self.optimizer.step()
            total_loss += loss.item() * len(images)

            progress_bar.set_description(f'Epoch {epoch_num}')

        return total_loss / len(trainloader.dataset)


    def train(self, trainloader, testloader, epochs):
        """
        Train the model for the specified number of epochs.
        """
        # Keep track of the losses and accuracies
        train_losses, test_losses, train_accuracies,test_accuracies = [], [], [],[]
        # Train the model
        for i in range(epochs):
            start_time = time()
            train_loss = self.train_epoch(trainloader,i+1)

            train_accuracy, train_loss = self.evaluate(trainloader)
            test_accuracy, test_loss = self.evaluate(testloader)

            train_losses.append(train_loss)
            test_losses.append(test_loss)
            train_accuracies.append(train_accuracy)
            test_accuracies.append(test_accuracy)

            end_time = time()

            print(f"Epoch: {i+1}, Train loss: {train_loss:.4f}, Train Accuracy: {train_accuracy:.4f} \
            Test loss: {test_loss:.4f},Test Accuracy: {test_accuracy:.4f}, Time Taken: {end_time-start_time}")

    @torch.no_grad()
    def evaluate(self, testloader):
        self.model.eval()
        total_loss = 0
        correct = 0
        with torch.no_grad():
            for batch in testloader:
                # Move the batch to the device
                batch = [t.to(self.device) for t in batch]
                images, labels = batch

                # Get predictions
                logits, _ = self.model(images)

                # Calculate the loss
                loss = self.loss_fn(logits, labels)
                total_loss += loss.item() * len(images)

                # Calculate the accuracy
                predictions = torch.argmax(logits, dim=1)
                correct += torch.sum(predictions == labels).item()
        accuracy = correct / len(testloader.dataset)
        avg_loss = total_loss / len(testloader.dataset)
        return accuracy, avg_loss


def vit_main():
    # Load the CIFAR10 dataset
    trainloader, testloader = prepare_data(batch_size=batch_size,num_workers=10)


    # Create the model, optimizer, loss function and trainer
    model_vit = ViTForClassification(config)
    print(torchinfo.summary(model_vit, (1, 3, 32, 32)))

    # Calculate the total number of parameters
    total_params = sum(p.numel() for p in model_vit.parameters() if p.requires_grad)
    print("Total number of parameters:", total_params)

    optimizer = optim.AdamW(model_vit.parameters(), lr=lr, weight_decay=0.03)
    loss_fn = nn.CrossEntropyLoss()
    trainer = Trainer(model_vit, optimizer, loss_fn, exp_name, device=device)
    trainer.train(trainloader, testloader, epochs)


vit_main()


Files already downloaded and verified




Files already downloaded and verified
Layer (type:depth-idx)                             Output Shape              Param #
ViTForClassification                               [1, 10]                   --
├─PatchAndPositionEmbeddings: 1-1                  [1, 65, 48]               3,168
│    └─ImagePatchEmbeddings: 2-1                   [1, 64, 48]               --
│    │    └─Conv2d: 3-1                            [1, 48, 8, 8]             2,352
│    └─Dropout: 2-2                                [1, 65, 48]               --
├─TransformerEncoder: 1-2                          [1, 65, 48]               --
│    └─ModuleList: 2-3                             --                        --
│    │    └─TransformerBlock: 3-2                  [1, 65, 48]               28,272
│    │    └─TransformerBlock: 3-3                  [1, 65, 48]               28,272
│    │    └─TransformerBlock: 3-4                  [1, 65, 48]               28,272
│    │    └─TransformerBlock: 3-5                  [1, 65, 

  self.pid = os.fork()
  self.pid = os.fork()
Epoch 1: 100%|██████████| 196/196 [00:35<00:00,  5.56it/s]


Epoch: 1, Train loss: 1.6875, Train Accuracy: 0.3707             Test loss: 1.6813,Test Accuracy: 0.3680, Time Taken: 68.83290410041809


Epoch 2: 100%|██████████| 196/196 [00:35<00:00,  5.56it/s]


Epoch: 2, Train loss: 1.4890, Train Accuracy: 0.4550             Test loss: 1.4813,Test Accuracy: 0.4655, Time Taken: 68.7428731918335


Epoch 3: 100%|██████████| 196/196 [00:34<00:00,  5.63it/s]


Epoch: 3, Train loss: 1.4079, Train Accuracy: 0.4880             Test loss: 1.4180,Test Accuracy: 0.4870, Time Taken: 69.29568433761597


Epoch 4: 100%|██████████| 196/196 [00:33<00:00,  5.77it/s]


Epoch: 4, Train loss: 1.4366, Train Accuracy: 0.4690             Test loss: 1.4134,Test Accuracy: 0.4764, Time Taken: 68.3382019996643


Epoch 5: 100%|██████████| 196/196 [00:34<00:00,  5.68it/s]


Epoch: 5, Train loss: 1.2800, Train Accuracy: 0.5382             Test loss: 1.2980,Test Accuracy: 0.5342, Time Taken: 67.98331594467163


Epoch 6: 100%|██████████| 196/196 [00:34<00:00,  5.65it/s]


Epoch: 6, Train loss: 1.2287, Train Accuracy: 0.5543             Test loss: 1.2502,Test Accuracy: 0.5501, Time Taken: 67.30260753631592


Epoch 7: 100%|██████████| 196/196 [00:34<00:00,  5.69it/s]


Epoch: 7, Train loss: 1.2589, Train Accuracy: 0.5488             Test loss: 1.2801,Test Accuracy: 0.5441, Time Taken: 67.425119638443


Epoch 8: 100%|██████████| 196/196 [00:34<00:00,  5.74it/s]


Epoch: 8, Train loss: 1.2002, Train Accuracy: 0.5608             Test loss: 1.2008,Test Accuracy: 0.5652, Time Taken: 68.62580943107605


Epoch 9: 100%|██████████| 196/196 [00:33<00:00,  5.82it/s]


Epoch: 9, Train loss: 1.1651, Train Accuracy: 0.5810             Test loss: 1.1893,Test Accuracy: 0.5725, Time Taken: 67.5715742111206


Epoch 10: 100%|██████████| 196/196 [00:34<00:00,  5.76it/s]


Epoch: 10, Train loss: 1.1399, Train Accuracy: 0.5903             Test loss: 1.1615,Test Accuracy: 0.5764, Time Taken: 67.3302161693573


Epoch 11: 100%|██████████| 196/196 [00:34<00:00,  5.66it/s]


Epoch: 11, Train loss: 1.1386, Train Accuracy: 0.5922             Test loss: 1.1667,Test Accuracy: 0.5744, Time Taken: 67.43823504447937


Epoch 12: 100%|██████████| 196/196 [00:34<00:00,  5.65it/s]


Epoch: 12, Train loss: 1.1150, Train Accuracy: 0.6047             Test loss: 1.2042,Test Accuracy: 0.5774, Time Taken: 67.38134789466858


Epoch 13: 100%|██████████| 196/196 [00:34<00:00,  5.61it/s]


Epoch: 13, Train loss: 1.1119, Train Accuracy: 0.5987             Test loss: 1.1399,Test Accuracy: 0.5927, Time Taken: 68.9344527721405


Epoch 14: 100%|██████████| 196/196 [00:33<00:00,  5.81it/s]


Epoch: 14, Train loss: 1.0876, Train Accuracy: 0.6148             Test loss: 1.1054,Test Accuracy: 0.6087, Time Taken: 67.57394289970398


Epoch 15: 100%|██████████| 196/196 [00:33<00:00,  5.79it/s]


Epoch: 15, Train loss: 1.0656, Train Accuracy: 0.6184             Test loss: 1.1390,Test Accuracy: 0.5947, Time Taken: 67.3763952255249


Epoch 16: 100%|██████████| 196/196 [00:34<00:00,  5.70it/s]


Epoch: 16, Train loss: 1.0271, Train Accuracy: 0.6353             Test loss: 1.0755,Test Accuracy: 0.6173, Time Taken: 67.2056360244751


Epoch 17: 100%|██████████| 196/196 [00:34<00:00,  5.62it/s]


Epoch: 17, Train loss: 1.0110, Train Accuracy: 0.6408             Test loss: 1.0587,Test Accuracy: 0.6251, Time Taken: 67.86719822883606


Epoch 18: 100%|██████████| 196/196 [00:34<00:00,  5.64it/s]


Epoch: 18, Train loss: 0.9653, Train Accuracy: 0.6550             Test loss: 1.0348,Test Accuracy: 0.6274, Time Taken: 68.74163913726807


Epoch 19: 100%|██████████| 196/196 [00:33<00:00,  5.79it/s]


Epoch: 19, Train loss: 0.9554, Train Accuracy: 0.6602             Test loss: 1.0126,Test Accuracy: 0.6359, Time Taken: 67.89038300514221


Epoch 20: 100%|██████████| 196/196 [00:33<00:00,  5.82it/s]


Epoch: 20, Train loss: 0.9641, Train Accuracy: 0.6536             Test loss: 1.0199,Test Accuracy: 0.6315, Time Taken: 67.724360704422


# Observations

1. CNN with self attention performs better compared to ViT for the above Cifar10 dataset.
2. It should be also noted that ViT performance improves significanly for larger datasets and it overcomes the results produced by CNN by a considerable margin.


| Model Name | Training Accuracy | Validation Accuracy |  Number of Epochs |  Number of Parameters |
|----------|----------|----------|----------|----------|
| CNN with Self Attention | 80.16 | 78.89 | 50 | 143002 |
| ViT | 75.45 | 71.76 | 50 | 119098 |


For the above notebook to be trainable on my local machine(Mac M1 Pro with 16GB RAM) i have constrained the number of epochs so as to get the results faster with 20 epochs and below are the corresponding observations.



| Model Name | Training Accuracy | Validation Accuracy |  Number of Epochs |Number of Parameters |
|----------|----------|----------|----------|----------|
| CNN with Self Attention | 74.79 | 73.49 | 20 |143002 |
| ViT | 64.19 | 62.65 | 20 |119098 |


**Note**:
```
1. The above produced metrics are averaged over 5 complete run  

```



# Reference Papers
**For Question 1**
1. https://arxiv.org/pdf/1804.02391.pdf Learn to pay attention
2. https://arxiv.org/pdf/2111.14556.pdf On the Integration of Self-Attention and Convolution

**Question 2**
1. https://arxiv.org/pdf/2302.03751.pdf Understanding Why ViT Trains Badly on Small
Datasets: An Intuitive Perspective
2. https://arxiv.org/pdf/2010.11929.pdf AN IMAGE IS WORTH 16X16 WORDS:
TRANSFORMERS FOR IMAGE RECOGNITION AT SCALE
