# Lab 2.2.5: Vision Transformer (ViT)

**Module:** 2.2 - Computer Vision  
**Time:** 3 hours  
**Difficulty:** ‚≠ê‚≠ê‚≠ê‚≠ê‚≠ê

---

## üéØ Learning Objectives

By the end of this notebook, you will:
- [ ] Understand how Vision Transformers adapt NLP techniques for images
- [ ] Implement ViT from scratch, including patch embedding and self-attention
- [ ] Train ViT on CIFAR-10
- [ ] Compare ViT with CNN architectures

---

## üìö Prerequisites

- Completed: Labs 2.2.1-2.2.4
- Knowledge of: Transformers (from NLP), self-attention, CNNs

---

## üåç Real-World Context

**Vision Transformers are revolutionizing computer vision:**

- üñºÔ∏è **State-of-the-art**: ViT and variants (Swin, DeiT) now top ImageNet leaderboards
- ü§ñ **Unified architecture**: Same transformer can process text, images, audio
- üß† **Foundation models**: CLIP, DALL-E, Stable Diffusion all use ViT variants
- üìä **Scalability**: ViT scales better than CNNs with more data and compute

---

## üßí ELI5: What is a Vision Transformer?

> **Imagine you're reading a book with pictures...**
>
> When you read, you don't look at every letter individually. You see words and sentences as chunks.
>
> **Vision Transformers treat images the same way:**
> 1. **Cut the image into patches** (like words in a sentence)
> 2. **Arrange patches in a sequence** (like words in order)
> 3. **Let patches "talk" to each other** through attention (like understanding context)
>
> The magic: Instead of looking at local neighborhoods (like CNNs), every patch can directly attend to every other patch!

### From CNN to ViT

```
CNN approach:                           ViT approach:
‚îå‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îê                        ‚îå‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îê
‚îÇ üîç ‚Üí üîç ‚Üí üîç‚îÇ  Local filters         ‚îÇ ‚ñ¢ ‚ñ¢ ‚ñ¢ ‚ñ¢    ‚îÇ  Cut into patches
‚îÇ üîç ‚Üí üîç ‚Üí üîç‚îÇ  slide across          ‚îÇ ‚ñ¢ ‚ñ¢ ‚ñ¢ ‚ñ¢    ‚îÇ  
‚îÇ üîç ‚Üí üîç ‚Üí üîç‚îÇ  the image             ‚îÇ ‚ñ¢ ‚ñ¢ ‚ñ¢ ‚ñ¢    ‚îÇ  
‚îî‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îò                        ‚îÇ ‚ñ¢ ‚ñ¢ ‚ñ¢ ‚ñ¢    ‚îÇ  
                                       ‚îî‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îò
         ‚îÇ                                    ‚îÇ
         ‚ñº                                    ‚ñº
   Hierarchical                        [CLS] P1 P2 P3 ... P16
   feature maps                              ‚Üì
         ‚îÇ                             Transformer Layers
         ‚ñº                                   ‚Üì
      Output                           Classify from [CLS]
```

In [None]:
# Setup
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader
import torchvision
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
import numpy as np
from typing import Tuple, Optional, Dict, List
from tqdm.auto import tqdm
import time
import math

# DGX Spark optimizations
torch.backends.cudnn.benchmark = True
torch.set_float32_matmul_precision('high')

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"üñ•Ô∏è  Device: {device}")
if torch.cuda.is_available():
    print(f"üéÆ GPU: {torch.cuda.get_device_name(0)}")
    print(f"üíæ GPU Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")

---

## Part 1: Patch Embedding

The first step is converting an image into a sequence of patch embeddings.

