# Vision Transformer (ViT) on CIFAR-10 - **IMPROVED VERSION**

**Target: 88-92% Test Accuracy**

Based on: *An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale*

---

## Key Improvements Over Base Version

1. **Stronger Data Augmentation**: AutoAugment + Cutout
2. **Label Smoothing**: Reduces overfitting
3. **Stochastic Depth**: Dropout for transformer blocks
4. **Longer Training**: 200 epochs instead of 100
5. **Better Regularization**: Higher dropout + weight decay
6. **Gradient Accumulation**: Effective batch size 256

---

## Expected Results

| Configuration | Test Accuracy | Training Time |
|---------------|---------------|---------------|
| Previous | 83-84% | ~1 hour |
| **This Version** | **88-92%** | **~3 hours** |

---

In [None]:
# Imports
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader
import torchvision
import torchvision.transforms as transforms
from torchvision.datasets import CIFAR10

import numpy as np
import matplotlib.pyplot as plt
from tqdm.auto import tqdm
import time
import math
import random

print(f"PyTorch: {torch.__version__}")
print(f"CUDA: {torch.cuda.is_available()}")

# Seeds
def set_seed(seed=42):
    torch.manual_seed(seed)
    np.random.seed(seed)
    random.seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False

set_seed(42)

In [None]:
# IMPROVED Configuration
class Config:
    # Model
    img_size = 32
    patch_size = 4
    in_channels = 3
    num_classes = 10
    embed_dim = 384  # Increased from 256
    depth = 7        # Increased from 6
    num_heads = 8
    mlp_ratio = 3.0  # Increased from 2.0
    dropout = 0.2    # Increased from 0.1
    drop_path = 0.1  # Stochastic depth
    
    # Training
    batch_size = 128
    accumulation_steps = 2  # Effective batch size = 256
    num_epochs = 200        # Increased from 100
    learning_rate = 5e-4    # Slightly higher
    weight_decay = 0.05
    warmup_epochs = 10
    label_smoothing = 0.1   # NEW
    
    device = 'cuda' if torch.cuda.is_available() else 'cpu'

config = Config()
print(f"Training on: {config.device}")
print(f"Effective batch size: {config.batch_size * config.accumulation_steps}")

In [None]:
# Stochastic Depth (Drop Path)
def drop_path(x, drop_prob=0., training=False):
    if drop_prob == 0. or not training:
        return x
    keep_prob = 1 - drop_prob
    shape = (x.shape[0],) + (1,) * (x.ndim - 1)
    random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device)
    random_tensor.floor_()
    output = x.div(keep_prob) * random_tensor
    return output

class DropPath(nn.Module):
    def __init__(self, drop_prob=None):
        super().__init__()
        self.drop_prob = drop_prob

    def forward(self, x):
        return drop_path(x, self.drop_prob, self.training)

