In [32]:
import os
import sys

import torch
from torchinfo import summary

from pathlib import Path

super_directory = os.path.abspath('..')
sys.path.append(super_directory)

from data_setup import get_dataloaders
from vit import ViT

In [33]:
# Device agnostic code
device = 'cuda' if torch.cuda.is_available() else 'cpu'

In [34]:
# Hyperparameters from the paper (except batch size - computational bottleneck)

BATCH_SIZE = 32

# Patches
PATCH_SIZE = (7, 7)
NUM_PATCHES = int((28 / 7) ** 2)

# Patches to Embeddings
EMBED_DIMS = 48

# Number of Attention heads
NUM_ATTENTION_HEADS = 12

# Number of hidden layers in MLP block
RATIO_HIDDEN_MLP = 4

# Number of encoder blocks
NUM_ENC_BLOCKS = 12

# Data

In [35]:
# Data paths
data_path = Path('../data/')

In [36]:
# Getting the dataloaders
train_dataloader, test_dataloader, class_labels = get_dataloaders(batch_size=BATCH_SIZE)

In [37]:
# Get X and y from the first batch
batch_X, batch_y = next(iter(train_dataloader))

print(f"Batched images shape: {batch_X.shape} -> (batch_dim, color_channels, image_height, image_width)")

Batched images shape: torch.Size([32, 1, 28, 28]) -> (batch_dim, color_channels, image_height, image_width)


# ViT

In [38]:
# Initializing the custom created ViT model
model = ViT(in_channels=1,
            out_dims=len(class_labels),
            patch_size=PATCH_SIZE,
            num_patches=NUM_PATCHES,
            embed_dims=EMBED_DIMS,
            num_attn_heads=NUM_ATTENTION_HEADS,
            ratio_hidden_mlp=RATIO_HIDDEN_MLP,
            num_encoder_blocks=NUM_ENC_BLOCKS)

In [39]:
# Model summary
summary(model,
        input_size=(1, 1, 28, 28),                                                # Batch dim, color channels, height, width
        col_names=['input_size', 'output_size', 'num_params', 'trainable'],
        col_width=15,
        row_settings=['var_names'])

Layer (type (var_name))                                 Input Shape     Output Shape    Param #         Trainable
ViT (ViT)                                               [1, 1, 28, 28]  [1, 10]         --              True
├─DataEmbeddings (data_embeddings)                      [1, 1, 28, 28]  [1, 17, 48]     864             True
│    └─Conv2d (conv_layer)                              [1, 1, 28, 28]  [1, 48, 4, 4]   2,400           True
│    └─Flatten (flatten)                                [1, 48, 4, 4]   [1, 48, 16]     --              --
├─Sequential (encoder_blocks)                           [1, 17, 48]     [1, 17, 48]     --              True
│    └─EncoderBlock (0)                                 [1, 17, 48]     [1, 17, 48]     --              True
│    │    └─LayerNorm (layer_norm)                      [1, 17, 48]     [1, 17, 48]     96              True
│    │    └─MultiheadAttention (multi_head_attn)        --              [1, 17, 48]     9,408           True
│    │    └─Laye

In [40]:
# Output of ViT
model = model.to(device)
batch_X = batch_X.to(device)

vit_out = model(batch_X)
print(f"Model output shape: {vit_out.shape} -> (batch_dim, num_classes)")

Model output shape: torch.Size([32, 10]) -> (batch_dim, num_classes)


In [41]:
# Calculating probabilities of each class for the first batch
softmax = torch.nn.Softmax(dim=1)
prob_out = softmax(vit_out)
prob_out

tensor([[0.0611, 0.0810, 0.0973, 0.1126, 0.0414, 0.1122, 0.1472, 0.1249, 0.1191,
         0.1030],
        [0.0665, 0.0845, 0.1014, 0.1094, 0.0426, 0.1035, 0.1502, 0.1241, 0.1191,
         0.0987],
        [0.0657, 0.0869, 0.1065, 0.1038, 0.0449, 0.1034, 0.1560, 0.1210, 0.1162,
         0.0955],
        [0.0598, 0.0860, 0.0936, 0.1097, 0.0417, 0.1183, 0.1442, 0.1293, 0.1141,
         0.1033],
        [0.0633, 0.0810, 0.0912, 0.1148, 0.0395, 0.1141, 0.1448, 0.1274, 0.1208,
         0.1031],
        [0.0638, 0.0835, 0.1001, 0.1099, 0.0415, 0.1059, 0.1531, 0.1217, 0.1215,
         0.0990],
        [0.0631, 0.0842, 0.1029, 0.1077, 0.0453, 0.1124, 0.1437, 0.1288, 0.1123,
         0.0996],
        [0.0625, 0.0815, 0.1018, 0.1089, 0.0438, 0.1157, 0.1406, 0.1260, 0.1174,
         0.1018],
        [0.0630, 0.0831, 0.0998, 0.1124, 0.0419, 0.1068, 0.1478, 0.1248, 0.1190,
         0.1013],
        [0.0637, 0.0791, 0.0949, 0.1156, 0.0409, 0.1130, 0.1415, 0.1274, 0.1208,
         0.1030],
        [0

In [42]:
# Predicted classes from the un-trained model
pred_class_idx = torch.argmax(prob_out, dim=1)
pred_class_idx

tensor([6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6,
        6, 6, 6, 6, 6, 6, 6, 6], device='cuda:0')

In [43]:
# Actual classes idx
batch_y

tensor([7, 3, 6, 8, 5, 2, 8, 5, 6, 5, 7, 4, 7, 4, 2, 5, 5, 5, 9, 6, 6, 5, 4, 9,
        2, 4, 6, 1, 5, 3, 2, 9])