## Importing Libraries

In [2]:
import numpy as np
import matplotlib.pyplot as plt
import torch as tc
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision.datasets import CIFAR10
from torchvision.transforms import ToTensor
from sklearn.metrics import accuracy_score
import numpy as np
from tqdm import tqdm, trange
from torch.utils.data import DataLoader


## Q1 Image CNN with Attention

Attention Layer

In [3]:
class Att_layer(tc.nn.Module):
  def __init__(self, channel_num, input_dim, output_channel_len = 0):
     super(Att_layer, self).__init__()
     self.weights_query = nn.Parameter(tc.Tensor(channel_num, channel_num))
     nn.init.xavier_uniform_(self.weights_query)
     self.weights_key = nn.Parameter(tc.Tensor(channel_num, channel_num))
     nn.init.xavier_uniform_(self.weights_key)
     self.weights_value = nn.Parameter(tc.Tensor(channel_num, channel_num))
     nn.init.xavier_uniform_(self.weights_value)
     self.d_k = channel_num ** 2
     self.height = input_dim[0]
     self.width = input_dim[1]
     self.channel_num = channel_num
     self.scale = nn.Parameter(tc.tensor(1).float())
     self.bias = nn.Parameter(tc.tensor(0).float())

  def forward(self, input_images):
      batch_size = input_images.shape[0]

      h = self.height
      w = self.width
      c = self.channel_num

      input_images = input_images.permute(0, 2, 3, 1)

      queries = tc.matmul(input_images, self.weights_query.t())
      keys = tc.matmul(input_images, self.weights_key.t())
      values = tc.matmul(input_images, self.weights_value.t())

      queries_flat = queries.reshape(-1, c)
      keys_flat = keys.reshape(-1, c).t()
      values_flat = values.reshape(-1, c)



      attention_scores = tc.matmul(queries_flat, keys_flat) / np.sqrt(c)

      attention_weights = F.softmax(attention_scores, dim=-1)

      attended_values = tc.matmul(attention_weights, values_flat)

      attended_values = attended_values.reshape(batch_size, h, w, c)

      temp = input_images + attended_values
      temp = temp.reshape(batch_size, h, w, c)

      means = temp.mean(dim=(1, 2), keepdim=True)
      stds = temp.std(dim=(1, 2), keepdim=True)
      stds = tc.clamp(stds, 1, 100)

      norm_images = (temp - means) / stds

      output = self.scale * norm_images + self.bias.unsqueeze(0)
      input_images.detach()
      queries.detach()
      keys.detach()
      values.detach()
      queries_flat.detach()
      keys_flat.detach()
      values_flat.detach()
      attention_weights.detach()
      attention_scores.detach()
      temp.detach()
      attended_values.detach()
      norm_images.detach()
      return output.permute(0, 3, 1, 2)


CNN with attention layers

In [4]:
class Att_CNN(nn.Module):
    def __init__(self):
        super(Att_CNN, self).__init__()
        # Define layers
        self.conv1 = nn.Conv2d(3, 16, kernel_size=3, stride=1, padding=1)
        nn.init.xavier_uniform_(self.conv1.weight)
        self.conv2 = nn.Conv2d(16, 32, kernel_size=3, stride=1, padding=1)
        nn.init.xavier_uniform_(self.conv2.weight)
        self.sa1 = Att_layer(32, (16, 16))
        self.conv3 = nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1)
        nn.init.xavier_uniform_(self.conv3.weight)
        self.sa2 = Att_layer(64, (8, 8))
        self.conv4 = nn.Conv2d(64, 32, kernel_size=3, stride=1, padding=1)
        nn.init.xavier_uniform_(self.conv4.weight)
        self.sa3 = Att_layer(32, (4, 4))
        self.conv5 = nn.Conv2d(32, 16, kernel_size=3, stride=1, padding=1)
        nn.init.xavier_uniform_(self.conv5.weight)
        self.sa4 = Att_layer(16, (4, 4))
        self.conv6 = nn.Conv2d(16, 10, kernel_size=3, stride=1, padding=1)
        nn.init.xavier_uniform_(self.conv6.weight)

    def forward(self, x):
        x = F.relu(self.conv1(x))

        x = F.relu(self.conv2(x))
        x = F.max_pool2d(x, kernel_size=2, stride=2)
        x = self.sa1.forward(x)
        x = F.relu(self.conv3(x))
        x = F.max_pool2d(x, kernel_size=2, stride=2)
        x = self.sa2.forward(x)
        x = F.relu(self.conv4(x))
        x = F.max_pool2d(x, kernel_size=2, stride=2)
        x = self.sa3.forward(x)
        x = F.relu(self.conv5(x))
        x = self.sa4.forward(x)
        x = F.relu(self.conv6(x))
        GAP = tc.mean(x, dim=(2, 3))

        return GAP


In [5]:
def main():
    # Loading data
    transform = ToTensor()

    train_set = CIFAR10(root="data", train=False, download=True, transform=ToTensor())

    test_set = CIFAR10(root="data", train=True, download=True, transform=ToTensor())
    train_loader = DataLoader(train_set, shuffle=True, batch_size=64)
    test_loader = DataLoader(test_set, shuffle=False, batch_size=64)

    device = tc.device("cuda" if tc.cuda.is_available() else "cpu")
    print("Using device: ", device, f"({tc.cuda.get_device_name(device)})" if tc.cuda.is_available() else "")
    model = Att_CNN().to(device)
    N_EPOCHS = 10
    LR = 0.001

    # Training loop
    optimizer = optim.Adam(model.parameters(), lr=LR)
    criterion = nn.CrossEntropyLoss()
    for epoch in trange(N_EPOCHS, desc="Training"):
        train_loss = 0.0
        for batch in tqdm(train_loader, desc=f"Epoch {epoch + 1} in training", leave=False):
            x, y = batch
            x, y = x.to(device), y.to(device)
            y_hat = model(x)

            loss = criterion(y_hat, y)

            train_loss += loss.detach().cpu().item() / len(train_loader)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

        print(f"Epoch {epoch + 1}/{N_EPOCHS} loss: {train_loss:.2f}")

    # Test loop
    with tc.no_grad():
        correct, total = 0, 0
        test_loss = 0.0
        for batch in tqdm(test_loader, desc="Testing"):
            x, y = batch
            x, y = x.to(device), y.to(device)
            y_hat = model(x)
            loss = criterion(y_hat, y)
            test_loss += loss.detach().cpu().item() / len(test_loader)

            correct += tc.sum(tc.argmax(y_hat, dim=1) == y).detach().cpu().item()
            total += len(x)
        print(f"Test loss: {test_loss:.2f}")
        print(f"Test accuracy: {correct / total * 100:.2f}%")

