In [None]:
# Cell 1: Setup and Installations
!pip install -q timm pandas

print("Setup complete. Libraries installed.")

In [None]:
# Cell 2: Configuration for Fine-Tuning
from dataclasses import dataclass
import torch

@dataclass
class CFG:
    # --- Data ---
    dataset: str = "CIFAR10"
    # Set your data root. For Kaggle, it's usually "/kaggle/input/cifar-10/cifar-10-batches-py"
    # For Colab, you might need to download it first.
    data_root: str = "/kaggle/input/cifar-10/cifar-10-batches-py"
    num_classes: int = 10
    
    # --- Fine-Tuning Resolution ---
    # We increase the image size for fine-tuning. Your model was trained on 48x48.
    # Let's fine-tune on 64x64. This forces the model to learn finer details.
    img_size: int = 64
    
    # --- Pre-trained Model Path ---
    # Path to your best model file.
    pretrained_path: str = "/kaggle/input/cifar_vit_v0/pytorch/default/1/best_vit_cifar10.pth"
    # Original image size the model was trained on. This is CRITICAL for position embedding interpolation.
    original_img_size: int = 48

    # --- Model Architecture (should match your pre-trained model) ---
    patch_size: int = 4
    embed_dim: int = 512
    depth: int = 8
    num_heads: int = 8
    mlp_ratio: float = 4.0
    qkv_bias: bool = True
    drop_rate: float = 0.1
    attn_drop_rate: float = 0.0
    drop_path_rate: float = 0.2

    # --- Fine-Tuning Schedule ---
    # Fine-tuning requires fewer epochs and a smaller learning rate.
    epochs: int = 50
    batch_size: int = 128
    optimizer: str = "adamw" # Paper suggests SGD with momentum for fine-tuning 
    base_lr: float = 5e-4  # Smaller LR for fine-tuning
    min_lr: float = 1e-6
    weight_decay: float = 0.05
    warmup_epochs: int = 5
    mixup_freeze_epochs: int = 10  # disable MixUp/CutMix/LS for first N epochs


    # --- Augmentations & Tricks ---
    # We can slightly reduce strong regularization during fine-tuning.
    label_smoothing: float = 0.1
    mixup_alpha: float = 0.2
    cutmix_alpha: float = 1.0
    random_erasing_p: float = 0.1

    # --- Training Niceties ---
    amp: bool = True
    compile: bool = True
    ema_decay: float = 0.9999 # Slightly stronger EMA
    grad_accum_steps: int = 1
    out_dir: str = "./outputs"
    seed: int = 42

cfg = CFG()

# Set device and workers
cfg.device = "cuda" if torch.cuda.is_available() else "cpu"
cfg.num_workers = 2 # Or 4 if your machine supports it

In [None]:
# Cell 3: Imports
import os
import time
import math
import random
import numpy as np
import pandas as pd
from collections import OrderedDict

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset
from torchvision import datasets, transforms

from timm.models.layers import DropPath, to_2tuple, trunc_normal_
from timm.data.mixup import Mixup
from timm.loss import LabelSmoothingCrossEntropy, SoftTargetCrossEntropy
from timm.scheduler.cosine_lr import CosineLRScheduler

# For EMA
from copy import deepcopy

# For plots
import matplotlib.pyplot as plt

print("Imports complete.")

In [None]:
# Cell 4: Utilities

def seed_everything(seed: int):
    """Set random seeds for reproducibility."""
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

class ModelEmaV2(nn.Module):
    """ Model Exponential Moving Average V2
    From timm library.
    """
    def __init__(self, model, decay=0.9999, device=None):
        super(ModelEmaV2, self).__init__()
        self.module = deepcopy(model)
        self.module.eval()
        self.decay = decay
        self.device = device
        if self.device is not None:
            self.module.to(device=device)

    def _update(self, model, update_fn):
        with torch.no_grad():
            for ema_v, model_v in zip(self.module.state_dict().values(), model.state_dict().values()):
                if self.device is not None:
                    model_v = model_v.to(device=self.device)
                ema_v.copy_(update_fn(ema_v, model_v))

    def update(self, model):
        self._update(model, update_fn=lambda e, m: self.decay * e + (1. - self.decay) * m)

    def set(self, model):
        self._update(model, update_fn=lambda e, m: m)

