In [1]:
import os
import sys

import torch

from pathlib import Path

super_directory = os.path.abspath('..')
sys.path.append(super_directory)

from data_setup import get_dataloaders
from vit import DataEmbeddings, EncoderBlock

In [2]:
# Device agnostic code
device = 'cuda' if torch.cuda.is_available() else 'cpu'

In [3]:
# Hyperparameters
BATCH_SIZE = 32

# Patches
PATCH_SIZE = (7, 7)
NUM_PATCHES = int((28 / 7) ** 2)

# Patches to Embeddings
EMBED_DIMS = 48

# Number of Attention heads
NUM_ATTENTION_HEADS = 12

# Number of hidden layers in MLP block
RATIO_HIDDEN_MLP = 4

# Data

In [4]:
# Data paths
data_path = Path('../data/')

In [5]:
# Getting the dataloaders
train_dataloader, test_dataloader, class_labels = get_dataloaders(batch_size=BATCH_SIZE)

In [6]:
# Get X and y from the first batch
batch_X, batch_y = next(iter(train_dataloader))

print(f"Batched images shape: {batch_X.shape} -> (batch_dim, color_channels, image_height, image_width)")

Batched images shape: torch.Size([32, 1, 28, 28]) -> (batch_dim, color_channels, image_height, image_width)


# Encoder Input

In [7]:
# Module to process data embeddings
data_embed_module = DataEmbeddings(in_channels=1,
                                   patch_size=PATCH_SIZE,
                                   num_patches=NUM_PATCHES,
                                   embed_dims=EMBED_DIMS)

In [8]:
# Data embeddings and its shape
data_embeddings = data_embed_module(batch_X)
print(f"Data embeddings shape: {data_embeddings.shape} -> (batch_dim, num_patches + class_embedding, embedding_dims)")

Data embeddings shape: torch.Size([32, 17, 48]) -> (batch_dim, num_patches + class_embedding, embedding_dims)


# Encoder

In [9]:
# Transformer encoder module
encoder_block = EncoderBlock(embed_dims=EMBED_DIMS,
                             num_attn_heads=NUM_ATTENTION_HEADS,
                             ratio_hidden_mlp=RATIO_HIDDEN_MLP,
                             batch_first=True)

In [10]:
# Transformer encoder output
encoder_output = encoder_block(data_embeddings)
print(f"Encoder output shape: {encoder_output.shape} -> (batch_dim, num_patches + class_embedding, embedding_dims)")

Encoder output shape: torch.Size([32, 17, 48]) -> (batch_dim, num_patches + class_embedding, embedding_dims)


# Classifier

In [11]:
# Selecting the learnable embedding from the embeddings (include all batches and dimensions)
classifier_input = encoder_output[:, 0, :]

# Classifier which has number of classes as output
classifier = torch.nn.Linear(in_features=EMBED_DIMS,
                             out_features=len(class_labels))

In [12]:
# Output of ViT
vit_out = classifier(classifier_input)
print(f"Model output shape: {vit_out.shape} -> (batch_dim, num_classes)")

Model output shape: torch.Size([32, 10]) -> (batch_dim, num_classes)


In [13]:
# Calculating probabilities of each class for the first batch
softmax = torch.nn.Softmax(dim=1)
prob_out = softmax(vit_out)
prob_out

tensor([[0.0592, 0.1015, 0.0424, 0.0327, 0.0894, 0.1513, 0.0891, 0.0377, 0.1189,
         0.2778],
        [0.0601, 0.1020, 0.0424, 0.0324, 0.0903, 0.1511, 0.0880, 0.0378, 0.1146,
         0.2813],
        [0.0593, 0.1019, 0.0430, 0.0324, 0.0898, 0.1496, 0.0887, 0.0369, 0.1181,
         0.2803],
        [0.0559, 0.1046, 0.0444, 0.0341, 0.0894, 0.1437, 0.0910, 0.0410, 0.1225,
         0.2733],
        [0.0574, 0.1012, 0.0427, 0.0325, 0.0902, 0.1514, 0.0887, 0.0367, 0.1166,
         0.2825],
        [0.0575, 0.1026, 0.0420, 0.0337, 0.0900, 0.1501, 0.0891, 0.0358, 0.1181,
         0.2812],
        [0.0596, 0.1006, 0.0423, 0.0326, 0.0900, 0.1536, 0.0891, 0.0365, 0.1141,
         0.2816],
        [0.0578, 0.1030, 0.0424, 0.0332, 0.0907, 0.1479, 0.0904, 0.0388, 0.1209,
         0.2749],
        [0.0592, 0.1018, 0.0420, 0.0329, 0.0902, 0.1510, 0.0883, 0.0372, 0.1191,
         0.2784],
        [0.0601, 0.1005, 0.0420, 0.0317, 0.0896, 0.1532, 0.0870, 0.0359, 0.1169,
         0.2831],
        [0

In [14]:
# Predicted classes from the un-trained model
pred_class_idx = torch.argmax(prob_out, dim=1)
pred_class_idx

tensor([9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9,
        9, 9, 9, 9, 9, 9, 9, 9])

In [15]:
# Actual classes idx
batch_y

tensor([9, 1, 7, 2, 4, 1, 3, 8, 8, 7, 5, 7, 6, 2, 2, 2, 0, 1, 1, 2, 7, 1, 3, 1,
        7, 9, 5, 1, 6, 9, 1, 1])