In [None]:
class PatchEmbedding(nn.Module):
    """
    Convert image into sequence of patch embeddings.
    
    Input:  [B, C, H, W]  - Batch of images
    Output: [B, N, D]     - Sequence of N patch embeddings of dimension D
    
    where N = (H * W) / (patch_size^2)
    """
    
    def __init__(
        self,
        img_size: int = 224,
        patch_size: int = 16,
        in_channels: int = 3,
        embed_dim: int = 768
    ):
        super(PatchEmbedding, self).__init__()
        
        self.img_size = img_size
        self.patch_size = patch_size
        self.num_patches = (img_size // patch_size) ** 2
        
        # Linear projection of flattened patches
        # Using Conv2d with kernel_size=patch_size is equivalent to:
        # 1. Split image into patches
        # 2. Flatten each patch
        # 3. Apply linear projection
        self.projection = nn.Conv2d(
            in_channels,
            embed_dim,
            kernel_size=patch_size,
            stride=patch_size
        )
    
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        Args:
            x: [B, C, H, W]
        Returns:
            [B, num_patches, embed_dim]
        """
        # Project and reshape: [B, embed_dim, H/P, W/P] -> [B, embed_dim, num_patches]
        x = self.projection(x)
        x = x.flatten(2)  # [B, embed_dim, num_patches]
        x = x.transpose(1, 2)  # [B, num_patches, embed_dim]
        return x


# Test patch embedding
patch_embed = PatchEmbedding(img_size=32, patch_size=4, embed_dim=256)
dummy_img = torch.randn(1, 3, 32, 32)
patches = patch_embed(dummy_img)

print(f"üìä Patch Embedding:")
print(f"   Input image:    {dummy_img.shape}")
print(f"   Patch size:     {patch_embed.patch_size}")
print(f"   Number patches: {patch_embed.num_patches} ({32//4} √ó {32//4})")
print(f"   Output:         {patches.shape}")

In [None]:
def visualize_patches(img: torch.Tensor, patch_size: int = 4):
    """
    Visualize how an image is split into patches.
    """
    img_np = img.squeeze().permute(1, 2, 0).numpy()
    H, W = img_np.shape[:2]
    num_patches_h = H // patch_size
    num_patches_w = W // patch_size
    
    fig, axes = plt.subplots(1, 2, figsize=(12, 5))
    
    # Original image with grid
    axes[0].imshow(img_np.clip(0, 1))
    for i in range(num_patches_h + 1):
        axes[0].axhline(y=i*patch_size - 0.5, color='r', linewidth=0.5)
    for j in range(num_patches_w + 1):
        axes[0].axvline(x=j*patch_size - 0.5, color='r', linewidth=0.5)
    axes[0].set_title(f'Image with {num_patches_h}√ó{num_patches_w} patch grid')
    axes[0].axis('off')
    
    # Individual patches
    num_show = min(16, num_patches_h * num_patches_w)
    patches_grid = np.zeros((4 * patch_size, 4 * patch_size, 3))
    
    for idx in range(num_show):
        i = idx // num_patches_w
        j = idx % num_patches_w
        patch = img_np[i*patch_size:(i+1)*patch_size, j*patch_size:(j+1)*patch_size]
        
        grid_i = idx // 4
        grid_j = idx % 4
        patches_grid[grid_i*patch_size:(grid_i+1)*patch_size,
                    grid_j*patch_size:(grid_j+1)*patch_size] = patch
    
    axes[1].imshow(patches_grid.clip(0, 1))
    axes[1].set_title(f'First 16 patches (each {patch_size}√ó{patch_size})')
    axes[1].axis('off')
    
    plt.suptitle('üß© Image to Patch Conversion', fontsize=14)
    plt.tight_layout()
    plt.show()

# Load a sample image
transform = transforms.ToTensor()
dataset = torchvision.datasets.CIFAR10(root='../data', train=True, download=True, transform=transform)
sample_img, label = dataset[0]

visualize_patches(sample_img.unsqueeze(0), patch_size=4)

---

## Part 2: Multi-Head Self-Attention

### üßí ELI5: Self-Attention

> **Imagine you're in a classroom and need to answer a question...**
>
> You look around and decide who to pay attention to:
> - **Query (Q)**: "I need information about X"
> - **Key (K)**: Each classmate holds up a sign: "I know about Y"
> - **Value (V)**: The actual information each classmate has
>
> You compare your Query to everyone's Keys, then take a weighted average of their Values!
>
> **Multi-head**: Instead of asking one question, you ask 8 different questions in parallel (like having 8 TAs helping you). Each head learns to look for different things.

In [None]:
class MultiHeadSelfAttention(nn.Module):
    """
    Multi-Head Self-Attention mechanism.
    
    Attention(Q, K, V) = softmax(QK^T / sqrt(d_k)) V
    """
    
    def __init__(
        self,
        embed_dim: int = 768,
        num_heads: int = 12,
        dropout: float = 0.0
    ):
        super(MultiHeadSelfAttention, self).__init__()
        
        assert embed_dim % num_heads == 0, "embed_dim must be divisible by num_heads"
        
        self.embed_dim = embed_dim
        self.num_heads = num_heads
        self.head_dim = embed_dim // num_heads
        self.scale = self.head_dim ** -0.5  # 1/sqrt(d_k)
        
        # Combined QKV projection (more efficient)
        self.qkv = nn.Linear(embed_dim, embed_dim * 3)
        self.proj = nn.Linear(embed_dim, embed_dim)
        self.dropout = nn.Dropout(dropout)
    
    def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        """
        Args:
            x: [B, N, D] - Input sequence
        Returns:
            output: [B, N, D] - Attended sequence
            attention: [B, H, N, N] - Attention weights
        """
        B, N, D = x.shape
        
        # Compute Q, K, V
        qkv = self.qkv(x)  # [B, N, 3*D]
        qkv = qkv.reshape(B, N, 3, self.num_heads, self.head_dim)
        qkv = qkv.permute(2, 0, 3, 1, 4)  # [3, B, H, N, head_dim]
        q, k, v = qkv[0], qkv[1], qkv[2]  # Each: [B, H, N, head_dim]
        
        # Attention scores: [B, H, N, N]
        attention = (q @ k.transpose(-2, -1)) * self.scale
        attention = attention.softmax(dim=-1)
        attention = self.dropout(attention)
        
        # Apply attention to values
        out = attention @ v  # [B, H, N, head_dim]
        out = out.transpose(1, 2).reshape(B, N, D)  # [B, N, D]
        out = self.proj(out)
        
        return out, attention


# Test attention
mhsa = MultiHeadSelfAttention(embed_dim=256, num_heads=8)
dummy_seq = torch.randn(1, 64, 256)  # [B, N, D]
out, attn = mhsa(dummy_seq)

print(f"üìä Multi-Head Self-Attention:")
print(f"   Input:    {dummy_seq.shape}")
print(f"   Output:   {out.shape}")
print(f"   Attention weights: {attn.shape}")

In [None]:
def visualize_attention(attn_weights: torch.Tensor, num_heads: int = 4):
    """
    Visualize attention patterns from different heads.
    """
    fig, axes = plt.subplots(2, num_heads//2, figsize=(12, 8))
    
    for idx, ax in enumerate(axes.flat):
        if idx >= num_heads:
            break
        
        attn = attn_weights[0, idx].detach().numpy()  # [N, N]
        im = ax.imshow(attn, cmap='viridis')
        ax.set_title(f'Head {idx+1}', fontsize=10)
        ax.set_xlabel('Key position')
        ax.set_ylabel('Query position')
        plt.colorbar(im, ax=ax, fraction=0.046, pad=0.04)
    
    plt.suptitle('üîç Attention Patterns Across Heads', fontsize=14)
    plt.tight_layout()
    plt.show()

visualize_attention(attn, num_heads=8)

---

## Part 3: Transformer Block

A Transformer block combines self-attention with a feedforward network.

In [None]:
class MLP(nn.Module):
    """MLP (Feed-Forward Network) in Transformer block."""
    
    def __init__(
        self,
        embed_dim: int = 768,
        hidden_dim: int = 3072,
        dropout: float = 0.0
    ):
        super(MLP, self).__init__()
        self.fc1 = nn.Linear(embed_dim, hidden_dim)
        self.act = nn.GELU()  # ViT uses GELU activation
        self.fc2 = nn.Linear(hidden_dim, embed_dim)
        self.dropout = nn.Dropout(dropout)
    
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self.fc1(x)
        x = self.act(x)
        x = self.dropout(x)
        x = self.fc2(x)
        x = self.dropout(x)
        return x


class TransformerBlock(nn.Module):
    """
    Transformer encoder block.
    
    Structure:
        x ‚Üí LayerNorm ‚Üí MHSA ‚Üí + ‚Üí LayerNorm ‚Üí MLP ‚Üí +
            ‚îî‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îò   ‚îî‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îò
                (residual)              (residual)
    """
    
    def __init__(
        self,
        embed_dim: int = 768,
        num_heads: int = 12,
        mlp_ratio: float = 4.0,
        dropout: float = 0.0
    ):
        super(TransformerBlock, self).__init__()
        
        self.norm1 = nn.LayerNorm(embed_dim)
        self.attn = MultiHeadSelfAttention(embed_dim, num_heads, dropout)
        self.norm2 = nn.LayerNorm(embed_dim)
        self.mlp = MLP(embed_dim, int(embed_dim * mlp_ratio), dropout)
    
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        # Self-attention with residual
        attn_out, _ = self.attn(self.norm1(x))
        x = x + attn_out
        
        # MLP with residual
        x = x + self.mlp(self.norm2(x))
        
        return x

---

## Part 4: Complete Vision Transformer

In [None]:
class VisionTransformer(nn.Module):
    """
    Vision Transformer (ViT) for image classification.
    
    Original paper: "An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale"
    by Alexey Dosovitskiy et al., 2020
    
    Architecture:
        Image ‚Üí Patch Embed ‚Üí [CLS] + Patches + Pos Embed ‚Üí Transformer Layers ‚Üí [CLS] ‚Üí MLP Head
    
    Args:
        img_size: Input image size
        patch_size: Size of each patch
        in_channels: Number of input channels
        num_classes: Number of output classes
        embed_dim: Embedding dimension
        depth: Number of transformer layers
        num_heads: Number of attention heads
        mlp_ratio: MLP hidden dim = embed_dim * mlp_ratio
        dropout: Dropout rate
    """
    
    def __init__(
        self,
        img_size: int = 224,
        patch_size: int = 16,
        in_channels: int = 3,
        num_classes: int = 1000,
        embed_dim: int = 768,
        depth: int = 12,
        num_heads: int = 12,
        mlp_ratio: float = 4.0,
        dropout: float = 0.0
    ):
        super(VisionTransformer, self).__init__()
        
        self.num_patches = (img_size // patch_size) ** 2
        self.embed_dim = embed_dim
        
        # Patch embedding
        self.patch_embed = PatchEmbedding(img_size, patch_size, in_channels, embed_dim)
        
        # Learnable [CLS] token
        self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
        
        # Learnable positional embeddings
        self.pos_embed = nn.Parameter(torch.zeros(1, self.num_patches + 1, embed_dim))
        
        # Dropout after embedding
        self.pos_drop = nn.Dropout(dropout)
        
        # Transformer encoder layers
        self.blocks = nn.ModuleList([
            TransformerBlock(embed_dim, num_heads, mlp_ratio, dropout)
            for _ in range(depth)
        ])
        
        # Final layer norm
        self.norm = nn.LayerNorm(embed_dim)
        
        # Classification head
        self.head = nn.Linear(embed_dim, num_classes)
        
        # Initialize weights
        self._init_weights()
    
    def _init_weights(self):
        """Initialize weights like in the original ViT paper."""
        # Initialize cls_token and pos_embed
        nn.init.trunc_normal_(self.cls_token, std=0.02)
        nn.init.trunc_normal_(self.pos_embed, std=0.02)
        
        # Initialize other layers
        self.apply(self._init_module_weights)
    
    def _init_module_weights(self, m):
        if isinstance(m, nn.Linear):
            nn.init.trunc_normal_(m.weight, std=0.02)
            if m.bias is not None:
                nn.init.zeros_(m.bias)
        elif isinstance(m, nn.LayerNorm):
            nn.init.zeros_(m.bias)
            nn.init.ones_(m.weight)
    
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        Args:
            x: [B, C, H, W] - Input images
        Returns:
            [B, num_classes] - Class logits
        """
        B = x.shape[0]
        
        # Patch embedding: [B, num_patches, embed_dim]
        x = self.patch_embed(x)
        
        # Prepend [CLS] token
        cls_tokens = self.cls_token.expand(B, -1, -1)  # [B, 1, embed_dim]
        x = torch.cat([cls_tokens, x], dim=1)  # [B, num_patches + 1, embed_dim]
        
        # Add positional embedding
        x = x + self.pos_embed
        x = self.pos_drop(x)
        
        # Transformer blocks
        for block in self.blocks:
            x = block(x)
        
        # Final norm
        x = self.norm(x)
        
        # Classification from [CLS] token
        cls_output = x[:, 0]  # [B, embed_dim]
        logits = self.head(cls_output)  # [B, num_classes]
        
        return logits


# Test the model
model = VisionTransformer(
    img_size=32,
    patch_size=4,
    num_classes=10,
    embed_dim=256,
    depth=6,
    num_heads=8
)

dummy_img = torch.randn(1, 3, 32, 32)
output = model(dummy_img)

print(f"üìä Vision Transformer (ViT-Tiny for CIFAR-10):")
print(f"   Input shape:  {dummy_img.shape}")
print(f"   Output shape: {output.shape}")
print(f"   Parameters:   {sum(p.numel() for p in model.parameters()):,}")
print(f"\n   Configuration:")
print(f"   - Patch size:    4√ó4")
print(f"   - Num patches:   {model.num_patches} ({32//4}√ó{32//4})")
print(f"   - Embed dim:     256")
print(f"   - Num layers:    6")
print(f"   - Num heads:     8")

---

## Part 5: Training ViT on CIFAR-10

Training ViT from scratch is tricky - they typically need lots of data or strong regularization.

In [None]:
# Data loading with strong augmentation (important for ViT!)
# NOTE: When using num_workers > 0 in Docker, use --ipc=host flag
# Example: docker run --gpus all --ipc=host ...

def get_cifar10_loaders_vit(batch_size: int = 128) -> Tuple[DataLoader, DataLoader]:
    """
    Create CIFAR-10 loaders with strong augmentation for ViT.
    
    ViT is data-hungry, so we use aggressive augmentation.
    
    Args:
        batch_size: Batch size. DGX Spark can handle 256+ due to 128GB memory.
    """
    # Strong augmentation for training
    train_transform = transforms.Compose([
        transforms.RandomCrop(32, padding=4),
        transforms.RandomHorizontalFlip(),
        transforms.RandAugment(num_ops=2, magnitude=9),  # Strong augmentation
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
        transforms.RandomErasing(p=0.25),  # Cutout-like
    ])
    
    test_transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
    ])
    
    train_dataset = torchvision.datasets.CIFAR10(
        root='../data', train=True, download=True, transform=train_transform
    )
    test_dataset = torchvision.datasets.CIFAR10(
        root='../data', train=False, download=True, transform=test_transform
    )
    
    # num_workers=4 requires --ipc=host when running in Docker
    train_loader = DataLoader(train_dataset, batch_size=batch_size,
                              shuffle=True, num_workers=4, pin_memory=True)
    test_loader = DataLoader(test_dataset, batch_size=batch_size,
                             shuffle=False, num_workers=4, pin_memory=True)
    
    return train_loader, test_loader