seed_everything(cfg.seed)
print("Utilities defined and seed set.")

In [None]:
# Cell 5: Data Loading and Augmentations

def get_datasets(img_size):
    # Mean and std for CIFAR-10
    mean, std = (0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616)

    # Strong augmentations for training
    train_transform = transforms.Compose([
        transforms.Resize(img_size, interpolation=transforms.InterpolationMode.BICUBIC),
        transforms.RandomCrop(img_size, padding=4),
        transforms.RandomHorizontalFlip(),
        transforms.TrivialAugmentWide(), # A good auto-augmentation policy
        transforms.ToTensor(),
        transforms.Normalize(mean, std),
        transforms.RandomErasing(p=cfg.random_erasing_p, scale=(0.02, 0.33), ratio=(0.3, 3.3), value=0, inplace=False),
    ])

    # Simple transforms for validation
    val_transform = transforms.Compose([
        transforms.Resize(img_size, interpolation=transforms.InterpolationMode.BICUBIC),
        transforms.ToTensor(),
        transforms.Normalize(mean, std),
    ])

    train_dataset = datasets.CIFAR10(
        root="/kaggle/working/",
        train=True,
        download=True,       #  allow download now
        transform=train_transform
    )
    val_dataset = datasets.CIFAR10(
        root="/kaggle/working/",
        train=False,
        download=True,
        transform=val_transform
    )
    
    return train_dataset, val_dataset

train_dataset, val_dataset = get_datasets(cfg.img_size)

# Create DataLoaders
train_loader = DataLoader(train_dataset, batch_size=cfg.batch_size, shuffle=True, num_workers=cfg.num_workers, pin_memory=True)
val_loader = DataLoader(val_dataset, batch_size=cfg.batch_size, shuffle=False, num_workers=cfg.num_workers, pin_memory=True)

print(f"Data loaded. Training on {len(train_dataset)} images, validating on {len(val_dataset)} images.")
print(f"Image size for fine-tuning: {cfg.img_size}x{cfg.img_size}")

In [None]:
# Cell 6: Vision Transformer (ViT) Model Definition

class PatchEmbed(nn.Module):
    """ Image to Patch Embedding """
    def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768):
        super().__init__()
        img_size = to_2tuple(img_size)
        patch_size = to_2tuple(patch_size)
        num_patches = (img_size[1] // patch_size[1]) * (img_size[0] // patch_size[0])
        self.img_size = img_size
        self.patch_size = patch_size
        self.num_patches = num_patches
        self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)

    def forward(self, x):
        B, C, H, W = x.shape
        assert H == self.img_size[0] and W == self.img_size[1], \
            f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
        x = self.proj(x).flatten(2).transpose(1, 2)
        return x

class Attention(nn.Module):
    def __init__(self, dim, num_heads=8, qkv_bias=False, attn_drop=0., proj_drop=0.):
        super().__init__()
        self.num_heads = num_heads
        head_dim = dim // num_heads
        self.scale = head_dim ** -0.5
        self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
        self.attn_drop = nn.Dropout(attn_drop)
        self.proj = nn.Linear(dim, dim)
        self.proj_drop = nn.Dropout(proj_drop)

    def forward(self, x):
        B, N, C = x.shape
        qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
        q, k, v = qkv[0], qkv[1], qkv[2]
        attn = (q @ k.transpose(-2, -1)) * self.scale
        attn = attn.softmax(dim=-1)
        attn = self.attn_drop(attn)
        x = (attn @ v).transpose(1, 2).reshape(B, N, C)
        x = self.proj(x)
        x = self.proj_drop(x)
        return x

class Mlp(nn.Module):
    def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
        super().__init__()
        out_features = out_features or in_features
        hidden_features = hidden_features or in_features
        self.fc1 = nn.Linear(in_features, hidden_features)
        self.act = act_layer()
        self.fc2 = nn.Linear(hidden_features, out_features)
        self.drop = nn.Dropout(drop)

    def forward(self, x):
        x = self.fc1(x)
        x = self.act(x)
        x = self.drop(x)
        x = self.fc2(x)
        x = self.drop(x)
        return x

