<a href="https://colab.research.google.com/github/ankitgoelcmu/DeepLearning/blob/main/Transformer_for_Image_Recognition_VIT_Paper_Implementation_Exercise.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# In this Notebook - I am replicating the VIT Paper - [AN IMAGE IS WORTH 16X16 WORDS: TRANSFORMERS FOR IMAGE RECOGNITION AT SCALE](https://arxiv.org/pdf/2010.11929)

In this notebook - I will be building visiion tranformer model as layed out by the paper.

# In this notebook -  I am implenting 2 iterations of this ViT Model

## First Iteration - I am planning to use Pytorch OOB function `TransformerEncoderLayer` https://docs.pytorch.org/docs/stable/generated/torch.nn.TransformerEncoderLayer.html

## Second Iteration -I will implemnent Multi Head Block and MLP Block. Represented by Equation 2 and 3 of the Paper. And then use those 2 building blocks to create Transformer Layer.

### Get various imports and helper functions

In [None]:
# For this notebook to run with updated APIs, we need torch 1.12+ and torchvision 0.13+
try:
    import torch
    import torchvision
    assert int(torch.__version__.split(".")[1]) >= 12, "torch version should be 1.12+"
    assert int(torchvision.__version__.split(".")[1]) >= 13, "torchvision version should be 0.13+"
    print(f"torch version: {torch.__version__}")
    print(f"torchvision version: {torchvision.__version__}")
except:
    print(f"[INFO] torch/torchvision versions not as required, installing nightly versions.")
    !pip3 install -U --pre torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/nightly/cu113
    import torch
    import torchvision
    print(f"torch version: {torch.__version__}")
    print(f"torchvision version: {torchvision.__version__}")


In [None]:
# Continue with regular imports
import matplotlib.pyplot as plt
import torch
import torchvision

from torch import nn
from torchvision import transforms

# Try to get torchinfo, install it if it doesn't work
try:
    from torchinfo import summary
except:
    print("[INFO] Couldn't find torchinfo... installing it.")
    !pip install -q torchinfo
    from torchinfo import summary

# Try to import the going_modular directory, download it from GitHub if it doesn't work
try:
    from going_modular.going_modular import data_setup, engine
    from helper_functions import download_data, set_seeds, plot_loss_curves
except:
    # Get the going_modular scripts
    print("[INFO] Couldn't find going_modular or helper_functions scripts... downloading them from GitHub.")
    !git clone https://github.com/mrdbourke/pytorch-deep-learning
    !mv pytorch-deep-learning/going_modular .
    !mv pytorch-deep-learning/helper_functions.py . # get the helper_functions.py script
    !rm -rf pytorch-deep-learning
    from going_modular.going_modular import data_setup, engine
    from helper_functions import download_data, set_seeds, plot_loss_curves

# Check the `device`

In [None]:
device = "cuda" if torch.cuda.is_available() else "cpu"
device

# Get data

Want to download the data we've been using in PyTorch Paper Replicating: https://www.learnpytorch.io/08_pytorch_paper_replicating/#1-get-data

In [None]:
# Download pizza, steak, sushi images from GitHub
image_path = download_data(source="https://github.com/mrdbourke/pytorch-deep-learning/raw/main/data/pizza_steak_sushi.zip",
                           destination="pizza_steak_sushi")
image_path

In [None]:
# Setup directory paths to train and test images
train_dir = image_path / "train"
test_dir = image_path / "test"

### Preprocess data

Turn images into tensors. Table 3 of the paper mentions image resolution to be 224 - `Training resolution is 224.`

In [None]:
# Create image size (from Table 3 in the ViT paper )
IMG_SIZE = 224

# Create transform pipeline manually
manual_transforms = transforms.Compose([
    transforms.Resize((IMG_SIZE, IMG_SIZE)),
    transforms.ToTensor(),
])
print(f"Manually created transforms: {manual_transforms}")

In [None]:
# Set the batch size
# Note ViT paper mentions (under Table 3) = All models are trained with a batch size of 4096.
# But for our model, we will start small and use batch size = 32
BATCH_SIZE = 32 # this is lower than the ViT paper but it's because we're starting small

# Create Data Loaders. The DataLoader provides features like automatic batching, shuffling, and multiprocessing for efficient data loading.
train_dataloader, test_dataloader, class_names = data_setup.create_dataloaders(
    train_dir=train_dir,
    test_dir=test_dir,
    transform=manual_transforms, # use manually created transforms
    batch_size=BATCH_SIZE
)

train_dataloader, test_dataloader, class_names

