# ViT Fine-tuning on CIFAR-10

Training mini_vit architecture on CIFAR-10 dataset.

In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as transforms
import torchvision.datasets as datasets
from torch.utils.data import DataLoader
import numpy as np
import random
from tqdm import tqdm
import os

print(f"PyTorch version: {torch.__version__}")

PyTorch version: 2.9.1


## Hyperparameters

In [8]:
SEED = 0
BATCH_SIZE = 256
EPOCHS = 150
WARMUP_EPOCHS = 15
LEARNING_RATE = 0.001
WEIGHT_DECAY = 0.05
LABEL_SMOOTHING = 0.1

PATCH_SIZE = 4
IMG_SIZE = 32
NUM_CLASSES = 10
VIT_MLP_RATIO = 2
CHANNEL = 192
DEPTH = 12
HEADS = 12

MIXUP_ALPHA = 1.0
CUTMIX_BETA = 1.0
MIX_PROB = 0.5
RANDOM_ERASING_PROB = 0.25
REPEATED_AUG = 3

torch.manual_seed(SEED)
np.random.seed(SEED)
random.seed(SEED)

DEVICE = torch.device('cuda')
print(f"Device: {DEVICE}")

Device: cuda


## Data Loading

### Data Augmentation Techniques

- **RandomHorizontalFlip** and **RandomCrop**: Basic spatial augmentations

- **AutoAugment (CIFAR10)**: Learned augmentation policy optimized for CIFAR-10

- **RandomErasing**: Randomly masks rectangular regions to improve robustness

- **Normalization**: Using CIFAR-10 dataset statistics

### Repeated Augmentation

Repeating each sample 3x with different augmentations per epoch to increase training diversity.

In [3]:
from torchvision.transforms import AutoAugment, AutoAugmentPolicy, RandomErasing
from torch.utils.data import Sampler

cifar_mean = (0.4914, 0.4822, 0.4465)
cifar_std = (0.2023, 0.1994, 0.2010)

train_transform = transforms.Compose([
    transforms.RandomHorizontalFlip(),
    transforms.RandomCrop(IMG_SIZE, padding=4),
    AutoAugment(policy=AutoAugmentPolicy.CIFAR10),
    transforms.ToTensor(),
    transforms.Normalize(mean=cifar_mean, std=cifar_std),
    RandomErasing(p=RANDOM_ERASING_PROB, scale=(0.02, 0.4), ratio=(0.3, 3.3), value='random'),
])

test_transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean=cifar_mean, std=cifar_std),
])

train_dataset = datasets.CIFAR10(root='./data', train=True, download=True, transform=train_transform)
test_dataset = datasets.CIFAR10(root='./data', train=False, download=True, transform=test_transform)

from torch.utils.data import random_split

class RepeatedAugmentationSampler(Sampler):
    def __init__(self, dataset, num_repeats=1, shuffle=True):
        self.dataset = dataset
        self.num_repeats = num_repeats
        self.shuffle = shuffle
        self.num_samples = len(dataset) * num_repeats

    def __iter__(self):
        indices = list(range(len(self.dataset)))
        if self.shuffle:
            random.shuffle(indices)
        
        indices = indices * self.num_repeats
        if self.shuffle:
            random.shuffle(indices)
        
        return iter(indices)

    def __len__(self):
        return self.num_samples

train_sampler = RepeatedAugmentationSampler(train_dataset, num_repeats=REPEATED_AUG, shuffle=True)

train_loader = DataLoader(
    train_dataset, 
    batch_size=BATCH_SIZE, 
    sampler=train_sampler,
    num_workers=2,
    drop_last=True
)
test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=2)

print(f"Test dataset size: {len(test_dataset)}")
print(f"Train batches: {len(train_loader)}")
print(f"Test batches: {len(test_loader)}")

  entry = pickle.load(f, encoding="latin1")


Test dataset size: 10000
Train batches: 585
Test batches: 40


