# Q1: Vision Transformer on CIFAR-10

Implementation of Vision Transformer (ViT) for CIFAR-10 classification based on "An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale" (Dosovitskiy et al., ICLR 2021).

## Setup and Dependencies

In [None]:
!pip install torch torchvision matplotlib numpy tqdm
!pip install timm  # For comparison and some utilities

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import DataLoader
import torchvision
import torchvision.transforms as transforms
import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm
import math
import random

# Set random seeds for reproducibility
torch.manual_seed(42)
np.random.seed(42)
random.seed(42)
if torch.cuda.is_available():
    torch.cuda.manual_seed(42)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'Using device: {device}')

## Data Loading and Preprocessing

In [None]:
# Data augmentation and normalization
train_transform = transforms.Compose([
    transforms.Resize((224, 224)),  # Resize to standard ViT input size
    transforms.RandomHorizontalFlip(p=0.5),
    transforms.RandomRotation(15),
    transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    transforms.RandomErasing(p=0.1)  # Additional augmentation
])

test_transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

# Load CIFAR-10 dataset
train_dataset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=train_transform)
test_dataset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=test_transform)

# Create data loaders
batch_size = 64
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=2)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=2)

print(f'Training samples: {len(train_dataset)}')
print(f'Test samples: {len(test_dataset)}')
print(f'Number of classes: {len(train_dataset.classes)}')
print(f'Classes: {train_dataset.classes}')

## Vision Transformer Implementation

In [None]:
class PatchEmbedding(nn.Module):
    """Split image into patches and embed them."""
    def __init__(self, img_size=224, patch_size=16, in_channels=3, embed_dim=768):
        super().__init__()
        self.img_size = img_size
        self.patch_size = patch_size
        self.n_patches = (img_size // patch_size) ** 2
        
        # Use conv2d to extract patches and embed them
        self.projection = nn.Conv2d(in_channels, embed_dim, kernel_size=patch_size, stride=patch_size)
        
    def forward(self, x):
        # x: (batch_size, channels, height, width)
        x = self.projection(x)  # (batch_size, embed_dim, n_patches_h, n_patches_w)
        x = x.flatten(2).transpose(1, 2)  # (batch_size, n_patches, embed_dim)
        return x

In [None]:
class MultiHeadSelfAttention(nn.Module):
    """Multi-Head Self-Attention mechanism."""
    def __init__(self, embed_dim=768, num_heads=12, dropout=0.1):
        super().__init__()
        self.embed_dim = embed_dim
        self.num_heads = num_heads
        self.head_dim = embed_dim // num_heads
        
        assert embed_dim % num_heads == 0, "embed_dim must be divisible by num_heads"
        
        self.qkv = nn.Linear(embed_dim, embed_dim * 3)
        self.projection = nn.Linear(embed_dim, embed_dim)
        self.dropout = nn.Dropout(dropout)
        
    def forward(self, x):
        batch_size, seq_len, embed_dim = x.shape
        
        # Generate Q, K, V
        qkv = self.qkv(x)  # (batch_size, seq_len, embed_dim * 3)
        qkv = qkv.reshape(batch_size, seq_len, 3, self.num_heads, self.head_dim)
        qkv = qkv.permute(2, 0, 3, 1, 4)  # (3, batch_size, num_heads, seq_len, head_dim)
        q, k, v = qkv[0], qkv[1], qkv[2]
        
        # Attention scores
        scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.head_dim)
        attention_weights = F.softmax(scores, dim=-1)
        attention_weights = self.dropout(attention_weights)
        
        # Apply attention to values
        out = torch.matmul(attention_weights, v)
        out = out.transpose(1, 2).reshape(batch_size, seq_len, embed_dim)
        
        # Final projection
        out = self.projection(out)
        return out

