# üîÆ Vision Transformer (ViT) from Scratch

A minimal implementation of the **Vision Transformer** architecture in PyTorch, trained on MNIST.

---

## üìñ Introduction

The Vision Transformer (ViT) was introduced in the paper ["An Image is Worth 16x16 Words"](https://arxiv.org/abs/2010.11929) by Dosovitskiy et al. (2020). It applies the Transformer architecture, originally designed for NLP, directly to image classification.

### Key Idea
Instead of using convolutions, ViT:
1. Splits an image into fixed-size patches
2. Linearly embeds each patch
3. Adds positional embeddings
4. Feeds the sequence to a standard Transformer encoder
5. Uses a classification token ([CLS]) for the final prediction

---

## üèóÔ∏è Architecture Overview

```
Input Image (28√ó28)
        ‚îÇ
        ‚ñº
‚îå‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îê
‚îÇ  Patch Embedding  ‚îÇ  Split into 7√ó7 patches ‚Üí 16 patches
‚îÇ  (Conv2d)         ‚îÇ  Project to 64 dimensions
‚îî‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îò
        ‚îÇ
        ‚ñº
‚îå‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îê
‚îÇ + [CLS] Token     ‚îÇ  Prepend learnable class token
‚îÇ + Position Embed  ‚îÇ  Add positional information
‚îî‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îò
        ‚îÇ
        ‚ñº
‚îå‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îê
‚îÇ Transformer       ‚îÇ  4√ó Encoder blocks
‚îÇ Encoder Stack     ‚îÇ  (MHSA + MLP + LayerNorm)
‚îî‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îò
        ‚îÇ
        ‚ñº
‚îå‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îê
‚îÇ    MLP Head       ‚îÇ  [CLS] token ‚Üí 10 classes
‚îî‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îò
        ‚îÇ
        ‚ñº
   Output (10 classes)
```

---

## üì¶ Setup & Imports

In [None]:
import torch
import torchvision
import torch.utils.data as dataloader
import torch.nn as nn
import matplotlib.pyplot as plt
import numpy as np

print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")

## üìä Data Preparation

We use the MNIST dataset - 70,000 grayscale images of handwritten digits (0-9).
- **Training set**: 60,000 images
- **Test set**: 10,000 images
- **Image size**: 28√ó28 pixels

In [None]:
# Transformation: Convert PIL images to tensors
transformation_operation = torchvision.transforms.Compose([
    torchvision.transforms.ToTensor()
])

# Download and load datasets
train_dataset = torchvision.datasets.MNIST(
    root='./data', train=True, download=True, transform=transformation_operation
)
val_dataset = torchvision.datasets.MNIST(
    root='./data', train=False, download=True, transform=transformation_operation
)

print(f"Training samples: {len(train_dataset)}")
print(f"Validation samples: {len(val_dataset)}")

### üëÅÔ∏è Visualize Sample Images

In [None]:
# Visualize some training samples
fig, axes = plt.subplots(2, 5, figsize=(12, 5))
for i, ax in enumerate(axes.flat):
    img, label = train_dataset[i]
    ax.imshow(img.squeeze(), cmap='gray')
    ax.set_title(f'Label: {label}', fontsize=12)
    ax.axis('off')
plt.suptitle('Sample MNIST Images', fontsize=14, fontweight='bold')
plt.tight_layout()
plt.show()

## ‚öôÔ∏è Hyperparameters

| Parameter | Value | Description |
|-----------|-------|-------------|
| `patch_size` | 7 | Each 28√ó28 image ‚Üí 4√ó4 grid of patches |
| `num_patches` | 16 | Total patches per image |
| `embedding_dim` | 64 | Dimension of patch embeddings |
| `attention_heads` | 4 | Number of attention heads |
| `transformer_blocks` | 4 | Depth of transformer |
| `mlp_hidden_nodes` | 128 | Hidden layer size in MLP |

In [None]:
# Model hyperparameters
num_classes = 10
batch_size = 64
num_channels = 1
img_size = 28
patch_size = 7
num_patches = (img_size // patch_size) ** 2
embedding_dim = 64
attention_heads = 4
transformer_blocks = 4
mlp_hidden_nodes = 128
learning_rate = 0.001
epochs = 5

print(f"Image size: {img_size}√ó{img_size}")
print(f"Patch size: {patch_size}√ó{patch_size}")
print(f"Number of patches: {num_patches} ({img_size//patch_size}√ó{img_size//patch_size} grid)")
print(f"Sequence length: {num_patches + 1} (patches + [CLS] token)")

In [None]:
# Create data loaders
train_loader = dataloader.DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_loader = dataloader.DataLoader(val_dataset, batch_size=batch_size, shuffle=False)

print(f"Training batches: {len(train_loader)}")
print(f"Validation batches: {len(val_loader)}")

## üß© Understanding Patch Embedding

The first step in ViT is to split the image into patches and embed them.

**Process:**
1. Use Conv2d with `kernel_size=patch_size` and `stride=patch_size`
2. This extracts non-overlapping patches and projects them to `embedding_dim`
3. Reshape from (B, D, H', W') to (B, num_patches, D)

In [None]:
# Demonstrate patch embedding
data_point, label = next(iter(train_loader))

print("=== Patch Embedding Demonstration ===")
print(f"Input shape: {data_point.shape}")
print(f"  ‚Üí (batch_size, channels, height, width)")

# Simulate patch embedding with Conv2d
patch_embed = nn.Conv2d(num_channels, embedding_dim, kernel_size=patch_size, stride=patch_size)
patch_embed_output = patch_embed(data_point)
print(f"\nAfter Conv2d: {patch_embed_output.shape}")
print(f"  ‚Üí (batch_size, embedding_dim, patches_h, patches_w)")

# Reshape to sequence
sequence = patch_embed_output.flatten(2).transpose(1, 2)
print(f"\nAfter reshape: {sequence.shape}")
print(f"  ‚Üí (batch_size, num_patches, embedding_dim)")

### üñºÔ∏è Visualize Patch Extraction

In [None]:
# Visualize how an image is split into patches
sample_img = data_point[0].squeeze().numpy()

fig, axes = plt.subplots(1, 2, figsize=(12, 5))

# Original image with patch grid
axes[0].imshow(sample_img, cmap='gray')
axes[0].set_title('Original Image with Patch Grid', fontsize=12)
for i in range(1, img_size // patch_size):
    axes[0].axhline(y=i * patch_size - 0.5, color='red', linewidth=2)
    axes[0].axvline(x=i * patch_size - 0.5, color='red', linewidth=2)
axes[0].axis('off')

# Individual patches
ax_patches = axes[1]
patches_per_side = img_size // patch_size
patch_grid = np.zeros((patches_per_side * (patch_size + 1), patches_per_side * (patch_size + 1)))

for i in range(patches_per_side):
    for j in range(patches_per_side):
        patch = sample_img[i*patch_size:(i+1)*patch_size, j*patch_size:(j+1)*patch_size]
        y_start = i * (patch_size + 1)
        x_start = j * (patch_size + 1)
        patch_grid[y_start:y_start+patch_size, x_start:x_start+patch_size] = patch

ax_patches.imshow(patch_grid, cmap='gray')
ax_patches.set_title(f'Extracted Patches ({patches_per_side}√ó{patches_per_side} = {num_patches} patches)', fontsize=12)
ax_patches.axis('off')

plt.suptitle('Image Patchification Process', fontsize=14, fontweight='bold')
plt.tight_layout()
plt.show()

---

## üîß Model Components

Now let's build the ViT step by step.

---

### 1Ô∏è‚É£ Patch Embedding Layer

Converts the input image into a sequence of patch embeddings.

In [None]:
class PatchEmbedding(nn.Module):
    """
    Splits image into patches and projects to embedding dimension.
    
    Input:  (B, C, H, W) = (batch, channels, height, width)
    Output: (B, N, D)    = (batch, num_patches, embedding_dim)
    """
    def __init__(self):
        super().__init__()
        # Conv2d acts as both patch extraction and linear projection
        self.patch_embed = nn.Conv2d(
            num_channels, embedding_dim, 
            kernel_size=patch_size, stride=patch_size
        )

    def forward(self, x):
        # x: (B, C, H, W)
        x = self.patch_embed(x)  # (B, D, H', W')
        x = x.flatten(2)          # (B, D, N)
        x = x.transpose(1, 2)     # (B, N, D)
        return x

### 2Ô∏è‚É£ Transformer Encoder Block

The core building block with:
- **Multi-Head Self-Attention**: Allows each patch to attend to all other patches
- **MLP (Feed-Forward Network)**: Non-linear transformation
- **Layer Normalization**: Stabilizes training
- **Residual Connections**: Helps gradient flow

We use **Pre-Norm** architecture (LayerNorm before attention/MLP).

In [None]:
class TransformerEncoder(nn.Module):
    """
    Single Transformer encoder block with Pre-LayerNorm.
    
    Structure:
        x ‚Üí LayerNorm ‚Üí MHSA ‚Üí + ‚Üí LayerNorm ‚Üí MLP ‚Üí +
        ‚îî‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îò    ‚îî‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îò
              (residual)              (residual)
    """
    def __init__(self):
        super().__init__()
        self.layer_norm1 = nn.LayerNorm(embedding_dim)
        self.layer_norm2 = nn.LayerNorm(embedding_dim)
        
        # Multi-Head Self-Attention
        self.multihead_attention = nn.MultiheadAttention(
            embedding_dim, attention_heads, batch_first=True
        )
        
        # Feed-Forward MLP
        self.mlp = nn.Sequential(
            nn.Linear(embedding_dim, mlp_hidden_nodes),
            nn.GELU(),  # Smooth activation function
            nn.Linear(mlp_hidden_nodes, embedding_dim),
        )

    def forward(self, x):
        # Self-attention block with residual
        residual1 = x
        x = self.layer_norm1(x)
        x = self.multihead_attention(x, x, x)[0]  # Self-attention: Q=K=V
        x = x + residual1
        
        # MLP block with residual
        residual2 = x
        x = self.layer_norm2(x)
        x = self.mlp(x)
        x = x + residual2
        
        return x

### 3Ô∏è‚É£ MLP Classification Head

Takes the [CLS] token representation and outputs class logits.

In [None]:
class MLP_head(nn.Module):
    """
    Classification head operating on the [CLS] token.
    
    Input:  (B, D)          = (batch, embedding_dim)
    Output: (B, num_classes) = (batch, 10)
    """
    def __init__(self):
        super().__init__()
        self.layer_norm = nn.LayerNorm(embedding_dim)
        self.mlp_head = nn.Linear(embedding_dim, num_classes)

    def forward(self, x):
        x = self.layer_norm(x)
        x = self.mlp_head(x)
        return x

### 4Ô∏è‚É£ Complete Vision Transformer

Putting it all together:
1. **Patch Embedding**: Image ‚Üí patch sequence
2. **Class Token**: Prepend learnable [CLS] token
3. **Position Embedding**: Add positional information
4. **Transformer Encoder**: Process with self-attention
5. **MLP Head**: Classify using [CLS] token

In [None]:
class VisionTransformer(nn.Module):
    """
    Complete Vision Transformer for image classification.
    
    Input:  (B, C, H, W)    = (batch, 1, 28, 28)
    Output: (B, num_classes) = (batch, 10)
    """
    def __init__(self):
        super().__init__()
        
        # Patch embedding layer
        self.patch_embedding = PatchEmbedding()
        
        # Learnable [CLS] token for classification
        self.cls_token = nn.Parameter(torch.randn(1, 1, embedding_dim))
        
        # Learnable position embeddings (for CLS + patches)
        self.position_embedding = nn.Parameter(
            torch.randn(1, 1 + num_patches, embedding_dim)
        )
        
        # Stack of transformer encoder blocks
        self.transformer_blocks = nn.Sequential(
            *[TransformerEncoder() for _ in range(transformer_blocks)]
        )
        
        # Classification head
        self.mlp_head = MLP_head()

    def forward(self, x):
        # 1. Embed patches
        x = self.patch_embedding(x)  # (B, N, D)
        
        # 2. Prepend [CLS] token
        B = x.size(0)
        class_token = self.cls_token.expand(B, -1, -1)  # (B, 1, D)
        x = torch.cat((class_token, x), dim=1)  # (B, 1+N, D)
        
        # 3. Add position embeddings
        x = x + self.position_embedding
        
        # 4. Pass through transformer blocks
        x = self.transformer_blocks(x)
        
        # 5. Extract [CLS] token and classify
        x = x[:, 0]  # (B, D) - only the CLS token
        x = self.mlp_head(x)
        
        return x

---

## üèãÔ∏è Training Setup

In [None]:
# Setup device, model, optimizer, and loss function
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

model = VisionTransformer().to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
criterion = nn.CrossEntropyLoss()

# Count parameters
total_params = sum(p.numel() for p in model.parameters())
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"Total parameters: {total_params:,}")
print(f"Trainable parameters: {trainable_params:,}")

---

## üöÄ Training Loop

Train the model and track metrics.

In [None]:
# Store metrics for visualization
train_losses = []
train_accuracies = []

# Training loop
for epoch in range(epochs):
    model.train()
    total_loss = 0
    correct_epoch = 0
    total_epoch = 0

    print(f"\n{'='*60}")
    print(f"Epoch {epoch+1}/{epochs}")
    print(f"{'='*60}")

    for batch_idx, (images, labels) in enumerate(train_loader):
        images, labels = images.to(device), labels.to(device)
        
        # Forward pass
        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, labels)
        
        # Backward pass
        loss.backward()
        optimizer.step()

        # Track metrics
        total_loss += loss.item()
        preds = outputs.argmax(dim=1)
        correct = (preds == labels).sum()
        accuracy = 100 * correct / labels.size(0)
        correct_epoch += correct
        total_epoch += labels.size(0)

        # Print progress every 100 batches
        if batch_idx % 100 == 0:
            print(f"  Batch {batch_idx+1:4d}/{len(train_loader)}: Loss = {loss:.4f}, Acc = {accuracy:.2f}%")

    # Epoch summary
    epoch_acc = 100.0 * correct_epoch / total_epoch
    train_losses.append(total_loss)
    train_accuracies.append(epoch_acc.item())
    
    print(f"\nüìä Epoch {epoch+1} Summary:")
    print(f"   Total Loss: {total_loss:.4f}")
    print(f"   Accuracy:   {epoch_acc:.2f}%")

---

## üìà Training Visualization

In [None]:
# Plot training metrics
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 5))

