# Vision Transformers (ViT) for Image Classification

In this notebook, we'll explore **Vision Transformers**, an architecture that applies the transformer mechanism (originally designed for NLP) to computer vision tasks.

## Key Concepts:
1. **Patch Embedding**: Images are split into fixed-size patches and linearly embedded
2. **Position Embeddings**: Since transformers have no inherent notion of order, we add learnable position embeddings
3. **Self-Attention**: The core mechanism that allows patches to "attend" to other patches
4. **Classification Token**: A special learnable token prepended to the sequence for classification

## Why ViT?
- **Global context**: Unlike CNNs with local receptive fields, ViT can capture global dependencies from the start
- **Scalability**: Transformers scale well with data and model size
- **Transfer learning**: Pre-trained ViTs work exceptionally well for downstream tasks

## 1. Setup and Imports

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
import numpy as np
from tqdm import tqdm

# Check if GPU is available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

## 2. Understanding the ViT Architecture

Let's build a Vision Transformer from scratch to understand each component.

### 2.1 Patch Embedding

The first step is to divide the image into patches and embed them into vectors.

In [None]:
class PatchEmbedding(nn.Module):
    """
    Splits image into patches and embeds them.
    
    For a 32x32 image with 4x4 patches:
    - We get 8x8 = 64 patches
    - Each patch is flattened: 4*4*3 = 48 values
    - Then projected to embedding dimension
    """
    def __init__(self, img_size=32, patch_size=4, in_channels=3, embed_dim=128):
        super().__init__()
        self.img_size = img_size
        self.patch_size = patch_size
        self.n_patches = (img_size // patch_size) ** 2  # Number of patches
        
        # We can use a Conv2d layer with kernel=patch_size and stride=patch_size
        # This is equivalent to splitting into patches and projecting
        self.projection = nn.Conv2d(in_channels, embed_dim, 
                                   kernel_size=patch_size, stride=patch_size)
        
    def forward(self, x):
        # x: (batch_size, channels, height, width)
        x = self.projection(x)  # (batch_size, embed_dim, n_patches**0.5, n_patches**0.5)
        x = x.flatten(2)  # (batch_size, embed_dim, n_patches)
        x = x.transpose(1, 2)  # (batch_size, n_patches, embed_dim)
        return x

# Test the patch embedding
patch_embed = PatchEmbedding(img_size=32, patch_size=4, embed_dim=128)
dummy_img = torch.randn(2, 3, 32, 32)  # batch of 2 images
patches = patch_embed(dummy_img)
print(f"Input shape: {dummy_img.shape}")
print(f"Patches shape: {patches.shape}")  # Should be (2, 64, 128)
print(f"Number of patches: {patch_embed.n_patches}")

### 2.2 Multi-Head Self-Attention

The core of the transformer architecture. It allows patches to attend to all other patches.

In [None]:
class MultiHeadAttention(nn.Module):
    """
    Multi-head self-attention mechanism.
    
    Key idea: Instead of one attention, we have multiple "heads" that can 
    attend to different aspects of the input.
    """
    def __init__(self, embed_dim=128, num_heads=4, dropout=0.1):
        super().__init__()
        self.embed_dim = embed_dim
        self.num_heads = num_heads
        self.head_dim = embed_dim // num_heads
        
        assert embed_dim % num_heads == 0, "embed_dim must be divisible by num_heads"
        
        # Linear projections for Q, K, V
        self.qkv = nn.Linear(embed_dim, embed_dim * 3)
        self.projection = nn.Linear(embed_dim, embed_dim)
        self.dropout = nn.Dropout(dropout)
        
    def forward(self, x):
        batch_size, n_tokens, embed_dim = x.shape
        
        # Generate Q, K, V
        qkv = self.qkv(x)  # (batch_size, n_tokens, 3*embed_dim)
        qkv = qkv.reshape(batch_size, n_tokens, 3, self.num_heads, self.head_dim)
        qkv = qkv.permute(2, 0, 3, 1, 4)  # (3, batch_size, num_heads, n_tokens, head_dim)
        q, k, v = qkv[0], qkv[1], qkv[2]
        
        # Scaled dot-product attention
        # scores = Q @ K^T / sqrt(head_dim)
        scores = torch.matmul(q, k.transpose(-2, -1)) / (self.head_dim ** 0.5)
        attention = F.softmax(scores, dim=-1)
        attention = self.dropout(attention)
        
        # Apply attention to values
        out = torch.matmul(attention, v)  # (batch_size, num_heads, n_tokens, head_dim)
        out = out.transpose(1, 2)  # (batch_size, n_tokens, num_heads, head_dim)
        out = out.reshape(batch_size, n_tokens, embed_dim)
        
        # Final projection
        out = self.projection(out)
        return out

# Test attention
attention = MultiHeadAttention(embed_dim=128, num_heads=4)
out = attention(patches)
print(f"After attention shape: {out.shape}")  # Should be (2, 64, 128)

### 2.3 MLP (Feed-Forward) Block

After attention, we apply a feed-forward network to each token independently.

In [None]:
class MLP(nn.Module):
    """
    Multi-Layer Perceptron (Feed-Forward Network).
    Typically: Linear -> GELU -> Dropout -> Linear -> Dropout
    """
    def __init__(self, embed_dim=128, hidden_dim=512, dropout=0.1):
        super().__init__()
        self.fc1 = nn.Linear(embed_dim, hidden_dim)
        self.fc2 = nn.Linear(hidden_dim, embed_dim)
        self.dropout = nn.Dropout(dropout)
        
    def forward(self, x):
        x = self.fc1(x)
        x = F.gelu(x)  # GELU activation is commonly used in transformers
        x = self.dropout(x)
        x = self.fc2(x)
        x = self.dropout(x)
        return x

### 2.4 Transformer Block

Combines attention and MLP with residual connections and layer normalization.

In [None]:
class TransformerBlock(nn.Module):
    """
    A single transformer block.
    Structure:
        1. LayerNorm -> MultiHeadAttention -> Residual
        2. LayerNorm -> MLP -> Residual
    """
    def __init__(self, embed_dim=128, num_heads=4, mlp_ratio=4, dropout=0.1):
        super().__init__()
        self.norm1 = nn.LayerNorm(embed_dim)
        self.attention = MultiHeadAttention(embed_dim, num_heads, dropout)
        self.norm2 = nn.LayerNorm(embed_dim)
        self.mlp = MLP(embed_dim, embed_dim * mlp_ratio, dropout)
        
    def forward(self, x):
        # Attention block with residual
        x = x + self.attention(self.norm1(x))
        # MLP block with residual
        x = x + self.mlp(self.norm2(x))
        return x

### 2.5 Complete Vision Transformer

Now we put everything together into a complete ViT model.

In [None]:
class VisionTransformer(nn.Module):
    """
    Complete Vision Transformer for image classification.
    
    Args:
        img_size: Input image size (assumes square images)
        patch_size: Size of each patch
        in_channels: Number of input channels (3 for RGB)
        num_classes: Number of output classes
        embed_dim: Embedding dimension
        depth: Number of transformer blocks
        num_heads: Number of attention heads
        mlp_ratio: Ratio of mlp hidden dim to embedding dim
        dropout: Dropout rate
    """
    def __init__(self, img_size=32, patch_size=4, in_channels=3, num_classes=10,
                 embed_dim=128, depth=6, num_heads=4, mlp_ratio=4, dropout=0.1):
        super().__init__()
        
        # Patch embedding
        self.patch_embed = PatchEmbedding(img_size, patch_size, in_channels, embed_dim)
        num_patches = self.patch_embed.n_patches
        
        # Class token - a learnable embedding prepended to the sequence
        self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
        
        # Position embeddings - learnable for each position (including cls token)
        self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim))
        self.pos_dropout = nn.Dropout(dropout)
        
        # Transformer blocks
        self.blocks = nn.ModuleList([
            TransformerBlock(embed_dim, num_heads, mlp_ratio, dropout)
            for _ in range(depth)
        ])
        
        # Classification head
        self.norm = nn.LayerNorm(embed_dim)
        self.head = nn.Linear(embed_dim, num_classes)
        
        # Initialize weights
        nn.init.trunc_normal_(self.pos_embed, std=0.02)
        nn.init.trunc_normal_(self.cls_token, std=0.02)
        self.apply(self._init_weights)
        
    def _init_weights(self, m):
        if isinstance(m, nn.Linear):
            nn.init.trunc_normal_(m.weight, std=0.02)
            if m.bias is not None:
                nn.init.constant_(m.bias, 0)
        elif isinstance(m, nn.LayerNorm):
            nn.init.constant_(m.bias, 0)
            nn.init.constant_(m.weight, 1.0)
    
    def forward(self, x):
        batch_size = x.shape[0]
        
        # Patch embedding
        x = self.patch_embed(x)  # (batch_size, num_patches, embed_dim)
        
        # Prepend class token
        cls_tokens = self.cls_token.expand(batch_size, -1, -1)
        x = torch.cat([cls_tokens, x], dim=1)  # (batch_size, num_patches+1, embed_dim)
        
        # Add position embeddings
        x = x + self.pos_embed
        x = self.pos_dropout(x)
        
        # Apply transformer blocks
        for block in self.blocks:
            x = block(x)
        
        # Layer norm and extract class token
        x = self.norm(x)
        cls_token_final = x[:, 0]  # Take the class token
        
        # Classification
        out = self.head(cls_token_final)
        return out

