In [1]:
# Check GPU
!nvidia-smi

# Install required packages
!pip install timm einops torchsummary

Sat Oct  4 14:04:15 2025       
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 550.54.15              Driver Version: 550.54.15      CUDA Version: 12.4     |
|-----------------------------------------+------------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id          Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |           Memory-Usage | GPU-Util  Compute M. |
|                                         |                        |               MIG M. |
|   0  Tesla T4                       Off |   00000000:00:04.0 Off |                    0 |
| N/A   49C    P8             10W /   70W |       0MiB /  15360MiB |      0%      Default |
|                                         |                        |                  N/A |
+-----------------------------------------+------------------------+----------------------+
                                                

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import transforms, datasets
from torchvision.models import vit_b_16, ViT_B_16_Weights
from tqdm import tqdm
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.metrics import confusion_matrix

# Configuration for maximum accuracy
CFG = {
    "epochs": 20,
    "batch_size": 256,  
    "lr": 1.5e-4,      
    "weight_decay": 0.01,
    "num_workers": 4,
    "seed": 42,
    "use_amp": True,
    "patience": 6,
    "label_smoothing": 0.1,
    "warmup_epochs": 3,
    "grad_clip": 2.0,   
}

device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Device: {device}")

# Enable cuDNN benchmark for fixed input sizes
if device == "cuda":
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    torch.backends.cudnn.benchmark = True
    torch.backends.cudnn.deterministic = False

def set_seed(seed):
    torch.manual_seed(seed)
    np.random.seed(seed)
    if device == "cuda":
        torch.cuda.manual_seed_all(seed)

set_seed(CFG["seed"])

# Precompiled transforms (cached, not rebuilt each call)
TRANSFORM_TRAIN = transforms.Compose([
    transforms.Resize(224),
    transforms.RandomHorizontalFlip(p=0.5),
    transforms.RandomCrop(224, padding=32, padding_mode='reflect'),
    transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1),
    transforms.RandomRotation(15),
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
    transforms.RandomErasing(p=0.3, scale=(0.02, 0.25)),
])

TRANSFORM_TEST = transforms.Compose([
    transforms.Resize(224),
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
])

# Load datasets
train_dataset = datasets.CIFAR10(root="./data", train=True, download=True, transform=TRANSFORM_TRAIN)
test_dataset = datasets.CIFAR10(root="./data", train=False, download=True, transform=TRANSFORM_TEST)

# Increased batch size + persistent workers
trainloader = DataLoader(
    train_dataset,
    batch_size=CFG["batch_size"],
    shuffle=True,
    num_workers=CFG["num_workers"],
    pin_memory=True,
    persistent_workers=True,
    prefetch_factor=2  # Prefetch batches for faster loading
)
testloader = DataLoader(
    test_dataset,
    batch_size=CFG["batch_size"],
    shuffle=False,
    num_workers=CFG["num_workers"],
    pin_memory=True,
    persistent_workers=True,
    prefetch_factor=2
)

print(f"Train: {len(train_dataset)}, Test: {len(test_dataset)}")
print(f"Batches per epoch: {len(trainloader)}")

# Simplified model loading
print("Loading ViT-B/16...")
try:
    model = vit_b_16(weights=ViT_B_16_Weights.IMAGENET1K_V1)
    print("Loaded IMAGENET1K_V1 weights")
except AttributeError:
    model = vit_b_16(weights='DEFAULT')
    print("Loaded DEFAULT weights")

in_features = model.heads.head.in_features
model.heads = nn.Sequential(nn.Dropout(0.1), nn.Linear(in_features, 10))

# Selective parameter freezing with parameter groups
backbone_params = []
head_params = []

for name, param in model.named_parameters():
    param.requires_grad = False
    # Unfreeze last 3 encoder blocks
    if any(f"encoder.layers.{i}." in name for i in [9, 10, 11]):
        param.requires_grad = True
        backbone_params.append(param)
    # Unfreeze head
    elif "heads" in name:
        param.requires_grad = True
        head_params.append(param)

model = model.to(device)

trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
total = sum(p.numel() for p in model.parameters())
print(f"Trainable: {trainable:,} / {total:,} ({100*trainable/total:.1f}%)")

# Parameter groups with different learning rates
criterion = nn.CrossEntropyLoss(label_smoothing=CFG["label_smoothing"])

# Use of different LRs for backbone vs head (head learns faster)
optimizer = optim.AdamW([
    {'params': backbone_params, 'lr': CFG["lr"], 'weight_decay': CFG["weight_decay"]},
    {'params': head_params, 'lr': CFG["lr"] * 2, 'weight_decay': CFG["weight_decay"] * 0.5}
], betas=(0.9, 0.999), eps=1e-8)

scaler = torch.cuda.amp.GradScaler(enabled=CFG["use_amp"])

