# Training


Importing libaries and setting up setups

In [None]:
# Importing and setups
!pip install ptflops # need to install everytime either cpu or gpu
import os
import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm
import time
import pandas as pd
import seaborn as sns
from sklearn.metrics import confusion_matrix, precision_recall_fscore_support, classification_report, roc_curve, auc
import csv
from datetime import datetime
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.optim.lr_scheduler import CosineAnnealingLR
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
from torchvision.utils import make_grid
from ptflops import get_model_complexity_info

# Set seeds for reproducibility
seed = 42
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)  # also need to set cuda seed
np.random.seed(seed)
torch.backends.cudnn.deterministic = True  # reproducible

# Check for GPU
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")
if device.type == 'cpu':
    print("WARNING: Training will be very slow without GPU!")

Collecting ptflops
  Downloading ptflops-0.7.4-py3-none-any.whl.metadata (9.4 kB)
Collecting nvidia-cuda-nvrtc-cu12==12.4.127 (from torch>=2.0->ptflops)
  Downloading nvidia_cuda_nvrtc_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-runtime-cu12==12.4.127 (from torch>=2.0->ptflops)
  Downloading nvidia_cuda_runtime_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-cupti-cu12==12.4.127 (from torch>=2.0->ptflops)
  Downloading nvidia_cuda_cupti_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cudnn-cu12==9.1.0.70 (from torch>=2.0->ptflops)
  Downloading nvidia_cudnn_cu12-9.1.0.70-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cublas-cu12==12.4.5.8 (from torch>=2.0->ptflops)
  Downloading nvidia_cublas_cu12-12.4.5.8-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cufft-cu12==11.2.1.3 (from torch>=2.0->ptflops)
  Downloading nvidia_

Preparing dataset

In [None]:
# Data Preparation with Augmentation
class CIFAR10DataModule:
    def __init__(self, batch_size=128, num_workers=4):
        self.batch_size = batch_size
        self.num_workers = num_workers

        # CIFAR10 normalization values - DON'T CHANGE
        self.mean = (0.4914, 0.4822, 0.4465)
        self.std = (0.2470, 0.2435, 0.2616)

        # Define transformations
        self.train_transform = transforms.Compose([
            transforms.RandomCrop(32, padding=4),
            transforms.RandomHorizontalFlip(),  # standard augmentation
            transforms.RandAugment(num_ops=2, magnitude=9),  # tried 3 ops but too aggressive
            transforms.ToTensor(),
            transforms.Normalize(self.mean, self.std)
        ])

        # No augmentation for test set
        self.test_transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize(self.mean, self.std)
        ])

    def setup(self):
        # Download datasets
        print("Setting up datasets...")
        self.train_dataset = datasets.CIFAR10(
            root='./data',
            train=True,
            download=True,
            transform=self.train_transform
        )

        self.val_dataset = datasets.CIFAR10(
            root='./data',
            train=False,
            download=True,
            transform=self.test_transform
        )
        print(f"Loaded {len(self.train_dataset)} training and {len(self.val_dataset)} validation samples")

    def train_dataloader(self):
        return DataLoader(
            self.train_dataset,
            batch_size=self.batch_size,
            shuffle=True,  # important for training!
            num_workers=self.num_workers,
            pin_memory=True  # helps if using GPU
        )

    def val_dataloader(self):
        return DataLoader(
            self.val_dataset,
            batch_size=self.batch_size,
            shuffle=False,  # no need to shuffle for validation
            num_workers=self.num_workers,
            pin_memory=True
        )

patch embedding