In [None]:
class MLP(nn.Module):
    """Multi-Layer Perceptron block."""
    def __init__(self, embed_dim=768, mlp_dim=3072, dropout=0.1):
        super().__init__()
        self.fc1 = nn.Linear(embed_dim, mlp_dim)
        self.fc2 = nn.Linear(mlp_dim, embed_dim)
        self.dropout = nn.Dropout(dropout)
        
    def forward(self, x):
        x = F.gelu(self.fc1(x))
        x = self.dropout(x)
        x = self.fc2(x)
        x = self.dropout(x)
        return x

In [None]:
class TransformerBlock(nn.Module):
    """Transformer encoder block with MHSA + MLP + residual connections + layer norm."""
    def __init__(self, embed_dim=768, num_heads=12, mlp_dim=3072, dropout=0.1):
        super().__init__()
        self.norm1 = nn.LayerNorm(embed_dim)
        self.attention = MultiHeadSelfAttention(embed_dim, num_heads, dropout)
        self.norm2 = nn.LayerNorm(embed_dim)
        self.mlp = MLP(embed_dim, mlp_dim, dropout)
        
    def forward(self, x):
        # Pre-norm architecture
        x = x + self.attention(self.norm1(x))
        x = x + self.mlp(self.norm2(x))
        return x

In [None]:
class VisionTransformer(nn.Module):
    """Vision Transformer implementation."""
    def __init__(self, img_size=224, patch_size=16, in_channels=3, num_classes=10, 
                 embed_dim=768, num_layers=12, num_heads=12, mlp_dim=3072, dropout=0.1):
        super().__init__()
        
        # Patch embedding
        self.patch_embedding = PatchEmbedding(img_size, patch_size, in_channels, embed_dim)
        
        # CLS token and positional embedding
        self.cls_token = nn.Parameter(torch.randn(1, 1, embed_dim))
        self.pos_embedding = nn.Parameter(torch.randn(1, self.patch_embedding.n_patches + 1, embed_dim))
        self.dropout = nn.Dropout(dropout)
        
        # Transformer blocks
        self.blocks = nn.ModuleList([
            TransformerBlock(embed_dim, num_heads, mlp_dim, dropout)
            for _ in range(num_layers)
        ])
        
        # Classification head
        self.norm = nn.LayerNorm(embed_dim)
        self.head = nn.Linear(embed_dim, num_classes)
        
        # Initialize weights
        self.apply(self._init_weights)
        
    def _init_weights(self, m):
        if isinstance(m, nn.Linear):
            torch.nn.init.xavier_uniform_(m.weight)
            if m.bias is not None:
                nn.init.normal_(m.bias, std=1e-6)
        elif isinstance(m, nn.Conv2d):
            nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
        elif isinstance(m, nn.LayerNorm):
            nn.init.constant_(m.bias, 0)
            nn.init.constant_(m.weight, 1.0)
            
    def forward(self, x):
        batch_size = x.shape[0]
        
        # Patch embedding
        x = self.patch_embedding(x)  # (batch_size, n_patches, embed_dim)
        
        # Add CLS token
        cls_tokens = self.cls_token.expand(batch_size, -1, -1)
        x = torch.cat([cls_tokens, x], dim=1)  # (batch_size, n_patches + 1, embed_dim)
        
        # Add positional embedding
        x = x + self.pos_embedding
        x = self.dropout(x)
        
        # Pass through transformer blocks
        for block in self.blocks:
            x = block(x)
            
        # Classification from CLS token
        x = self.norm(x)
        cls_token_final = x[:, 0]  # Take CLS token
        
        return self.head(cls_token_final)

## Model Configuration and Training Setup

In [None]:
# Model configuration - optimized for CIFAR-10
config = {
    'img_size': 224,
    'patch_size': 16,
    'in_channels': 3,
    'num_classes': 10,
    'embed_dim': 384,  # Smaller than standard ViT for efficiency
    'num_layers': 6,   # Fewer layers for CIFAR-10
    'num_heads': 6,
    'mlp_dim': 1536,   # 4 * embed_dim
    'dropout': 0.1
}

