# Vision Transformer (ViT) Implementation from Scratch

This notebook provides a step-by-step implementation of the **Vision Transformer (ViT)** architecture using PyTorch, applied to the MNIST dataset. 

Unlike Convolutional Neural Networks (CNNs) that process pixels, ViT treats an image as a sequence of patches, similar to how Transformers in NLP treat sentences as sequences of words.

### Key Architecture Steps:
1.  **Patch Embedding:** Split the image into fixed-size patches and flatten them.
2.  **Position Embedding:** Add learnable position vectors so the model knows the order of patches.
3.  **Transformer Encoder:** Process the sequence using Self-Attention and MLPs.
4.  **Classification Head:** Use a special "Class Token" to predict the final digit.

## 1. Imports and Setup

In [None]:
# Import necessary libraries
import torch
import torch.nn as nn
import torchvision
from torch.utils.data import DataLoader
import torchvision.transforms as transforms

# Set device configuration (GPU if available, else CPU)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

## 2. Data Preparation
We use the **MNIST** dataset. We strictly convert the images to Tensors. No complex augmentation is used here to keep the focus on the architecture.

In [None]:
# Transformation: Convert PIL image to Tensor
transform = transforms.Compose([transforms.ToTensor()])

# Load MNIST Dataset
train_dataset = torchvision.datasets.MNIST(root="./data", train=True, download=True, transform=transform)
val_dataset = torchvision.datasets.MNIST(root="./data", train=False, download=True, transform=transform)

## 3. Hyperparameters
Here we define the structural parameters of the ViT. 

* **Patch Size (7):** We split the 28x28 image into 7x7 squares. This results in $4 \times 4 = 16$ total patches.
* **Embedding Dim (64):** Each patch will be projected into a vector of size 64.
* **Attention Heads (4):** The Multi-Head Attention mechanism will split the 64 dimensions into 4 heads of 16 dimensions each.

In [None]:
# Configuration
num_classes = 10          # MNIST has digits 0-9
batch_size = 64
num_channels = 1          # Grayscale images
img_size = 28             # MNIST image size
patch_size = 7            # Size of each patch (must be a divisor of img_size)

# Calculated parameters
num_patches = (img_size // patch_size) ** 2  # Total patches: (28/7)^2 = 16

# Model parameters
embedding_dim = 64        # Dimension of the linear projection of patches
attention_heads = 4       # Number of heads in Multi-Head Attention
transformer_blocks = 4    # Number of Transformer Encoder layers
mlp_hidden_nodes = 128    # Hidden size in the MLP feed-forward network

# Training parameters
learning_rate = 0.001
num_epochs = 5

In [None]:
# Define DataLoaders
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, drop_last=True)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, drop_last=True)

## 4. Patch Embedding Layer

**The Logic:** Instead of manually cropping the image into squares, we can use a **Convolutional Layer** (`nn.Conv2d`). 

If we set `kernel_size` and `stride` both equal to the `patch_size`, the convolution operation naturally hops over the image in non-overlapping grids, effectively creating the patches and projecting them to `embedding_dim` in one step.

In [None]:
class PatchEmbedding(nn.Module):
    def __init__(self):
        super().__init__()
        # Conv2d usage as a Patch Embedder:
        # Input: (Batch, 1, 28, 28) -> Output: (Batch, 64, 4, 4)
        self.patch_embed = nn.Conv2d(num_channels, embedding_dim, kernel_size=patch_size, stride=patch_size)

    def forward(self, x):
        # x shape: (Batch, 1, 28, 28)
        
        x = self.patch_embed(x) 
        # x shape: (Batch, 64, 4, 4)
        # 64 is the embedding_dim. 4x4 comes from 28/7 = 4.
        
        x = x.flatten(2) 
        # Flatten the spatial dimensions (4, 4) into a single sequence dimension (16)
        # x shape: (Batch, 64, 16)

        x = x.transpose(1, 2) 
        # Transformers expect sequence length first (or second if batch_first=True)
        # x shape: (Batch, 16, 64) -> (Batch, Num_Patches, Embedding_Dim)
        
        return x

## 5. Transformer Encoder Block

A standard Transformer Encoder block consists of:
1.  **Layer Normalization**
2.  **Multi-Head Self Attention (MSA)**
3.  **Residual Connection** (Add)
4.  **Layer Normalization**
5.  **Multi-Layer Perceptron (MLP)**
6.  **Residual Connection** (Add)

