# 06: Vision Transformer (ViT) from Scratch for CIFAR-10

Deep learning paper implementation from scratch using PyTorch.
1. **Patch Embedding**: Split image into fixed-size patches, flatten, and project to embedding dimension
1. Patch Embedding Implementation


In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader
import torchvision
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
import numpy as np
import math
import time
import copy

torch.manual_seed(42)
np.random.seed(42)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')


## 1. Data Loading

In [None]:
# CIFAR-10 normalization
CIFAR10_MEAN = (0.4914, 0.4822, 0.4465)
CIFAR10_STD = (0.2470, 0.2435, 0.2616)

# Transforms with augmentation
train_transform = transforms.Compose([
    transforms.RandomCrop(32, padding=4),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize(CIFAR10_MEAN, CIFAR10_STD),
])

test_transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(CIFAR10_MEAN, CIFAR10_STD),
])

# Datasets
train_dataset = torchvision.datasets.CIFAR10(
    root='./data', train=True, download=True, transform=train_transform
)
test_dataset = torchvision.datasets.CIFAR10(
    root='./data', train=False, download=True, transform=test_transform
)

# Dataloaders
BATCH_SIZE = 128
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=2)
test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=2)

CLASSES = ('airplane', 'automobile', 'bird', 'cat', 'deer',
           'dog', 'frog', 'horse', 'ship', 'truck')

print(f"Training samples: {len(train_dataset)}")
print(f"Test samples: {len(test_dataset)}")

## 2. Patch Embedding

The patch embedding converts an image into a sequence of patch tokens:

```
Image: (B, C, H, W) = (B, 3, 32, 32)
                ↓
Split into P×P patches: (B, num_patches, patch_size²×C)
                ↓
Linear projection: (B, num_patches, embed_dim)
```

For a 32×32 image with patch size 4: num_patches = (32/4)² = 64