class Block(nn.Module):
    def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, drop=0., attn_drop=0.,
                 drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm):
        super().__init__()
        self.norm1 = norm_layer(dim)
        self.attn = Attention(dim, num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop)
        self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
        self.norm2 = norm_layer(dim)
        mlp_hidden_dim = int(dim * mlp_ratio)
        self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)

    def forward(self, x):
        x = x + self.drop_path(self.attn(self.norm1(x)))
        x = x + self.drop_path(self.mlp(self.norm2(x)))
        return x

class VisionTransformer(nn.Module):
    def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, embed_dim=768, depth=12,
                 num_heads=12, mlp_ratio=4., qkv_bias=True, drop_rate=0., attn_drop_rate=0.,
                 drop_path_rate=0., norm_layer=nn.LayerNorm):
        super().__init__()
        self.num_classes = num_classes
        self.num_features = self.embed_dim = embed_dim
        self.patch_embed = PatchEmbed(img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim)
        num_patches = self.patch_embed.num_patches

        self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
        self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim))
        self.pos_drop = nn.Dropout(p=drop_rate)

        dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)]
        self.blocks = nn.ModuleList([
            Block(dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias,
                  drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer)
            for i in range(depth)])
        self.norm = norm_layer(embed_dim)
        self.head = nn.Linear(embed_dim, num_classes)

        trunc_normal_(self.pos_embed, std=.02)
        trunc_normal_(self.cls_token, std=.02)
        self.apply(self._init_weights)

    def _init_weights(self, m):
        if isinstance(m, nn.Linear):
            trunc_normal_(m.weight, std=.02)
            if isinstance(m, nn.Linear) and m.bias is not None:
                nn.init.constant_(m.bias, 0)
        elif isinstance(m, nn.LayerNorm):
            nn.init.constant_(m.bias, 0)
            nn.init.constant_(m.weight, 1.0)

    def forward(self, x):
        B = x.shape[0]
        x = self.patch_embed(x)
        cls_tokens = self.cls_token.expand(B, -1, -1)
        x = torch.cat((cls_tokens, x), dim=1)
        x = x + self.pos_embed
        x = self.pos_drop(x)
        for blk in self.blocks:
            x = blk(x)
        x = self.norm(x)
        return self.head(x[:, 0])

def interpolate_pos_encoding(model, checkpoint_model):
    # Try to find the pos_embed key automatically
    pos_keys = [k for k in checkpoint_model.keys() if "pos_embed" in k]
    if not pos_keys:
        raise KeyError("No 'pos_embed' key found in checkpoint!")
    previous_pos_embed = checkpoint_model[pos_keys[0]]

    current_pos_embed = model.pos_embed
    num_patches_new = current_pos_embed.shape[1] - 1
    num_patches_old = previous_pos_embed.shape[1] - 1

    if num_patches_new == num_patches_old:
        print(" Positional embeddings match. No interpolation needed.")
        return previous_pos_embed

    print(f"Interpolating position embeddings from {num_patches_old} → {num_patches_new} patches")

    cls_pos_embed = previous_pos_embed[:, 0]
    patch_pos_embed = previous_pos_embed[:, 1:]
    dim = patch_pos_embed.shape[-1]

    old_size = int(math.sqrt(num_patches_old))
    new_size = int(math.sqrt(num_patches_new))
    patch_pos_embed = patch_pos_embed.reshape(1, old_size, old_size, dim).permute(0, 3, 1, 2)
    patch_pos_embed = F.interpolate(patch_pos_embed, size=(new_size, new_size), mode='bicubic', align_corners=False)
    patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).reshape(1, -1, dim)

    new_pos_embed = torch.cat((cls_pos_embed.unsqueeze(0), patch_pos_embed), dim=1)
    return new_pos_embed


print("ViT model definition complete.")

In [None]:
# Cell 7: Model Initialization and Weight Loading

# 1. Instantiate model
model = VisionTransformer(
    img_size=cfg.img_size,
    patch_size=cfg.patch_size,
    num_classes=cfg.num_classes,
    embed_dim=cfg.embed_dim,
    depth=cfg.depth,
    num_heads=cfg.num_heads,
    mlp_ratio=cfg.mlp_ratio,
    qkv_bias=cfg.qkv_bias,
    drop_rate=cfg.drop_rate,
    attn_drop_rate=cfg.attn_drop_rate,
    drop_path_rate=cfg.drop_path_rate
).to(cfg.device)