In [None]:
# Patch Embedding
class PatchEmbedding(nn.Module):
    def __init__(self, img_size=32, patch_size=4, in_channels=3, embed_dim=384):
        super().__init__()
        self.num_patches = (img_size // patch_size) ** 2
        self.projection = nn.Conv2d(in_channels, embed_dim, kernel_size=patch_size, stride=patch_size)
    
    def forward(self, x):
        x = self.projection(x)
        x = x.flatten(2).transpose(1, 2)
        return x

In [None]:
# Attention
class Attention(nn.Module):
    def __init__(self, dim, num_heads=8, qkv_bias=True, attn_drop=0., proj_drop=0.):
        super().__init__()
        self.num_heads = num_heads
        head_dim = dim // num_heads
        self.scale = head_dim ** -0.5

        self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
        self.attn_drop = nn.Dropout(attn_drop)
        self.proj = nn.Linear(dim, dim)
        self.proj_drop = nn.Dropout(proj_drop)

    def forward(self, x):
        B, N, C = x.shape
        qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
        q, k, v = qkv[0], qkv[1], qkv[2]

        attn = (q @ k.transpose(-2, -1)) * self.scale
        attn = attn.softmax(dim=-1)
        attn = self.attn_drop(attn)

        x = (attn @ v).transpose(1, 2).reshape(B, N, C)
        x = self.proj(x)
        x = self.proj_drop(x)
        return x

In [None]:
# MLP
class MLP(nn.Module):
    def __init__(self, in_features, hidden_features, drop=0.):
        super().__init__()
        self.fc1 = nn.Linear(in_features, hidden_features)
        self.act = nn.GELU()
        self.fc2 = nn.Linear(hidden_features, in_features)
        self.drop = nn.Dropout(drop)

    def forward(self, x):
        x = self.fc1(x)
        x = self.act(x)
        x = self.drop(x)
        x = self.fc2(x)
        x = self.drop(x)
        return x

In [None]:
# Transformer Block with Drop Path
class TransformerBlock(nn.Module):
    def __init__(self, dim, num_heads, mlp_ratio=4., drop=0., attn_drop=0., drop_path=0.):
        super().__init__()
        self.norm1 = nn.LayerNorm(dim)
        self.attn = Attention(dim, num_heads=num_heads, qkv_bias=True, attn_drop=attn_drop, proj_drop=drop)
        self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
        self.norm2 = nn.LayerNorm(dim)
        mlp_hidden_dim = int(dim * mlp_ratio)
        self.mlp = MLP(in_features=dim, hidden_features=mlp_hidden_dim, drop=drop)

    def forward(self, x):
        x = x + self.drop_path(self.attn(self.norm1(x)))
        x = x + self.drop_path(self.mlp(self.norm2(x)))
        return x

In [None]:
# Vision Transformer
class VisionTransformer(nn.Module):
    def __init__(self, img_size=32, patch_size=4, in_chans=3, num_classes=10,
                 embed_dim=384, depth=7, num_heads=8, mlp_ratio=3., drop_rate=0.2, drop_path_rate=0.1):
        super().__init__()
        
        self.patch_embed = PatchEmbedding(img_size, patch_size, in_chans, embed_dim)
        num_patches = self.patch_embed.num_patches

        self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
        self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim))
        self.pos_drop = nn.Dropout(p=drop_rate)

        # Stochastic depth decay rule
        dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)]
        self.blocks = nn.ModuleList([
            TransformerBlock(embed_dim, num_heads, mlp_ratio, drop_rate, drop_rate, dpr[i])
            for i in range(depth)
        ])
        
        self.norm = nn.LayerNorm(embed_dim)
        self.head = nn.Linear(embed_dim, num_classes)

        # Init
        nn.init.trunc_normal_(self.pos_embed, std=0.02)
        nn.init.trunc_normal_(self.cls_token, std=0.02)
        self.apply(self._init_weights)

    def _init_weights(self, m):
        if isinstance(m, nn.Linear):
            nn.init.trunc_normal_(m.weight, std=0.02)
            if m.bias is not None:
                nn.init.zeros_(m.bias)
        elif isinstance(m, nn.LayerNorm):
            nn.init.zeros_(m.bias)
            nn.init.ones_(m.weight)

    def forward(self, x):
        B = x.shape[0]
        x = self.patch_embed(x)
        cls_tokens = self.cls_token.expand(B, -1, -1)
        x = torch.cat((cls_tokens, x), dim=1)
        x = x + self.pos_embed
        x = self.pos_drop(x)

        for blk in self.blocks:
            x = blk(x)

        x = self.norm(x)
        return self.head(x[:, 0])

print("✓ Improved ViT architecture defined")