In [None]:
class PatchEmbedding(nn.Module):
    def __init__(self, img_size: int = 32, patch_size: int = 4, 
                 in_channels: int = 3, embed_dim: int = 192):
        super().__init__()
        self.img_size = img_size
        self.patch_size = patch_size
        self.num_patches = (img_size // patch_size) ** 2
        
        # Use conv2d for efficient patch extraction + embedding
        # kernel_size = stride = patch_size means no overlap
        self.proj = nn.Conv2d(
            in_channels, embed_dim,
            kernel_size=patch_size, stride=patch_size
        )
    
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        # (B, C, H, W) -> (B, embed_dim, H/P, W/P)
        x = self.proj(x)
        # (B, embed_dim, H/P, W/P) -> (B, embed_dim, num_patches)
        x = x.flatten(2)
        # (B, embed_dim, num_patches) -> (B, num_patches, embed_dim)
        x = x.transpose(1, 2)
        return x


# Test patch embedding
print("Testing PatchEmbedding...")
patch_embed = PatchEmbedding(img_size=32, patch_size=4, embed_dim=192)
x = torch.randn(2, 3, 32, 32)
patches = patch_embed(x)
print(f"Input: {x.shape}")
print(f"Output: {patches.shape}")
print(f"Number of patches: {patch_embed.num_patches}")
assert patches.shape == (2, 64, 192), "Shape mismatch!"

In [None]:
# Visualize patching
def visualize_patches(img_tensor, patch_size=4):
    # Unnormalize
    img = img_tensor.numpy().transpose(1, 2, 0)
    mean = np.array(CIFAR10_MEAN)
    std = np.array(CIFAR10_STD)
    img = std * img + mean
    img = np.clip(img, 0, 1)
    
    h, w = img.shape[:2]
    num_patches_h = h // patch_size
    num_patches_w = w // patch_size
    
    fig, axes = plt.subplots(num_patches_h, num_patches_w + 1, 
                             figsize=(num_patches_w + 2, num_patches_h + 1))
    
    # Show original image in first column
    for i in range(num_patches_h):
        axes[i, 0].imshow(img)
        axes[i, 0].axis('off')
        if i == 0:
            axes[i, 0].set_title('Original', fontsize=8)
    
    # Show patches
    patch_idx = 0
    for i in range(num_patches_h):
        for j in range(num_patches_w):
            patch = img[i*patch_size:(i+1)*patch_size, 
                       j*patch_size:(j+1)*patch_size]
            axes[i, j+1].imshow(patch)
            axes[i, j+1].axis('off')
            axes[i, j+1].set_title(f'P{patch_idx}', fontsize=6)
            patch_idx += 1
    
    plt.suptitle(f'Image split into {patch_size}×{patch_size} patches', fontsize=10)
    plt.tight_layout()
    plt.show()

# Get a sample image
sample_img, _ = test_dataset[0]
visualize_patches(sample_img, patch_size=4)

## 3. Transformer Components

### Multi-Head Self-Attention

$$\text{Attention}(Q, K, V) = \text{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right)V$$

In [None]:
class MultiHeadSelfAttention(nn.Module):
    def __init__(self, embed_dim: int, num_heads: int, dropout: float = 0.0):
        super().__init__()
        assert embed_dim % num_heads == 0, "embed_dim must be divisible by num_heads"
        
        self.embed_dim = embed_dim
        self.num_heads = num_heads
        self.head_dim = embed_dim // num_heads
        self.scale = self.head_dim ** -0.5
        
        # Combined QKV projection for efficiency
        self.qkv = nn.Linear(embed_dim, 3 * embed_dim, bias=True)
        self.proj = nn.Linear(embed_dim, embed_dim)
        self.attn_dropout = nn.Dropout(dropout)
        self.proj_dropout = nn.Dropout(dropout)
    
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        B, N, D = x.shape
        
        qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim)
        qkv = qkv.permute(2, 0, 3, 1, 4)  # (3, B, H, N, D/H)
        q, k, v = qkv[0], qkv[1], qkv[2]  # Each: (B, H, N, D/H)
        
        # Attention scores: (B, H, N, N)
        attn = (q @ k.transpose(-2, -1)) * self.scale
        attn = F.softmax(attn, dim=-1)
        attn = self.attn_dropout(attn)
        
        # Apply attention to values: (B, H, N, D/H)
        out = attn @ v
        
        # Concatenate heads: (B, N, D)
        out = out.transpose(1, 2).reshape(B, N, D)
        out = self.proj(out)
        out = self.proj_dropout(out)
        
        return out


class MLP(nn.Module):
    def __init__(self, embed_dim: int, mlp_ratio: float = 4.0, dropout: float = 0.0):
        super().__init__()
        hidden_dim = int(embed_dim * mlp_ratio)
        self.fc1 = nn.Linear(embed_dim, hidden_dim)
        self.act = nn.GELU()
        self.fc2 = nn.Linear(hidden_dim, embed_dim)
        self.dropout = nn.Dropout(dropout)
    
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self.fc1(x)
        x = self.act(x)
        x = self.dropout(x)
        x = self.fc2(x)
        x = self.dropout(x)
        return x


class TransformerBlock(nn.Module):
    def __init__(self, embed_dim: int, num_heads: int, mlp_ratio: float = 4.0, 
                 dropout: float = 0.0):
        super().__init__()
        self.norm1 = nn.LayerNorm(embed_dim)
        self.attn = MultiHeadSelfAttention(embed_dim, num_heads, dropout)
        self.norm2 = nn.LayerNorm(embed_dim)
        self.mlp = MLP(embed_dim, mlp_ratio, dropout)
    
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        # Pre-norm architecture (better for training stability)
        x = x + self.attn(self.norm1(x))
        x = x + self.mlp(self.norm2(x))
        return x


## 4. Vision Transformer (ViT)

Full ViT architecture:
1. Patch embedding
2. Prepend [CLS] token
3. Add positional embeddings
4. Transformer encoder blocks
5. Layer norm on [CLS] output
6. Classification head