In [None]:
# Patch Embedding Layer
class PatchEmbedding(nn.Module):
    def __init__(self, img_size=32, patch_size=4, in_channels=3, embed_dim=192):
        super().__init__()
        self.img_size = img_size
        self.patch_size = patch_size
        self.n_patches = (img_size // patch_size) ** 2

        # Originally used a linear layer here, but conv is more efficient and does the same thing
        self.proj = nn.Conv2d(
            in_channels,
            embed_dim,
            kernel_size=patch_size,
            stride=patch_size
        )

    def forward(self, x):
        # x shape: [B, C, H, W]
        B, C, H, W = x.shape
        assert H == self.img_size and W == self.img_size, \
            f"Input image size ({H}*{W}) doesn't match expected size ({self.img_size}*{self.img_size})"

        # [B, C, H, W] -> [B, E, H/P, W/P] -> [B, E, (H/P)*(W/P)] -> [B, (H/P)*(W/P), E]
        x = self.proj(x)  # [B, E, H/P, W/P]
        x = x.flatten(2)  # [B, E, (H/P)*(W/P)]
        x = x.transpose(1, 2)  # [B, (H/P)*(W/P), E]

        return x

Multi Head Attention (MHA)

In [None]:
# Multi-Head Self-Attention
class MultiHeadSelfAttention(nn.Module):
    def __init__(self, embed_dim=192, num_heads=8, dropout=0.1): # 192/8 = 24 per head
        super().__init__()
        self.embed_dim = embed_dim
        self.num_heads = num_heads
        self.head_dim = embed_dim // num_heads

        # Double-check dimensions
        assert self.head_dim * num_heads == embed_dim, \
            f"embed_dim {embed_dim} must be divisible by num_heads {num_heads}"

        # Combined QKV projections
        self.qkv = nn.Linear(embed_dim, embed_dim * 3)
        self.proj = nn.Linear(embed_dim, embed_dim)
        self.attn_dropout = nn.Dropout(dropout)
        self.proj_dropout = nn.Dropout(dropout)



    def forward(self, x):
        # x shape: [B, N, E] - B=batch, N=sequence_length, E=embedding_dim
        B, N, E = x.shape

        # Project to Q, K, V and reshape for multi-head attention
        # This is that fancy reshape for multi-head attention
        qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4)
        q, k, v = qkv[0], qkv[1], qkv[2]  # [B, H, N, D] - H=heads, D=head_dim

        # Scaled dot-product attention
        # The scaling is super important - training dies without it
        attn = (q @ k.transpose(-2, -1)) * (1.0 / np.sqrt(self.head_dim))  # [B, H, N, N]
        attn = F.softmax(attn, dim=-1)
        attn = self.attn_dropout(attn)  # helps generalization

        # Apply attention to values
        x = (attn @ v).transpose(1, 2).reshape(B, N, E)  # [B, N, E]
        x = self.proj(x)  # final projection
        x = self.proj_dropout(x)

        return x

MLP

In [None]:
# MLP Block
class MLP(nn.Module):
    def __init__(self, in_features, hidden_features, out_features, dropout=0.1):
        super().__init__()
        self.fc1 = nn.Linear(in_features, hidden_features)
        # GELU Better than ReLU for transformers
        self.act = nn.GELU()
        self.fc2 = nn.Linear(hidden_features, out_features)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        x = self.fc1(x)
        x = self.act(x)
        x = self.dropout(x)
        x = self.fc2(x)
        x = self.dropout(x)  # second dropout seems to help
        return x

Transformer Encoder Block

In [None]:
# Transformer Encoder Block
class TransformerBlock(nn.Module):
    def __init__(self, embed_dim=192, num_heads=8, mlp_ratio=4.0, dropout=0.1):
        super().__init__()
        self.norm1 = nn.LayerNorm(embed_dim)
        self.attn = MultiHeadSelfAttention(embed_dim, num_heads, dropout)
        self.norm2 = nn.LayerNorm(embed_dim)
        self.mlp = MLP(
            in_features=embed_dim,
            hidden_features=int(embed_dim * mlp_ratio),  # the ratio matters!
            out_features=embed_dim,
            dropout=dropout
        )
        # NOTE: we're using pre-norm formulation

    def forward(self, x):
        # Pre-norm formulation - more stable, can train deeper networks
        # x + sublayer(norm(x)) instead of norm(x + sublayer(x))
        x = x + self.attn(self.norm1(x))
        x = x + self.mlp(self.norm2(x))
        return x

Complete Vision Transformer Model