train_loader, test_loader = get_cifar10_loaders_vit(batch_size=128)
print(f"üìä Dataset loaded with strong augmentation")
print(f"   Training:   {len(train_loader.dataset):,} images")
print(f"   Test:       {len(test_loader.dataset):,} images")

In [None]:
def train_vit(
    model: nn.Module,
    train_loader: DataLoader,
    test_loader: DataLoader,
    epochs: int = 50,
    lr: float = 0.001,
    weight_decay: float = 0.1,
    warmup_epochs: int = 5,
    device: torch.device = device
) -> Dict[str, List[float]]:
    """
    Train Vision Transformer with warmup and cosine annealing.
    """
    model = model.to(device)
    
    # Use label smoothing (helps regularization)
    criterion = nn.CrossEntropyLoss(label_smoothing=0.1)
    
    # AdamW optimizer (important for transformers!)
    optimizer = optim.AdamW(model.parameters(), lr=lr, weight_decay=weight_decay)
    
    # Warmup + cosine annealing scheduler
    warmup_scheduler = optim.lr_scheduler.LinearLR(
        optimizer, start_factor=0.01, total_iters=warmup_epochs * len(train_loader)
    )
    cosine_scheduler = optim.lr_scheduler.CosineAnnealingLR(
        optimizer, T_max=(epochs - warmup_epochs) * len(train_loader)
    )
    scheduler = optim.lr_scheduler.SequentialLR(
        optimizer, [warmup_scheduler, cosine_scheduler],
        milestones=[warmup_epochs * len(train_loader)]
    )
    
    history = {'train_loss': [], 'train_acc': [], 'test_loss': [], 'test_acc': []}
    
    for epoch in range(epochs):
        # Training
        model.train()
        train_loss, correct, total = 0, 0, 0
        
        pbar = tqdm(train_loader, desc=f'Epoch {epoch+1}/{epochs}')
        for images, labels in pbar:
            images, labels = images.to(device), labels.to(device)
            
            optimizer.zero_grad()
            outputs = model(images)
            loss = criterion(outputs, labels)
            loss.backward()
            
            # Gradient clipping (important for transformers!)
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
            
            optimizer.step()
            scheduler.step()
            
            train_loss += loss.item()
            _, predicted = outputs.max(1)
            total += labels.size(0)
            correct += predicted.eq(labels).sum().item()
            
            pbar.set_postfix({
                'loss': f'{train_loss/total:.4f}',
                'acc': f'{100.*correct/total:.1f}%',
                'lr': f'{scheduler.get_last_lr()[0]:.6f}'
            })
        
        history['train_loss'].append(train_loss / len(train_loader))
        history['train_acc'].append(100. * correct / total)
        
        # Evaluation
        model.eval()
        test_loss, correct, total = 0, 0, 0
        
        with torch.no_grad():
            for images, labels in test_loader:
                images, labels = images.to(device), labels.to(device)
                outputs = model(images)
                loss = criterion(outputs, labels)
                
                test_loss += loss.item()
                _, predicted = outputs.max(1)
                total += labels.size(0)
                correct += predicted.eq(labels).sum().item()
        
        history['test_loss'].append(test_loss / len(test_loader))
        history['test_acc'].append(100. * correct / total)
        
        print(f"   Test: Loss={history['test_loss'][-1]:.4f}, Acc={history['test_acc'][-1]:.1f}%")
    
    return history