In [None]:
class VisionTransformer(nn.Module):
    def __init__(self, img_size: int = 32, patch_size: int = 4, in_channels: int = 3,
                 num_classes: int = 10, embed_dim: int = 192, depth: int = 6,
                 num_heads: int = 6, mlp_ratio: float = 4.0, dropout: float = 0.1):
        super().__init__()
        
        self.patch_size = patch_size
        self.num_patches = (img_size // patch_size) ** 2
        
        # Patch embedding
        self.patch_embed = PatchEmbedding(img_size, patch_size, in_channels, embed_dim)
        
        # Class token: learnable embedding prepended to patch sequence
        self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
        
        # Positional embedding: learnable, added to patch + cls tokens
        # +1 for the cls token
        self.pos_embed = nn.Parameter(torch.zeros(1, self.num_patches + 1, embed_dim))
        
        self.pos_dropout = nn.Dropout(dropout)
        
        # Transformer encoder
        self.blocks = nn.ModuleList([
            TransformerBlock(embed_dim, num_heads, mlp_ratio, dropout)
            for _ in range(depth)
        ])
        
        # Final layer norm
        self.norm = nn.LayerNorm(embed_dim)
        
        # Classification head
        self.head = nn.Linear(embed_dim, num_classes)
        
        # Initialize weights
        self._init_weights()
    
    def _init_weights(self):
        # Initialize positional embedding with truncated normal
        nn.init.trunc_normal_(self.pos_embed, std=0.02)
        nn.init.trunc_normal_(self.cls_token, std=0.02)
        
        # Initialize other layers
        for m in self.modules():
            if isinstance(m, nn.Linear):
                nn.init.trunc_normal_(m.weight, std=0.02)
                if m.bias is not None:
                    nn.init.zeros_(m.bias)
            elif isinstance(m, nn.LayerNorm):
                nn.init.ones_(m.weight)
                nn.init.zeros_(m.bias)
    
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        B = x.shape[0]
        
        x = self.patch_embed(x)
        
        cls_tokens = self.cls_token.expand(B, -1, -1)  # (1, 1, D) -> (B, 1, D)
        x = torch.cat([cls_tokens, x], dim=1)
        
        # Add positional embedding
        x = x + self.pos_embed
        x = self.pos_dropout(x)
        
        # Transformer blocks
        for block in self.blocks:
            x = block(x)
        
        # Final norm
        x = self.norm(x)
        
        # Use [CLS] token for classification
        cls_output = x[:, 0]  # (B, embed_dim)
        logits = self.head(cls_output)  # (B, num_classes)
        
        return logits


def ViT_tiny(patch_size=4, **kwargs):
    return VisionTransformer(
        img_size=32, patch_size=patch_size, embed_dim=192,
        depth=6, num_heads=6, mlp_ratio=4.0, **kwargs
    )


# Create and inspect model
model = ViT_tiny(patch_size=4).to(device)
total_params = sum(p.numel() for p in model.parameters())
print(f"ViT-Tiny parameters: {total_params:,}")

# Test forward pass
x = torch.randn(2, 3, 32, 32).to(device)
y = model(x)
print(f"Input: {x.shape} -> Output: {y.shape}")

In [None]:
# Sanity check 1: Patch shapes
print("="*50)
print("SANITY CHECK 1: Patch Shapes")
print("="*50)

for patch_size in [4, 8]:
    model = ViT_tiny(patch_size=patch_size)
    num_patches = model.num_patches
    expected_patches = (32 // patch_size) ** 2
    
    print(f"\nPatch size: {patch_size}×{patch_size}")
    print(f"  Expected patches: {expected_patches}")
    print(f"  Actual patches: {num_patches}")
    print(f"  Pos embed shape: {model.pos_embed.shape}")
    print(f"  ✓ Correct!" if num_patches == expected_patches else "  ✗ Mismatch!")

In [None]:
# Sanity check 2: Attention weights sum to 1
print("\n" + "="*50)
print("SANITY CHECK 2: Attention Weights")
print("="*50)

model = ViT_tiny(patch_size=4).to(device)
model.eval()

# Hook to capture attention weights
attn_weights = []
def hook_fn(module, input, output):
    # Get attention scores before softmax
    B, N, D = input[0].shape
    H = module.num_heads
    head_dim = D // H
    
    qkv = module.qkv(input[0]).reshape(B, N, 3, H, head_dim)
    qkv = qkv.permute(2, 0, 3, 1, 4)
    q, k, v = qkv[0], qkv[1], qkv[2]
    
    attn = (q @ k.transpose(-2, -1)) * (head_dim ** -0.5)
    attn = F.softmax(attn, dim=-1)
    attn_weights.append(attn.detach().cpu())

# Register hook on first attention layer
hook = model.blocks[0].attn.register_forward_hook(hook_fn)

with torch.no_grad():
    x = torch.randn(1, 3, 32, 32).to(device)
    _ = model(x)

hook.remove()

# Check attention sums
attn = attn_weights[0]  # (B, H, N, N)
attn_sums = attn.sum(dim=-1)  # Sum over keys
print(f"Attention shape: {attn.shape}")
print(f"Attention sums (should be 1.0): min={attn_sums.min():.6f}, max={attn_sums.max():.6f}")
print(f"✓ Attention weights sum to 1!" if torch.allclose(attn_sums, torch.ones_like(attn_sums), atol=1e-5) else "✗ Problem!")

In [None]:
# Sanity check 3: Gradient flow
print("\n" + "="*50)
print("SANITY CHECK 3: Gradient Flow")
print("="*50)

model = ViT_tiny(patch_size=4).to(device)
x = torch.randn(2, 3, 32, 32, requires_grad=True).to(device)
y = model(x)
loss = y.sum()
loss.backward()

print(f"Input gradient norm: {x.grad.norm().item():.4f}")
print(f"Patch embed proj gradient: {model.patch_embed.proj.weight.grad.norm().item():.4f}")
print(f"Pos embed gradient: {model.pos_embed.grad.norm().item():.4f}")
print(f"CLS token gradient: {model.cls_token.grad.norm().item():.4f}")
print(f"Block 0 attn qkv gradient: {model.blocks[0].attn.qkv.weight.grad.norm().item():.4f}")
print(f"Block 5 mlp fc2 gradient: {model.blocks[5].mlp.fc2.weight.grad.norm().item():.4f}")
print(f"Head gradient: {model.head.weight.grad.norm().item():.4f}")
print("✓ Gradients flowing through all layers!")

In [None]:
def train_epoch(model, train_loader, criterion, optimizer, device):
    model.train()
    running_loss = 0.0
    correct = 0
    total = 0
    
    for inputs, targets in train_loader:
        inputs, targets = inputs.to(device), targets.to(device)
        
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, targets)
        loss.backward()
        
        # Gradient clipping (helps with transformer training)
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        
        optimizer.step()
        
        running_loss += loss.item() * inputs.size(0)
        _, predicted = outputs.max(1)
        total += targets.size(0)
        correct += predicted.eq(targets).sum().item()
    
    return running_loss / total, 100. * correct / total


def evaluate(model, test_loader, criterion, device):
    model.eval()
    running_loss = 0.0
    correct = 0
    total = 0
    
    with torch.no_grad():
        for inputs, targets in test_loader:
            inputs, targets = inputs.to(device), targets.to(device)
            outputs = model(inputs)
            loss = criterion(outputs, targets)
            
            running_loss += loss.item() * inputs.size(0)
            _, predicted = outputs.max(1)
            total += targets.size(0)
            correct += predicted.eq(targets).sum().item()
    
    return running_loss / total, 100. * correct / total


def train_vit(model, train_loader, test_loader, epochs, lr, device, verbose=True):
    criterion = nn.CrossEntropyLoss()
    # AdamW is standard for transformers
    optimizer = optim.AdamW(model.parameters(), lr=lr, weight_decay=0.05, betas=(0.9, 0.999))
    # Warmup + cosine decay
    warmup_epochs = 10
    
    def lr_lambda(epoch):
        if epoch < warmup_epochs:
            return epoch / warmup_epochs
        else:
            progress = (epoch - warmup_epochs) / (epochs - warmup_epochs)
            return 0.5 * (1 + math.cos(math.pi * progress))
    
    scheduler = optim.lr_scheduler.LambdaLR(optimizer, lr_lambda)
    
    history = {'train_loss': [], 'train_acc': [], 'test_loss': [], 'test_acc': [], 'lr': []}
    best_acc = 0.0
    best_state = None
    
    start_time = time.time()
    
    for epoch in range(epochs):
        train_loss, train_acc = train_epoch(model, train_loader, criterion, optimizer, device)
        test_loss, test_acc = evaluate(model, test_loader, criterion, device)
        current_lr = optimizer.param_groups[0]['lr']
        scheduler.step()
        
        history['train_loss'].append(train_loss)
        history['train_acc'].append(train_acc)
        history['test_loss'].append(test_loss)
        history['test_acc'].append(test_acc)
        history['lr'].append(current_lr)
        
        if test_acc > best_acc:
            best_acc = test_acc
            best_state = copy.deepcopy(model.state_dict())
        
        if verbose and (epoch + 1) % 10 == 0:
            print(f"Epoch {epoch+1:3d}/{epochs} | "
                  f"Train: {train_acc:.2f}% | Test: {test_acc:.2f}% | LR: {current_lr:.6f}")
    
    total_time = time.time() - start_time
    if verbose:
        print(f"\nTraining complete in {total_time:.1f}s. Best accuracy: {best_acc:.2f}%")
    
    return history, best_state, best_acc, total_time


## 7. Train ViT

In [None]:
NUM_EPOCHS = 100
LEARNING_RATE = 1e-3  # Lower LR for transformers

print(f"Training ViT-Tiny (patch_size=4) for {NUM_EPOCHS} epochs...")
print("="*70)

vit_model = ViT_tiny(patch_size=4, dropout=0.1).to(device)
vit_history, vit_best_state, vit_best_acc, vit_time = train_vit(
    vit_model, train_loader, test_loader,
    epochs=NUM_EPOCHS, lr=LEARNING_RATE, device=device
)

In [None]:
# Plot training curves
fig, axes = plt.subplots(1, 3, figsize=(15, 4))

# Loss
axes[0].plot(vit_history['train_loss'], label='Train')
axes[0].plot(vit_history['test_loss'], label='Test')
axes[0].set_xlabel('Epoch')
axes[0].set_ylabel('Loss')
axes[0].set_title('Loss Curves')
axes[0].legend()
axes[0].grid(True, alpha=0.3)

# Accuracy
axes[1].plot(vit_history['train_acc'], label='Train')
axes[1].plot(vit_history['test_acc'], label='Test')
axes[1].set_xlabel('Epoch')
axes[1].set_ylabel('Accuracy (%)')
axes[1].set_title('Accuracy Curves')
axes[1].legend()
axes[1].grid(True, alpha=0.3)

# Learning rate
axes[2].plot(vit_history['lr'])
axes[2].set_xlabel('Epoch')
axes[2].set_ylabel('Learning Rate')
axes[2].set_title('LR Schedule (Warmup + Cosine)')
axes[2].grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

## 8. Ablation: Patch Size 4 vs 8

In [None]:
def run_patch_size_ablation(patch_size, epochs=100):
    model = ViT_tiny(patch_size=patch_size, dropout=0.1).to(device)
    num_params = sum(p.numel() for p in model.parameters())
    num_patches = model.num_patches
    
    history, best_state, best_acc, train_time = train_vit(
        model, train_loader, test_loader,
        epochs=epochs, lr=1e-3, device=device, verbose=False
    )
    
    return {
        'patch_size': patch_size,
        'num_patches': num_patches,
        'num_params': num_params,
        'best_acc': best_acc,
        'train_time': train_time,
        'history': history
    }

print("Running ablation study: Patch Size 4 vs 8")
print("="*60)

# Patch size 4
print("\nTraining with patch_size=4...")
results_p4 = run_patch_size_ablation(patch_size=4, epochs=100)
print(f"  Patches: {results_p4['num_patches']}, Best accuracy: {results_p4['best_acc']:.2f}%")

# Patch size 8
print("\nTraining with patch_size=8...")
results_p8 = run_patch_size_ablation(patch_size=8, epochs=100)
print(f"  Patches: {results_p8['num_patches']}, Best accuracy: {results_p8['best_acc']:.2f}%")

In [None]:
# Ablation results
print("\n" + "="*70)
print("ABLATION STUDY: Patch Size Comparison")
print("="*70)
print(f"{'Configuration':<20} {'Patches':<12} {'Params':<15} {'Best Acc':<15} {'Time (s)':<12}")
print("-"*70)
print(f"{'Patch Size = 4':<20} {results_p4['num_patches']:<12} {results_p4['num_params']:>12,} {results_p4['best_acc']:>10.2f}% {results_p4['train_time']:>10.1f}")
print(f"{'Patch Size = 8':<20} {results_p8['num_patches']:<12} {results_p8['num_params']:>12,} {results_p8['best_acc']:>10.2f}% {results_p8['train_time']:>10.1f}")
print("="*70)

diff = results_p4['best_acc'] - results_p8['best_acc']
speedup = results_p4['train_time'] / results_p8['train_time']
print(f"\nPatch size 4 vs 8:")
print(f"  Accuracy difference: {diff:+.2f}%")
print(f"  Patch 8 speedup: {speedup:.2f}x faster")
print(f"  Patch 8 has {(results_p4['num_patches'] / results_p8['num_patches']):.1f}x fewer patches (shorter sequence)")

In [None]:
# Plot ablation
fig, axes = plt.subplots(1, 2, figsize=(12, 4))

# Test accuracy
axes[0].plot(results_p4['history']['test_acc'], label='Patch=4 (64 patches)', alpha=0.8)
axes[0].plot(results_p8['history']['test_acc'], label='Patch=8 (16 patches)', alpha=0.8)
axes[0].set_xlabel('Epoch')
axes[0].set_ylabel('Test Accuracy (%)')
axes[0].set_title('Test Accuracy: Patch Size Comparison')
axes[0].legend()
axes[0].grid(True, alpha=0.3)

# Bar comparison
configs = ['Patch=4\n(64 patches)', 'Patch=8\n(16 patches)']
accs = [results_p4['best_acc'], results_p8['best_acc']]
times = [results_p4['train_time'], results_p8['train_time']]

x = np.arange(len(configs))
width = 0.35

ax2 = axes[1]
bars1 = ax2.bar(x - width/2, accs, width, label='Accuracy (%)', color='#4ecdc4')
ax2.set_ylabel('Accuracy (%)', color='#4ecdc4')
ax2.tick_params(axis='y', labelcolor='#4ecdc4')

ax3 = ax2.twinx()
bars2 = ax3.bar(x + width/2, times, width, label='Time (s)', color='#ff6b6b')
ax3.set_ylabel('Training Time (s)', color='#ff6b6b')
ax3.tick_params(axis='y', labelcolor='#ff6b6b')

ax2.set_xticks(x)
ax2.set_xticklabels(configs)
ax2.set_title('Accuracy vs Training Time')

# Add value labels
for bar, val in zip(bars1, accs):
    ax2.text(bar.get_x() + bar.get_width()/2, bar.get_height() + 1,
             f'{val:.1f}%', ha='center', fontsize=9, color='#4ecdc4')
for bar, val in zip(bars2, times):
    ax3.text(bar.get_x() + bar.get_width()/2, bar.get_height() + 5,
             f'{val:.0f}s', ha='center', fontsize=9, color='#ff6b6b')

plt.tight_layout()
plt.show()