# Create a small ViT model for CIFAR-10
model = VisionTransformer(
    img_size=32,
    patch_size=4,
    in_channels=3,
    num_classes=10,
    embed_dim=128,
    depth=3,
    num_heads=4,
    mlp_ratio=4,
    dropout=0.1
).to(device)

# Count parameters
total_params = sum(p.numel() for p in model.parameters())
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"Total parameters: {total_params:,}")
print(f"Trainable parameters: {trainable_params:,}")

# Test forward pass
test_input = torch.randn(4, 3, 32, 32).to(device)
output = model(test_input)
print(f"\nInput shape: {test_input.shape}")
print(f"Output shape: {output.shape}")

## 3. Load CIFAR-10 Dataset

We'll use CIFAR-10 because it has small 32x32 images, making training computationally efficient.

In [None]:
# Data augmentation and normalization
transform_train = transforms.Compose([
    transforms.RandomCrop(32, padding=4),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
])

transform_test = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
])

# Load datasets
train_dataset = torchvision.datasets.CIFAR10(root='./data', train=True, 
                                             download=True, transform=transform_train)
test_dataset = torchvision.datasets.CIFAR10(root='./data', train=False, 
                                            download=True, transform=transform_test)

# Create data loaders
batch_size = 128
train_loader = DataLoader(train_dataset, batch_size=batch_size, 
                         shuffle=True, num_workers=2)
