# Vision Transformer (ViT) Implementation Demo
## COMP3314 Group 16 - Reproducing "An Image is Worth 16x16 Words"

This notebook demonstrates:
1. **Training Visualization** - Analyze CIFAR-100 training progress
2. **Attention Visualization** - Understand what the model learns
3. **Transfer Learning** - Fine-tune CIFAR-100 model on CIFAR-10

---


## üì¶ Part 1: Setup & Imports


In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
import torchvision.transforms as transforms
import torch.nn.functional as F

import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
from PIL import Image
from tqdm import tqdm
import os
import json
import pickle
from datetime import datetime

# Import project modules
import sys
sys.path.append('./src')
from model import VisionTransformer, vit_small_patch16_224
from train import CIFAR100Dataset, get_transforms, TrainingHistory
from utils import load_model, CIFAR100_CLASSES

# Set device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")


---
## üìä Part 2: Training Visualization (CIFAR-100)

Analyze the training progress of our ViT-Small model on CIFAR-100


### 2.1 Define Visualizer Class


In [None]:
class ViTVisualizer:
    """Visualizer for Vision Transformer training progress"""
    
    def __init__(self, checkpoint_dir='./src/checkpoints'):
        self.checkpoint_dir = checkpoint_dir
        self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        self.history = None
    
    def load_history_from_json(self, json_path=None):
        """Load training history from JSON file"""
        if json_path is None:
            json_path = os.path.join(self.checkpoint_dir, 'training_history.json')
        
        print(f"Loading history from: {json_path}")
        history_obj = TrainingHistory()
        history_obj.load(json_path)
        
        self.history = {
            'epochs': history_obj.epochs,
            'train_losses': history_obj.train_losses,
            'train_accs': history_obj.train_accs,
            'val_losses': history_obj.val_losses,
            'val_accs': history_obj.val_accs,
            'learning_rates': history_obj.learning_rates
        }
        return self.history
    
    def plot_training_progress(self, save_path=None):
        """Plot training progress with 4 subplots"""
        if not self.history or not self.history.get('epochs'):
            print("No training history available")
            return
        
        fig, ((ax1, ax2), (ax3, ax4)) = plt.subplots(2, 2, figsize=(15, 10))
        epochs = self.history['epochs']
        
        # 1. Loss curves
        ax1.plot(epochs, self.history['train_losses'], 'b-', label='Train Loss', linewidth=2)
        ax1.plot(epochs, self.history['val_losses'], 'r-', label='Val Loss', linewidth=2)
        ax1.set_xlabel('Epoch', fontsize=12)
        ax1.set_ylabel('Loss', fontsize=12)
        ax1.set_title('Training and Validation Loss', fontsize=14, fontweight='bold')
        ax1.legend()
        ax1.grid(True, alpha=0.3)
        
        # 2. Accuracy curves
        ax2.plot(epochs, self.history['train_accs'], 'b-', label='Train Acc', linewidth=2)
        ax2.plot(epochs, self.history['val_accs'], 'r-', label='Val Acc', linewidth=2)
        ax2.set_xlabel('Epoch', fontsize=12)
        ax2.set_ylabel('Accuracy (%)', fontsize=12)
        ax2.set_title('Training and Validation Accuracy', fontsize=14, fontweight='bold')
        ax2.legend()
        ax2.grid(True, alpha=0.3)
        
        # 3. Learning rate schedule
        ax3.plot(epochs, self.history['learning_rates'], 'g-', linewidth=2)
        ax3.set_xlabel('Epoch', fontsize=12)
        ax3.set_ylabel('Learning Rate', fontsize=12)
        ax3.set_title('Learning Rate Schedule', fontsize=14, fontweight='bold')
        ax3.set_yscale('log')
        ax3.grid(True, alpha=0.3)
        
        # 4. Overfitting analysis
        overfit_gap = [train - val for train, val in 
                      zip(self.history['train_accs'], self.history['val_accs'])]
        ax4.plot(epochs, overfit_gap, 'orange', linewidth=2)
        ax4.set_xlabel('Epoch', fontsize=12)
        ax4.set_ylabel('Train-Val Gap (%)', fontsize=12)
        ax4.set_title('Overfitting Analysis', fontsize=14, fontweight='bold')
        ax4.grid(True, alpha=0.3)
        ax4.axhline(y=np.mean(overfit_gap), color='red', linestyle='--', 
                   label=f'Avg: {np.mean(overfit_gap):.2f}%')
        ax4.legend()
        
        plt.tight_layout()
        if save_path:
            plt.savefig(save_path, dpi=300, bbox_inches='tight')
            print(f"Saved: {save_path}")
        plt.show()
        return fig
    
    def print_summary(self):
        """Print training summary"""
        if not self.history or not self.history.get('epochs'):
            return
        
        train_accs = self.history['train_accs']
        val_accs = self.history['val_accs']
        
        print("="*60)
        print("CIFAR-100 TRAINING SUMMARY")
        print("="*60)
        print(f"Final Train Accuracy:  {train_accs[-1]:.2f}%")
        print(f"Final Val Accuracy:    {val_accs[-1]:.2f}%")
        print(f"Best Val Accuracy:     {max(val_accs):.2f}% (Epoch {val_accs.index(max(val_accs))+1})")
        print(f"Overfitting Gap:       {train_accs[-1] - val_accs[-1]:.2f}%")
        print("="*60)