In [None]:
# Get a batch of images
image_batch, label_batch = next(iter(train_dataloader))

# Get a single image from the batch
image, label = image_batch[0], label_batch[0]

# View the batch shapes
image.shape, label

In [None]:
# Plot image with matplotlib
plt.imshow(image.permute(1, 2, 0)) # rearrange image dimensions to suit matplotlib [color_channels, height, width] -> [height, width, color_channels]
plt.title(class_names[label])
plt.axis(False);

## 1. For this ViT architecture I will use in-built [PyTorch transformer layers](https://pytorch.org/docs/stable/nn.html#transformer-layers).

* You'll want to look into replacing our `TransformerEncoderBlock()` class with [`torch.nn.TransformerEncoderLayer()`](https://pytorch.org/docs/stable/generated/torch.nn.TransformerEncoderLayer.html#torch.nn.TransformerEncoderLayer) (these contain the same layers as our custom blocks).
* You can stack `torch.nn.TransformerEncoderLayer()`'s on top of each other with [`torch.nn.TransformerEncoder()`](https://pytorch.org/docs/stable/generated/torch.nn.TransformerEncoder.html#torch.nn.TransformerEncoder).

In [None]:
# TODO: Ankit
class ViT (nn.Module):
  def __init__ (self,
                img_size: int=224,
                patch_size: int=16,
                in_channels: int=3,
                num_classes: int=1000,
                mlp_size: int=3072,
                num_heads: int=12,
                num_encooder_layer: int=12,
                embedding_dim: int=768,
                attn_dropout: float=0,
                mlp_dropput: float=0.1,
                embedding_dropout: float=0.1,
                num_out_class: int=1000):
    super().__init__()

    assert img_size % patch_size == 0, 'Image dimensions must be divisible by the patch size.'
    self.num_patches = int((img_size // patch_size) ** 2)
    self.patch_size = patch_size

    self.embedding_class = nn.Parameter(torch.zeros(1, 1, embedding_dim))
    self.embedding_position = nn.Parameter(torch.zeros(1, 1 + self.num_patches, embedding_dim))
    self.embedding_dropout = nn.Dropout(embedding_dropout)
    #Using Convolution to create Patch Embeddings
    self.patch = nn.Conv2d(in_channels=in_channels,
                                    out_channels=embedding_dim,
                                    kernel_size=patch_size,
                                    stride=patch_size)
    #flattening the patches
    self.flatten_patches = nn.Flatten (start_dim=2, end_dim=3)
    self.encoder_layer = nn.TransformerEncoderLayer(d_model=embedding_dim,
                                                    nhead=num_heads,
                                                    dim_feedforward=mlp_size,
                                                    dropout=attn_dropout,
                                                    activation='gelu',
                                                    batch_first=True,
                                                    norm_first=True)  #batch, seq, feature. --> image batch, patch seq, embidding features
    self.encoder = nn.TransformerEncoder(encoder_layer=self.encoder_layer,
                                         num_layers=num_encooder_layer)

    self.classifier = nn.Sequential(nn.LayerNorm(normalized_shape=embedding_dim),
                                    nn.Linear(in_features=embedding_dim, out_features=num_out_class)
                                    )



  def forward(self, x):
    batch_size = x.shape[0]
    x = self.patch (x)
    x = self.flatten_patches(x)
    x = x.permute(0, 2, 1)
    class_embedding = self.embedding_class.expand(batch_size, -1, -1)
    x = torch.cat([class_embedding, x], dim=1)
    x = self.embedding_position + x
    x = self.embedding_dropout(x)
    x = self.encoder(x)
    x = self.classifier(x[:, 0])
    return x






In [None]:
random_image_tensor = torch.randn(1, 3, 224, 224) # (batch_size, color_channels, height, width)
vit = ViT(num_out_class=3)
vit(random_image_tensor)

In [None]:
from torchinfo import summary

# # Print a summary of our custom ViT model using torchinfo (uncomment for actual output)
summary(model=vit,
        input_size=(32, 3, 224, 224), # (batch_size, color_channels, height, width)
        # col_names=["input_size"], # uncomment for smaller output
        col_names=["input_size", "output_size", "num_params", "trainable"],
        col_width=20,
        row_settings=["var_names"]
)

In [None]:
# Create a random tensor with same shape as a single image
random_image_tensor = torch.randn(1, 3, 224, 224) # (batch_size, color_channels, height, width)
embedding_class = nn.Parameter(torch.zeros(1, 1, 768))
embedding_class.shape
patch_features_oo = nn.Conv2d(in_channels=3,
                                    out_channels=768,
                                    kernel_size=16,
                                    stride=16)


In [None]:
from going_modular.going_modular import engine

optimizer = torch.optim.Adam(params=vit.parameters(), lr=0.001)
loss_fn = nn.CrossEntropyLoss()

results = engine.train(model=vit,
                       train_dataloader=train_dataloader,
                       test_dataloader=test_dataloader,
                       optimizer=optimizer,
                       loss_fn=loss_fn,
                       epochs=10,
                       device=device)

In [None]:
from helper_functions import plot_loss_curves

plot_loss_curves(results)

### So far we have used Pytorch OOB Layer `TransformerEncoderLayer` to implement Encoder layer of the ViT transformer.


# In this step we will implement Encode Layer with Multi Head Block and MLP Block.

### `Multi Head Block `= `Layer Norm(Embedded Patches) --> Multi Head Attention + Embedded Patches (Residual Connection) + Dropout `

#### **Residual Connection** - Residual connections help stabilize neural network training by providing a direct pathway for information and gradient flow during the forward pass and backpropagation.

In [None]:
from torch import nn

# Implementation of MHA block
class MultiHeadAttention(nn.Module):
  def __init__(self,
               embedding_dim: int=768,
               num_heads: int=12,

               dropout: float=0.1): #assuming no dropout in MHA block as ViT paper doesn't mention it
      super().__init__()

      #Calculate number of patches (height * width/patch^2)
      self.embedding_dim = embedding_dim
      self.num_heads = num_heads

      #Normalization layer - Layer Normalization normalizes activations across feature dimensions within a layer,
      #ensuring that no single feature’s scale dominates the others, which stabilizes training and improves gradient flow.
      self.norm = nn.LayerNorm(embedding_dim)
      self.MultiHeadAttention = nn.MultiheadAttention(embed_dim=embedding_dim,
                                                      num_heads=num_heads,
                                                      dropout=dropout,
                                                      batch_first=True) #If True, then the input and output tensors are provided as (batch, seq, feature) --> batch_size, embdded_patch_feature, embedded_feature_dimension
      self.dropout = nn.Dropout(dropout)


  def forward(self, x):
    residual_input = x
    x = self.norm(x)
    x, attn_output_weights = self.MultiHeadAttention(x, x, x)
    x = self.dropout(x)
    return x + residual_input

In [None]:
import torch
from torch import nn

patched_position_embedded_tensor = torch.randn(1, 197, 768) # (batch_size, patch_seq_length, embedding_dim)

MHA_instance = MultiHeadAttention()
output_after_MHA_block = MHA_instance(patched_position_embedded_tensor)
print(f"Input shape of MHA block: {patched_position_embedded_tensor.shape}")
print(f"Output shape MHA block: {output_after_MHA_block.shape}")

# In this step we will implement MLP Block. Paper Quote: "The MLP contains two layers with a GELU non-linearity in between." Dropout is implied/standard (0.1 rate, as in the full transformer spec).
##` MLP = Layer Norm Linear -> Linear (expand) → GELU activation → Dropout → Linear (contract) → Dropout + Output from Multi Head Block (residual connection)`

In [None]:
from torch import nn

# Implementation of MLP block

class MLP (nn.Module):
  def __init__(self,
               embedding_dim: int=768,
               dropout: float=0.1,
               mlp_size: int=3072):
      super().__init__()
      #Normalization layer - Layer Normalization normalizes activations across feature dimensions within a layer,
      #ensuring that no single feature’s scale dominates the others, which stabilizes training and improves gradient flow.
      self.norm = nn.LayerNorm(embedding_dim)
      self.mlp_layer = nn.Sequential(nn.Linear(embedding_dim, mlp_size),
                                     nn.GELU(),
                                     nn.Dropout(dropout),
                                     nn.Linear(mlp_size, embedding_dim))
      self.dropout = nn.Dropout(dropout)

  def forward(self, x):
    residual_input = x
    x = self.norm(x)
    x = self.mlp_layer(x)
    x = self.dropout(x)
    return x + residual_input


In [None]:
##Testing

import torch
from torch import nn

output_after_MHA_block = torch.randn(1, 197, 768) # (batch_size, patch_seq_length, embedding_dim)

MLP_instance = MLP()
output_after_MLP_block = MHA_instance(output_after_MHA_block)
print(f"Input shape of MLP block: {output_after_MHA_block.shape}")
print(f"Output shape MLPblock: {output_after_MLP_block.shape}")

## In this step we will create Encoder Layer combinind our above two blocks `MultiHeadAttention` and `MLP`

In [None]:
class EncoderLayer(nn.Module):
    def __init__(self,
                 embedding_dim: int=768,
                 num_heads: int=12,
                 mlp_size: int=3072,
                 dropout: float=0.1):
        super().__init__()
        self.mha_block = MultiHeadAttention(embedding_dim, num_heads, dropout)
        self.mlp_block = MLP(embedding_dim, dropout,mlp_size)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        # Sequential: Attention then MLP
        x = self.mha_block(x)
        x = self.mlp_block(x)
        return x

In [None]:
##Testing

import torch
from torch import nn

patched_position_embedded_tensor = torch.randn(1, 197, 768) # (batch_size, patch_seq_length, embedding_dim)

encoder = EncoderLayer()
output_after_encoder_layer  = encoder(patched_position_embedded_tensor)
print(f"Input shape to Encoder Layer: {patched_position_embedded_tensor.shape}")
print(f"Output shape after Encoder Layer: {output_after_encoder_layer.shape}")

# Now Lets re-implement Vit with our custom made Encoder Layer

In [None]:
# TODO: Ankit
class ViT_01 (nn.Module):
  def __init__ (self,
                img_size: int=224,
                patch_size: int=16,
                in_channels: int=3,
                mlp_size: int=3072,
                num_heads: int=12,
                num_encooder_layers: int=12,
                embedding_dim: int=768,
                attn_dropout: float=0.1,
                mlp_dropout: float=0.1,
                embedding_dropout: float=0.1,
                num_out_class: int=1000):
    super().__init__()

    assert img_size % patch_size == 0, 'Image dimensions must be divisible by the patch size.'
    self.num_patches = int((img_size // patch_size) ** 2)
    self.patch_size = patch_size

    self.embedding_class = nn.Parameter(torch.zeros(1, 1, embedding_dim))
    self.embedding_position = nn.Parameter(torch.zeros(1, 1 + self.num_patches, embedding_dim))
    self.embedding_dropout = nn.Dropout(embedding_dropout)
    self.patch = nn.Conv2d(in_channels=in_channels,
                                    out_channels=embedding_dim,
                                    kernel_size=patch_size,
                                    stride=patch_size)
    self.flatten_patches = nn.Flatten(start_dim=2, end_dim=3)

    # Custom encoder stack (ModuleList for explicit layering)
    self.encoder_layers = nn.ModuleList([
            EncoderLayer(
                embedding_dim=embedding_dim,
                num_heads=num_heads,
                mlp_size=mlp_size
            )
            for _ in range(num_encooder_layers)
        ])

    self.classifier = nn.Sequential(nn.LayerNorm(normalized_shape=embedding_dim),
                                    nn.Linear(in_features=embedding_dim, out_features=num_out_class)
                                    )



  def forward(self, x):
    batch_size = x.shape[0]
    x = self.patch(x)
    x = self.flatten_patches(x)
    x = x.permute(0, 2, 1)
    class_embedding = self.embedding_class.expand(batch_size, -1, -1)
    x = torch.cat([class_embedding, x], dim=1)
    x = self.embedding_position + x
    x = self.embedding_dropout(x)
    # Encoder stack (FIX: Loop over layers)
    for encoder_layer in self.encoder_layers:
            x = encoder_layer(x)
    x = self.classifier(x[:, 0])
    return x






In [None]:
random_image_tensor = torch.randn(1, 3, 224, 224) # (batch_size, color_channels, height, width)
vit = ViT_01(num_out_class=3)
vit(random_image_tensor)

In [None]:
# Try to get torchinfo, install it if it doesn't work
!pip install -q torchinfo
from torchinfo import summary

# # Print a summary of our custom ViT model using torchinfo (uncomment for actual output)
summary(model=vit,
        input_size=(32, 3, 224, 224), # (batch_size, color_channels, height, width)
        # col_names=["input_size"], # uncomment for smaller output
        col_names=["input_size", "output_size", "num_params", "trainable"],
        col_width=20,
        row_settings=["var_names"]
)

In [None]:
from going_modular.going_modular import engine

optimizer = torch.optim.Adam(params=vit.parameters(), lr=0.001)
loss_fn = nn.CrossEntropyLoss()


results = engine.train(model=vit,
                       train_dataloader=train_dataloader,
                       test_dataloader=test_dataloader,
                       optimizer=optimizer,
                       loss_fn=loss_fn,
                       epochs=10,
                       device=device)

In [None]:
from helper_functions import plot_loss_curves

plot_loss_curves(results)