# Combined warmup + cosine annealing
def get_lr(epoch):
    if epoch < CFG["warmup_epochs"]:
        return (epoch + 1) / CFG["warmup_epochs"]
    progress = (epoch - CFG["warmup_epochs"]) / (CFG["epochs"] - CFG["warmup_epochs"])
    return max(0.001, 0.5 * (1.0 + np.cos(np.pi * progress)))

scheduler = optim.lr_scheduler.LambdaLR(optimizer, get_lr)

def train_epoch(epoch):
    model.train()
    total, correct, loss_sum = 0, 0, 0.0

    # Reduced logging overhead with leave=False
    pbar = tqdm(trainloader, desc=f"E{epoch+1}/{CFG['epochs']}", leave=False, ncols=100)

    for imgs, labels in pbar:
        imgs, labels = imgs.to(device, non_blocking=True), labels.to(device, non_blocking=True)

        # Zero grad with set_to_none for memory efficiency
        optimizer.zero_grad(set_to_none=True)

        # AMP for all forward/backward ops
        with torch.cuda.amp.autocast(enabled=CFG["use_amp"]):
            outputs = model(imgs)
            loss = criterion(outputs, labels)

        # Scaled backward pass
        scaler.scale(loss).backward()

        # Gradient clipping after unscaling
        scaler.unscale_(optimizer)
        grad_norm = nn.utils.clip_grad_norm_(model.parameters(), CFG["grad_clip"])

        scaler.step(optimizer)
        scaler.update()

        # Metrics
        loss_sum += loss.item()
        _, predicted = outputs.max(1)
        total += labels.size(0)
        correct += predicted.eq(labels).sum().item()

        if total % 1000 == 0:
            pbar.set_postfix({
                'loss': f'{loss.item():.3f}',
                'acc': f'{100.*correct/total:.1f}%',
                'grad': f'{grad_norm:.2f}'
            })

    return loss_sum / len(trainloader), 100.*correct/total

def evaluate():
    model.eval()
    total, correct = 0, 0
    all_preds, all_labels = [], []

    with torch.no_grad():
        for imgs, labels in tqdm(testloader, desc="Eval", leave=False, ncols=100):
            imgs, labels = imgs.to(device, non_blocking=True), labels.to(device, non_blocking=True)

            with torch.cuda.amp.autocast(enabled=CFG["use_amp"]):
                outputs = model(imgs)

            _, predicted = outputs.max(1)
            total += labels.size(0)
            correct += predicted.eq(labels).sum().item()

            all_preds.extend(predicted.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())

    return 100.*correct/total, np.array(all_preds), np.array(all_labels)

# Training loop
print("\nTraining...\n")
best_acc, patience = 0.0, 0
history = {'train_loss': [], 'train_acc': [], 'test_acc': [], 'lr': []}

for epoch in range(CFG["epochs"]):
    train_loss, train_acc = train_epoch(epoch)
    test_acc, preds, labels = evaluate()

    current_lr = optimizer.param_groups[0]['lr']
    scheduler.step()

    history['train_loss'].append(train_loss)
    history['train_acc'].append(train_acc)
    history['test_acc'].append(test_acc)
    history['lr'].append(current_lr)

    print(f"E{epoch+1:2d}: Loss={train_loss:.3f} | Train={train_acc:.2f}% | Test={test_acc:.2f}% | LR={current_lr:.2e}", end="")

    if test_acc > best_acc:
        best_acc = test_acc
        patience = 0
        # Saving only essential state
        torch.save({
            'model': model.state_dict(),
            'acc': test_acc,
            'epoch': epoch
        }, "best_vit_cifar10.pth")
        print(" <- BEST")
    else:
        patience += 1
        print(f" (patience {patience}/{CFG['patience']})")

    if patience >= CFG["patience"]:
        print(f"\nEarly stop at epoch {epoch+1}")
        break

print(f"\nBest Test Accuracy: {best_acc:.2f}%")

#  Conditional visualization (only at end)
print("\nGenerating results...")
checkpoint = torch.load("best_vit_cifar10.pth")
model.load_state_dict(checkpoint['model'])
final_acc, final_preds, final_labels = evaluate()

# Confusion matrix
cm = confusion_matrix(final_labels, final_preds)
class_names = ['plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck']

fig, axes = plt.subplots(2, 2, figsize=(16, 12))

# 1. Confusion matrix
sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', xticklabels=class_names,
            yticklabels=class_names, ax=axes[0, 0], cbar_kws={'label': 'Count'})
axes[0, 0].set_title(f'Confusion Matrix (Acc: {best_acc:.2f}%)', fontweight='bold')
axes[0, 0].set_ylabel('True')
axes[0, 0].set_xlabel('Predicted')