# Create model
model = VisionTransformer(**config).to(device)

# Count parameters
total_params = sum(p.numel() for p in model.parameters())
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f'Total parameters: {total_params:,}')
print(f'Trainable parameters: {trainable_params:,}')

# Training configuration
num_epochs = 100
learning_rate = 3e-4
weight_decay = 0.1

# Optimizer and scheduler
optimizer = optim.AdamW(model.parameters(), lr=learning_rate, weight_decay=weight_decay)
scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=num_epochs, eta_min=1e-6)

# Loss function
criterion = nn.CrossEntropyLoss(label_smoothing=0.1)

print(f'\nTraining configuration:')
print(f'Epochs: {num_epochs}')
print(f'Learning rate: {learning_rate}')
print(f'Weight decay: {weight_decay}')
print(f'Batch size: {batch_size}')

## Training Loop

In [None]:
def train_epoch(model, train_loader, optimizer, criterion, device):
    model.train()
    running_loss = 0.0
    correct = 0
    total = 0
    
    pbar = tqdm(train_loader, desc='Training')
    for batch_idx, (data, target) in enumerate(pbar):
        data, target = data.to(device), target.to(device)
        
        optimizer.zero_grad()
        output = model(data)
        loss = criterion(output, target)
        loss.backward()
        
        # Gradient clipping
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
        
        optimizer.step()
        
        running_loss += loss.item()
        _, predicted = output.max(1)
        total += target.size(0)
        correct += predicted.eq(target).sum().item()
        
        pbar.set_postfix({
            'Loss': f'{running_loss/(batch_idx+1):.4f}',
            'Acc': f'{100.*correct/total:.2f}%'
        })
    
    return running_loss/len(train_loader), 100.*correct/total

def evaluate(model, test_loader, criterion, device):
    model.eval()
    test_loss = 0.0
    correct = 0
    total = 0
    
    with torch.no_grad():
        pbar = tqdm(test_loader, desc='Evaluating')
        for data, target in pbar:
            data, target = data.to(device), target.to(device)
            output = model(data)
            test_loss += criterion(output, target).item()
            
            _, predicted = output.max(1)
            total += target.size(0)
            correct += predicted.eq(target).sum().item()
            
            pbar.set_postfix({
                'Acc': f'{100.*correct/total:.2f}%'
            })
    
    return test_loss/len(test_loader), 100.*correct/total

In [None]:
# Training loop
train_losses = []
train_accuracies = []
test_losses = []
test_accuracies = []
best_test_acc = 0.0

print('Starting training...')
for epoch in range(num_epochs):
    print(f'\nEpoch {epoch+1}/{num_epochs}')
    print(f'Learning rate: {optimizer.param_groups[0]["lr"]:.6f}')
    
    # Train
    train_loss, train_acc = train_epoch(model, train_loader, optimizer, criterion, device)
    
    # Evaluate
    test_loss, test_acc = evaluate(model, test_loader, criterion, device)
    
    # Update scheduler
    scheduler.step()
    
    # Save metrics
    train_losses.append(train_loss)
    train_accuracies.append(train_acc)
    test_losses.append(test_loss)
    test_accuracies.append(test_acc)
    
    # Save best model
    if test_acc > best_test_acc:
        best_test_acc = test_acc
        torch.save({
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'test_acc': test_acc,
            'config': config
        }, 'best_vit_cifar10.pth')
        print(f'New best test accuracy: {best_test_acc:.2f}%')
    
    print(f'Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.2f}%')
    print(f'Test Loss: {test_loss:.4f}, Test Acc: {test_acc:.2f}%')

print(f'\nTraining completed!')
print(f'Best test accuracy: {best_test_acc:.2f}%')

## Results Visualization

In [None]:
# Plot training curves
plt.figure(figsize=(15, 5))

