In [4]:
import sys
import os, os.path as osp
import numpy as np
import torch
import torch.nn as nn
import torchvision
from torchvision import datasets, transforms
import matplotlib.pyplot as plt
import pandas as pd

## Globals

In [2]:
root_dir = "~/prep/data/"


## Data Loader

In [5]:
# Define transformations (resize to 32x32 as ViT generally takes larger inputs)
transform = transforms.Compose([
    transforms.Resize((32, 32)),  # Resize to 32x32 for ViT input
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))  # Normalize MNIST data: The MNIST dataset has pixel values in the range [0, 1]
])

## The data will now be in the range [-1, 1]

train_dataset = datasets.MNIST(root=root_dir, train=True, download=True, transform=transform)
test_dataset = datasets.MNIST(root=root_dir, train=False, download=True, transform=transform)

train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=64, shuffle=True)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=64, shuffle=False) # No shuffle as test set. 



Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Failed to download (trying next):
HTTP Error 403: Forbidden

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz to /home/ayush.sharma/prep/data/MNIST/raw/train-images-idx3-ubyte.gz


100.0%


Extracting /home/ayush.sharma/prep/data/MNIST/raw/train-images-idx3-ubyte.gz to /home/ayush.sharma/prep/data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Failed to download (trying next):
HTTP Error 403: Forbidden

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-labels-idx1-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-labels-idx1-ubyte.gz to /home/ayush.sharma/prep/data/MNIST/raw/train-labels-idx1-ubyte.gz


100.0%


Extracting /home/ayush.sharma/prep/data/MNIST/raw/train-labels-idx1-ubyte.gz to /home/ayush.sharma/prep/data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Failed to download (trying next):
HTTP Error 403: Forbidden

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-images-idx3-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-images-idx3-ubyte.gz to /home/ayush.sharma/prep/data/MNIST/raw/t10k-images-idx3-ubyte.gz


100.0%


Extracting /home/ayush.sharma/prep/data/MNIST/raw/t10k-images-idx3-ubyte.gz to /home/ayush.sharma/prep/data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Failed to download (trying next):
HTTP Error 403: Forbidden

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-labels-idx1-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-labels-idx1-ubyte.gz to /home/ayush.sharma/prep/data/MNIST/raw/t10k-labels-idx1-ubyte.gz


100.0%

Extracting /home/ayush.sharma/prep/data/MNIST/raw/t10k-labels-idx1-ubyte.gz to /home/ayush.sharma/prep/data/MNIST/raw






### Sanity check

In [6]:
dataiter = iter(train_loader)
images, labels = dataiter.next()
images.shape, labels.shape

(torch.Size([64, 1, 32, 32]), torch.Size([64]))

## Config

In [60]:
from dataclasses import dataclass