In [None]:
class TransformerEncoder(nn.Module):
    def __init__(self):
        super().__init__()
        # Normalization layers
        self.layer_norm1 = nn.LayerNorm(embedding_dim)
        self.layer_norm2 = nn.LayerNorm(embedding_dim)
        
        # Multi-Head Attention
        self.multihead_attention = nn.MultiheadAttention(embedding_dim, attention_heads, batch_first=True)
        
        # Feed Forward Network (MLP)
        self.mlp = nn.Sequential(
            nn.Linear(embedding_dim, mlp_hidden_nodes),
            nn.GELU(), # GELU is often used in Transformers instead of ReLU
            nn.Linear(mlp_hidden_nodes, embedding_dim)
        )

    def forward(self, x):
        # Block 1: Attention + Residual
        residual1 = x
        x = self.layer_norm1(x)
        # In self-attention, Query, Key, and Value are all the same input 'x'
        attn_output, _ = self.multihead_attention(x, x, x)
        x = attn_output + residual1

        # Block 2: MLP + Residual
        residual2 = x
        x = self.layer_norm2(x)
        x = self.mlp(x)
        x = x + residual2
        
        return x

## 6. The Classification Head (MLP Head)
This is the final layer that takes the representation of the **Class Token** and projects it to the 10 output classes (digits 0-9).

In [None]:
class MLPHead(nn.Module):
    def __init__(self):
        super().__init__()
        self.layer_norm = nn.LayerNorm(embedding_dim)
        self.mlp_head = nn.Linear(embedding_dim, num_classes)

    def forward(self, x):
        x = self.layer_norm(x)
        x = self.mlp_head(x)
        return x

## 7. Vision Transformer (Full Assembly)

**Crucial Concepts:**
1.  **CLS Token:** Since a Transformer outputs a sequence (one vector per patch), which one do we use for classification? We prepend a special learnable vector (CLS token) to the start of the sequence. We ignore the patch outputs and only use this CLS token output for prediction.
2.  **Position Embedding:** Transformers have no sense of "up/down" or "left/right". We add a learnable position vector to every patch embedding so the model can learn spatial structure.

In [None]:
class VisionTransformer(nn.Module):
    def __init__(self):
        super().__init__()
        self.patch_embedding = PatchEmbedding()
        
        # Learnable Class Token (1, 1, embedding_dim)
        self.cls_token = nn.Parameter(torch.rand(1, 1, embedding_dim))
        
        # Learnable Position Embedding (1, num_patches + 1, embedding_dim)
        # +1 because we are adding the CLS token to the sequence
        self.position_embedding = nn.Parameter(torch.randn(1, num_patches + 1, embedding_dim))
        
        # Stack multiple Transformer Encoder blocks
        self.transformer_blocks = nn.Sequential(*[TransformerEncoder() for _ in range(transformer_blocks)])
        
        self.mlp_head = MLPHead()

    def forward(self, x):
        # 1. Embed Patches
        x = self.patch_embedding(x) # (Batch, 16, 64)
        
        # 2. Add CLS Token
        batch_size = x.size(0)
        class_token = self.cls_token.expand(batch_size, -1, -1) # Expand to match batch size
        x = torch.cat([class_token, x], dim=1) # (Batch, 17, 64)
        
        # 3. Add Position Embedding
        x = x + self.position_embedding
        
        # 4. Pass through Transformer Blocks
        x = self.transformer_blocks(x)
        
        # 5. Extract only the CLS token (index 0) for classification
        x = x[:, 0] # (Batch, 64)
        
        # 6. Final Classification
        x = self.mlp_head(x) # (Batch, 10)
        return x

## 8. Training Loop

In [None]:
# Initialize Model, Optimizer, and Loss Function
model = VisionTransformer().to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
criterion = nn.CrossEntropyLoss()

print("Start Training...")
for epoch in range(num_epochs):
    model.train()
    total_loss = 0
    correct_epoch = 0
    total_samples = 0
    
    print(f"\nEpoch {epoch+1}/{num_epochs}")
    
    for batch_idx, (images, labels) in enumerate(train_loader):
        images, labels = images.to(device), labels.to(device)
        
        # Forward pass
        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, labels)
        
        # Backward pass
        loss.backward()
        optimizer.step()
        
        # Metrics
        total_loss += loss.item()
        preds = outputs.argmax(dim=1)
        correct_epoch += (preds == labels).sum().item()
        total_samples += labels.size(0)
        
        if (batch_idx + 1) % 100 == 0:
            accuracy = 100 * correct_epoch / total_samples
            print(f"  Batch {batch_idx+1}: Loss = {loss.item():.4f}, Running Acc = {accuracy:.2f}%")
    
    epoch_acc = 100 * correct_epoch / total_samples
    print(f"Epoch {epoch+1} Summary: Avg Loss = {total_loss/len(train_loader):.4f}, Accuracy = {epoch_acc:.2f}%")

## 9. Evaluation

In [None]:
model.eval() # Set model to evaluation mode
correct = 0
total = 0

with torch.no_grad():
    for images, labels in val_loader:
        images, labels = images.to(device), labels.to(device)
        outputs = model(images)
        preds = outputs.argmax(dim=1)
        correct += (preds == labels).sum().item()
        total += labels.size(0)

test_acc = 100 * correct / total
print(f"\nFinal Test Accuracy: {test_acc:.2f}%")