In [6]:
tc.manual_seed(50)
main()

Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to data/cifar-10-python.tar.gz


100%|██████████| 170498071/170498071 [00:01<00:00, 95779225.47it/s] 


Extracting data/cifar-10-python.tar.gz to data
Files already downloaded and verified
Using device:  cuda (Tesla T4)


Training:   0%|          | 0/10 [00:00<?, ?it/s]
Epoch 1 in training:   0%|          | 0/157 [00:00<?, ?it/s][A
Epoch 1 in training:   1%|          | 1/157 [00:02<06:12,  2.39s/it][A
Epoch 1 in training:   2%|▏         | 3/157 [00:02<01:47,  1.44it/s][A
Epoch 1 in training:   3%|▎         | 4/157 [00:02<01:16,  2.00it/s][A
Epoch 1 in training:   3%|▎         | 5/157 [00:02<00:56,  2.68it/s][A
Epoch 1 in training:   4%|▍         | 6/157 [00:02<00:43,  3.45it/s][A
Epoch 1 in training:   4%|▍         | 7/157 [00:03<00:35,  4.25it/s][A
Epoch 1 in training:   5%|▌         | 8/157 [00:03<00:29,  5.04it/s][A
Epoch 1 in training:   6%|▌         | 9/157 [00:03<00:25,  5.78it/s][A
Epoch 1 in training:   6%|▋         | 10/157 [00:03<00:22,  6.42it/s][A
Epoch 1 in training:   7%|▋         | 11/157 [00:03<00:20,  6.97it/s][A
Epoch 1 in training:   8%|▊         | 12/157 [00:03<00:19,  7.41it/s][A
Epoch 1 in training:   8%|▊         | 13/157 [00:03<00:18,  7.74it/s][A
Epoch 1 in training

Epoch 1/10 loss: 2.14



Epoch 2 in training:   0%|          | 0/157 [00:00<?, ?it/s][A
Epoch 2 in training:   1%|▏         | 2/157 [00:00<00:14, 10.52it/s][A
Epoch 2 in training:   3%|▎         | 4/157 [00:00<00:16,  9.32it/s][A
Epoch 2 in training:   3%|▎         | 5/157 [00:00<00:16,  9.09it/s][A
Epoch 2 in training:   4%|▍         | 6/157 [00:00<00:16,  8.95it/s][A
Epoch 2 in training:   4%|▍         | 7/157 [00:00<00:16,  8.84it/s][A
Epoch 2 in training:   5%|▌         | 8/157 [00:00<00:17,  8.76it/s][A
Epoch 2 in training:   6%|▌         | 9/157 [00:01<00:16,  8.71it/s][A
Epoch 2 in training:   6%|▋         | 10/157 [00:01<00:16,  8.67it/s][A
Epoch 2 in training:   7%|▋         | 11/157 [00:01<00:16,  8.67it/s][A
Epoch 2 in training:   8%|▊         | 12/157 [00:01<00:16,  8.65it/s][A
Epoch 2 in training:   8%|▊         | 13/157 [00:01<00:16,  8.63it/s][A
Epoch 2 in training:   9%|▉         | 14/157 [00:01<00:16,  8.63it/s][A
Epoch 2 in training:  10%|▉         | 15/157 [00:01<00:16,  8.56it

Epoch 2/10 loss: 1.77



Epoch 3 in training:   0%|          | 0/157 [00:00<?, ?it/s][A
Epoch 3 in training:   1%|▏         | 2/157 [00:00<00:14, 10.95it/s][A
Epoch 3 in training:   3%|▎         | 4/157 [00:00<00:16,  9.43it/s][A
Epoch 3 in training:   3%|▎         | 5/157 [00:00<00:16,  9.18it/s][A
Epoch 3 in training:   4%|▍         | 6/157 [00:00<00:16,  9.01it/s][A
Epoch 3 in training:   4%|▍         | 7/157 [00:00<00:16,  8.90it/s][A
Epoch 3 in training:   5%|▌         | 8/157 [00:00<00:17,  8.74it/s][A
Epoch 3 in training:   6%|▌         | 9/157 [00:00<00:16,  8.75it/s][A
Epoch 3 in training:   6%|▋         | 10/157 [00:01<00:16,  8.72it/s][A
Epoch 3 in training:   7%|▋         | 11/157 [00:01<00:16,  8.69it/s][A
Epoch 3 in training:   8%|▊         | 12/157 [00:01<00:16,  8.67it/s][A
Epoch 3 in training:   8%|▊         | 13/157 [00:01<00:16,  8.66it/s][A
Epoch 3 in training:   9%|▉         | 14/157 [00:01<00:16,  8.65it/s][A
Epoch 3 in training:  10%|▉         | 15/157 [00:01<00:16,  8.62it

Epoch 3/10 loss: 1.58



Epoch 4 in training:   0%|          | 0/157 [00:00<?, ?it/s][A
Epoch 4 in training:   1%|▏         | 2/157 [00:00<00:14, 10.71it/s][A
Epoch 4 in training:   3%|▎         | 4/157 [00:00<00:16,  9.37it/s][A
Epoch 4 in training:   3%|▎         | 5/157 [00:00<00:16,  9.12it/s][A
Epoch 4 in training:   4%|▍         | 6/157 [00:00<00:16,  8.96it/s][A
Epoch 4 in training:   4%|▍         | 7/157 [00:00<00:16,  8.85it/s][A
Epoch 4 in training:   5%|▌         | 8/157 [00:00<00:16,  8.79it/s][A
Epoch 4 in training:   6%|▌         | 9/157 [00:01<00:16,  8.74it/s][A
Epoch 4 in training:   6%|▋         | 10/157 [00:01<00:16,  8.71it/s][A
Epoch 4 in training:   7%|▋         | 11/157 [00:01<00:17,  8.55it/s][A
Epoch 4 in training:   8%|▊         | 12/157 [00:01<00:16,  8.62it/s][A
Epoch 4 in training:   8%|▊         | 13/157 [00:01<00:16,  8.63it/s][A
Epoch 4 in training:   9%|▉         | 14/157 [00:01<00:16,  8.64it/s][A
Epoch 4 in training:  10%|▉         | 15/157 [00:01<00:16,  8.64it

Epoch 4/10 loss: 1.44



Epoch 5 in training:   0%|          | 0/157 [00:00<?, ?it/s][A
Epoch 5 in training:   1%|▏         | 2/157 [00:00<00:14, 10.81it/s][A
Epoch 5 in training:   3%|▎         | 4/157 [00:00<00:16,  9.45it/s][A
Epoch 5 in training:   3%|▎         | 5/157 [00:00<00:16,  9.07it/s][A
Epoch 5 in training:   4%|▍         | 6/157 [00:00<00:17,  8.85it/s][A
Epoch 5 in training:   4%|▍         | 7/157 [00:00<00:17,  8.58it/s][A
Epoch 5 in training:   5%|▌         | 8/157 [00:00<00:17,  8.62it/s][A
Epoch 5 in training:   6%|▌         | 9/157 [00:01<00:17,  8.62it/s][A
Epoch 5 in training:   6%|▋         | 10/157 [00:01<00:17,  8.64it/s][A
Epoch 5 in training:   7%|▋         | 11/157 [00:01<00:16,  8.61it/s][A
Epoch 5 in training:   8%|▊         | 12/157 [00:01<00:16,  8.62it/s][A
Epoch 5 in training:   8%|▊         | 13/157 [00:01<00:16,  8.61it/s][A
Epoch 5 in training:   9%|▉         | 14/157 [00:01<00:16,  8.59it/s][A
Epoch 5 in training:  10%|▉         | 15/157 [00:01<00:16,  8.60it

Epoch 5/10 loss: 1.37



Epoch 6 in training:   0%|          | 0/157 [00:00<?, ?it/s][A
Epoch 6 in training:   1%|▏         | 2/157 [00:00<00:14, 10.78it/s][A
Epoch 6 in training:   3%|▎         | 4/157 [00:00<00:16,  9.24it/s][A
Epoch 6 in training:   3%|▎         | 5/157 [00:00<00:16,  9.11it/s][A
Epoch 6 in training:   4%|▍         | 6/157 [00:00<00:16,  8.97it/s][A
Epoch 6 in training:   4%|▍         | 7/157 [00:00<00:16,  8.85it/s][A
Epoch 6 in training:   5%|▌         | 8/157 [00:00<00:16,  8.78it/s][A
Epoch 6 in training:   6%|▌         | 9/157 [00:01<00:16,  8.74it/s][A
Epoch 6 in training:   6%|▋         | 10/157 [00:01<00:16,  8.71it/s][A
Epoch 6 in training:   7%|▋         | 11/157 [00:01<00:16,  8.69it/s][A
Epoch 6 in training:   8%|▊         | 12/157 [00:01<00:16,  8.68it/s][A
Epoch 6 in training:   8%|▊         | 13/157 [00:01<00:16,  8.68it/s][A
Epoch 6 in training:   9%|▉         | 14/157 [00:01<00:16,  8.64it/s][A
Epoch 6 in training:  10%|▉         | 15/157 [00:01<00:16,  8.63it

Epoch 6/10 loss: 1.28



Epoch 7 in training:   0%|          | 0/157 [00:00<?, ?it/s][A
Epoch 7 in training:   1%|▏         | 2/157 [00:00<00:14, 10.75it/s][A
Epoch 7 in training:   3%|▎         | 4/157 [00:00<00:16,  9.06it/s][A
Epoch 7 in training:   3%|▎         | 5/157 [00:00<00:17,  8.84it/s][A
Epoch 7 in training:   4%|▍         | 6/157 [00:00<00:17,  8.85it/s][A
Epoch 7 in training:   4%|▍         | 7/157 [00:00<00:17,  8.79it/s][A
Epoch 7 in training:   5%|▌         | 8/157 [00:00<00:17,  8.74it/s][A
Epoch 7 in training:   6%|▌         | 9/157 [00:01<00:17,  8.70it/s][A
Epoch 7 in training:   6%|▋         | 10/157 [00:01<00:16,  8.67it/s][A
Epoch 7 in training:   7%|▋         | 11/157 [00:01<00:16,  8.66it/s][A
Epoch 7 in training:   8%|▊         | 12/157 [00:01<00:16,  8.65it/s][A
Epoch 7 in training:   8%|▊         | 13/157 [00:01<00:16,  8.49it/s][A
Epoch 7 in training:   9%|▉         | 14/157 [00:01<00:16,  8.46it/s][A
Epoch 7 in training:  10%|▉         | 15/157 [00:01<00:17,  7.99it

Epoch 7/10 loss: 1.19



Epoch 8 in training:   0%|          | 0/157 [00:00<?, ?it/s][A
Epoch 8 in training:   1%|▏         | 2/157 [00:00<00:14, 10.60it/s][A
Epoch 8 in training:   3%|▎         | 4/157 [00:00<00:16,  9.35it/s][A
Epoch 8 in training:   3%|▎         | 5/157 [00:00<00:16,  9.11it/s][A
Epoch 8 in training:   4%|▍         | 6/157 [00:00<00:16,  8.97it/s][A
Epoch 8 in training:   4%|▍         | 7/157 [00:00<00:16,  8.83it/s][A
Epoch 8 in training:   5%|▌         | 8/157 [00:00<00:16,  8.78it/s][A
Epoch 8 in training:   6%|▌         | 9/157 [00:01<00:16,  8.73it/s][A
Epoch 8 in training:   6%|▋         | 10/157 [00:01<00:16,  8.71it/s][A
Epoch 8 in training:   7%|▋         | 11/157 [00:01<00:16,  8.69it/s][A
Epoch 8 in training:   8%|▊         | 12/157 [00:01<00:16,  8.66it/s][A
Epoch 8 in training:   8%|▊         | 13/157 [00:01<00:16,  8.65it/s][A
Epoch 8 in training:   9%|▉         | 14/157 [00:01<00:16,  8.64it/s][A
Epoch 8 in training:  10%|▉         | 15/157 [00:01<00:16,  8.64it

Epoch 8/10 loss: 1.12



Epoch 9 in training:   0%|          | 0/157 [00:00<?, ?it/s][A
Epoch 9 in training:   1%|▏         | 2/157 [00:00<00:16,  9.41it/s][A
Epoch 9 in training:   2%|▏         | 3/157 [00:00<00:16,  9.29it/s][A
Epoch 9 in training:   3%|▎         | 4/157 [00:00<00:16,  9.01it/s][A
Epoch 9 in training:   3%|▎         | 5/157 [00:00<00:17,  8.88it/s][A
Epoch 9 in training:   4%|▍         | 6/157 [00:00<00:17,  8.68it/s][A
Epoch 9 in training:   4%|▍         | 7/157 [00:00<00:18,  7.98it/s][A
Epoch 9 in training:   5%|▌         | 8/157 [00:00<00:18,  8.04it/s][A
Epoch 9 in training:   6%|▌         | 9/157 [00:01<00:18,  8.00it/s][A
Epoch 9 in training:   6%|▋         | 10/157 [00:01<00:17,  8.36it/s][A
Epoch 9 in training:   7%|▋         | 11/157 [00:01<00:17,  8.43it/s][A
Epoch 9 in training:   8%|▊         | 12/157 [00:01<00:17,  8.48it/s][A
Epoch 9 in training:   8%|▊         | 13/157 [00:01<00:16,  8.52it/s][A
Epoch 9 in training:   9%|▉         | 14/157 [00:01<00:16,  8.54it/

Epoch 9/10 loss: 1.05



Epoch 10 in training:   0%|          | 0/157 [00:00<?, ?it/s][A
Epoch 10 in training:   1%|▏         | 2/157 [00:00<00:14, 10.49it/s][A
Epoch 10 in training:   3%|▎         | 4/157 [00:00<00:16,  9.16it/s][A
Epoch 10 in training:   3%|▎         | 5/157 [00:00<00:16,  9.02it/s][A
Epoch 10 in training:   4%|▍         | 6/157 [00:00<00:16,  8.90it/s][A
Epoch 10 in training:   4%|▍         | 7/157 [00:00<00:17,  8.82it/s][A
Epoch 10 in training:   5%|▌         | 8/157 [00:00<00:17,  8.75it/s][A
Epoch 10 in training:   6%|▌         | 9/157 [00:01<00:16,  8.72it/s][A
Epoch 10 in training:   6%|▋         | 10/157 [00:01<00:16,  8.69it/s][A
Epoch 10 in training:   7%|▋         | 11/157 [00:01<00:16,  8.67it/s][A
Epoch 10 in training:   8%|▊         | 12/157 [00:01<00:16,  8.65it/s][A
Epoch 10 in training:   8%|▊         | 13/157 [00:01<00:16,  8.62it/s][A
Epoch 10 in training:   9%|▉         | 14/157 [00:01<00:16,  8.62it/s][A
Epoch 10 in training:  10%|▉         | 15/157 [00:01<

Epoch 10/10 loss: 0.98


Testing: 100%|██████████| 782/782 [00:43<00:00, 17.99it/s]

Test loss: 1.27
Test accuracy: 54.97%





## Q2 Vision Transformer

Function to form patches out of  the images

In [7]:
def make_patches(images, n_patches):
    n, c, h, w = images.shape

    assert h == w, "Method is implemented for square images only"

    patches = tc.zeros(n, n_patches ** 2, h * w * c // n_patches ** 2)
    patch_size = h // n_patches

    for idx, image in enumerate(images):
        for i in range(n_patches):
            for j in range(n_patches):
                patch = image[:, i * patch_size: (i + 1) * patch_size, j * patch_size: (j + 1) * patch_size]
                patches[idx, i * n_patches + j] = patch.flatten()
    return patches

Multi head attention layer

In [8]:

class MSA(nn.Module):
    def __init__(self, d, n_heads=2):
        super(MSA, self).__init__()
        self.d = d
        self.n_heads = n_heads

        assert d % n_heads == 0, f"Can't divide dimension {d} into {n_heads} heads"

        d_head = int(d / n_heads)

        # Creating head number of linear layers for queries, keys and values
        self.q_mappings = nn.ModuleList([nn.Linear(d_head, d_head) for _ in range(self.n_heads)])
        for linear_layer in self.q_mappings:
          nn.init.xavier_uniform_(linear_layer.weight)
        self.k_mappings = nn.ModuleList([nn.Linear(d_head, d_head) for _ in range(self.n_heads)])
        for linear_layer in self.q_mappings:
          nn.init.xavier_uniform_(linear_layer.weight)
        self.v_mappings = nn.ModuleList([nn.Linear(d_head, d_head) for _ in range(self.n_heads)])
        for linear_layer in self.q_mappings:
          nn.init.xavier_uniform_(linear_layer.weight)
        self.d_head = d_head
        self.softmax = nn.Softmax(dim=-1)

    def forward(self, sequences):
        N, seq_length, token_dim = sequences.size()
        seq_heads = sequences.view(N, seq_length, self.n_heads, -1).transpose(1, 2)  # (N, n_heads, seq_length, d_head)

        q_heads = [q_mapping(seq_heads[:, i]) for i, q_mapping in enumerate(self.q_mappings)]  # List of (N, seq_length, d_head)
        k_heads = [k_mapping(seq_heads[:, i]) for i, k_mapping in enumerate(self.k_mappings)]  # List of (N, seq_length, d_head)
        v_heads = [v_mapping(seq_heads[:, i]) for i, v_mapping in enumerate(self.v_mappings)]  # List of (N, seq_length, d_head)

        q = tc.cat([q.unsqueeze(2) for q in q_heads], dim=2)  # (N, seq_length, n_heads, d_head)
        k = tc.cat([k.unsqueeze(2) for k in k_heads], dim=2)  # (N, seq_length, n_heads, d_head)
        v = tc.cat([v.unsqueeze(2) for v in v_heads], dim=2)  # (N, seq_length, n_heads, d_head)

        attention = self.softmax(tc.matmul(q, k.transpose(-1, -2)) / (self.d ** 0.5))  # (N, n_heads, seq_length, seq_length)

        attended_values = tc.matmul(attention, v)  # (N, n_heads, seq_length, d_head)

        result = attended_values.transpose(1, 2).reshape(N, seq_length, -1)  # (N, seq_length, d)
        return result

Positional encodings

In [9]:
def get_positional_embeddings(sequence_length, d):
    result = tc.ones(sequence_length, d)
    for i in range(sequence_length):
        for j in range(d):
            result[i][j] = np.sin(i / (10000 ** (j / d))) if j % 2 == 0 else np.cos(i / (10000 ** ((j - 1) / d)))
    return result

Encoder Block

In [10]:
class ViTBlock(nn.Module):
    def __init__(self, hidden_d, n_heads, mlp_ratio=4):
        super(ViTBlock, self).__init__()
        self.hidden_d = hidden_d
        self.n_heads = n_heads

        self.norm1 = nn.LayerNorm(hidden_d)
        self.msa = MSA(hidden_d, n_heads)
        self.norm2 = nn.LayerNorm(hidden_d)
        self.mlp = nn.Sequential(
            nn.Linear(hidden_d, mlp_ratio * hidden_d),
            nn.GELU(),
            nn.Linear(mlp_ratio * hidden_d, hidden_d)
        )
        for idx, module in enumerate(self.mlp):
          if isinstance(module, nn.Linear):
              nn.init.xavier_uniform_(module.weight)


    def forward(self, x):
        out = x + self.msa(self.norm1(x))
        out = out + self.mlp(self.norm2(out))
        return out

Encoder transformer

In [11]:
class ViT(nn.Module):
    def __init__(self, chw, n_patches=7, n_blocks=2, hidden_d=8, n_heads=2, out_d=10):
        super(ViT, self).__init__()

        self.chw = chw # ( C , H , W )
        self.n_patches = n_patches
        self.n_blocks = n_blocks
        self.n_heads = n_heads
        self.hidden_d = hidden_d

        assert chw[1] % n_patches == 0, "Input shape not entirely divisible by number of patches"
        assert chw[2] % n_patches == 0, "Input shape not entirely divisible by number of patches"
        self.patch_size = (chw[1] / n_patches, chw[2] / n_patches)

        # Linear layer to map from patches to hidden dim
        self.input_d = int(chw[0] * self.patch_size[0] * self.patch_size[1])
        self.linear_mapper = nn.Linear(self.input_d, self.hidden_d)

        # Classification token added as a param
        self.class_token = nn.Parameter(tc.rand(1, self.hidden_d))

        # Adding the constant positional embedding
        self.register_buffer('positional_embeddings', get_positional_embeddings(n_patches ** 2 + 1, hidden_d), persistent=False)

        # Setting number of encoder blocks
        self.blocks = nn.ModuleList([ViTBlock(hidden_d, n_heads) for _ in range(n_blocks)])

        # Mlp at final layer for encoded classification token
        self.mlp = nn.Sequential(
            nn.Linear(self.hidden_d, out_d),
            nn.Softmax(dim=-1)
        )

    def forward(self, images):
        n, c, h, w = images.shape
        patches = make_patches(images, self.n_patches).to(self.positional_embeddings.device)

        tokens = self.linear_mapper(patches)

        tokens = tc.cat((self.class_token.expand(n, 1, -1), tokens), dim=1)

        out = tokens + self.positional_embeddings.repeat(n, 1, 1)

        # Passing encoding as input to next layers
        for block in self.blocks:
            out = block(out)

        out = out[:, 0]

        return self.mlp(out)

In [12]:
def main():
    transform = ToTensor()

    train_set = CIFAR10(root="data", train=False, download=True, transform=ToTensor())

    test_set = CIFAR10(root="data", train=True, download=True, transform=ToTensor())
    train_loader = DataLoader(train_set, shuffle=True, batch_size=256)
    test_loader = DataLoader(test_set, shuffle=False, batch_size=256)

    device = tc.device("cuda" if tc.cuda.is_available() else "cpu")
    print("Using device: ", device, f"({tc.cuda.get_device_name(device)})" if tc.cuda.is_available() else "")
    model = ViT((3, 32, 32), n_patches=4, n_blocks=6, hidden_d=32, n_heads=4, out_d=10).to(device)

    N_EPOCHS = 10
    LR = 0.001

    optimizer = optim.Adam(model.parameters(), lr=LR)
    criterion = nn.CrossEntropyLoss()
    for epoch in trange(N_EPOCHS, desc="Training"):
        train_loss = 0.0
        for batch in tqdm(train_loader, desc=f"Epoch {epoch + 1} in training", leave=False):
            x, y = batch
            x, y = x.to(device), y.to(device)
            y_hat = model(x)

            loss = criterion(y_hat, y)

            train_loss += loss.detach().cpu().item() / len(train_loader)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

        print(f"Epoch {epoch + 1}/{N_EPOCHS} loss: {train_loss:.2f}")

    # Test loop
    with tc.no_grad():
        correct, total = 0, 0
        test_loss = 0.0
        for batch in tqdm(test_loader, desc="Testing"):
            x, y = batch
            x, y = x.to(device), y.to(device)
            y_hat = model(x)
            loss = criterion(y_hat, y)
            test_loss += loss.detach().cpu().item() / len(test_loader)

            correct += tc.sum(tc.argmax(y_hat, dim=1) == y).detach().cpu().item()
            total += len(x)
        print(f"Test loss: {test_loss:.2f}")
        print(f"Test accuracy: {correct / total * 100:.2f}%")

In [13]:
tc.manual_seed(80)
main()

Files already downloaded and verified
Files already downloaded and verified
Using device:  cuda (Tesla T4)


Training:   0%|          | 0/10 [00:00<?, ?it/s]
Epoch 1 in training:   0%|          | 0/40 [00:00<?, ?it/s][A
Epoch 1 in training:   2%|▎         | 1/40 [00:00<00:13,  2.89it/s][A
Epoch 1 in training:   5%|▌         | 2/40 [00:00<00:14,  2.68it/s][A
Epoch 1 in training:   8%|▊         | 3/40 [00:01<00:13,  2.72it/s][A
Epoch 1 in training:  10%|█         | 4/40 [00:01<00:12,  2.81it/s][A
Epoch 1 in training:  12%|█▎        | 5/40 [00:01<00:12,  2.69it/s][A
Epoch 1 in training:  15%|█▌        | 6/40 [00:02<00:13,  2.56it/s][A
Epoch 1 in training:  18%|█▊        | 7/40 [00:02<00:11,  2.83it/s][A
Epoch 1 in training:  20%|██        | 8/40 [00:02<00:10,  3.06it/s][A
Epoch 1 in training:  22%|██▎       | 9/40 [00:03<00:09,  3.24it/s][A
Epoch 1 in training:  25%|██▌       | 10/40 [00:03<00:08,  3.41it/s][A
Epoch 1 in training:  28%|██▊       | 11/40 [00:03<00:08,  3.49it/s][A
Epoch 1 in training:  30%|███       | 12/40 [00:03<00:08,  3.50it/s][A
Epoch 1 in training:  32%|███▎   

Epoch 1/10 loss: 2.26



Epoch 2 in training:   0%|          | 0/40 [00:00<?, ?it/s][A
Epoch 2 in training:   2%|▎         | 1/40 [00:00<00:09,  3.93it/s][A
Epoch 2 in training:   5%|▌         | 2/40 [00:00<00:09,  3.85it/s][A
Epoch 2 in training:   8%|▊         | 3/40 [00:00<00:09,  3.80it/s][A
Epoch 2 in training:  10%|█         | 4/40 [00:01<00:09,  3.77it/s][A
Epoch 2 in training:  12%|█▎        | 5/40 [00:01<00:10,  3.28it/s][A
Epoch 2 in training:  15%|█▌        | 6/40 [00:01<00:11,  3.06it/s][A
Epoch 2 in training:  18%|█▊        | 7/40 [00:02<00:11,  2.96it/s][A
Epoch 2 in training:  20%|██        | 8/40 [00:02<00:11,  2.86it/s][A
Epoch 2 in training:  22%|██▎       | 9/40 [00:02<00:11,  2.69it/s][A
Epoch 2 in training:  25%|██▌       | 10/40 [00:03<00:10,  2.77it/s][A
Epoch 2 in training:  28%|██▊       | 11/40 [00:03<00:09,  3.01it/s][A
Epoch 2 in training:  30%|███       | 12/40 [00:03<00:08,  3.17it/s][A
Epoch 2 in training:  32%|███▎      | 13/40 [00:04<00:08,  3.31it/s][A
Epoch 2 i

Epoch 2/10 loss: 2.20



Epoch 3 in training:   0%|          | 0/40 [00:00<?, ?it/s][A
Epoch 3 in training:   2%|▎         | 1/40 [00:00<00:10,  3.89it/s][A
Epoch 3 in training:   5%|▌         | 2/40 [00:00<00:09,  3.83it/s][A
Epoch 3 in training:   8%|▊         | 3/40 [00:00<00:09,  3.83it/s][A
Epoch 3 in training:  10%|█         | 4/40 [00:01<00:09,  3.81it/s][A
Epoch 3 in training:  12%|█▎        | 5/40 [00:01<00:09,  3.81it/s][A
Epoch 3 in training:  15%|█▌        | 6/40 [00:01<00:08,  3.79it/s][A
Epoch 3 in training:  18%|█▊        | 7/40 [00:01<00:08,  3.81it/s][A
Epoch 3 in training:  20%|██        | 8/40 [00:02<00:08,  3.59it/s][A
Epoch 3 in training:  22%|██▎       | 9/40 [00:02<00:09,  3.19it/s][A
Epoch 3 in training:  25%|██▌       | 10/40 [00:02<00:09,  3.07it/s][A
Epoch 3 in training:  28%|██▊       | 11/40 [00:03<00:09,  2.94it/s][A
Epoch 3 in training:  30%|███       | 12/40 [00:03<00:10,  2.77it/s][A
Epoch 3 in training:  32%|███▎      | 13/40 [00:04<00:10,  2.62it/s][A
Epoch 3 i

Epoch 3/10 loss: 2.15



Epoch 4 in training:   0%|          | 0/40 [00:00<?, ?it/s][A
Epoch 4 in training:   2%|▎         | 1/40 [00:00<00:10,  3.65it/s][A
Epoch 4 in training:   5%|▌         | 2/40 [00:00<00:10,  3.71it/s][A
Epoch 4 in training:   8%|▊         | 3/40 [00:00<00:09,  3.78it/s][A
Epoch 4 in training:  10%|█         | 4/40 [00:01<00:09,  3.82it/s][A
Epoch 4 in training:  12%|█▎        | 5/40 [00:01<00:09,  3.75it/s][A
Epoch 4 in training:  15%|█▌        | 6/40 [00:01<00:09,  3.77it/s][A
Epoch 4 in training:  18%|█▊        | 7/40 [00:01<00:08,  3.73it/s][A
Epoch 4 in training:  20%|██        | 8/40 [00:02<00:08,  3.76it/s][A
Epoch 4 in training:  22%|██▎       | 9/40 [00:02<00:08,  3.74it/s][A
Epoch 4 in training:  25%|██▌       | 10/40 [00:02<00:08,  3.74it/s][A
Epoch 4 in training:  28%|██▊       | 11/40 [00:02<00:07,  3.70it/s][A
Epoch 4 in training:  30%|███       | 12/40 [00:03<00:08,  3.31it/s][A
Epoch 4 in training:  32%|███▎      | 13/40 [00:03<00:08,  3.13it/s][A
Epoch 4 i

Epoch 4/10 loss: 2.13



Epoch 5 in training:   0%|          | 0/40 [00:00<?, ?it/s][A
Epoch 5 in training:   2%|▎         | 1/40 [00:00<00:10,  3.70it/s][A
Epoch 5 in training:   5%|▌         | 2/40 [00:00<00:10,  3.68it/s][A
Epoch 5 in training:   8%|▊         | 3/40 [00:00<00:09,  3.78it/s][A
Epoch 5 in training:  10%|█         | 4/40 [00:01<00:09,  3.75it/s][A
Epoch 5 in training:  12%|█▎        | 5/40 [00:01<00:09,  3.73it/s][A
Epoch 5 in training:  15%|█▌        | 6/40 [00:01<00:09,  3.77it/s][A
Epoch 5 in training:  18%|█▊        | 7/40 [00:01<00:08,  3.82it/s][A
Epoch 5 in training:  20%|██        | 8/40 [00:02<00:08,  3.84it/s][A
Epoch 5 in training:  22%|██▎       | 9/40 [00:02<00:08,  3.77it/s][A
Epoch 5 in training:  25%|██▌       | 10/40 [00:02<00:07,  3.81it/s][A
Epoch 5 in training:  28%|██▊       | 11/40 [00:02<00:07,  3.81it/s][A
Epoch 5 in training:  30%|███       | 12/40 [00:03<00:07,  3.85it/s][A
Epoch 5 in training:  32%|███▎      | 13/40 [00:03<00:07,  3.79it/s][A
Epoch 5 i

Epoch 5/10 loss: 2.12



Epoch 6 in training:   0%|          | 0/40 [00:00<?, ?it/s][A
Epoch 6 in training:   2%|▎         | 1/40 [00:00<00:10,  3.79it/s][A
Epoch 6 in training:   5%|▌         | 2/40 [00:00<00:10,  3.66it/s][A
Epoch 6 in training:   8%|▊         | 3/40 [00:00<00:09,  3.73it/s][A
Epoch 6 in training:  10%|█         | 4/40 [00:01<00:09,  3.79it/s][A
Epoch 6 in training:  12%|█▎        | 5/40 [00:01<00:09,  3.68it/s][A
Epoch 6 in training:  15%|█▌        | 6/40 [00:01<00:09,  3.68it/s][A
Epoch 6 in training:  18%|█▊        | 7/40 [00:01<00:08,  3.74it/s][A
Epoch 6 in training:  20%|██        | 8/40 [00:02<00:08,  3.78it/s][A
Epoch 6 in training:  22%|██▎       | 9/40 [00:02<00:08,  3.79it/s][A
Epoch 6 in training:  25%|██▌       | 10/40 [00:02<00:07,  3.76it/s][A
Epoch 6 in training:  28%|██▊       | 11/40 [00:02<00:07,  3.81it/s][A
Epoch 6 in training:  30%|███       | 12/40 [00:03<00:07,  3.78it/s][A
Epoch 6 in training:  32%|███▎      | 13/40 [00:03<00:07,  3.81it/s][A
Epoch 6 i

Epoch 6/10 loss: 2.11



Epoch 7 in training:   0%|          | 0/40 [00:00<?, ?it/s][A
Epoch 7 in training:   2%|▎         | 1/40 [00:00<00:10,  3.67it/s][A
Epoch 7 in training:   5%|▌         | 2/40 [00:00<00:10,  3.76it/s][A
Epoch 7 in training:   8%|▊         | 3/40 [00:00<00:09,  3.73it/s][A
Epoch 7 in training:  10%|█         | 4/40 [00:01<00:09,  3.76it/s][A
Epoch 7 in training:  12%|█▎        | 5/40 [00:01<00:09,  3.76it/s][A
Epoch 7 in training:  15%|█▌        | 6/40 [00:01<00:09,  3.69it/s][A
Epoch 7 in training:  18%|█▊        | 7/40 [00:01<00:09,  3.59it/s][A
Epoch 7 in training:  20%|██        | 8/40 [00:02<00:08,  3.61it/s][A
Epoch 7 in training:  22%|██▎       | 9/40 [00:02<00:08,  3.62it/s][A
Epoch 7 in training:  25%|██▌       | 10/40 [00:02<00:08,  3.63it/s][A
Epoch 7 in training:  28%|██▊       | 11/40 [00:03<00:07,  3.63it/s][A
Epoch 7 in training:  30%|███       | 12/40 [00:03<00:07,  3.65it/s][A
Epoch 7 in training:  32%|███▎      | 13/40 [00:03<00:07,  3.69it/s][A
Epoch 7 i

Epoch 7/10 loss: 2.08



Epoch 8 in training:   0%|          | 0/40 [00:00<?, ?it/s][A
Epoch 8 in training:   2%|▎         | 1/40 [00:00<00:10,  3.80it/s][A
Epoch 8 in training:   5%|▌         | 2/40 [00:00<00:09,  3.84it/s][A
Epoch 8 in training:   8%|▊         | 3/40 [00:00<00:09,  3.87it/s][A
Epoch 8 in training:  10%|█         | 4/40 [00:01<00:09,  3.79it/s][A
Epoch 8 in training:  12%|█▎        | 5/40 [00:01<00:09,  3.79it/s][A
Epoch 8 in training:  15%|█▌        | 6/40 [00:01<00:08,  3.80it/s][A
Epoch 8 in training:  18%|█▊        | 7/40 [00:01<00:08,  3.82it/s][A
Epoch 8 in training:  20%|██        | 8/40 [00:02<00:08,  3.71it/s][A
Epoch 8 in training:  22%|██▎       | 9/40 [00:02<00:08,  3.75it/s][A
Epoch 8 in training:  25%|██▌       | 10/40 [00:02<00:07,  3.76it/s][A
Epoch 8 in training:  28%|██▊       | 11/40 [00:02<00:07,  3.77it/s][A
Epoch 8 in training:  30%|███       | 12/40 [00:03<00:07,  3.73it/s][A
Epoch 8 in training:  32%|███▎      | 13/40 [00:03<00:07,  3.78it/s][A
Epoch 8 i

Epoch 8/10 loss: 2.08



Epoch 9 in training:   0%|          | 0/40 [00:00<?, ?it/s][A
Epoch 9 in training:   2%|▎         | 1/40 [00:00<00:10,  3.65it/s][A
Epoch 9 in training:   5%|▌         | 2/40 [00:00<00:10,  3.77it/s][A
Epoch 9 in training:   8%|▊         | 3/40 [00:00<00:09,  3.80it/s][A
Epoch 9 in training:  10%|█         | 4/40 [00:01<00:09,  3.77it/s][A
Epoch 9 in training:  12%|█▎        | 5/40 [00:01<00:09,  3.74it/s][A
Epoch 9 in training:  15%|█▌        | 6/40 [00:01<00:09,  3.75it/s][A
Epoch 9 in training:  18%|█▊        | 7/40 [00:01<00:08,  3.78it/s][A
Epoch 9 in training:  20%|██        | 8/40 [00:02<00:08,  3.69it/s][A
Epoch 9 in training:  22%|██▎       | 9/40 [00:02<00:08,  3.71it/s][A
Epoch 9 in training:  25%|██▌       | 10/40 [00:02<00:07,  3.75it/s][A
Epoch 9 in training:  28%|██▊       | 11/40 [00:02<00:07,  3.77it/s][A
Epoch 9 in training:  30%|███       | 12/40 [00:03<00:07,  3.76it/s][A
Epoch 9 in training:  32%|███▎      | 13/40 [00:03<00:07,  3.78it/s][A
Epoch 9 i

Epoch 9/10 loss: 2.08



Epoch 10 in training:   0%|          | 0/40 [00:00<?, ?it/s][A
Epoch 10 in training:   2%|▎         | 1/40 [00:00<00:10,  3.68it/s][A
Epoch 10 in training:   5%|▌         | 2/40 [00:00<00:09,  3.80it/s][A
Epoch 10 in training:   8%|▊         | 3/40 [00:00<00:09,  3.84it/s][A
Epoch 10 in training:  10%|█         | 4/40 [00:01<00:09,  3.82it/s][A
Epoch 10 in training:  12%|█▎        | 5/40 [00:01<00:09,  3.80it/s][A
Epoch 10 in training:  15%|█▌        | 6/40 [00:01<00:08,  3.84it/s][A
Epoch 10 in training:  18%|█▊        | 7/40 [00:01<00:08,  3.82it/s][A
Epoch 10 in training:  20%|██        | 8/40 [00:02<00:08,  3.83it/s][A
Epoch 10 in training:  22%|██▎       | 9/40 [00:02<00:08,  3.79it/s][A
Epoch 10 in training:  25%|██▌       | 10/40 [00:02<00:07,  3.82it/s][A
Epoch 10 in training:  28%|██▊       | 11/40 [00:02<00:07,  3.84it/s][A
Epoch 10 in training:  30%|███       | 12/40 [00:03<00:07,  3.82it/s][A
Epoch 10 in training:  32%|███▎      | 13/40 [00:03<00:07,  3.80it/s

Epoch 10/10 loss: 2.06


Testing: 100%|██████████| 196/196 [00:48<00:00,  4.01it/s]

Test loss: 2.08
Test accuracy: 37.95%





### Number of Parameters in each model

In [14]:
def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)



In [19]:
model = ViT((3, 32, 32), n_patches=4, n_blocks=6, hidden_d=32, n_heads=4, out_d=10)
total_params = count_parameters(model)
print("Number of parameters:", total_params)

Number of parameters: 62602


In [17]:
model = Att_CNN()
total_params = count_parameters(model)
print("Number of parameters:", total_params)

Number of parameters: 67330


#### The CNN with attention is giving a better accuracy, this may be due to the low resolution of the input images. VIT would mostl likely perform better when image resolution is high as it would make intuitive sense to attend to portions of the image rather than every single pixel.