test_loader = DataLoader(test_dataset, batch_size=batch_size, 
                        shuffle=False, num_workers=2)

# Class names
classes = ('plane', 'car', 'bird', 'cat', 'deer', 
          'dog', 'frog', 'horse', 'ship', 'truck')

print(f"Training samples: {len(train_dataset)}")
print(f"Test samples: {len(test_dataset)}")
print(f"Number of classes: {len(classes)}")

### Visualize some training examples

In [None]:
def imshow(img):
    img = img / 2 + 0.5  # unnormalize
    npimg = img.numpy()
    plt.imshow(np.transpose(npimg, (1, 2, 0)))
    plt.axis('off')

# Get some random training images
dataiter = iter(train_loader)
images, labels = next(dataiter)

# Create a grid of images
imshow(torchvision.utils.make_grid(images[:8]))

print('Labels: ', ' '.join(f'{classes[labels[j]]:5s}' for j in range(8)))

## 4. Training the Vision Transformer

Let's define our training and evaluation functions.

In [None]:
def train_epoch(model, train_loader, criterion, optimizer, device):
    """Train for one epoch."""
    model.train()
    running_loss = 0.0
    correct = 0
    total = 0
    
    pbar = tqdm(train_loader, desc='Training')
    for inputs, labels in pbar:
        inputs, labels = inputs.to(device), labels.to(device)
        
        # Zero the gradients
        optimizer.zero_grad()
        
        # Forward pass
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        
        # Backward pass and optimize
        loss.backward()
        optimizer.step()
        
        # Statistics
        running_loss += loss.item()
        _, predicted = outputs.max(1)
        total += labels.size(0)
        correct += predicted.eq(labels).sum().item()
        
        # Update progress bar
        pbar.set_postfix({
            'loss': running_loss / (pbar.n + 1),
            'acc': 100. * correct / total
        })
    
    epoch_loss = running_loss / len(train_loader)
    epoch_acc = 100. * correct / total
    return epoch_loss, epoch_acc