# 2. Training curves
axes[0, 1].plot(history['train_acc'], label='Train', linewidth=2, marker='o', markersize=4)
axes[0, 1].plot(history['test_acc'], label='Test', linewidth=2, marker='s', markersize=4)
axes[0, 1].axhline(y=best_acc, color='green', linestyle='--', label=f'Best: {best_acc:.2f}%')
axes[0, 1].set_xlabel('Epoch')
axes[0, 1].set_ylabel('Accuracy (%)')
axes[0, 1].set_title('Accuracy Curves', fontweight='bold')
axes[0, 1].legend()
axes[0, 1].grid(True, alpha=0.3)

# 3. Loss curve
axes[1, 0].plot(history['train_loss'], linewidth=2, color='red', marker='o', markersize=4)
axes[1, 0].set_xlabel('Epoch')
axes[1, 0].set_ylabel('Loss')
axes[1, 0].set_title('Training Loss', fontweight='bold')
axes[1, 0].grid(True, alpha=0.3)

# 4. Learning rate schedule
axes[1, 1].plot(history['lr'], linewidth=2, color='purple', marker='o', markersize=4)
axes[1, 1].set_xlabel('Epoch')
axes[1, 1].set_ylabel('Learning Rate')
axes[1, 1].set_title('LR Schedule', fontweight='bold')
axes[1, 1].set_yscale('log')
axes[1, 1].grid(True, alpha=0.3)

plt.tight_layout()
plt.savefig('vit_cifar10_results.png', dpi=300, bbox_inches='tight')
print("Saved: vit_cifar10_results.png")

# Per-class accuracy
class_correct = [0] * 10
class_total = [0] * 10
for pred, true in zip(final_preds, final_labels):
    class_total[true] += 1
    if pred == true:
        class_correct[true] += 1

print("\nPer-Class Accuracy:")
print("-" * 35)
for i, name in enumerate(class_names):
    acc = 100 * class_correct[i] / class_total[i]
    print(f"{name:8s}: {acc:5.2f}% ({class_correct[i]}/{class_total[i]})")
print("-" * 35)
print(f"Overall: {best_acc:.2f}%")

# Summary
print("\nOptimizations Applied:")
print("- Precompiled transforms (cached)")
print("- Batch size 256 (increased GPU utilization)")
print("- Parameter groups (different LRs for backbone/head)")
print("- cuDNN benchmark enabled")
print("- Prefetch factor 2 (faster data loading)")
print("- Gradient norm tracking")
print("- Reduced logging overhead")

Device: cuda
GPU: Tesla T4




Train: 50000, Test: 10000
Batches per epoch: 196
Loading ViT-B/16...
Downloading: "https://download.pytorch.org/models/vit_b_16-c867db91.pth" to /root/.cache/torch/hub/checkpoints/vit_b_16-c867db91.pth


100%|██████████| 330M/330M [00:01<00:00, 181MB/s]
  scaler = torch.cuda.amp.GradScaler(enabled=CFG["use_amp"])


Loaded IMAGENET1K_V1 weights
Trainable: 7,690 / 85,806,346 (0.0%)

Training...



  with torch.cuda.amp.autocast(enabled=CFG["use_amp"]):
  with torch.cuda.amp.autocast(enabled=CFG["use_amp"]):


E 1: Loss=1.573 | Train=63.78% | Test=89.99% | LR=5.00e-05 <- BEST




E 2: Loss=0.942 | Train=84.95% | Test=92.50% | LR=1.00e-04 <- BEST




E 3: Loss=0.845 | Train=87.39% | Test=93.42% | LR=1.50e-04 <- BEST




E 4: Loss=0.816 | Train=88.50% | Test=94.12% | LR=1.50e-04 <- BEST




E 5: Loss=0.799 | Train=89.18% | Test=94.32% | LR=1.49e-04 <- BEST




E 6: Loss=0.789 | Train=89.59% | Test=94.38% | LR=1.45e-04 <- BEST




E 7: Loss=0.784 | Train=89.79% | Test=94.62% | LR=1.39e-04 <- BEST




E 8: Loss=0.777 | Train=90.10% | Test=94.75% | LR=1.30e-04 <- BEST




E 9: Loss=0.778 | Train=90.00% | Test=94.82% | LR=1.20e-04 <- BEST




E10: Loss=0.774 | Train=90.30% | Test=94.89% | LR=1.08e-04 <- BEST




E11: Loss=0.770 | Train=90.52% | Test=94.93% | LR=9.55e-05 <- BEST




E12: Loss=0.768 | Train=90.53% | Test=95.02% | LR=8.19e-05 <- BEST




E13: Loss=0.767 | Train=90.68% | Test=95.06% | LR=6.81e-05 <- BEST


E14/20:  94%|█████████████████▊ | 184/196 [05:49<00:17,  1.43s/it, loss=0.809, acc=90.8%, grad=0.48]