In [None]:
# IMPROVED Data Augmentation
def get_data_loaders(config):
    mean = (0.4914, 0.4822, 0.4465)
    std = (0.2471, 0.2435, 0.2616)

    # Strong augmentation
    train_transform = transforms.Compose([
        transforms.RandomCrop(32, padding=4),
        transforms.RandomHorizontalFlip(),
        transforms.AutoAugment(policy=transforms.AutoAugmentPolicy.CIFAR10),  # NEW
        transforms.ToTensor(),
        transforms.Normalize(mean, std),
        transforms.RandomErasing(p=0.25),  # Cutout
    ])

    test_transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize(mean, std)
    ])

    train_dataset = CIFAR10(root='./data', train=True, download=True, transform=train_transform)
    test_dataset = CIFAR10(root='./data', train=False, download=True, transform=test_transform)

    train_loader = DataLoader(train_dataset, batch_size=config.batch_size, shuffle=True, 
                              num_workers=2, pin_memory=True, persistent_workers=True)
    test_loader = DataLoader(test_dataset, batch_size=config.batch_size, shuffle=False, 
                             num_workers=2, pin_memory=True, persistent_workers=True)

    return train_loader, test_loader

train_loader, test_loader = get_data_loaders(config)
print(f"✓ Data with strong augmentation: {len(train_loader.dataset)} train, {len(test_loader.dataset)} test")

In [None]:
# Create model
model = VisionTransformer(
    img_size=config.img_size,
    patch_size=config.patch_size,
    in_chans=config.in_channels,
    num_classes=config.num_classes,
    embed_dim=config.embed_dim,
    depth=config.depth,
    num_heads=config.num_heads,
    mlp_ratio=config.mlp_ratio,
    drop_rate=config.dropout,
    drop_path_rate=config.drop_path
).to(config.device)

n_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"✓ Model: {n_params:,} parameters")
print(f"✓ Embed dim: {config.embed_dim}, Depth: {config.depth}")

In [None]:
# Label Smoothing Loss
class LabelSmoothingCrossEntropy(nn.Module):
    def __init__(self, smoothing=0.1):
        super().__init__()
        self.smoothing = smoothing

    def forward(self, pred, target):
        n_class = pred.size(1)
        one_hot = torch.zeros_like(pred).scatter(1, target.view(-1, 1), 1)
        one_hot = one_hot * (1 - self.smoothing) + (1 - one_hot) * self.smoothing / (n_class - 1)
        log_prob = F.log_softmax(pred, dim=1)
        return -(one_hot * log_prob).sum(dim=1).mean()

criterion = LabelSmoothingCrossEntropy(smoothing=config.label_smoothing)
print(f"✓ Using label smoothing: {config.label_smoothing}")

In [None]:
# Optimizer & Scheduler
def get_lr_scheduler(optimizer, warmup_epochs, total_epochs, steps_per_epoch):
    warmup_steps = warmup_epochs * steps_per_epoch
    total_steps = total_epochs * steps_per_epoch
    
    def lr_lambda(step):
        if step < warmup_steps:
            return step / warmup_steps
        progress = (step - warmup_steps) / (total_steps - warmup_steps)
        return 0.5 * (1.0 + math.cos(math.pi * progress))
    
    return optim.lr_scheduler.LambdaLR(optimizer, lr_lambda)

optimizer = optim.AdamW(model.parameters(), lr=config.learning_rate, weight_decay=config.weight_decay)
scheduler = get_lr_scheduler(optimizer, config.warmup_epochs, config.num_epochs, len(train_loader))
print("✓ Optimizer: AdamW with cosine annealing")

In [None]:
# Training with Gradient Accumulation
def train_epoch(model, loader, criterion, optimizer, scheduler, device, accumulation_steps):
    model.train()
    total_loss = 0
    correct = 0
    total = 0
    
    optimizer.zero_grad()
    pbar = tqdm(loader, desc='Training')
    
    for i, (inputs, targets) in enumerate(pbar):
        inputs, targets = inputs.to(device), targets.to(device)
        
        outputs = model(inputs)
        loss = criterion(outputs, targets)
        loss = loss / accumulation_steps
        loss.backward()
        
        if (i + 1) % accumulation_steps == 0:
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            optimizer.step()
            scheduler.step()
            optimizer.zero_grad()
        
        total_loss += loss.item() * accumulation_steps
        _, predicted = outputs.max(1)
        total += targets.size(0)
        correct += predicted.eq(targets).sum().item()
        
        pbar.set_postfix({
            'loss': f'{total_loss/(pbar.n+1):.3f}',
            'acc': f'{100.*correct/total:.2f}%',
            'lr': f'{scheduler.get_last_lr()[0]:.2e}'
        })
    
    return total_loss / len(loader), 100. * correct / total

