In [14]:
import os
import sys

import torch
from torchinfo import summary

from pathlib import Path

super_directory = os.path.abspath('..')
sys.path.append(super_directory)

from data_setup import get_dataloaders
from vit import ViT

In [15]:
# Device agnostic code
device = 'cuda' if torch.cuda.is_available() else 'cpu'

In [16]:
# Hyperparameters from the paper (except batch size - computational bottleneck)

BATCH_SIZE = 32

# Patches
PATCH_SIZE = (16, 16)
NUM_PATCHES = int((224 / 16) ** 2)

# Patches to Embeddings
EMBED_DIMS = 768

# Number of Attention heads
NUM_ATTENTION_HEADS = 4

# Number of hidden layers in MLP block
RATIO_HIDDEN_MLP = 4

# Number of encoder blocks
NUM_ENC_BLOCKS = 12

# Data

In [17]:
# Data paths
data_path = Path('../data/desserts')
train_path = data_path / 'train'
test_path = data_path / 'test'

In [18]:
# Getting the dataloaders
train_dataloader, test_dataloader, class_labels = get_dataloaders(train_path=train_path,
                                                                  test_path=test_path,
                                                                  batch_size=BATCH_SIZE)

In [19]:
# Get X and y from the first batch
batch_X, batch_y = next(iter(train_dataloader))

print(f"Batched images shape: {batch_X.shape} -> (batch_dim, color_channels, image_height, image_width)")

Batched images shape: torch.Size([32, 3, 224, 224]) -> (batch_dim, color_channels, image_height, image_width)


# ViT

In [20]:
# Initializing the custom created ViT model
model = ViT(in_channels=3,
            out_dims=len(class_labels),
            patch_size=PATCH_SIZE,
            num_patches=NUM_PATCHES,
            embed_dims=EMBED_DIMS,
            num_attn_heads=NUM_ATTENTION_HEADS,
            ratio_hidden_mlp=RATIO_HIDDEN_MLP,
            num_encoder_blocks=NUM_ENC_BLOCKS)

In [21]:
# Model summary
summary(model,
        input_size=(1, 3, 224, 224),                                                # Batch dim, color channels, height, width
        col_names=['input_size', 'output_size', 'num_params', 'trainable'],
        col_width=20,
        row_settings=['var_names'])

Layer (type (var_name))                                 Input Shape          Output Shape         Param #              Trainable
ViT (ViT)                                               [1, 3, 224, 224]     [1, 5]               --                   True
├─DataEmbeddings (data_embeddings)                      [1, 3, 224, 224]     [1, 197, 768]        152,064              True
│    └─Conv2d (conv_layer)                              [1, 3, 224, 224]     [1, 768, 14, 14]     590,592              True
│    └─Flatten (flatten)                                [1, 768, 14, 14]     [1, 768, 196]        --                   --
├─Sequential (encoder_blocks)                           [1, 197, 768]        [1, 197, 768]        --                   True
│    └─EncoderBlock (0)                                 [1, 197, 768]        [1, 197, 768]        --                   True
│    │    └─LayerNorm (layer_norm)                      [1, 197, 768]        [1, 197, 768]        1,536                True
│    

In [22]:
# Output of ViT
model = model.to(device)
batch_X = batch_X.to(device)

vit_out = model(batch_X)
print(f"Model output shape: {vit_out.shape} -> (batch_dim, num_classes)")

Model output shape: torch.Size([32, 5]) -> (batch_dim, num_classes)


In [23]:
# Calculating probabilities of each class for the first batch
softmax = torch.nn.Softmax(dim=1)
prob_out = softmax(vit_out)
prob_out

tensor([[0.1518, 0.0820, 0.4683, 0.1372, 0.1607],
        [0.1531, 0.0779, 0.4715, 0.1337, 0.1637],
        [0.1568, 0.0798, 0.4683, 0.1334, 0.1617],
        [0.1559, 0.0843, 0.4532, 0.1444, 0.1622],
        [0.1586, 0.0842, 0.4471, 0.1477, 0.1623],
        [0.1421, 0.0826, 0.4810, 0.1357, 0.1586],
        [0.1529, 0.0863, 0.4492, 0.1514, 0.1601],
        [0.1488, 0.0906, 0.4423, 0.1579, 0.1605],
        [0.1283, 0.0906, 0.4905, 0.1365, 0.1541],
        [0.1470, 0.0805, 0.4788, 0.1332, 0.1606],
        [0.1577, 0.0816, 0.4573, 0.1410, 0.1623],
        [0.1366, 0.0889, 0.4789, 0.1395, 0.1561],
        [0.1473, 0.0866, 0.4650, 0.1430, 0.1582],
        [0.1387, 0.0852, 0.4831, 0.1366, 0.1564],
        [0.1339, 0.0931, 0.4711, 0.1456, 0.1563],
        [0.1430, 0.0861, 0.4633, 0.1488, 0.1587],
        [0.1509, 0.0795, 0.4782, 0.1300, 0.1613],
        [0.1469, 0.0831, 0.4740, 0.1367, 0.1594],
        [0.1653, 0.0798, 0.4536, 0.1363, 0.1649],
        [0.1570, 0.0837, 0.4572, 0.1380, 0.1641],


In [24]:
# Predicted classes from the un-trained model
pred_class_idx = torch.argmax(prob_out, dim=1)
pred_class_idx

tensor([2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
        2, 2, 2, 2, 2, 2, 2, 2], device='cuda:0')

In [25]:
# Actual classes idx
batch_y

tensor([3, 2, 3, 0, 4, 0, 4, 4, 0, 2, 0, 1, 1, 1, 3, 1, 2, 1, 3, 3, 4, 0, 4, 1,
        2, 0, 1, 0, 4, 4, 1, 4])