print("ViTVisualizer class defined successfully!")


### 2.2 Load and Visualize CIFAR-100 Training History


In [None]:
# Create visualizer and load history
visualizer = ViTVisualizer(checkpoint_dir='./src/checkpoints')
visualizer.load_history_from_json('./src/checkpoints/training_history.json')

# Print summary
visualizer.print_summary()

# Plot training progress
visualizer.plot_training_progress(save_path='cifar100_training_progress.png')


---
## üëÅÔ∏è Part 3: Attention Visualization

Understand what regions the Vision Transformer focuses on when making predictions


### 3.1 Attention Extraction Function


In [None]:
def get_attention_maps(model, image, device='cpu'):
    """Extract attention maps from all transformer blocks"""
    model.eval()
    attention_maps = []
    
    with torch.no_grad():
        x = image.to(device)
        B = x.shape[0]
        
        # Patch embedding
        x = model.patch_embed(x)
        cls_tokens = model.cls_token.expand(B, -1, -1)
        x = torch.cat([cls_tokens, x], dim=1)
        x = x + model.pos_embed
        x = model.pos_dropout(x)
        
        # Extract attention from each block
        for block in model.blocks:
            x_norm = block.norm1(x)
            B, N, C = x_norm.shape
            attn_module = block.attn
            
            # Compute attention weights manually
            qkv = attn_module.qkv(x_norm).reshape(B, N, 3, attn_module.num_heads, attn_module.head_dim)
            qkv = qkv.permute(2, 0, 3, 1, 4)
            q, k, v = qkv[0], qkv[1], qkv[2]
            
            attn = (q @ k.transpose(-2, -1)) * attn_module.scale
            attn = attn.softmax(dim=-1)
            attention_maps.append(attn.detach().cpu())
            
            # Complete block forward
            attn_dropped = attn_module.attn_dropout(attn)
            x_attn = (attn_dropped @ v).transpose(1, 2).reshape(B, N, C)
            x_attn = attn_module.proj(x_attn)
            x_attn = attn_module.proj_dropout(x_attn)
            x = x + x_attn
            x = x + block.mlp(block.norm2(x))
    
    return attention_maps

print("Attention extraction function defined!")


### 3.2 Visualize Layer-wise Attention