plt.subplot(1, 3, 1)
plt.plot(train_losses, label='Train Loss')
plt.plot(test_losses, label='Test Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.title('Training and Test Loss')
plt.legend()
plt.grid(True)

plt.subplot(1, 3, 2)
plt.plot(train_accuracies, label='Train Accuracy')
plt.plot(test_accuracies, label='Test Accuracy')
plt.xlabel('Epoch')
plt.ylabel('Accuracy (%)')
plt.title('Training and Test Accuracy')
plt.legend()
plt.grid(True)

plt.subplot(1, 3, 3)
plt.plot([optimizer.param_groups[0]['lr'] * (0.5 ** (epoch // 30)) for epoch in range(len(test_accuracies))])
plt.xlabel('Epoch')
plt.ylabel('Learning Rate')
plt.title('Learning Rate Schedule')
plt.grid(True)

plt.tight_layout()
plt.show()

print(f'Final Results:')
print(f'Best Test Accuracy: {best_test_acc:.2f}%')
print(f'Final Train Accuracy: {train_accuracies[-1]:.2f}%')
print(f'Final Test Accuracy: {test_accuracies[-1]:.2f}%')

## Model Analysis and Evaluation

In [None]:
# Load best model for final evaluation
checkpoint = torch.load('best_vit_cifar10.pth')
model.load_state_dict(checkpoint['model_state_dict'])

# Final evaluation
final_test_loss, final_test_acc = evaluate(model, test_loader, criterion, device)
print(f'Final Test Accuracy with Best Model: {final_test_acc:.2f}%')

# Per-class accuracy
class_correct = list(0. for i in range(10))
class_total = list(0. for i in range(10))
classes = train_dataset.classes

model.eval()
with torch.no_grad():
    for data, target in test_loader:
        data, target = data.to(device), target.to(device)
        outputs = model(data)
        _, predicted = torch.max(outputs, 1)
        c = (predicted == target).squeeze()
        for i in range(target.size(0)):
            label = target[i]
            class_correct[label] += c[i].item()
            class_total[label] += 1

print('\nPer-class accuracy:')
for i in range(10):
    print(f'{classes[i]}: {100 * class_correct[i] / class_total[i]:.2f}%')

## Sample Predictions Visualization

In [None]:
# Visualize some predictions
def visualize_predictions(model, test_loader, classes, num_images=8):
    model.eval()
    images_shown = 0
    
    plt.figure(figsize=(15, 8))
    
    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)
            outputs = model(data)
            _, predicted = torch.max(outputs, 1)
            
            for i in range(data.size(0)):
                if images_shown >= num_images:
                    break
                    
                # Denormalize image for display
                img = data[i].cpu()
                img = img * torch.tensor([0.229, 0.224, 0.225]).view(3, 1, 1) + torch.tensor([0.485, 0.456, 0.406]).view(3, 1, 1)
                img = torch.clamp(img, 0, 1)
                
                plt.subplot(2, 4, images_shown + 1)
                plt.imshow(img.permute(1, 2, 0))
                plt.title(f'True: {classes[target[i]]}, Pred: {classes[predicted[i]]}')
                plt.axis('off')
                
                images_shown += 1
            
            if images_shown >= num_images:
                break
    
    plt.tight_layout()
    plt.show()

visualize_predictions(model, test_loader, classes)

## Summary

This implementation includes:
1. **Patch Embedding**: Images are split into 16x16 patches and linearly embedded
2. **Positional Embeddings**: Learnable positional embeddings added to patch embeddings
3. **CLS Token**: Special classification token prepended to sequence
4. **Transformer Blocks**: Multi-head self-attention + MLP with residual connections and layer normalization
5. **Classification Head**: Final prediction from CLS token

**Optimizations for CIFAR-10**:
- Smaller model (384 embed_dim, 6 layers) for efficiency
- Strong data augmentation (rotation, color jitter, random erasing)
- Label smoothing and gradient clipping
- Cosine annealing learning rate schedule
- AdamW optimizer with weight decay