In [None]:
# Complete Vision Transformer Model
class VisionTransformer(nn.Module):
    def __init__(
        self,
        img_size=32,
        patch_size=4,  # 4x4 patches for CIFAR ie(32^2//4^2 == 64 tokens)
        in_channels=3, # RGB channel
        num_classes=10,# number of expected outputs
        embed_dim=192,  # tried 384 but too many params for CIFAR tend to overfit
        depth=9,  # paper uses 12, but 9 is enough for CIFAR and 12 tend to overfit
        num_heads=8,  # must divide embed_dim evenly 192/8 = 24
        mlp_ratio=4.0,
        dropout=0.1, # probablity of skiping connection ie 10 percent
        embed_dropout=0.1  # separate dropout rate for embeddings
    ):
        super().__init__()
        self.num_classes = num_classes
        self.embed_dim = embed_dim
        self.num_tokens = (img_size // patch_size) ** 2

        # Patch embedding
        self.patch_embed = PatchEmbedding(
            img_size=img_size,
            patch_size=patch_size,
            in_channels=in_channels,
            embed_dim=embed_dim
        )

        # Class token and position embeddings
        self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
        # Position embeddings - could use sinusoidal but learned works fine
        # postional embeddings are used because we have 8 multi head attention we need assign position for each vector
        self.pos_embed = nn.Parameter(torch.zeros(1, self.num_tokens + 1, embed_dim))

        # Initialize weights for faster convergence
        nn.init.trunc_normal_(self.pos_embed, std=0.02)
        nn.init.trunc_normal_(self.cls_token, std=0.02)

        self.dropout = nn.Dropout(embed_dropout)

        # Transformer blocks - this is the main part of the model
        self.blocks = nn.ModuleList([
            TransformerBlock(
                embed_dim=embed_dim,
                num_heads=num_heads,
                mlp_ratio=mlp_ratio,
                dropout=dropout
            )
            for _ in range(depth) # we just use for loop instead rewriting tranformer 8 times
        ])

        # Final normalization layer
        self.norm = nn.LayerNorm(embed_dim)

        # Classification head - just a linear layer
        self.head = nn.Linear(embed_dim, num_classes)

        # Initialize weights
        self.apply(self._init_weights)

        # How many params?
        #print(f"ViT params: {sum(p.numel() for p in self.parameters())}")

    def _init_weights(self, m):
        # Weight initialization matters for transformers!
        if isinstance(m, nn.Linear):
            nn.init.trunc_normal_(m.weight, std=0.02)
            if m.bias is not None:
                nn.init.constant_(m.bias, 0)
        elif isinstance(m, nn.LayerNorm):
            nn.init.constant_(m.bias, 0)
            nn.init.constant_(m.weight, 1.0)

    def forward(self, x):
        # x shape: [B, C, H, W]
        B = x.shape[0]

        # Create patch embeddings
        x = self.patch_embed(x)  # [B, N, E]

        # Add class token - used for final classification
        cls_token = self.cls_token.expand(B, -1, -1)  # [B, 1, E]
        x = torch.cat((cls_token, x), dim=1)  # [B, N+1, E]

        # Add position embeddings and apply dropout
        x = x + self.pos_embed  # broadcasting takes care of batch dim
        x = self.dropout(x)

        # Pass through transformer blocks
        for i, block in enumerate(self.blocks):
            # Could add intermediate supervision here?
            # Tried it, didn't help much, so removed it
            x = block(x)

        # Apply final normalization
        x = self.norm(x)

        # Take class token for classification
        # Could use pooling over all tokens but this works better
        x = x[:, 0]  # just get CLS token

        # Classification head
        x = self.head(x)
        # Could add an extra non-linearity here but linear seems fine

        return x

training and eval

In [None]:
# Training and Evaluation Utilities
def train_one_epoch(model, train_loader, criterion, optimizer, scheduler, device):
    model.train()  # set model to training mode
    total_loss = 0.0
    correct = 0
    total = 0
    batch_time = 0.0

    # Progress bar
    pbar = tqdm(train_loader, desc="Training")
    start_time = time.time()

    for batch_idx, (data, target) in enumerate(pbar):
        batch_start = time.time()
        data, target = data.to(device), target.to(device)

        # Forward pass
        optimizer.zero_grad()  # clear gradients first
        output = model(data)
        loss = criterion(output, target)

        # Backward pass
        loss.backward()

        # Could add gradient clipping here
        # torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
        # But Adam seems to work fine without it

        optimizer.step()

        # Update learning rate - using per-step scheduler
        if scheduler is not None:
            scheduler.step()

        # Track metrics
        total_loss += loss.item() * data.size(0)
        _, predicted = output.max(1)  # get predicted class
        total += target.size(0)
        correct += predicted.eq(target).sum().item()

        # Track batch time
        batch_end = time.time()
        batch_time += (batch_end - batch_start)

        # Update progress bar - helps to see how training is going
        pbar.set_postfix({
            "loss": f"{loss.item():.4f}",
            "acc": f"{100. * correct / total:.1f}%",
            #"lr": f"{optimizer.param_groups[0]['lr']:.6f}"  # uncomment for debugging
        })


    epoch_time = time.time() - start_time

    return total_loss / len(train_loader.dataset), 100. * correct / total, epoch_time, batch_time / len(train_loader)


def evaluate(model, val_loader, criterion, device, classes=None, full_metrics=False):
    model.eval()  # set model to evaluation mode
    total_loss = 0.0
    correct = 0
    total = 0
    inference_times = []

    # For confusion matrix and per-class metrics
    all_targets = []
    all_predictions = []

    with torch.no_grad():  # no need to track gradients during evaluation
        for data, target in tqdm(val_loader, desc="Evaluation"):
            data, target = data.to(device), target.to(device)

            # Measure inference time
            start_time = time.time()
            output = model(data)
            inference_time = time.time() - start_time
            inference_times.append(inference_time)

            loss = criterion(output, target)

            # Track metrics
            total_loss += loss.item() * data.size(0)
            _, predicted = output.max(1)
            total += target.size(0)
            correct += predicted.eq(target).sum().item()

            # Store targets and predictions for additional metrics
            all_targets.extend(target.cpu().numpy())
            all_predictions.extend(predicted.cpu().numpy())

    # Compute aggregate metrics
    avg_loss = total_loss / len(val_loader.dataset)
    accuracy = 100. * correct / total
    avg_inference_time = sum(inference_times) / len(inference_times)

    results = {
        'loss': avg_loss,
        'accuracy': accuracy,
        'inference_time_ms': avg_inference_time * 1000  # Convert to ms
    }

    # Add detailed metrics if requested
    if full_metrics and classes:
        # Calculate per-class precision, recall, f1-score
        # Can't skip this computation - might seem slow but it's useful info
        precision, recall, f1, support = precision_recall_fscore_support(
            all_targets, all_predictions, labels=range(len(classes)), average=None
        )

        # Create confusion matrix
        cm = confusion_matrix(all_targets, all_predictions, labels=range(len(classes)))

        # Add to results
        results['confusion_matrix'] = cm
        results['per_class'] = {
            'precision': precision,
            'recall': recall,
            'f1': f1,
            'support': support
        }
        results['classes'] = classes
        results['targets'] = all_targets
        results['predictions'] = all_predictions

    return results

metrics and visualization function

In [None]:
# Metrics and Visualization Functions
def calculate_and_plot_metrics(model, val_loader, criterion, device, classes):
    print("Calculating detailed metrics...")
    results = evaluate(model, val_loader, criterion, device, classes, full_metrics=True)

    # results
    cm = results['confusion_matrix']
    per_class = results['per_class']
    targets = results['targets']
    predictions = results['predictions']

    # 1. Plot confusion matrix
    plt.figure(figsize=(10, 8))
    # Tried various colormaps - Blues is most readable
    sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', xticklabels=classes, yticklabels=classes)
    plt.xlabel('Predicted')
    plt.ylabel('True')
    plt.title('Confusion Matrix')
    plt.tight_layout()
    plt.savefig('vit_confusion_matrix.png', dpi=200)  # higher DPI for paper-quality

    # 2. Plot per-class metrics
    plt.figure(figsize=(12, 6))
    x = np.arange(len(classes))
    width = 0.2  # width of bars

    # Plot bar chart with precision, recall, F1
    plt.bar(x - width, per_class['precision'], width, label='Precision')
    plt.bar(x, per_class['recall'], width, label='Recall')
    plt.bar(x + width, per_class['f1'], width, label='F1-Score')

    plt.xlabel('Classes')
    plt.ylabel('Score')
    plt.title('Per-Class Performance Metrics')
    plt.xticks(x, classes, rotation=45)
    plt.legend()
    plt.tight_layout()
    plt.savefig('vit_per_class_metrics.png')

    # 3. Compute and plot ROC curves (one-vs-rest)
    plt.figure(figsize=(12, 10))

    # Prepare one-hot encoded targets for ROC
    target_one_hot = np.zeros((len(targets), len(classes)))
    for i, t in enumerate(targets):
        target_one_hot[i, t] = 1

    # Get probability outputs for all samples
    # Need to rerun the model to get probabilities
    all_probs = []
    model.eval()
    with torch.no_grad():
        for data, _ in val_loader:
            data = data.to(device)
            outputs = model(data)
            probs = F.softmax(outputs, dim=1).cpu().numpy()
            all_probs.append(probs)

    all_probs = np.vstack(all_probs)

    # Plot ROC curve for each class
    mean_auc = 0
    for i, cls in enumerate(classes):
        fpr, tpr, _ = roc_curve(target_one_hot[:, i], all_probs[:, i])
        roc_auc = auc(fpr, tpr)
        mean_auc += roc_auc
        plt.plot(fpr, tpr, lw=2, label=f'{cls} (AUC = {roc_auc:.2f})')

    mean_auc /= len(classes)

    # Add diagonal line (random classifier)
    plt.plot([0, 1], [0, 1], 'k--', lw=2)
    plt.xlim([0.0, 1.0])
    plt.ylim([0.0, 1.05])
    plt.xlabel('False Positive Rate')
    plt.ylabel('True Positive Rate')
    plt.title(f'ROC Curves (Mean AUC = {mean_auc:.2f})')
    plt.legend(loc="lower right")
    plt.tight_layout()
    plt.savefig('vit_roc_curves.png')

    # Return metrics for CSV export
    return {
        'accuracy': results['accuracy'],
        'loss': results['loss'],
        'inference_time_ms': results['inference_time_ms'],
        'per_class_precision': per_class['precision'],
        'per_class_recall': per_class['recall'],
        'per_class_f1': per_class['f1'],
        'mean_auc': mean_auc
    }


# Calculate model complexity
def calculate_model_complexity(model, input_size=(3, 32, 32)):
    print("Calculating model complexity...")
    macs, params = get_model_complexity_info(
        model, input_size, as_strings=False, print_per_layer_stat=False
    )

    # Did you know? FLOPs ≈ 2 * MACs
    # ptflops returns MACs, but papers usually report FLOPs
    return {
        'params': params,
        'flops': macs * 2,  # Convert MACs to FLOPs
        'params_millions': params / 1e6,
        'flops_billions': macs * 2 / 1e9
    }


# Export metrics to CSV
def export_metrics_to_csv(metrics, model_name='ViT', filename='model_metrics.csv'):
    # Create directory if it doesn't exist
    os.makedirs('metrics', exist_ok=True)

    # Prepare CSV file path with timestamp
    timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
    filepath = f'metrics/{model_name}_{timestamp}.csv'

    # Flatten nested dictionaries
    flat_metrics = {}
    for key, value in metrics.items():
        if isinstance(value, dict):
            for subkey, subvalue in value.items():
                flat_metrics[f'{key}_{subkey}'] = subvalue
        elif isinstance(value, np.ndarray):
            for i, val in enumerate(value):
                flat_metrics[f'{key}_{i}'] = val
        else:
            flat_metrics[key] = value

    # Write to CSV
    with open(filepath, 'w', newline='') as csvfile:
        writer = csv.writer(csvfile)

        # Write header
        writer.writerow(['Metric', 'Value'])

        # Write metrics
        for key, value in flat_metrics.items():
            writer.writerow([key, value])

    print(f"Metrics exported to {filepath}")

    # Also create a summary CSV for model comparison
    # This is super handy when doing hyperparameter sweeps!
    summary_path = 'metrics/model_comparison.csv'

    # Check if summary file exists, create with header if not
    file_exists = os.path.isfile(summary_path)
    with open(summary_path, 'a', newline='') as csvfile:
        writer = csv.writer(csvfile)

        if not file_exists:
            writer.writerow([
                'Model', 'Accuracy', 'Loss', 'Params (M)', 'FLOPs (G)',
                'Inference Time (ms)', 'Mean AUC', 'Training Time (s)'
            ])


        writer.writerow([
            model_name,
            metrics['accuracy'],
            metrics['loss'],
            metrics['complexity']['params_millions'],
            metrics['complexity']['flops_billions'],
            metrics['inference_time_ms'],
            metrics['mean_auc'],
            metrics['training_time']
        ])

    print(f"Summary metrics added to {summary_path}")

Main training function

In [None]:
# Main Training Function
def train_vit_cifar10(epochs=100, batch_size=128, lr=1e-3, warmup_epochs=5, model_name='ViT'):
    # Setup data
    print(f"\n=== Setting up {model_name} training ===")
    print(f"Epochs: {epochs}, Batch size: {batch_size}, LR: {lr}")

    data_module = CIFAR10DataModule(batch_size=batch_size)
    data_module.setup()
    train_loader = data_module.train_dataloader()
    val_loader = data_module.val_dataloader()

    # CIFAR-10 classes
    classes = ('plane', 'car', 'bird', 'cat', 'deer',
               'dog', 'frog', 'horse', 'ship', 'truck')

    # Create model - this is the standard ViT config for CIFAR
    model = VisionTransformer(
        img_size=32,
        patch_size=4,  # 4x4 patches, so 8x8=64 patches total
        in_channels=3,
        num_classes=10,
        embed_dim=192,  # tried 384 but it was overkill
        depth=9,
        num_heads=8,  # 192 / 8 = 24 dim per head
        mlp_ratio=4.0,
        dropout=0.1,  # dropout helps a lot on CIFAR
        embed_dropout=0.1
    ).to(device)

    # Count parameters
    total_params = sum(p.numel() for p in model.parameters())
    print(f"Total parameters: {total_params:,}")

    # Calculate model complexity
    complexity = calculate_model_complexity(model)
    print(f"FLOPs: {complexity['flops_billions']:.2f} G")
    print(f"Parameters: {complexity['params_millions']:.2f} M")

    # Loss function
    criterion = nn.CrossEntropyLoss()

    # Optimizer
    #  AdamW works better for transformers
    optimizer = optim.AdamW(model.parameters(), lr=lr, weight_decay=0.05)

    # Learning rate scheduler - cosine decay with warmup
    # Warmup is crucial for transformer training stability
    total_steps = len(train_loader) * epochs
    warmup_steps = len(train_loader) * warmup_epochs

    # Learning rate schedule
    def lr_lambda(step):
        # Linear warmup + cosine decay
        if step < warmup_steps:
            return float(step) / float(max(1, warmup_steps))
        # Cosine annealing
        return 0.5 * (1.0 + np.cos(np.pi * float(step - warmup_steps) / float(total_steps - warmup_steps)))

    # Create scheduler
    scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda)

    # Training loop
    print("\n=== Starting training ===")
    best_acc = 0.0
    train_losses, train_accs = [], []
    val_losses, val_accs = [], []
    epoch_times, batch_times = [], []
    total_training_time = 0
    lr_history = []

    for epoch in range(epochs):
        print(f"\nEpoch {epoch+1}/{epochs}")

        # Log learning rate
        current_lr = optimizer.param_groups[0]['lr']
        lr_history.append(current_lr)

        # Train
        train_loss, train_acc, epoch_time, avg_batch_time = train_one_epoch(
            model, train_loader, criterion, optimizer, scheduler, device
        )
        train_losses.append(train_loss)
        train_accs.append(train_acc)
        epoch_times.append(epoch_time)
        batch_times.append(avg_batch_time)
        total_training_time += epoch_time

        # Evaluate
        val_results = evaluate(model, val_loader, criterion, device)
        val_loss = val_results['loss']
        val_acc = val_results['accuracy']
        val_losses.append(val_loss)
        val_accs.append(val_acc)

        # Print metrics
        print(f"Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.2f}%")
        print(f"Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.2f}%")
        print(f"Epoch Time: {epoch_time:.2f}s, Avg Batch Time: {avg_batch_time*1000:.2f}ms")
        print(f"Current LR: {current_lr:.6f}")

        # Save best model
        if val_acc > best_acc:
            best_acc = val_acc
            torch.save(model.state_dict(), "vit_cifar10_best.pth")
            print(f"New best validation accuracy: {best_acc:.2f}%!")
            # Also save at specific checkpoints (optional)
            #if val_acc > 90:
            #    torch.save(model.state_dict(), f"vit_cifar10_{val_acc:.1f}.pth")

        # Early stopping check after a reasonable number of epochs
        # No need to train forever if we're already good
        if epoch >= 50 and best_acc >= 90.0:
            print(f"Reached target accuracy of 90%. Stopping early!")
            break

    print(f"\n=== Training complete ===")
    print(f"Total training time: {total_training_time:.2f}s")
    print(f"Best validation accuracy: {best_acc:.2f}%")

    # Plot final expanded training metrics
    print("\nGenerating final training plots...")
    plt.figure(figsize=(18, 12))

    # 1. Loss curves
    plt.subplot(2, 3, 1)
    plt.plot(train_losses, label='Train Loss')
    plt.plot(val_losses, label='Val Loss')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.title('Loss Curves')
    plt.legend()

    # 2. Accuracy curves
    plt.subplot(2, 3, 2)
    plt.plot(train_accs, label='Train Accuracy')
    plt.plot(val_accs, label='Val Accuracy')
    plt.xlabel('Epoch')
    plt.ylabel('Accuracy (%)')
    plt.title('Accuracy Curves')
    plt.legend()

    # 3. Epoch times
    plt.subplot(2, 3, 3)
    plt.plot(epoch_times)
    plt.xlabel('Epoch')
    plt.ylabel('Time (s)')
    plt.title('Epoch Training Time')

    # 4. Batch times
    plt.subplot(2, 3, 4)
    plt.plot(batch_times)
    plt.xlabel('Epoch')
    plt.ylabel('Time (s)')
    plt.title('Average Batch Processing Time')

    # 5. Learning rate
    plt.subplot(2, 3, 5)
    plt.plot(lr_history)
    plt.xlabel('Epoch')
    plt.ylabel('Learning Rate')
    plt.title('Learning Rate Schedule')
    # Add grid for readability
    plt.grid(alpha=0.3)

    plt.tight_layout()
    plt.savefig('vit_training_metrics.png')
    plt.close()  # close to avoid display issues with multiple plots

    # Load best model for final evaluation
    print("\nLoading best model for final evaluation...")
    model.load_state_dict(torch.load("vit_cifar10_best.pth"))

    # Calculate detailed metrics
    detailed_metrics = calculate_and_plot_metrics(model, val_loader, criterion, device, classes)

    # Prepare metrics for export
    final_metrics = {
        'accuracy': best_acc,
        'loss': val_losses[-1],
        'inference_time_ms': detailed_metrics['inference_time_ms'],
        'training_time': total_training_time,
        'epochs': len(train_losses),
        'avg_epoch_time': sum(epoch_times) / len(epoch_times),
        'avg_batch_time': sum(batch_times) / len(batch_times),
        'complexity': complexity,
        'mean_auc': detailed_metrics['mean_auc'],
        'per_class': {
            'precision': detailed_metrics['per_class_precision'],
            'recall': detailed_metrics['per_class_recall'],
            'f1': detailed_metrics['per_class_f1']
        }
    }

    # Export metrics to CSV
    export_metrics_to_csv(final_metrics, model_name)

    print(f"Best validation accuracy: {best_acc:.2f}%")
    return model, best_acc, final_metrics