In [None]:
def visualize_attention(model, image_tensor, label, img_size=224, patch_size=16, save_path=None):
    """Visualize attention maps across layers"""
    device = next(model.parameters()).device
    image_batch = image_tensor.unsqueeze(0)
    
    # Get prediction
    with torch.no_grad():
        output = model(image_batch.to(device))
        probs = F.softmax(output, dim=1)
        pred_idx = torch.argmax(probs, dim=1).item()
        confidence = probs[0, pred_idx].item()
    
    # Denormalize image
    img_display = image_tensor.cpu().numpy().transpose(1, 2, 0)
    mean = np.array([0.5070751592371323, 0.48654887331495095, 0.4409178433670343])
    std = np.array([0.2673342858792401, 0.2564384629170883, 0.27615047132568404])
    img_display = img_display * std + mean
    img_display = np.clip(img_display, 0, 1)
    
    # Get attention maps
    attention_maps = get_attention_maps(model, image_batch, device)
    num_patches = (img_size // patch_size) ** 2
    selected_layers = [0, len(attention_maps)//2, len(attention_maps)-1]
    
    fig, axes = plt.subplots(2, len(selected_layers) + 1, figsize=(15, 8))
    
    # Show original image
    axes[0, 0].imshow(img_display)
    axes[0, 0].set_title(f'Original\\nTrue: {CIFAR100_CLASSES[label]}', fontsize=10)
    axes[0, 0].axis('off')
    
    axes[1, 0].imshow(img_display)
    pred_color = 'green' if pred_idx == label else 'red'
    axes[1, 0].set_title(f'Prediction\\n{CIFAR100_CLASSES[pred_idx]} ({confidence*100:.1f}%)', 
                         fontsize=10, color=pred_color)
    axes[1, 0].axis('off')
    
    # Visualize selected layers
    for idx, layer_idx in enumerate(selected_layers, start=1):
        attn = attention_maps[layer_idx][0]
        attn_mean = attn.mean(dim=0)
        cls_attn = attn_mean[0, 1:]
        
        grid_size = int(np.sqrt(num_patches))
        cls_attn_map = cls_attn.reshape(grid_size, grid_size).numpy()
        cls_attn_upsampled = np.kron(cls_attn_map, np.ones((patch_size, patch_size)))
        
        # Heatmap
        im1 = axes[0, idx].imshow(cls_attn_upsampled, cmap='hot', interpolation='bilinear')
        axes[0, idx].set_title(f'Layer {layer_idx+1}\\nHeatmap', fontsize=10)
        axes[0, idx].axis('off')
        plt.colorbar(im1, ax=axes[0, idx], fraction=0.046)
        
        # Overlay
        axes[1, idx].imshow(img_display)
        axes[1, idx].imshow(cls_attn_upsampled, cmap='hot', alpha=0.6, interpolation='bilinear')
        axes[1, idx].set_title(f'Layer {layer_idx+1}\\nOverlay', fontsize=10)
        axes[1, idx].axis('off')
    
    plt.suptitle('ViT Attention Visualization - Different Layers Focus on Different Regions', 
                 fontsize=14, y=0.98)
    plt.tight_layout()
    
    if save_path:
        plt.savefig(save_path, dpi=150, bbox_inches='tight')
        print(f"Saved: {save_path}")
    plt.show()

print("Layer-wise attention visualization function defined!")


### 3.3 Run Attention Visualization on Sample Images


In [None]:
# Load trained model
model, checkpoint = load_model('./src/checkpoints/best_model.pth', device=device)
print(f"Loaded CIFAR-100 model - Best Acc: {checkpoint['best_acc']:.2f}%")

# Load test dataset
test_dataset = CIFAR100Dataset(
    './cifar-100-python/cifar-100-python',
    train=False,
    transform=get_transforms(224, train=False)
)

# Visualize attention for 2 random samples
for i in range(2):
    idx = np.random.randint(len(test_dataset))
    image, label = test_dataset[idx]
    print(f"\nExample {i+1}: {CIFAR100_CLASSES[label]}")
    visualize_attention(model, image, label, save_path=f'attention_example_{i+1}.png')


---
## üîÑ Part 4: Transfer Learning - Fine-tuning on CIFAR-10

Transfer the CIFAR-100 trained model to CIFAR-10 using Option A: Unfreeze Last 4 Transformer Blocks


### 4.1 CIFAR-10 Dataset and Helper Functions


In [None]:
class CIFAR10Dataset(torch.utils.data.Dataset):
    """CIFAR-10 Dataset"""
    def __init__(self, data_dir, train=True, transform=None):
        self.transform = transform
        self.data = []
        self.labels = []
        
        if train:
            for i in range(1, 6):
                with open(os.path.join(data_dir, f'data_batch_{i}'), 'rb') as f:
                    batch = pickle.load(f, encoding='bytes')
                    self.data.append(batch[b'data'])
                    self.labels.extend(batch[b'labels'])
            self.data = np.concatenate(self.data)
        else:
            with open(os.path.join(data_dir, 'test_batch'), 'rb') as f:
                batch = pickle.load(f, encoding='bytes')
                self.data = batch[b'data']
                self.labels = batch[b'labels']
        
        self.data = self.data.reshape(-1, 3, 32, 32)
        print(f"Loaded {'train' if train else 'test'} data: {len(self.data)} images")
    
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        image = torch.from_numpy(self.data[idx]).float()
        label = self.labels[idx]
        if self.transform:
            image = self.transform(image)
        return image, label


def get_cifar10_transforms(img_size=224, train=True):
    """CRITICAL: Use CIFAR-100 normalization (backbone expects this!)"""
    mean = [0.5070751592371323, 0.48654887331495095, 0.4409178433670343]
    std = [0.2673342858792401, 0.2564384629170883, 0.27615047132568404]
    
    if train:
        return transforms.Compose([
            transforms.RandomCrop(32, padding=4),
            transforms.RandomHorizontalFlip(p=0.5),
            transforms.ConvertImageDtype(torch.float32),
            transforms.Resize((img_size, img_size)),
            transforms.Normalize(mean=mean, std=std),
        ])
    else:
        return transforms.Compose([
            transforms.ConvertImageDtype(torch.float32),
            transforms.Resize((img_size, img_size)),
            transforms.Normalize(mean=mean, std=std),
        ])


def load_cifar100_backbone(checkpoint_path, device):
    """Load CIFAR-100 model and adapt for CIFAR-10"""
    checkpoint = torch.load(checkpoint_path, map_location=device, weights_only=False)
    
    model_c100 = vit_small_patch16_224(num_classes=100)
    model_c100.load_state_dict(checkpoint['model_state_dict'])
    
    model_c10 = vit_small_patch16_224(num_classes=10)
    
    # Copy all weights except classification head
    with torch.no_grad():
        model_c10.patch_embed.load_state_dict(model_c100.patch_embed.state_dict())
        model_c10.cls_token.copy_(model_c100.cls_token)
        model_c10.pos_embed.copy_(model_c100.pos_embed)
        for i in range(len(model_c10.blocks)):
            model_c10.blocks[i].load_state_dict(model_c100.blocks[i].state_dict())
        model_c10.norm.load_state_dict(model_c100.norm.state_dict())
    
    print("[OK] Loaded CIFAR-100 backbone, initialized 10-class head")
    return model_c10.to(device)


def freeze_layers(model, unfreeze_last_n_blocks=4):
    """Freeze all except last N blocks and head"""
    # Freeze everything
    for param in model.parameters():
        param.requires_grad = False
    
    # Unfreeze last N blocks
    for i in range(12 - unfreeze_last_n_blocks, 12):
        for param in model.blocks[i].parameters():
            param.requires_grad = True
    
    # Unfreeze norm and head
    for param in model.norm.parameters():
        param.requires_grad = True
    for param in model.head.parameters():
        param.requires_grad = True
    
    trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
    total = sum(p.numel() for p in model.parameters())
    print(f"Trainable: {trainable:,} / {total:,} ({100*trainable/total:.2f}%)")
    return model

print("Transfer learning functions defined!")


### 4.2 Load Model and Data (Run this cell to start fine-tuning)


In [None]:
# Load CIFAR-10 data
train_dataset = CIFAR10Dataset('./cifar-10-python/cifar-10-batches-py', 
                               train=True, transform=get_cifar10_transforms(224, True))
test_dataset = CIFAR10Dataset('./cifar-10-python/cifar-10-batches-py', 
                              train=False, transform=get_cifar10_transforms(224, False))

train_loader = DataLoader(train_dataset, batch_size=128, shuffle=True, num_workers=2, pin_memory=True)
test_loader = DataLoader(test_dataset, batch_size=128, shuffle=False, num_workers=2, pin_memory=True)

# Load pretrained CIFAR-100 backbone
ft_model = load_cifar100_backbone('./src/checkpoints/best_model.pth', device)
ft_model = freeze_layers(ft_model, unfreeze_last_n_blocks=4)

print("\nModel ready for fine-tuning!")


### 4.3 Fine-tune (Optional - Skip if model already trained)


In [None]:
# Uncomment to train (takes ~2 hours on GPU)
# criterion = nn.CrossEntropyLoss()
# optimizer = optim.AdamW(filter(lambda p: p.requires_grad, ft_model.parameters()), 
#                         lr=1e-4, weight_decay=0.01)
# scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=50, eta_min=1e-6)
# 
# best_acc = 0.0
# for epoch in range(1, 51):
#     # Train
#     ft_model.train()
#     for images, labels in tqdm(train_loader, desc=f"Epoch {epoch}"):
#         images, labels = images.to(device), labels.to(device)
#         optimizer.zero_grad()
#         outputs = ft_model(images)
#         loss = criterion(outputs, labels)
#         loss.backward()
#         optimizer.step()
#     
#     # Validate
#     ft_model.eval()
#     correct, total = 0, 0
#     with torch.no_grad():
#         for images, labels in test_loader:
#             images, labels = images.to(device), labels.to(device)
#             outputs = ft_model(images)
#             _, predicted = outputs.max(1)
#             total += labels.size(0)
#             correct += predicted.eq(labels).sum().item()
#     
#     val_acc = 100. * correct / total
#     print(f"Epoch {epoch}: Val Acc = {val_acc:.2f}%")
#     
#     if val_acc > best_acc:
#         best_acc = val_acc
#         torch.save({'model_state_dict': ft_model.state_dict(), 'best_acc': best_acc},
#                   './src/cifar10_finetuned_optionA/best_model.pth')
#     
#     scheduler.step()

print("To train: uncomment the code above and run this cell")


---
## üìä Part 5: Results Summary

### Key Results from Our Implementation:

| Model | Dataset | Method | Accuracy | Notes |
|-------|---------|--------|----------|-------|
| ViT-Small | CIFAR-100 | From Scratch | **66.65%** | 300 epochs, cosine LR |
| ViT-Small | CIFAR-10 | Transfer (Frozen) | 12.21% | Only head unfrozen ‚ùå |
| ViT-Small | CIFAR-10 | Transfer (Option A) | **63.66%** | Last 4 blocks unfrozen ‚úì |

### Key Insights:

1. **Training from Scratch**
   - Successfully trained ViT-Small on CIFAR-100
   - Achieved competitive 66.65% accuracy
   - Used warmup + cosine LR schedule, RandAugment, gradient clipping

2. **Attention Visualization**
   - Early layers: Broad, distributed attention
   - Middle layers: Feature aggregation  
   - Late layers: Focused on discriminative regions

3. **Transfer Learning**
   - Freezing only head: **Failed** (12.21%)
   - Unfreezing last 4 blocks: **Success** (63.66%)
   - Top-5 accuracy: 96.75%
   - **Lesson**: Strategic layer unfreezing is crucial!

4. **Implementation Details**
   - Must use CIFAR-100 normalization for transfer learning
   - Lower learning rate for fine-tuning (1e-4 vs 1e-3)
   - Fewer epochs needed (50 vs 300)

---

### üìö References:
- **Paper**: [An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale](https://arxiv.org/abs/2010.11929)
- **Dataset**: [CIFAR-10/100](https://www.cs.toronto.edu/~kriz/cifar.html)

---

**COMP3314 2025-2026 Group 16**

*This notebook demonstrates the full pipeline: training, visualization, and transfer learning with Vision Transformers*