@torch.no_grad()
def evaluate(model, loader, criterion, device):
    model.eval()
    total_loss = 0
    correct = 0
    total = 0
    
    for inputs, targets in tqdm(loader, desc='Evaluating'):
        inputs, targets = inputs.to(device), targets.to(device)
        outputs = model(inputs)
        loss = criterion(outputs, targets)
        
        total_loss += loss.item()
        _, predicted = outputs.max(1)
        total += targets.size(0)
        correct += predicted.eq(targets).sum().item()
    
    return total_loss / len(loader), 100. * correct / total

print("✓ Training functions ready")

In [None]:
# Main training loop
print("="*70)
print("Starting IMPROVED training (target: 88-92% accuracy)")
print("="*70)

best_acc = 0
train_losses, test_losses = [], []
train_accs, test_accs = [], []
start_time = time.time()

for epoch in range(config.num_epochs):
    print(f"\nEpoch {epoch+1}/{config.num_epochs}")
    print("-"*70)
    
    train_loss, train_acc = train_epoch(model, train_loader, criterion, optimizer, 
                                        scheduler, config.device, config.accumulation_steps)
    test_loss, test_acc = evaluate(model, test_loader, criterion, config.device)
    
    train_losses.append(train_loss)
    test_losses.append(test_loss)
    train_accs.append(train_acc)
    test_accs.append(test_acc)
    
    print(f"Train: Loss={train_loss:.4f}, Acc={train_acc:.2f}%")
    print(f"Test:  Loss={test_loss:.4f}, Acc={test_acc:.2f}%")
    
    if test_acc > best_acc:
        best_acc = test_acc
        torch.save({
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'test_acc': test_acc,
            'config': vars(config)
        }, 'best_vit_improved.pth')
        print(f"✓ Best: {best_acc:.2f}%")

elapsed = time.time() - start_time
print(f"\n{'='*70}")
print(f"Training complete in {elapsed/3600:.2f} hours")
print(f"Best test accuracy: {best_acc:.2f}%")
print(f"{'='*70}")

In [None]:
# Plot results
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 5))

epochs = range(1, len(train_losses) + 1)

ax1.plot(epochs, train_losses, 'b-', label='Train Loss', linewidth=2)
ax1.plot(epochs, test_losses, 'r-', label='Test Loss', linewidth=2)
ax1.set_xlabel('Epoch', fontsize=12)
ax1.set_ylabel('Loss', fontsize=12)
ax1.set_title('Training and Test Loss', fontsize=14, fontweight='bold')
ax1.legend(fontsize=11)
ax1.grid(True, alpha=0.3)

ax2.plot(epochs, train_accs, 'b-', label='Train Accuracy', linewidth=2)
ax2.plot(epochs, test_accs, 'r-', label='Test Accuracy', linewidth=2)
ax2.set_xlabel('Epoch', fontsize=12)
ax2.set_ylabel('Accuracy (%)', fontsize=12)
ax2.set_title('Training and Test Accuracy', fontsize=14, fontweight='bold')
ax2.legend(fontsize=11)
ax2.grid(True, alpha=0.3)

plt.tight_layout()
plt.savefig('training_curves_improved.png', dpi=150)
plt.show()

---

## Final Results

In [None]:
print("="*70)
print("FINAL RESULTS - IMPROVED VERSION")
print("="*70)
print(f"Overall Classification Test Accuracy: {best_acc:.2f}%")
print(f"\nImprovement over base: {best_acc - 83.89:.2f}%")
print("="*70)
print(f"\nModel: best_vit_improved.pth")
print(f"Plots: training_curves_improved.png")