Attention Visualization function

In [None]:
# Attention Visualization Function
def visualize_attention(model, dataloader, device, num_images=4):
    # Get some test images
    dataiter = iter(dataloader)
    images, labels = next(dataiter)
    images = images[:num_images].to(device)
    labels = labels[:num_images]

    # Get class names
    classes = ('plane', 'car', 'bird', 'cat', 'deer',
               'dog', 'frog', 'horse', 'ship', 'truck')

    # Set model to eval mode
    model.eval()


    def get_attention_maps(x):
        B = x.shape[0]
        x = model.patch_embed(x)
        cls_token = model.cls_token.expand(B, -1, -1)
        x = torch.cat((cls_token, x), dim=1)
        x = x + model.pos_embed
        x = model.dropout(x)

        # Pass through transformer blocks except the last one
        for i, block in enumerate(model.blocks[:-1]):
            x = block(x)

        # Get attention from the last block
        # We're interested in how the cls token attends to the patches
        x = model.blocks[-1].norm1(x)  # apply LN first (pre-norm)
        qkv = model.blocks[-1].attn.qkv(x).reshape(B, x.shape[1], 3, model.blocks[-1].attn.num_heads, model.blocks[-1].attn.head_dim).permute(2, 0, 3, 1, 4)
        q, k, v = qkv[0], qkv[1], qkv[2]
        attn = (q @ k.transpose(-2, -1)) * (1.0 / np.sqrt(model.blocks[-1].attn.head_dim))
        attn = F.softmax(attn, dim=-1)

        return attn

    with torch.no_grad():
        # Get model predictions
        outputs = model(images)
        _, predicted = torch.max(outputs, 1)

        # Get attention maps
        attentions = get_attention_maps(images)  # shape: [B, H, N, N]

        # Extract attention from the CLS token to all patches
        # Average over all heads for visualization
        cls_attentions = attentions[:, :, 0, 1:].mean(1)  # shape: [B, N-1]

    # Reshape attention maps to match the image patches
    patch_size = 4
    num_patches = 8  # 32 // 4 = 8

    plt.figure(figsize=(16, 4 * num_images))

    for i in range(num_images):
        # Original image - need to denormalize
        img = images[i].cpu().permute(1, 2, 0).numpy()
        img = img * np.array([0.2470, 0.2435, 0.2616]) + np.array([0.4914, 0.4822, 0.4465])
        img = np.clip(img, 0, 1)

        # Attention map
        attn_map = cls_attentions[i].reshape(num_patches, num_patches).cpu().numpy()

        # Upsample the attention map to match the image size
        # Simple nearest-neighbor upsampling
        attn_map = np.repeat(np.repeat(attn_map, patch_size, axis=0), patch_size, axis=1)

        # Color indicates attention strength
        plt.subplot(num_images, 3, i*3 + 1)
        plt.imshow(img)
        plt.title(f"Original: {classes[labels[i]]}\nPredicted: {classes[predicted[i]]}")
        plt.axis('off')

        plt.subplot(num_images, 3, i*3 + 2)
        plt.imshow(attn_map)
        plt.title("Attention Map")
        plt.axis('off')

        plt.subplot(num_images, 3, i*3 + 3)
        plt.imshow(img)
        plt.imshow(attn_map, alpha=0.5, cmap='jet')  # overlay with some transparency
        plt.title("Overlay")
        plt.axis('off')

    plt.tight_layout()
    plt.savefig('vit_attention_maps.png')
    plt.close()  # close to avoid display issues

Main execution

In [None]:
# Main Execution
if __name__ == "__main__":
    # Create models directory
    os.makedirs('metrics', exist_ok=True)

    # Train the model
    # You can customize these hyperparameters
    model, best_acc, metrics = train_vit_cifar10(
        epochs=100,      # max epochs
        batch_size=128,  # reduce if OOM
        lr=1e-3,         # tried 5e-4 and 3e-3, this works best
        warmup_epochs=5, # helps stabilize training
        model_name='ViT' # for saving metrics
    )

    # Visualize attention if we did well
    data_module = CIFAR10DataModule(batch_size=4)
    data_module.setup()
    val_loader = data_module.val_dataloader()

    print("Visualizing attention maps")
    visualize_attention(model, val_loader, device)

# Inference