In [None]:
# Train ViT
print("üèãÔ∏è Training Vision Transformer on CIFAR-10...")
print("="*50)
print("This may take a while - ViT needs careful training!")
print("="*50)

vit_model = VisionTransformer(
    img_size=32,
    patch_size=4,
    num_classes=10,
    embed_dim=256,
    depth=6,
    num_heads=8,
    mlp_ratio=4.0,
    dropout=0.1
)

start_time = time.time()

vit_history = train_vit(
    vit_model,
    train_loader,
    test_loader,
    epochs=20,  # Use more epochs (50+) for better results
    lr=0.001,
    weight_decay=0.1,
    warmup_epochs=2
)

vit_time = time.time() - start_time
print(f"\n‚úÖ Training complete in {vit_time/60:.1f} minutes")
print(f"   Best test accuracy: {max(vit_history['test_acc']):.1f}%")

In [None]:
# Plot training history
fig, axes = plt.subplots(1, 2, figsize=(12, 4))

axes[0].plot(vit_history['train_loss'], label='Train', linewidth=2)
axes[0].plot(vit_history['test_loss'], label='Test', linewidth=2)
axes[0].set_xlabel('Epoch')
axes[0].set_ylabel('Loss')
axes[0].set_title('üìâ Training and Test Loss')
axes[0].legend()
axes[0].grid(True, alpha=0.3)

