# Lab 2.2.5 Solution: Vision Transformer

**Module:** 2.2 - Computer Vision  
**Type:** Solution Notebook

---

This notebook contains solutions for Vision Transformer exercises, including patch size analysis and DeiT implementation.

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt
from typing import Tuple, Optional

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Device: {device}")

## Exercise Solution: Patch Embedding

The foundation of Vision Transformers: converting images into sequences of patch embeddings.

In [None]:
class PatchEmbedding(nn.Module):
    """
    Convert image into patch embeddings.
    
    The image is split into non-overlapping patches, each patch is
    linearly projected to an embedding vector.
    
    Args:
        img_size: Input image size (assumed square)
        patch_size: Size of each patch (assumed square)
        in_channels: Number of input channels
        embed_dim: Dimension of patch embeddings
    """
    
    def __init__(self, img_size: int, patch_size: int, in_channels: int, embed_dim: int):
        super(PatchEmbedding, self).__init__()
        self.img_size = img_size
        self.patch_size = patch_size
        self.num_patches = (img_size // patch_size) ** 2
        
        # Use conv2d as efficient patch extraction + linear projection
        self.projection = nn.Conv2d(
            in_channels, embed_dim,
            kernel_size=patch_size,
            stride=patch_size
        )
    
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        Args:
            x: Input image [B, C, H, W]
        
        Returns:
            Patch embeddings [B, num_patches, embed_dim]
        """
        x = self.projection(x)  # [B, embed_dim, H/P, W/P]
        x = x.flatten(2)  # [B, embed_dim, num_patches]
        x = x.transpose(1, 2)  # [B, num_patches, embed_dim]
        return x


# Test patch embedding
patch_embed = PatchEmbedding(img_size=32, patch_size=4, in_channels=3, embed_dim=256)
x = torch.randn(1, 3, 32, 32)
patches = patch_embed(x)

print(f"Patch Embedding")
print("="*50)
print(f"Input shape: {x.shape}")
print(f"Output shape: {patches.shape}")
print(f"Number of patches: {patch_embed.num_patches}")

## Exercise Solution: Comparing Different Patch Sizes

Analyzing the trade-offs between different patch sizes.

In [None]:
def compare_patch_sizes(img_size: int = 32, embed_dim: int = 256):
    """
    Compare different patch sizes for ViT.
    
    Trade-offs:
    - Smaller patches = more tokens = more compute but potentially better accuracy
    - Larger patches = fewer tokens = less compute but may lose fine details
    
    Args:
        img_size: Input image size
        embed_dim: Embedding dimension
    """
    patch_sizes = [2, 4, 8, 16]
    
    print(f"Patch Size Comparison for {img_size}x{img_size} images")
    print("="*70)
    print(f"{'Patch Size':<12} {'Num Patches':<15} {'Seq Length':<15} {'Params':<15} {'FLOPs Est.':<15}")
    print("-"*70)
    
    results = []
    
    for patch_size in patch_sizes:
        if img_size % patch_size != 0:
            continue
        
        patch_embed = PatchEmbedding(img_size, patch_size, 3, embed_dim)
        num_patches = (img_size // patch_size) ** 2
        seq_length = num_patches + 1  # +1 for CLS token
        params = sum(p.numel() for p in patch_embed.parameters())
        
        # Rough FLOPs estimate for attention: O(seq_length^2 * embed_dim)
        flops_attention = seq_length ** 2 * embed_dim
        
        print(f"{patch_size:<12} {num_patches:<15} {seq_length:<15} {params:,}")
        
        results.append({
            'patch_size': patch_size,
            'num_patches': num_patches,
            'seq_length': seq_length,
            'params': params
        })
    
    return results


# Compare for CIFAR-10 sized images
results = compare_patch_sizes(img_size=32)

print("\n" + "="*70)
print("Analysis:")
print("-"*70)
print("- Patch size 2: 256 patches - Very detailed but computationally expensive")
print("- Patch size 4: 64 patches - Good balance for small images like CIFAR")
print("- Patch size 8: 16 patches - Fast but may miss fine details")
print("- Patch size 16: 4 patches - Too few for meaningful attention patterns")

In [None]:
# Compare for ImageNet sized images (224x224)
print("\n" + "="*70)
results_imagenet = compare_patch_sizes(img_size=224)

print("\nFor 224x224 images (ImageNet):")
print("- Patch size 14: 256 patches - Common choice (ViT-B/14)")
print("- Patch size 16: 196 patches - Standard choice (ViT-B/16)")
print("- Patch size 32: 49 patches - Fast but coarse (ViT-B/32)")

## Exercise Solution: Simple Vision Transformer

A basic ViT implementation from scratch.

In [None]:
class VisionTransformer(nn.Module):
    """
    Simple Vision Transformer (ViT) implementation.
    
    Architecture:
    1. Patch Embedding
    2. Positional Embedding (learnable)
    3. Transformer Encoder
    4. MLP Head
    """
    
    def __init__(
        self,
        img_size: int = 32,
        patch_size: int = 4,
        in_channels: int = 3,
        num_classes: int = 10,
        embed_dim: int = 256,
        depth: int = 6,
        num_heads: int = 8,
        mlp_ratio: float = 4.0,
        dropout: float = 0.1
    ):
        super(VisionTransformer, self).__init__()
        
        self.num_patches = (img_size // patch_size) ** 2
        
        # Patch embedding
        self.patch_embed = PatchEmbedding(img_size, patch_size, in_channels, embed_dim)
        
        # Learnable CLS token
        self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
        
        # Learnable position embedding
        self.pos_embed = nn.Parameter(torch.zeros(1, self.num_patches + 1, embed_dim))
        
        self.pos_drop = nn.Dropout(dropout)
        
        # Transformer encoder
        encoder_layer = nn.TransformerEncoderLayer(
            d_model=embed_dim,
            nhead=num_heads,
            dim_feedforward=int(embed_dim * mlp_ratio),
            dropout=dropout,
            activation='gelu',
            batch_first=True,
            norm_first=True  # Pre-LayerNorm
        )
        self.encoder = nn.TransformerEncoder(encoder_layer, num_layers=depth)
        
        # Final layer norm
        self.norm = nn.LayerNorm(embed_dim)
        
        # Classification head
        self.head = nn.Linear(embed_dim, num_classes)
        
        # Initialize weights
        self._init_weights()
    
    def _init_weights(self):
        nn.init.trunc_normal_(self.cls_token, std=0.02)
        nn.init.trunc_normal_(self.pos_embed, std=0.02)
        nn.init.trunc_normal_(self.head.weight, std=0.02)
        nn.init.zeros_(self.head.bias)
    
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        B = x.shape[0]
        
        # Patch embedding
        x = self.patch_embed(x)  # [B, num_patches, embed_dim]
        
        # Prepend CLS token
        cls_tokens = self.cls_token.expand(B, -1, -1)
        x = torch.cat([cls_tokens, x], dim=1)  # [B, num_patches + 1, embed_dim]
        
        # Add positional embedding
        x = x + self.pos_embed
        x = self.pos_drop(x)
        
        # Transformer encoder
        x = self.encoder(x)
        x = self.norm(x)
        
        # Classification from CLS token
        cls_output = x[:, 0]
        logits = self.head(cls_output)
        
        return logits


# Test ViT
vit = VisionTransformer(
    img_size=32,
    patch_size=4,
    num_classes=10,
    embed_dim=256,
    depth=6,
    num_heads=8
)

x = torch.randn(2, 3, 32, 32)
output = vit(x)

print(f"Vision Transformer")
print("="*50)
print(f"Input shape: {x.shape}")
print(f"Output shape: {output.shape}")
print(f"Parameters: {sum(p.numel() for p in vit.parameters()):,}")

## Exercise Solution: DeiT with Distillation Token

Data-efficient Image Transformer (DeiT) adds a distillation token that learns from a CNN teacher.

**Key insight**: The distillation token provides a second "view" of the classification that can learn from a CNN's inductive biases.

In [None]:
class DeiT(nn.Module):
    """
    Data-efficient Image Transformer (DeiT).
    
    Adds a distillation token that learns from a CNN teacher.
    Paper: "Training data-efficient image transformers" (Touvron et al., 2021)
    
    Key differences from ViT:
    1. Distillation token alongside CLS token
    2. Two classification heads (one for each token)
    3. During inference, average both predictions
    """
    
    def __init__(
        self,
        img_size: int = 32,
        patch_size: int = 4,
        in_channels: int = 3,
        num_classes: int = 10,
        embed_dim: int = 256,
        depth: int = 6,
        num_heads: int = 8,
        mlp_ratio: float = 4.0,
        dropout: float = 0.1
    ):
        super(DeiT, self).__init__()
        
        num_patches = (img_size // patch_size) ** 2
        
        # Patch embedding
        self.patch_embed = nn.Conv2d(in_channels, embed_dim, patch_size, stride=patch_size)
        
        # CLS token and DISTILLATION token
        self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
        self.dist_token = nn.Parameter(torch.zeros(1, 1, embed_dim))  # New!
        
        # Position embedding for patches + cls + dist tokens
        self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 2, embed_dim))
        
        self.pos_drop = nn.Dropout(dropout)
        
        # Transformer encoder
        encoder_layer = nn.TransformerEncoderLayer(
            d_model=embed_dim,
            nhead=num_heads,
            dim_feedforward=int(embed_dim * mlp_ratio),
            dropout=dropout,
            activation='gelu',
            batch_first=True
        )
        self.encoder = nn.TransformerEncoder(encoder_layer, num_layers=depth)
        
        self.norm = nn.LayerNorm(embed_dim)
        
        # Two classification heads
        self.head = nn.Linear(embed_dim, num_classes)  # For CLS token
        self.dist_head = nn.Linear(embed_dim, num_classes)  # For distillation token
        
        # Initialize
        self._init_weights()
    
    def _init_weights(self):
        nn.init.trunc_normal_(self.cls_token, std=0.02)
        nn.init.trunc_normal_(self.dist_token, std=0.02)
        nn.init.trunc_normal_(self.pos_embed, std=0.02)
        nn.init.trunc_normal_(self.head.weight, std=0.02)
        nn.init.trunc_normal_(self.dist_head.weight, std=0.02)
    
    def forward(self, x: torch.Tensor):
        B = x.shape[0]
        
        # Patch embedding
        x = self.patch_embed(x).flatten(2).transpose(1, 2)  # [B, N, D]
        
        # Prepend CLS and DIST tokens
        cls_tokens = self.cls_token.expand(B, -1, -1)
        dist_tokens = self.dist_token.expand(B, -1, -1)
        x = torch.cat([cls_tokens, dist_tokens, x], dim=1)  # [B, N+2, D]
        
        # Add positional embedding
        x = x + self.pos_embed
        x = self.pos_drop(x)
        
        # Transformer
        x = self.encoder(x)
        x = self.norm(x)
        
        # Two outputs
        cls_output = self.head(x[:, 0])  # From CLS token
        dist_output = self.dist_head(x[:, 1])  # From DIST token
        
        # During inference, average both predictions
        if not self.training:
            return (cls_output + dist_output) / 2
        
        return cls_output, dist_output


# Test DeiT
deit = DeiT(img_size=32, patch_size=4, num_classes=10)
x = torch.randn(2, 3, 32, 32)

# Training mode (returns two outputs)
deit.train()
cls_out, dist_out = deit(x)
print(f"DeiT (Training Mode)")
print("="*50)
print(f"CLS output shape: {cls_out.shape}")
print(f"DIST output shape: {dist_out.shape}")

# Eval mode (returns averaged output)
deit.eval()
out = deit(x)
print(f"\nDeiT (Inference Mode)")
print("="*50)
print(f"Combined output shape: {out.shape}")
print(f"Parameters: {sum(p.numel() for p in deit.parameters()):,}")

## Exercise Solution: Distillation Training

How to train DeiT with knowledge distillation from a CNN teacher.

In [None]:
class DistillationLoss(nn.Module):
    """
    Distillation loss for DeiT training.
    
    Combines:
    1. Hard target loss (cross-entropy with true labels)
    2. Soft target loss (KL divergence with teacher outputs)
    
    The distillation token learns from the teacher,
    while the CLS token learns from the ground truth.
    """
    
    def __init__(
        self,
        teacher_model: nn.Module,
        temperature: float = 3.0,
        alpha: float = 0.5
    ):
        super(DistillationLoss, self).__init__()
        self.teacher = teacher_model
        self.teacher.eval()
        self.temperature = temperature
        self.alpha = alpha
        self.ce = nn.CrossEntropyLoss()
        self.kl = nn.KLDivLoss(reduction='batchmean')
    
    def forward(self, student_cls: torch.Tensor, student_dist: torch.Tensor, 
                images: torch.Tensor, labels: torch.Tensor) -> torch.Tensor:
        """
        Args:
            student_cls: CLS token output from student
            student_dist: DIST token output from student
            images: Input images (for teacher inference)
            labels: Ground truth labels
        """
        # Get teacher predictions
        with torch.no_grad():
            teacher_logits = self.teacher(images)
        
        # Hard loss: CLS token vs ground truth
        hard_loss = self.ce(student_cls, labels)
        
        # Soft loss: DIST token vs teacher (with temperature)
        soft_student = F.log_softmax(student_dist / self.temperature, dim=-1)
        soft_teacher = F.softmax(teacher_logits / self.temperature, dim=-1)
        soft_loss = self.kl(soft_student, soft_teacher) * (self.temperature ** 2)
        
        # Combined loss
        total_loss = self.alpha * hard_loss + (1 - self.alpha) * soft_loss
        
        return total_loss


print("DistillationLoss defined!")
print("\nUsage:")
print("  teacher = load_pretrained_resnet()  # CNN teacher")
print("  criterion = DistillationLoss(teacher, temperature=3.0, alpha=0.5)")
print("  ")
print("  # In training loop:")
print("  cls_out, dist_out = deit(images)")
print("  loss = criterion(cls_out, dist_out, images, labels)")

## Exercise Solution: Position Embedding Visualization

Visualizing what position embeddings learn.

In [None]:
def visualize_position_embeddings(pos_embed: torch.Tensor, grid_size: int = 8):
    """
    Visualize position embedding similarities.
    
    Shows how similar each position's embedding is to every other position.
    Ideally, nearby positions should have similar embeddings.
    """
    # Remove CLS token embedding (first position)
    patch_pos_embed = pos_embed[0, 1:, :].detach()  # [num_patches, embed_dim]
    
    # Compute cosine similarity matrix
    patch_pos_embed = F.normalize(patch_pos_embed, dim=-1)
    similarity = torch.mm(patch_pos_embed, patch_pos_embed.T)  # [N, N]
    
    fig, axes = plt.subplots(1, 3, figsize=(15, 5))
    
    # Full similarity matrix
    ax = axes[0]
    im = ax.imshow(similarity.cpu().numpy(), cmap='viridis')
    ax.set_title('Position Embedding Similarity Matrix')
    ax.set_xlabel('Position')
    ax.set_ylabel('Position')
    plt.colorbar(im, ax=ax)
    
    # Similarity from center patch
    center_idx = (grid_size * grid_size) // 2 + grid_size // 2
    center_sim = similarity[center_idx].cpu().numpy().reshape(grid_size, grid_size)
    
    ax = axes[1]
    im = ax.imshow(center_sim, cmap='viridis')
    ax.set_title(f'Similarity to Center Patch (idx={center_idx})')
    plt.colorbar(im, ax=ax)
    
    # Similarity from corner patch
    corner_sim = similarity[0].cpu().numpy().reshape(grid_size, grid_size)
    
    ax = axes[2]
    im = ax.imshow(corner_sim, cmap='viridis')
    ax.set_title('Similarity to Top-Left Corner (idx=0)')
    plt.colorbar(im, ax=ax)
    
    plt.tight_layout()
    plt.show()


# Create a ViT and visualize its initial position embeddings
vit = VisionTransformer(img_size=32, patch_size=4, embed_dim=256)
print("Visualizing initial (random) position embeddings:")
print("(After training, nearby positions should show higher similarity)")
visualize_position_embeddings(vit.pos_embed, grid_size=8)

## Summary

Key concepts covered:

1. **Patch Embedding**: Converting images to sequences via non-overlapping patches
2. **Patch Size Trade-offs**: Smaller patches = more detail but more compute
3. **Vision Transformer**: Complete ViT implementation with:
   - CLS token for classification
   - Learnable position embeddings
   - Transformer encoder
4. **DeiT**: Data-efficient ViT with distillation token
5. **Position Embeddings**: Learnable spatial information for patches

Recommended patch sizes:
- CIFAR (32x32): patch_size=4 (64 patches)
- ImageNet (224x224): patch_size=16 (196 patches)

In [None]:
# Cleanup
import gc
gc.collect()
if torch.cuda.is_available():
    torch.cuda.empty_cache()
print("Cleanup complete!")