# 2. Load pretrained weights
if os.path.exists(cfg.pretrained_path):
    print(f"Loading pre-trained weights from {cfg.pretrained_path}")
    checkpoint = torch.load(cfg.pretrained_path, map_location=cfg.device)

    # unwrap model or model_ema
    checkpoint_model = checkpoint.get("model") or checkpoint.get("model_ema") or checkpoint

    # strip 'module.' prefixes
    state_dict = OrderedDict()
    for k, v in checkpoint_model.items():
        if k.startswith("module."):
            name = k[len("module."):]
        elif k.startswith("_orig_mod."):
            name = k[len("_orig_mod."):]
        else:
            name = k
        state_dict[name] = v


    # --- a. interpolate positional embeddings safely ---
    pos_keys = [k for k in state_dict.keys() if "pos_embed" in k]
    if not pos_keys:
        raise KeyError("No 'pos_embed' key found in checkpoint!")
    pos_key = pos_keys[0]
    state_dict[pos_key] = interpolate_pos_encoding(model, {"pos_embed": state_dict[pos_key]})

    # --- b. remove old classification head (robustly) ---
    for key in ["head.weight", "head.bias", "fc.weight", "fc.bias", "classifier.weight", "classifier.bias"]:
        if key in state_dict:
            del state_dict[key]
            print(f"Removed {key} from checkpoint.")

    # --- c. load state dict with strict=False ---
    msg = model.load_state_dict(state_dict, strict=False)
    print(f"Weight loading message: {msg}")

    # --- d. re-init final head ---
    print("Re-initializing classification head.")
    model.head.weight.data.mul_(0.001)
    model.head.bias.data.zero_()
else:
    print(f"No checkpoint found at {cfg.pretrained_path}. Training from scratch.")

# (Optional) Compile
if cfg.compile and hasattr(torch, "compile"):
    print("Compiling model...")
    model = torch.compile(model)

print("Model ready for fine-tuning.")


In [None]:
# Cell 8: Loss, Optimizer, and Scheduler

# Mixup function for data augmentation
mixup_fn = Mixup(
    mixup_alpha=cfg.mixup_alpha,
    cutmix_alpha=cfg.cutmix_alpha,
    label_smoothing=cfg.label_smoothing,
    num_classes=cfg.num_classes
)

# Loss function
if cfg.mixup_alpha > 0.:
    criterion = SoftTargetCrossEntropy()
elif cfg.label_smoothing > 0.:
    criterion = LabelSmoothingCrossEntropy(smoothing=cfg.label_smoothing)
else:
    criterion = nn.CrossEntropyLoss()

# Optimizer
if cfg.optimizer.lower() == 'sgd':
    optimizer = torch.optim.SGD(model.parameters(), lr=cfg.base_lr, momentum=0.9, weight_decay=cfg.weight_decay)
else: # Default to AdamW
    optimizer = torch.optim.AdamW(model.parameters(), lr=cfg.base_lr, weight_decay=cfg.weight_decay)

# Learning Rate Scheduler
lr_scheduler = CosineLRScheduler(
    optimizer,
    t_initial=cfg.epochs,
    lr_min=cfg.min_lr,
    warmup_t=cfg.warmup_epochs,
    warmup_lr_init=1e-6,
    warmup_prefix=True
)

print(f"Using criterion: {type(criterion).__name__}")
print(f"Using optimizer: {type(optimizer).__name__}")

In [None]:
# Cell 9: Training and Validation Functions

def train_one_epoch(model, loader, optimizer, criterion, device, epoch, scaler, mixup_fn, model_ema):
    model.train()
    total_loss = 0.0
    total_correct = 0
    total_samples = 0
    start_time = time.time()
    
    for i, (inputs, targets) in enumerate(loader):
        inputs, targets = inputs.to(device), targets.to(device)
        
        # Apply mixup
        # MixUp/CutMix warmup: disable for first cfg.mixup_freeze_epochs
        if epoch < cfg.mixup_freeze_epochs:
            samples = inputs
            targets = torch.nn.functional.one_hot(targets, num_classes=cfg.num_classes).float()
        else:
            samples, targets = mixup_fn(inputs, targets)

        
        with torch.cuda.amp.autocast(enabled=cfg.amp):
            outputs = model(samples)
            loss = criterion(outputs, targets)
        
        scaler.scale(loss).backward()
        # Grad clipping (after unscale, before step)
        scaler.unscale_(optimizer)
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        scaler.step(optimizer)
        scaler.update()
        optimizer.zero_grad()
        
        if model_ema is not None:
            model_ema.update(model)

        total_loss += loss.item()
        
        # Accuracy for logging
        _, predicted = outputs.max(1)
        _, true_labels = targets.max(1) # Un-mix labels for accuracy
        total_correct += predicted.eq(true_labels).sum().item()
        total_samples += targets.size(0)

        if i % 50 == 0:
            print(f"  Batch {i}/{len(loader)} | Loss: {loss.item():.4f}")

    epoch_time = time.time() - start_time
    avg_loss = total_loss / len(loader)
    accuracy = 100. * total_correct / total_samples
    
    print(f"End of Epoch {epoch+1} | Train Time: {epoch_time:.2f}s | Avg Loss: {avg_loss:.4f} | Train Acc: {accuracy:.2f}%")
    return avg_loss, accuracy