# Loss plot
ax1.plot(range(1, epochs+1), train_losses, 'b-o', linewidth=2, markersize=8)
ax1.set_xlabel('Epoch', fontsize=12)
ax1.set_ylabel('Total Loss', fontsize=12)
ax1.set_title('Training Loss', fontsize=14, fontweight='bold')
ax1.grid(True, alpha=0.3)
ax1.set_xticks(range(1, epochs+1))

# Accuracy plot
ax2.plot(range(1, epochs+1), train_accuracies, 'g-o', linewidth=2, markersize=8)
ax2.set_xlabel('Epoch', fontsize=12)
ax2.set_ylabel('Accuracy (%)', fontsize=12)
ax2.set_title('Training Accuracy', fontsize=14, fontweight='bold')
ax2.grid(True, alpha=0.3)
ax2.set_xticks(range(1, epochs+1))
ax2.set_ylim([80, 100])

plt.tight_layout()
plt.show()

---

## üß™ Validation

In [None]:
# Evaluate on validation set
model.eval()
correct = 0
total = 0

with torch.no_grad():
    for images, labels in val_loader:
        images, labels = images.to(device), labels.to(device)
        outputs = model(images)
        preds = outputs.argmax(dim=1)
        correct += (preds == labels).sum().item()
        total += labels.size(0)