def evaluate(model, test_loader, criterion, device):
    """Evaluate the model."""
    model.eval()
    running_loss = 0.0
    correct = 0
    total = 0
    
    with torch.no_grad():
        for inputs, labels in tqdm(test_loader, desc='Evaluating'):
            inputs, labels = inputs.to(device), labels.to(device)
            
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            
            running_loss += loss.item()
            _, predicted = outputs.max(1)
            total += labels.size(0)
            correct += predicted.eq(labels).sum().item()
    
    epoch_loss = running_loss / len(test_loader)
    epoch_acc = 100. * correct / total
    return epoch_loss, epoch_acc

### Setup training configuration

In [None]:
# Create a fresh model
model = VisionTransformer(
    img_size=32,
    patch_size=4,
    in_channels=3,
    num_classes=10,
    embed_dim=128, # usually 512 or 768
    depth=3, # usually 12
    num_heads=4, # usually 8
    mlp_ratio=4,
    dropout=0.1
).to(device)

# Loss function
criterion = nn.CrossEntropyLoss()

# Optimizer - AdamW is commonly used for transformers
optimizer = torch.optim.AdamW(model.parameters(), lr=3e-4, weight_decay=0.05)

# Learning rate scheduler - cosine annealing
num_epochs = 5
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=num_epochs)

print(f"Model has {sum(p.numel() for p in model.parameters()):,} parameters")
print(f"Training for {num_epochs} epochs")

### Run training loop

**Note**: Training from scratch can take time. For a quick demo, you can reduce `num_epochs` to 10-20.

In [None]:
# Training history
history = {
    'train_loss': [],
    'train_acc': [],
    'test_loss': [],
    'test_acc': []
}

best_acc = 0.0

for epoch in range(num_epochs):
    print(f"\nEpoch {epoch+1}/{num_epochs}")
    print("-" * 60)
    
    # Train
    train_loss, train_acc = train_epoch(model, train_loader, criterion, optimizer, device)
    
    # Evaluate
    test_loss, test_acc = evaluate(model, test_loader, criterion, device)
    
    # Update scheduler
    scheduler.step()
    
    # Save history
    history['train_loss'].append(train_loss)
    history['train_acc'].append(train_acc)
    history['test_loss'].append(test_loss)
    history['test_acc'].append(test_acc)
    
    # Print epoch summary
    print(f"\nTrain Loss: {train_loss:.4f} | Train Acc: {train_acc:.2f}%")
    print(f"Test Loss: {test_loss:.4f} | Test Acc: {test_acc:.2f}%")
    print(f"Learning Rate: {scheduler.get_last_lr()[0]:.6f}")
    
    # Save best model
    if test_acc > best_acc:
        best_acc = test_acc
        torch.save(model.state_dict(), 'res/vit_cifar10_best.pth')
        print(f"âœ“ New best model saved! (Acc: {best_acc:.2f}%)")

print("\n" + "="*60)
print(f"Training completed! Best test accuracy: {best_acc:.2f}%")

### Visualize training progress

In [None]:
fig, axes = plt.subplots(1, 2, figsize=(15, 5))

# Plot loss
axes[0].plot(history['train_loss'], label='Train Loss', linewidth=2)
axes[0].plot(history['test_loss'], label='Test Loss', linewidth=2)
axes[0].set_xlabel('Epoch', fontsize=12)
axes[0].set_ylabel('Loss', fontsize=12)
axes[0].set_title('Training and Test Loss', fontsize=14, fontweight='bold')
axes[0].legend(fontsize=11)
axes[0].grid(True, alpha=0.3)

