In [None]:
import os
import sys

import torch

from pathlib import Path

super_directory = os.path.abspath('..')
sys.path.append(super_directory)

from data_setup import get_dataloaders
from vit import DataEmbeddings

In [13]:
# Device agnostic code
device = 'cuda' if torch.cuda.is_available() else 'cpu'

In [14]:
# Hyperparameters
BATCH_SIZE = 32

# Patches
PATCH_SIZE = (16, 16)
NUM_PATCHES = int((224 / 16) ** 2)

# Patches to Embeddings
EMBED_DIMS = 768

# Number of Attention heads
NUM_ATTENTION_HEADS = 4

# Number of hidden layers in MLP block
RATIO_HIDDEN_MLP = 4

# Data

In [15]:
# Data paths
data_path = Path('../data/desserts')
train_path = data_path / 'train'
test_path = data_path / 'test'

In [16]:
# Getting the dataloaders
train_dataloader, test_dataloader, class_labels = get_dataloaders(train_path=train_path,
                                                                  test_path=test_path,
                                                                  batch_size=BATCH_SIZE)

In [17]:
# Get X and y from the first batch
batch_X, batch_y = next(iter(train_dataloader))

print(f"Batched images shape: {batch_X.shape} -> (batch_dim, color_channels, image_height, image_width)")

Batched images shape: torch.Size([32, 3, 224, 224]) -> (batch_dim, color_channels, image_height, image_width)


# Encoder Input

In [18]:
# Module to process data embeddings
data_embed_module = DataEmbeddings(in_channels=3,
                                   patch_size=PATCH_SIZE,
                                   num_patches=NUM_PATCHES,
                                   embed_dims=EMBED_DIMS)

In [19]:
# Data embeddings and its shape
data_embeddings = data_embed_module(batch_X)
print(f"Data embeddings shape: {data_embeddings.shape} -> (batch_dim, num_patches + class_embedding, embedding_dims)")

Data embeddings shape: torch.Size([32, 197, 768]) -> (batch_dim, num_patches + class_embedding, embedding_dims)


# Encoder Blocks

In [20]:
# Encoder blocks
layer_norm = torch.nn.LayerNorm(normalized_shape=EMBED_DIMS)

multi_head_attention = torch.nn.MultiheadAttention(embed_dim=EMBED_DIMS,
                                                   num_heads=NUM_ATTENTION_HEADS,
                                                   batch_first=True)

multi_layer_perceptron = torch.nn.Sequential(
    torch.nn.Linear(in_features=EMBED_DIMS,
                    out_features=int(EMBED_DIMS * RATIO_HIDDEN_MLP)),
    torch.nn.GELU(),
    torch.nn.Linear(in_features=int(EMBED_DIMS * RATIO_HIDDEN_MLP),
                    out_features=EMBED_DIMS))

In [21]:
# Multi-head attention block
norm_out = layer_norm(data_embeddings)
attn_output, attn_output_weights = multi_head_attention(query=norm_out, 
                                                        key=norm_out, 
                                                        value=norm_out)
residual_out = attn_output + data_embeddings

In [22]:
# Multi-layer perceptron
norm_out = layer_norm(residual_out)
mlp_out = multi_layer_perceptron(norm_out)
enc_out = mlp_out + residual_out

In [23]:
# Output of Encoder
print(f"Encoder output shape: {enc_out.shape} -> (batch_dim, num_patches + class_embedding, embedding_dims)")

Encoder output shape: torch.Size([32, 197, 768]) -> (batch_dim, num_patches + class_embedding, embedding_dims)