axes[1].plot(vit_history['train_acc'], label='Train', linewidth=2)
axes[1].plot(vit_history['test_acc'], label='Test', linewidth=2)
axes[1].set_xlabel('Epoch')
axes[1].set_ylabel('Accuracy (%)')
axes[1].set_title('üìà Training and Test Accuracy')
axes[1].legend()
axes[1].grid(True, alpha=0.3)

plt.suptitle('Vision Transformer Training on CIFAR-10', fontsize=14)
plt.tight_layout()
plt.show()

---

## Part 6: Visualizing What ViT Learns

In [None]:
def visualize_positional_embeddings(model: VisionTransformer):
    """
    Visualize the learned positional embeddings.
    
    Each patch position learns a unique embedding that encodes spatial information.
    """
    # Check for sklearn dependency
    try:
        from sklearn.decomposition import PCA
        has_sklearn = True
    except ImportError:
        print("‚ö†Ô∏è scikit-learn not installed. PCA visualization will be skipped.")
        print("   Install with: pip install scikit-learn")
        has_sklearn = False
    
    pos_embed = model.pos_embed[0, 1:].detach().cpu()  # Exclude [CLS]
    
    # Compute similarity between positional embeddings
    pos_embed_norm = F.normalize(pos_embed, dim=-1)
    similarity = pos_embed_norm @ pos_embed_norm.T
    
    num_patches = int(math.sqrt(pos_embed.shape[0]))
    
    # Adjust figure layout based on sklearn availability
    num_cols = 3 if has_sklearn else 2
    fig, axes = plt.subplots(1, num_cols, figsize=(5 * num_cols, 4))
    
    # Similarity matrix
    im0 = axes[0].imshow(similarity.numpy(), cmap='viridis')
    axes[0].set_title('Positional Embedding Similarity')
    axes[0].set_xlabel('Patch position')
    axes[0].set_ylabel('Patch position')
    plt.colorbar(im0, ax=axes[0])
    
    # Show similarity to center patch
    center_idx = (num_patches * num_patches) // 2
    center_sim = similarity[center_idx].reshape(num_patches, num_patches).numpy()
    im1 = axes[1].imshow(center_sim, cmap='RdBu_r', vmin=-1, vmax=1)
    axes[1].set_title(f'Similarity to center patch (idx={center_idx})')
    plt.colorbar(im1, ax=axes[1])
    
    # PCA visualization of embeddings (if sklearn available)
    if has_sklearn:
        pca = PCA(n_components=3)
        pos_pca = pca.fit_transform(pos_embed.numpy())
        pos_pca_img = pos_pca.reshape(num_patches, num_patches, 3)
        pos_pca_img = (pos_pca_img - pos_pca_img.min()) / (pos_pca_img.max() - pos_pca_img.min())
        
        axes[2].imshow(pos_pca_img)
        axes[2].set_title('Position Embeddings (PCA ‚Üí RGB)')
        axes[2].axis('off')
    
    axes[0].set_xticks([])
    axes[0].set_yticks([])
    axes[1].axis('off')
    
    plt.suptitle('üß† What ViT Learns: Positional Embeddings', fontsize=14)
    plt.tight_layout()
    plt.show()