@torch.no_grad()
def validate(model, loader, device):
    model.eval()
    total_correct = 0
    total_samples = 0
    
    for inputs, targets in loader:
        inputs, targets = inputs.to(device), targets.to(device)
        with torch.cuda.amp.autocast(enabled=cfg.amp):
            logits = model(inputs)
            logits_flip = model(torch.flip(inputs, dims=[-1]))
            outputs = (logits + logits_flip) / 2

        _, predicted = outputs.max(1)
        total_correct += predicted.eq(targets).sum().item()
        total_samples += targets.size(0)
        
    accuracy = 100. * total_correct / total_samples
    print(f"Validation Accuracy: {accuracy:.2f}%")
    return accuracy

print("Training and validation functions defined.")

In [None]:
# Cell 10: Main Fine-Tuning Loop

print("Starting fine-tuning...")

# Create output directory
os.makedirs(cfg.out_dir, exist_ok=True)

# Setup scaler for AMP
scaler = torch.cuda.amp.GradScaler(enabled=cfg.amp)

# Setup EMA
model_ema = ModelEmaV2(model, decay=cfg.ema_decay, device=cfg.device)

best_acc = 0.0
history = []

start_total_time = time.time()

for epoch in range(cfg.epochs):
    print("-" * 50)
    current_lr = optimizer.param_groups[0]["lr"]
    print(f"Epoch {epoch+1}/{cfg.epochs} | LR: {current_lr:.6f}")
    
    # Train
    train_loss, train_acc = train_one_epoch(model, train_loader, optimizer, criterion, cfg.device, epoch, scaler, mixup_fn, model_ema)
    
    # Update LR scheduler
    lr_scheduler.step(epoch + 1)
    
    # Validate with EMA model for stability
    val_acc = validate(model_ema.module, val_loader, cfg.device)
    
    history.append({'epoch': epoch, 'train_loss': train_loss, 'train_acc': train_acc, 'val_acc': val_acc, 'lr': current_lr})
    
    if val_acc > best_acc:
        best_acc = val_acc
        print(f"New best accuracy: {best_acc:.2f}%! Saving model...")
        save_path = os.path.join(cfg.out_dir, "best_finetuned_vit.pth")
        torch.save({
            'model_ema': model_ema.module.state_dict(),
            'config': cfg,
            'epoch': epoch,
        }, save_path)

total_training_time = time.time() - start_total_time
print(f"\nFine-tuning finished in {total_training_time/60:.2f} minutes.")
print(f"Best validation accuracy: {best_acc:.2f}%")

# Create a DataFrame for easy viewing
history_df = pd.DataFrame(history)

In [None]:
# Cell 11: Results Visualization and README Generation

# Plotting the results
plt.figure(figsize=(12, 5))

# Plot Accuracy
plt.subplot(1, 2, 1)
plt.plot(history_df['epoch'], history_df['train_acc'], label='Train Accuracy')
plt.plot(history_df['epoch'], history_df['val_acc'], label='Validation Accuracy')
plt.title('Model Accuracy')
plt.xlabel('Epoch')
plt.ylabel('Accuracy (%)')
plt.legend()
plt.grid(True)

# Plot Loss
plt.subplot(1, 2, 2)
plt.plot(history_df['epoch'], history_df['train_loss'], label='Train Loss')
plt.title('Model Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.legend()
plt.grid(True)

plt.tight_layout()
plt.show()