## Model

In [4]:
from src.models.mini_vit import MiniViT

model = MiniViT(
    img_size=IMG_SIZE,
    patch_size=PATCH_SIZE,
    n_classes=NUM_CLASSES,
    emb_dim=CHANNEL,
    n_layers=DEPTH,
    n_heads=HEADS,
    mlp_hidden_dim=CHANNEL * VIT_MLP_RATIO,
)

model.to(DEVICE)

num_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"Model: MiniViT")
print(f"Trainable parameters: {num_params:,}")
print(f"Image size: {IMG_SIZE}x{IMG_SIZE}")
print(f"Patch size: {PATCH_SIZE}")
print(f"Dim: {CHANNEL}")
print(f"Depth: {DEPTH}")
print(f"Heads: {HEADS}")
print(f"MLP ratio: {VIT_MLP_RATIO}")

Model: MiniViT
Trainable parameters: 3,562,378
Image size: 32x32
Patch size: 4
Dim: 192
Depth: 12
Heads: 12
MLP ratio: 2


## Loss, Optimizer, Scheduler

### Training Regularization & Optimization

- **Label Smoothing** (0.1): Prevents overconfident predictions

- **Weight Decay** (0.05): L2 regularization in AdamW optimizer

- **Cosine Annealing with Warmup**: 
  - Linear warmup for 15 epochs to stabilize training
  - Cosine decay schedule for remaining epochs to fine-tune convergence

In [5]:
criterion = nn.CrossEntropyLoss(label_smoothing=LABEL_SMOOTHING)

optimizer = optim.AdamW(
    model.parameters(),
    lr=LEARNING_RATE,
    weight_decay=WEIGHT_DECAY,
)

total_steps = len(train_loader) * EPOCHS
warmup_steps = len(train_loader) * WARMUP_EPOCHS

class CosineAnnealingWithWarmup:
    def __init__(self, optimizer, warmup_steps, total_steps, lr):
        self.optimizer = optimizer
        self.warmup_steps = warmup_steps
        self.total_steps = total_steps
        self.lr = lr
        self.step_count = 0
    
    def step(self):
        self.step_count += 1
        
        if self.step_count < self.warmup_steps:
            lr = self.lr * (self.step_count / self.warmup_steps)
        else:
            progress = (self.step_count - self.warmup_steps) / (self.total_steps - self.warmup_steps)
            lr = self.lr * 0.5 * (1 + np.cos(np.pi * progress))
        
        for param_group in self.optimizer.param_groups:
            param_group['lr'] = lr
        
        return lr

scheduler = CosineAnnealingWithWarmup(optimizer, warmup_steps, total_steps, LEARNING_RATE)

print(f"Criterion: CrossEntropyLoss with label smoothing={LABEL_SMOOTHING}")
print(f"Optimizer: AdamW(lr={LEARNING_RATE}, wd={WEIGHT_DECAY})")
print(f"Scheduler: CosineAnnealingWithWarmup")
print(f"Total steps: {total_steps}")
print(f"Warmup steps: {warmup_steps}")

Criterion: CrossEntropyLoss with label smoothing=0.1
Optimizer: AdamW(lr=0.001, wd=0.05)
Scheduler: CosineAnnealingWithWarmup
Total steps: 87750
Warmup steps: 8775


## Training Functions

### MixUp & CutMix

**MixUp**: Blends two images and their labels with random ratio  

**CutMix**: Cuts and pastes patches between images, mixing labels proportionally

**Mix Probability**: 50% chance to apply either MixUp or CutMix during training

In [6]:
def mixup_data(x, y, alpha=1.0):
    lam = np.random.beta(alpha, alpha) if alpha > 0 else 1
    batch_size = x.size(0)
    index = torch.randperm(batch_size).to(x.device)
    
    mixed_x = lam * x + (1 - lam) * x[index, :]
    y_a, y_b = y, y[index]
    
    return mixed_x, y_a, y_b, lam