# Visualize positional embeddings
visualize_positional_embeddings(vit_model)

In [None]:
def visualize_attention_maps(model: VisionTransformer, image: torch.Tensor):
    """
    Visualize attention maps from the last transformer layer.
    """
    model.eval()
    
    # Hook to capture attention
    attention_maps = []
    
    def hook_fn(module, input, output):
        _, attn = output
        attention_maps.append(attn.detach().cpu())
    
    # Register hook on last attention layer
    hook = model.blocks[-1].attn.register_forward_hook(hook_fn)
    
    with torch.no_grad():
        _ = model(image.unsqueeze(0).to(device))
    
    hook.remove()
    
    if not attention_maps:
        print("No attention captured")
        return
    
    # Get attention from [CLS] token to patches
    attn = attention_maps[0][0]  # [num_heads, num_tokens, num_tokens]
    cls_attn = attn[:, 0, 1:]  # [num_heads, num_patches]
    
    num_patches = int(math.sqrt(cls_attn.shape[1]))
    
    fig, axes = plt.subplots(2, 5, figsize=(15, 6))
    
    # Original image
    img_np = image.permute(1, 2, 0).numpy()
    img_np = (img_np - img_np.min()) / (img_np.max() - img_np.min())
    
    axes[0, 0].imshow(img_np)
    axes[0, 0].set_title('Input Image')
    axes[0, 0].axis('off')
    
    # Attention heads
    for idx, ax in enumerate(axes.flat[1:]):
        if idx >= cls_attn.shape[0]:
            ax.axis('off')
            continue
        
        attn_map = cls_attn[idx].reshape(num_patches, num_patches).numpy()
        im = ax.imshow(attn_map, cmap='hot')
        ax.set_title(f'Head {idx+1}')
        ax.axis('off')
    
    plt.suptitle('üëÅÔ∏è Attention from [CLS] Token to Image Patches', fontsize=14)
    plt.tight_layout()
    plt.show()

