In [92]:
import os
import sys

import torch

from pathlib import Path

super_directory = os.path.abspath('..')
sys.path.append(super_directory)

from data_setup import get_dataloaders
from vit import DataEmbeddings, EncoderBlock

In [93]:
# Device agnostic code
device = 'cuda' if torch.cuda.is_available() else 'cpu'

In [94]:
# Hyperparameters
BATCH_SIZE = 32

# Patches
PATCH_SIZE = (16, 16)
NUM_PATCHES = int((224 / 16) ** 2)

# Patches to Embeddings
EMBED_DIMS = 768

# Number of Attention heads
NUM_ATTENTION_HEADS = 4

# Number of hidden layers in MLP block
RATIO_HIDDEN_MLP = 4

# Data

In [95]:
# Data paths
data_path = Path('../data/desserts')
train_path = data_path / 'train'
test_path = data_path / 'test'

In [96]:
# Getting the dataloaders
train_dataloader, test_dataloader, class_labels = get_dataloaders(train_path=train_path,
                                                                  test_path=test_path,
                                                                  batch_size=BATCH_SIZE)

In [97]:
# Get X and y from the first batch
batch_X, batch_y = next(iter(train_dataloader))

print(f"Batched images shape: {batch_X.shape} -> (batch_dim, color_channels, image_height, image_width)")

Batched images shape: torch.Size([32, 3, 224, 224]) -> (batch_dim, color_channels, image_height, image_width)


# Encoder Input

In [98]:
# Module to process data embeddings
data_embed_module = DataEmbeddings(in_channels=3,
                                   patch_size=PATCH_SIZE,
                                   num_patches=NUM_PATCHES,
                                   embed_dims=EMBED_DIMS)

In [99]:
# Data embeddings and its shape
data_embeddings = data_embed_module(batch_X)
print(f"Data embeddings shape: {data_embeddings.shape} -> (batch_dim, num_patches + class_embedding, embedding_dims)")

Data embeddings shape: torch.Size([32, 197, 768]) -> (batch_dim, num_patches + class_embedding, embedding_dims)


# Encoder

In [100]:
# Transformer encoder module
encoder_block = EncoderBlock(embed_dims=EMBED_DIMS,
                             num_attn_heads=NUM_ATTENTION_HEADS,
                             ratio_hidden_mlp=RATIO_HIDDEN_MLP,
                             batch_first=True)

In [101]:
# Transformer encoder output
encoder_output = encoder_block(data_embeddings)
print(f"Encoder output shape: {encoder_output.shape} -> (batch_dim, num_patches + class_embedding, embedding_dims)")

Encoder output shape: torch.Size([32, 197, 768]) -> (batch_dim, num_patches + class_embedding, embedding_dims)


# Classifier

In [102]:
# Selecting the learnable embedding from the embeddings (include all batches and dimensions)
classifier_input = encoder_output[:, 0, :]

# Classifier which has number of classes as output
classifier = torch.nn.Linear(in_features=EMBED_DIMS,
                             out_features=len(class_labels))

In [103]:
# Output of ViT
vit_out = classifier(classifier_input)
print(f"Model output shape: {vit_out.shape} -> (batch_dim, num_classes)")

Model output shape: torch.Size([32, 5]) -> (batch_dim, num_classes)


In [104]:
# Calculating probabilities of each class for the first batch
softmax = torch.nn.Softmax(dim=1)
prob_out = softmax(vit_out)
prob_out

tensor([[0.2159, 0.3264, 0.0294, 0.1021, 0.3262],
        [0.2192, 0.3253, 0.0287, 0.1010, 0.3257],
        [0.2189, 0.3253, 0.0286, 0.1011, 0.3262],
        [0.2295, 0.3234, 0.0277, 0.0982, 0.3211],
        [0.2206, 0.3250, 0.0286, 0.1002, 0.3256],
        [0.2231, 0.3260, 0.0284, 0.0995, 0.3230],
        [0.2176, 0.3256, 0.0289, 0.1014, 0.3264],
        [0.2134, 0.3275, 0.0297, 0.1021, 0.3274],
        [0.2163, 0.3265, 0.0289, 0.1015, 0.3268],
        [0.2137, 0.3264, 0.0295, 0.1022, 0.3283],
        [0.2222, 0.3260, 0.0280, 0.0998, 0.3241],
        [0.2143, 0.3262, 0.0291, 0.1024, 0.3280],
        [0.2201, 0.3248, 0.0283, 0.1008, 0.3261],
        [0.2120, 0.3274, 0.0297, 0.1024, 0.3285],
        [0.2122, 0.3267, 0.0297, 0.1027, 0.3287],
        [0.2217, 0.3255, 0.0279, 0.0999, 0.3250],
        [0.2094, 0.3270, 0.0299, 0.1035, 0.3302],
        [0.2219, 0.3240, 0.0283, 0.1000, 0.3258],
        [0.2256, 0.3239, 0.0278, 0.0994, 0.3232],
        [0.2251, 0.3246, 0.0281, 0.0990, 0.3232],


In [105]:
# Predicted classes from the un-trained model
pred_class_idx = torch.argmax(prob_out, dim=1)
pred_class_idx

tensor([1, 4, 4, 1, 4, 1, 4, 1, 4, 4, 1, 4, 4, 4, 4, 1, 4, 4, 1, 1, 4, 1, 1, 1,
        1, 1, 4, 1, 1, 1, 4, 1])

In [106]:
# Actual classes idx
batch_y

tensor([0, 3, 0, 0, 0, 1, 0, 1, 1, 0, 4, 3, 2, 4, 3, 2, 1, 3, 1, 3, 3, 4, 1, 1,
        2, 0, 1, 2, 0, 1, 0, 2])