val_accuracy = 100 * correct / total
print(f"\n{'='*60}")
print(f"üéØ Validation Accuracy: {val_accuracy:.2f}%")
print(f"{'='*60}")

---

## üîç Visualize Predictions

In [None]:
# Visualize some predictions
model.eval()
sample_images, sample_labels = next(iter(val_loader))
sample_images = sample_images[:10].to(device)
sample_labels = sample_labels[:10]

with torch.no_grad():
    outputs = model(sample_images)
    predictions = outputs.argmax(dim=1).cpu()

fig, axes = plt.subplots(2, 5, figsize=(15, 6))
for i, ax in enumerate(axes.flat):
    img = sample_images[i].cpu().squeeze()
    ax.imshow(img, cmap='gray')
    
    pred = predictions[i].item()
    true = sample_labels[i].item()
    color = 'green' if pred == true else 'red'
    
    ax.set_title(f'Pred: {pred} | True: {true}', fontsize=12, color=color)
    ax.axis('off')

plt.suptitle('Model Predictions (Green=Correct, Red=Wrong)', fontsize=14, fontweight='bold')
plt.tight_layout()
plt.show()

---

## üìö Summary

In this notebook, we implemented a **Vision Transformer (ViT)** from scratch:

‚úÖ **Patch Embedding**: Split images into patches and embed them  
‚úÖ **Positional Encoding**: Added learnable position embeddings  
‚úÖ **Transformer Encoder**: Implemented Multi-Head Self-Attention + MLP  
‚úÖ **Classification**: Used [CLS] token for final prediction  

### Results
- Achieved **~98% accuracy** on MNIST
- Only **5 epochs** of training
- Minimal architecture (~100K parameters)

### Next Steps
- Try on larger datasets (CIFAR-10, ImageNet)
- Add dropout for regularization
- Implement attention visualization
- Experiment with different patch sizes

---