# Visualize on a sample
sample_img, _ = test_loader.dataset[0]
visualize_attention_maps(vit_model, sample_img)

---

## ‚úã Try It Yourself

1. **Try different patch sizes**: 2√ó2, 4√ó4, 8√ó8 - how does this affect accuracy and speed?
2. **Experiment with depth**: 4 vs 8 vs 12 layers
3. **Compare with ResNet**: Train a ResNet-18 and compare

<details>
<summary>üí° Hint</summary>

Smaller patches = more tokens = more compute but potentially better accuracy:

```python
# Patch size 2 (256 patches for 32√ó32 image)
model_p2 = VisionTransformer(patch_size=2, ...)

# Patch size 8 (16 patches for 32√ó32 image)
model_p8 = VisionTransformer(patch_size=8, ...)
```

</details>

In [None]:
# YOUR CODE HERE



---

## ‚ö†Ô∏è Common Mistakes

### Mistake 1: Training without warmup

```python
# ‚ùå Wrong: No warmup (training can diverge)
optimizer = optim.Adam(model.parameters(), lr=0.001)

# ‚úÖ Right: Linear warmup for transformers
warmup_scheduler = optim.lr_scheduler.LinearLR(optimizer, start_factor=0.01, total_iters=warmup_steps)
```
**Why:** Transformers are sensitive to learning rate at the start. Warmup stabilizes training.