@dataclass
class Config():
    num_epochs: int = 10
    batch_size: int = 64
    channels: int = 1
    image_size: int = 32
    patch_size: int = 16 # should divide image size
    num_patches: int = (image_size // patch_size) ** 2
    embed_dim: int = patch_size * patch_size
    num_transformer_blocks: int = 6 # Paper uses 12 blocks. But since MNIST is small, I use 6
    num_heads: int = 4
    num_classes: int=10 # 10 classes in MNIST
    device: str = "cuda" 
   
    
    
    


## Auxiliary function

In [54]:
def patchify_image(data: torch.Tensor, patch_size: int) -> torch.Tensor:
    batch_size, channels, H, W = data.shape
    assert H % patch_size == 0 and W % patch_size == 0
    patches = data.unfold(dimension=2, size=patch_size, step=patch_size).unfold(dimension=3, size=patch_size, step=patch_size)
    
    patches = patches.contiguous().view(batch_size, -1, patch_size * patch_size * channels)
    # Patches has shape [num_batch, num_patch, patch_size * patch_size]. Note seq_len is also called num_patch 
    return patches
    
# patches = patchify_image(images, Config.patch_size)
# patches.shape
    
    


## Model architecture

### Multi head attention block

<img src="../assets/mha.png" alt="Multi head attention"> <br>


In [63]:
class SelfAttentionBlock(nn.Module):
    
    def __init__(self, num_heads: int, input_dim: int, embed_dim: int):
        super(SelfAttentionBlock, self).__init__()
        self.num_heads = num_heads
        self.input_dim = input_dim
        self.embed_dim = embed_dim
        
        # Linear projections for queries, keys, and values
        self.queries = nn.Linear(input_dim, embed_dim)
        self.keys = nn.Linear(input_dim, embed_dim)
        self.values = nn.Linear(input_dim, embed_dim)
        
        # Now we need to combine the output. A layer for that
        self.out = nn.Linear(output_dim, output_dim)
        
    def forward(self, x: torch.tensor):
        batch_size, seq_len, input_dim = x.shape   
        assert input_dim == self.input_dim # the one provided in cofig and the one in input should match
        per_head_dim = self.embed_dim // self.num_heads
                            
                             
        # Projecting the input to find queries, keys and values
        Q = self.queries(x)
        K = self.keys(x)
        V = self.values(x)
                                 
        # split the Q, K, V across individual heads 
        # Note that we want the output to be [batch_size, num_heads, seq_len, per_head_dim]
        Q = Q.view(x.size(0), x.size(1), self.num_heads, per_head_dim).transpose(1, 2)
        K = K.view(x.size(0), x.size(1), self.num_heads, per_head_dim).transpose(1, 2)
        V = V.view(x.size(0), x.size(1), self.num_heads, per_head_dim).transpose(1, 2)
        
        scale = torch.sqrt(torch.tensor(per_head_dim, dtype=torch.float32))
        attention_scores = torch.matmul(Q, K.transpose(-2, -1)) / scale # shape: [batch_size, num_heads, seq_len, seq_len]
        attention_weights = torch.softmax(attention_scores, dim=-1)
        
        # Multiply by V to get the attention output. 
        # Note we don't use * as its element wise multiplicaton. We want matmul usch as:
        # [batch_size, num_heads, seq_len, seq_len] <dot> [batch_size, num_heads, seq_len, per_head_dim]
        attention_output = torch.matmul(attention_weights, V) # shape: [batch_size, num_heads, seq_len, per_head_dim]
        
        attention_output_og_shape = attention_output.transpose(1,2).contiguous() # transpose makes a tensor not-contiguous. So , make a new contigous copy
        attention_output_og_shape = attention_output_og_shape.view(batch_size, seq_len, self.embed_dim)
        
        # Note that we cold have also done it using reshape
        # attention_output_og_shape = attention_output.transpose(1, 2).reshape(x.size(0), x.size(1), self.embed_dim)
        
        return attention_output_og_shape

        
        


### Questions <br>


#### 1. What is sequence length here?
Answer: THe sequence length here is the number of patches of the image. The ViT authors make the analogy being each grid is a word and `[patch_size x patch_size` is the dimension of the embedding of that word.

#### 2. Why do we use transpose in this part of the code? <br>
```
Q = Q.view(x.size(0), x.size(1), self.num_heads, per_head_dim).transpose(1, 2)
K = K.view(x.size(0), x.size(1), self.num_heads, per_head_dim).transpose(1, 2)
V = V.view(x.size(0), x.size(1), self.num_heads, per_head_dim).transpose(1, 2)
```

Answer:
 - After reshaping the tensor into [batch_size, seq_len, num_heads, head_dim], we need to change the order of the dimensions so that the num_heads dimension comes before the sequence length (seq_len). This is necessary because we want to compute attention independently for each head.
 - By transposing, we rearrange the tensor into:
   `[batch_size, num_heads, seq_len, per_head_dim]`
   we can now apply attention on the last 2 dimensions which is `sequence_length`, `per_head_dim`


#### 3. (Important) Given word embeddings of dimension 128 for a sequence of 10 words (so, shape [batch_size, 10, 128]), if we linearly project these into queries of shape [batch_size, 10, 128], and then apply multi-head attention:
 - Why does multi-head attention split the word embedding into chunks?
 - **Doesn't this splitting lose the meaning of the word embedding**?
 
Answer: <br>
- Let's take an input example <br>
    Input Example: Word embedding has shape [batch_size, 10, 128], where:
    - 10 is the number of words in the sequence.
    - 128 is the word embedding dimension.

- Linear Projection:
  Each word embedding (128-dimensional) is linearly projected into queries, keys, and values before splitting into heads.
  
- Splitting Across Heads:
  After projection, multi-head attention splits the transformed representations, not the raw word embeddings.
  For 8 heads, the projected queries are split into 8 parts of size 16 (128 / 8 = 16).
  
- The linear projections allow each head to attend to a **different transformed representation of the word embedding**, enabling multiple diverse attention patterns.

- Each head focuses on a portion of the transformed embedding, but not directly splitting the original embedding.

- The core reason we don’t directly project into a [batch_size, seq_len, num_heads, head_dim] space (where num_heads × head_dim = input_dim) is **primarily for flexibility and efficient parameter sharing**



<img src="../assets/transformer_model_architecture.png" alt="ViT architecture"> <br>


In [65]:
class TransformerBlock(nn.Module):
    
    def __init__(self, num_heads: int, 
                 input_dim: int, 
                 output_dim: int):
        super(TransformerBlock, self).__init__()
        self.input_dim = input_dim
        self.output_dim = output_dim
        self.norm1 = nn.LayerNorm(input_dim) ## Normalize before attention
        self.mha = SelfAttentionBlock(num_heads, input_dim, output_dim)
        self.norm2 = nn.LayerNorm(output_dim) # layer norm after the transformer MHA block
        self.ffn = nn.Sequential(
                      nn.Linear(output_dim, output_dim * 4), 
                      nn.ReLU(),
                      nn.Linear(output_dim * 4, output_dim)
        )
        
    def forward(self, x: torch.tensor):
        batch_size, seq_len, input_dim = x.shape
        assert input_dim == self.input_dim, "The shapes of input dim specified in constructor and the data are not same"
        x = self.norm1(x)
        # The size of tensor a (64) must match the size of tensor b (256) at non-singleton dimension 2
        attn_output = self.mha(x)
        print("atoutput", attn_output.shape)
        x = attn_output + x     # Residual connection
        x = self.norm2(x)
        ffn_output = self.ffn(x)
        output = ffn_output + x # Residual connection
        return output
        
        

In [67]:
cfg = Config()
print(cfg)

Config(num_epochs=10, batch_size=64, channels=1, image_size=32, patch_size=16, num_patches=4, embed_dim=256, output_dim=64, num_transformer_blocks=6, num_heads=4, num_classes=10, device='cuda')


In [61]:
class ViT(nn.Module):
    
    def __init__(self, cfg: Config):
        super(ViT, self).__init__()
        self.cfg = cfg
        num_patches = cfg.num_patches
        embed_dim = cfg.embed_dim
        input_dim = cfg.patch_size ** 2
        
        self.input_projection = None
        if input_dim != embed_dim:
            self.input_projection = nn.Linear(input_dim, embed_dim)
        
        # parameters for positional embedding
        # we use a learable positional embed as used in ViT paper
        self.pos_embed = nn.Parameter(torch.randn(1, num_patches, embed_dim))
        self.transformers = nn.Sequential(
                                *[TransformerBlock(cfg.num_heads, cfg.embed_dim, cfg.output_dim) for _ in range(cfg.num_transformer_blocks)]
                            )
        
        self.global_avg_pool = nn.AdaptiveAvgPool1d(1)
        self.classifier = nn.Linear(cfg.output_dim, cfg.num_classes)  # For 10 MNIST classes

        
    def forward(self, x: torch.tensor):
        patches = patchify_image(x, self.cfg.patch_size)
        assert patches.shape[1] == self.cfg.num_patches, \
            f"Number of patches {patches.shape[1]} should match {self.cfg.num_patches}"
        if self.input_projection:
            patches = self.input_projection(patches)
        patches = patches + self.pos_embed # [batch_size,  seq_len, embed_dim] Note embed_dim is the input_dim
        x = patches.contiguous()
        x = self.transformers(x) # [batch_size,  seq_len, output_dim]
        x = self.global_avg_pool(x) # Apply adaptive pooling to get [batch_size, output_dim, 1]
        x = x.squeeze(-1)
        out = self.classifier(x)
        return out

## Training loop

In [66]:
import torch.optim as optim



def train(cfg: Config,
          model: nn.Module,
          train_loader: torch.utils.data.DataLoader,
          criterion, 
         ):
    num_epochs = cfg.num_epochs
    device = cfg.device
    model.to(device)
    running_loss = 0

    for idx, (images, labels) in enumerate(train_loader):
        images, labels = images.to(device), labels.to(device)
        # print(images.shape, labels.shape)
        optimizer.zero_grad()
        
        outputs = model(images)
        loss = criterion(outputs, labels)
        
        loss.backward()
        opimizer.step()
        
        running_loss += loss
        print(outputs.shape, labels.shape)
        break
        

cfg = Config()
model = ViT(cfg)
criterion = nn.CrossEntropyLoss()  
optimizer = optim.Adam(model.parameters(), lr=3e-4)
train(cfg, model, train_loader, criterion)
        

atoutput torch.Size([64, 4, 64])


RuntimeError: The size of tensor a (64) must match the size of tensor b (256) at non-singleton dimension 2