# Plot accuracy
axes[1].plot(history['train_acc'], label='Train Accuracy', linewidth=2)
axes[1].plot(history['test_acc'], label='Test Accuracy', linewidth=2)
axes[1].set_xlabel('Epoch', fontsize=12)
axes[1].set_ylabel('Accuracy (%)', fontsize=12)
axes[1].set_title('Training and Test Accuracy', fontsize=14, fontweight='bold')
axes[1].legend(fontsize=11)
axes[1].grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

print(f"Final Test Accuracy: {history['test_acc'][-1]:.2f}%")
print(f"Best Test Accuracy: {max(history['test_acc']):.2f}%")

## 5. Testing and Visualization

Let's test our model on some examples and visualize the predictions.

In [None]:
# Load best model
model.load_state_dict(torch.load('res/vit_cifar10_best.pth'))
model.eval()

# Get a batch of test images
dataiter = iter(test_loader)
images, labels = next(dataiter)
images, labels = images.to(device), labels.to(device)

# Make predictions
with torch.no_grad():
    outputs = model(images)
    _, predicted = outputs.max(1)
    probs = F.softmax(outputs, dim=1)

# Visualize predictions
fig, axes = plt.subplots(3, 4, figsize=(15, 11))
for idx, ax in enumerate(axes.flat):
    if idx < 12:
        # Display image
        img = images[idx].cpu()
        img = img / 2 + 0.5  # unnormalize
        npimg = img.numpy()
        ax.imshow(np.transpose(npimg, (1, 2, 0)))
        
        # Get prediction info
        true_label = classes[labels[idx]]
        pred_label = classes[predicted[idx]]
        confidence = probs[idx][predicted[idx]].item() * 100
        
        # Set title with color coding
        color = 'green' if predicted[idx] == labels[idx] else 'red'
        title = f'True: {true_label}\nPred: {pred_label} ({confidence:.1f}%)'
        ax.set_title(title, color=color, fontweight='bold')
        ax.axis('off')

plt.tight_layout()
plt.show()

# Calculate accuracy for this batch
correct = (predicted == labels).sum().item()
accuracy = 100. * correct / labels.size(0)
print(f"\nBatch accuracy: {accuracy:.2f}% ({correct}/{labels.size(0)})")

### Per-class accuracy analysis

In [None]:
# Calculate per-class accuracy
class_correct = [0] * 10
class_total = [0] * 10

model.eval()
with torch.no_grad():
    for images, labels in test_loader:
        images, labels = images.to(device), labels.to(device)
        outputs = model(images)
        _, predicted = outputs.max(1)
        
        for i in range(labels.size(0)):
            label = labels[i].item()
            class_total[label] += 1
            if predicted[i] == labels[i]:
                class_correct[label] += 1

# Print and plot results
print("Per-class accuracy:\n" + "="*40)
accuracies = []
for i in range(10):
    acc = 100 * class_correct[i] / class_total[i] if class_total[i] > 0 else 0
    accuracies.append(acc)
    print(f"{classes[i]:10s}: {acc:.2f}% ({class_correct[i]}/{class_total[i]})")

# Visualize per-class accuracy
plt.figure(figsize=(12, 6))
bars = plt.bar(classes, accuracies, color='steelblue', alpha=0.8)
plt.axhline(y=np.mean(accuracies), color='r', linestyle='--', 
            label=f'Mean: {np.mean(accuracies):.2f}%', linewidth=2)
plt.xlabel('Class', fontsize=12)
plt.ylabel('Accuracy (%)', fontsize=12)
plt.title('Per-Class Accuracy of Vision Transformer on CIFAR-10', 
          fontsize=14, fontweight='bold')
plt.ylim(0, 100)
plt.legend(fontsize=11)
plt.grid(True, alpha=0.3, axis='y')

# Add value labels on bars
for bar, acc in zip(bars, accuracies):
    height = bar.get_height()
    plt.text(bar.get_x() + bar.get_width()/2., height,
            f'{acc:.1f}%', ha='center', va='bottom', fontsize=10)

plt.tight_layout()
plt.show()