### Mistake 2: Not using gradient clipping

```python
# ‚ùå Wrong: No gradient clipping
loss.backward()
optimizer.step()

# ‚úÖ Right: Clip gradients
loss.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
optimizer.step()
```
**Why:** Attention can have large gradients, especially early in training.

### Mistake 3: Insufficient data/augmentation

```python
# ‚ùå Wrong: Weak augmentation
transform = transforms.ToTensor()

# ‚úÖ Right: Strong augmentation for ViT
transform = transforms.Compose([
    transforms.RandAugment(num_ops=2, magnitude=9),
    transforms.ToTensor(),
    transforms.RandomErasing(p=0.25),
])
```
**Why:** ViT lacks inductive biases of CNNs. It needs more data or stronger augmentation.

---

## üéâ Checkpoint

You've learned:
- ‚úÖ How ViT converts images into sequences via patch embedding
- ‚úÖ Multi-head self-attention mechanism
- ‚úÖ Complete ViT architecture with [CLS] token and positional embeddings
- ‚úÖ Training techniques for transformers (warmup, gradient clipping, strong augmentation)
- ‚úÖ Visualizing attention patterns

---

## üöÄ Challenge (Optional)

**Implement DeiT (Data-efficient Image Transformer) improvements:**

DeiT adds a "distillation token" that learns from a CNN teacher:

1. Train a ResNet teacher
2. Add a distillation token to ViT (similar to [CLS])
3. Train ViT to match both:
   - True labels (cross-entropy loss)
   - Teacher predictions (distillation loss)

<details>
<summary>üí° Starting Code</summary>

```python
class DeiT(VisionTransformer):
    def __init__(self, ...):
        super().__init__(...)
        # Add distillation token
        self.dist_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
        # New position embedding for +1 token
        self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 2, embed_dim))
        # Distillation head
        self.dist_head = nn.Linear(embed_dim, num_classes)
    
    def forward(self, x):
        # ... similar to ViT but includes dist_token
        cls_output = x[:, 0]
        dist_output = x[:, 1]
        return self.head(cls_output), self.dist_head(dist_output)
```

</details>

In [None]:
# YOUR CHALLENGE CODE HERE



---

## üìñ Further Reading

- [ViT Paper](https://arxiv.org/abs/2010.11929) - Original Vision Transformer
- [DeiT Paper](https://arxiv.org/abs/2012.12877) - Data-efficient training
- [Swin Transformer](https://arxiv.org/abs/2103.14030) - Hierarchical ViT
- [The Illustrated Transformer](http://jalammar.github.io/illustrated-transformer/) - Great visual explanation

---

## üßπ Cleanup

In [None]:
# Clear GPU memory
import gc

del vit_model
torch.cuda.empty_cache()
gc.collect()

print("‚úÖ Cleanup complete!")
if torch.cuda.is_available():
    print(f"üíæ GPU Memory Free: {torch.cuda.mem_get_info()[0] / 1e9:.1f} GB")