def cutmix_data(x, y, beta=1.0):
    lam = np.random.beta(beta, beta) if beta > 0 else 1
    batch_size = x.size(0)
    index = torch.randperm(batch_size).to(x.device)
    
    W = x.size(2)
    H = x.size(3)
    
    cut_ratio = np.sqrt(1.0 - lam)
    cut_h = int(H * cut_ratio)
    cut_w = int(W * cut_ratio)
    
    cx = np.random.randint(0, W)
    cy = np.random.randint(0, H)
    
    bbx1 = np.clip(cx - cut_w // 2, 0, W)
    bby1 = np.clip(cy - cut_h // 2, 0, H)
    bbx2 = np.clip(cx + cut_w // 2, 0, W)
    bby2 = np.clip(cy + cut_h // 2, 0, H)
    
    x[..., bby1:bby2, bbx1:bbx2] = x[index, :, bby1:bby2, bbx1:bbx2]
    
    lam = 1 - ((bbx2 - bbx1) * (bby2 - bby1)) / (H * W)
    y_a, y_b = y, y[index]
    
    return x, y_a, y_b, lam

def mixup_criterion(criterion, pred, y_a, y_b, lam):
    return lam * criterion(pred, y_a) + (1 - lam) * criterion(pred, y_b)

def train_one_epoch(epoch):
    model.train()
    total_loss = 0
    total_correct = 0
    total_samples = 0
    
    pbar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{EPOCHS} [Train]", leave=False)
    
    for images, targets in pbar:
        images = images.to(DEVICE)
        targets = targets.to(DEVICE)
        
        r = np.random.rand()
        if r < MIX_PROB:
            choice = np.random.rand()
            if choice < 0.5:
                images, targets_a, targets_b, lam = cutmix_data(images, targets, CUTMIX_BETA)
            else:
                images, targets_a, targets_b, lam = mixup_data(images, targets, MIXUP_ALPHA)
            
            outputs = model(images)
            loss = mixup_criterion(criterion, outputs, targets_a, targets_b, lam)
        else:
            outputs = model(images)
            loss = criterion(outputs, targets)
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        lr = scheduler.step()
        
        with torch.no_grad():
            _, pred = outputs.max(1)
            correct = pred.eq(targets).sum().item()
            total_correct += correct
            total_samples += targets.size(0)
            total_loss += loss.item() * targets.size(0)
        
        acc = 100.0 * correct / targets.size(0)
        pbar.set_postfix({'loss': f'{loss.item():.4f}', 'acc': f'{acc:.2f}%', 'lr': f'{lr:.6f}'})
    
    avg_loss = total_loss / total_samples
    avg_acc = 100.0 * total_correct / total_samples
    
    return avg_loss, avg_acc

def validate(loader):
    model.eval()
    total_loss = 0
    total_correct = 0
    total_samples = 0
    
    with torch.no_grad():
        for images, targets in loader:
            images = images.to(DEVICE)
            targets = targets.to(DEVICE)
            
            outputs = model(images)
            loss = criterion(outputs, targets)
            
            _, pred = outputs.max(1)
            correct = pred.eq(targets).sum().item()
            
            total_loss += loss.item() * targets.size(0)
            total_correct += correct
            total_samples += targets.size(0)
    
    avg_loss = total_loss / total_samples
    avg_acc = 100.0 * total_correct / total_samples
    
    return avg_loss, avg_acc

print("Training functions ready")

Training functions ready


## Training Loop

In [None]:
os.makedirs('checkpoints', exist_ok=True)

best_acc = 0
eval_interval = 5

for epoch in range(EPOCHS):
    train_loss, train_acc = train_one_epoch(epoch)

    val_loss, val_acc = validate(test_loader)
    
    if (epoch + 1) % eval_interval == 0 or epoch == EPOCHS - 1:
        print(f"Epoch {epoch+1:3d}/{EPOCHS} | "
            f"Train Loss: {train_loss:.4f} | Train Acc: {train_acc:.2f}% | "
            f"Val Loss: {val_loss:.4f} | Val Acc: {val_acc:.2f}%")
        
        if val_acc > best_acc:
            best_adcc = val_acc
            torch.save({
                'epoch': epoch,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'best_acc': best_acc,
            }, 'checkpoints/best.pt')
            print(f"Best model saved! Accuracy: {best_acc:.2f}%")

print(f"\nTraining completed! Best accuracy: {best_acc:.2f}%")

# Final test on test set
test_loss, test_acc = validate(test_loader)
print(f"Final Test Accuracy: {test_acc:.2f}%")

Epoch 1/150 [Train]:   0%|          | 0/585 [00:00<?, ?it/s]

Attention shapes - Q: torch.Size([256, 12, 65, 16]), K: torch.Size([256, 12, 65, 16]), V: torch.Size([256, 12, 65, 16]), A: torch.Size([256, 12, 65, 65])
Attention output shape: torch.Size([256, 65, 192])
Attention output shape: torch.Size([256, 65, 192])
Attention output shape: torch.Size([256, 65, 192])
Attention output shape: torch.Size([256, 65, 192])
Attention output shape: torch.Size([256, 65, 192])
Attention output shape: torch.Size([256, 65, 192])
Attention output shape: torch.Size([256, 65, 192])
Attention output shape: torch.Size([256, 65, 192])
Attention output shape: torch.Size([256, 65, 192])
Attention output shape: torch.Size([256, 65, 192])
Attention output shape: torch.Size([256, 65, 192])
Attention output shape: torch.Size([256, 65, 192])


                                                                                                            

Epoch   5/150 | Train Loss: 1.7534 | Train Acc: 42.42% | Val Loss: 1.2864 | Val Acc: 64.35%
Best model saved! Accuracy: 64.35%


                                                                                                             

Epoch  10/150 | Train Loss: 1.5832 | Train Acc: 50.68% | Val Loss: 1.1000 | Val Acc: 73.11%
Best model saved! Accuracy: 73.11%


                                                                                                             

Epoch  15/150 | Train Loss: 1.5049 | Train Acc: 54.41% | Val Loss: 0.9624 | Val Acc: 79.79%
Best model saved! Accuracy: 79.79%


                                                                                                             

Epoch  20/150 | Train Loss: 1.4157 | Train Acc: 58.37% | Val Loss: 0.9193 | Val Acc: 81.42%
Best model saved! Accuracy: 81.42%


                                                                                                             

Epoch  25/150 | Train Loss: 1.3258 | Train Acc: 63.25% | Val Loss: 0.8505 | Val Acc: 84.82%
Best model saved! Accuracy: 84.82%


                                                                                                             

Epoch  30/150 | Train Loss: 1.2980 | Train Acc: 64.44% | Val Loss: 0.7975 | Val Acc: 87.06%
Best model saved! Accuracy: 87.06%


                                                                                                             

Epoch  35/150 | Train Loss: 1.3006 | Train Acc: 64.56% | Val Loss: 0.7673 | Val Acc: 88.53%
Best model saved! Accuracy: 88.53%


                                                                                                             

Epoch  40/150 | Train Loss: 1.2581 | Train Acc: 64.61% | Val Loss: 0.7450 | Val Acc: 89.44%
Best model saved! Accuracy: 89.44%


                                                                                                             

Epoch  45/150 | Train Loss: 1.2286 | Train Acc: 66.89% | Val Loss: 0.7368 | Val Acc: 90.09%
Best model saved! Accuracy: 90.09%


                                                                                                             

Epoch  50/150 | Train Loss: 1.2130 | Train Acc: 67.08% | Val Loss: 0.7248 | Val Acc: 90.39%
Best model saved! Accuracy: 90.39%


                                                                                                             

Epoch  55/150 | Train Loss: 1.1970 | Train Acc: 68.15% | Val Loss: 0.7253 | Val Acc: 90.57%
Best model saved! Accuracy: 90.57%


                                                                                                             

Epoch  60/150 | Train Loss: 1.1462 | Train Acc: 71.08% | Val Loss: 0.7107 | Val Acc: 91.69%
Best model saved! Accuracy: 91.69%


                                                                                                             

Epoch  65/150 | Train Loss: 1.1536 | Train Acc: 69.93% | Val Loss: 0.7167 | Val Acc: 91.32%


                                                                                                             

Epoch  70/150 | Train Loss: 1.0991 | Train Acc: 72.79% | Val Loss: 0.7076 | Val Acc: 91.46%


                                                                                                             

Epoch  75/150 | Train Loss: 1.1147 | Train Acc: 72.26% | Val Loss: 0.6852 | Val Acc: 92.00%
Best model saved! Accuracy: 92.00%


                                                                                                             

Epoch  80/150 | Train Loss: 1.1044 | Train Acc: 73.07% | Val Loss: 0.6831 | Val Acc: 92.62%
Best model saved! Accuracy: 92.62%


                                                                                                             

Epoch  85/150 | Train Loss: 1.0830 | Train Acc: 73.47% | Val Loss: 0.6797 | Val Acc: 92.97%
Best model saved! Accuracy: 92.97%


                                                                                                             

Epoch  90/150 | Train Loss: 1.0787 | Train Acc: 73.90% | Val Loss: 0.6873 | Val Acc: 92.42%


                                                                                                             

Epoch  95/150 | Train Loss: 1.0522 | Train Acc: 75.34% | Val Loss: 0.6780 | Val Acc: 93.25%
Best model saved! Accuracy: 93.25%


                                                                                                              

Epoch 100/150 | Train Loss: 1.0531 | Train Acc: 75.06% | Val Loss: 0.6722 | Val Acc: 93.66%
Best model saved! Accuracy: 93.66%


                                                                                                              

Epoch 105/150 | Train Loss: 1.0630 | Train Acc: 74.84% | Val Loss: 0.6621 | Val Acc: 93.61%


                                                                                                              

Epoch 110/150 | Train Loss: 1.0127 | Train Acc: 76.65% | Val Loss: 0.6664 | Val Acc: 93.74%
Best model saved! Accuracy: 93.74%


                                                                                                              

Epoch 115/150 | Train Loss: 1.0370 | Train Acc: 75.91% | Val Loss: 0.6503 | Val Acc: 94.37%
Best model saved! Accuracy: 94.37%


                                                                                                              

Epoch 120/150 | Train Loss: 1.0298 | Train Acc: 74.78% | Val Loss: 0.6437 | Val Acc: 94.83%
Best model saved! Accuracy: 94.83%


                                                                                                              

Epoch 125/150 | Train Loss: 1.0106 | Train Acc: 75.51% | Val Loss: 0.6463 | Val Acc: 94.50%


                                                                                                              

Epoch 130/150 | Train Loss: 0.9818 | Train Acc: 77.56% | Val Loss: 0.6439 | Val Acc: 94.51%


                                                                                                              

Epoch 135/150 | Train Loss: 0.9855 | Train Acc: 76.86% | Val Loss: 0.6440 | Val Acc: 94.55%


                                                                                                              

Epoch 140/150 | Train Loss: 1.0057 | Train Acc: 75.62% | Val Loss: 0.6432 | Val Acc: 94.65%


                                                                                                              

Epoch 145/150 | Train Loss: 0.9973 | Train Acc: 77.33% | Val Loss: 0.6431 | Val Acc: 94.66%


                                                                                                              

Epoch 150/150 | Train Loss: 0.9805 | Train Acc: 77.74% | Val Loss: 0.6429 | Val Acc: 94.71%

Training completed! Best accuracy: 94.83%
Final Test Accuracy: 94.71%
