<a href="https://colab.research.google.com/github/Lcocks/DS6050-DeepLearning/blob/main/fashionMNIST_via_ViT_ResNet_DeiT.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# You can skip to the code if you want to run the code on Google COLAB.
# Fashion-MNIST Vision Transformer Training on Rivanna

This guide walks you through training a Vision Transformer (ViT) model on Rivanna's GPU cluster.


---

## One-Time Virtual Environment Setup

You only need to do this once. The virtual environment will contain all Python dependencies.

1. **Request an interactive session** (optional, but recommended for testing):
   ```bash
   ijob -c 4 -p standard -A your_allocation
   ```

2. **Create and activate a virtual environment:**
   ```bash
   mkdir -p ~/venvs
   python -m venv ~/venvs/dl-course
   source ~/venvs/dl-course/bin/activate
   ```

3. **Install dependencies:**
   ```bash
   python -m pip install --upgrade pip wheel
   python -m pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu121
   python -m pip install numpy pillow tqdm matplotlib
   ```

4. **Verify installation (optional):**
   ```bash
   python -c "import torch; print(f'PyTorch version: {torch.__version__}'); print(f'CUDA available: {torch.cuda.is_available()}')"
   ```

5. **Deactivate when done:**
   ```bash
   deactivate
   ```

---

## SLURM Submission Script

Create or edit `~/DL_course/ViT/run_training.sh` with the following content:

```bash
#!/bin/bash
#SBATCH --job-name=fmnist-vit
#SBATCH --partition=gpu
#SBATCH --gres=gpu:1
#SBATCH --time=12:00:00
#SBATCH --nodes=1
#SBATCH --ntasks=1
#SBATCH --cpus-per-task=8
#SBATCH --mem=32G
#SBATCH --output=/home/%u/DL_course/ViT/logs/fmnist-vit-%j.out
#SBATCH --account=YOUR_ALLOCATION_HERE

# Exit on error
set -euo pipefail

# Load CUDA modules
module purge
module load cuda/12.8.0
module load cudnn/9.8.0-CUDA-12.8.0

# Determine project directory
if [[ -n "${SLURM_SUBMIT_DIR:-}" ]]; then
    cd "${SLURM_SUBMIT_DIR}" || exit 1
else
    cd "$(dirname "${BASH_SOURCE[0]}")/.." || exit 1
fi

PROJECT_ROOT=$(pwd)

# Locate training script
if [[ -f "${PROJECT_ROOT}/ViT/train_fashionmnist.py" ]]; then
    SCRIPT_DIR="${PROJECT_ROOT}/ViT"
elif [[ -f "${PROJECT_ROOT}/train_fashionmnist.py" ]]; then
    SCRIPT_DIR="${PROJECT_ROOT}"
else
    echo "ERROR: Could not locate train_fashionmnist.py"
    exit 1
fi

PYTHON_SCRIPT="${SCRIPT_DIR}/train_fashionmnist.py"

# Activate virtual environment
VENV_PATH=${VENV_PATH:-$HOME/venvs/dl-course}
if [[ ! -d "${VENV_PATH}" ]]; then
    echo "ERROR: Virtual environment not found at ${VENV_PATH}"
    echo "Please create it following the setup instructions."
    exit 1
fi
source "${VENV_PATH}/bin/activate"

echo "Starting training..."
echo "Python script: ${PYTHON_SCRIPT}"
echo "Virtual environment: ${VENV_PATH}"

# Run training script with arguments
python "${PYTHON_SCRIPT}" "$@"

echo "Training complete!"
```

**Important:** Replace `YOUR_ALLOCATION_HERE` with your actual allocation name (e.g., `hpc_build`).

---

## Running Your Training Job

1. **Navigate to project directory:**
   ```bash
   cd ~/DL_course
   ```

2. **Make the script executable:**
   ```bash
   chmod +x ViT/run_training.sh
   ```

3. **Submit a training job:**
   ```bash
   sbatch ViT/run_training.sh --plot-results --epochs 30
   ```

4. **Submit with custom virtual environment path (if different):**
   ```bash
   sbatch --export=ALL,VENV_PATH=$HOME/my-custom-venv ViT/run_training.sh --plot-results --epochs 30
   ```

---

## Monitoring Your Job

- **Check job status:**
  ```bash
  squeue -u $USER
  ```

- **View job details:**
  ```bash
  scontrol show job <JOBID>
  ```

- **Monitor output in real-time:**
  ```bash
  tail -f ~/DL_course/ViT/logs/fmnist-vit-<JOBID>.out
  ```

- **Check all your recent jobs:**
  ```bash
  sacct -u $USER --format=JobID,JobName,Partition,State,Elapsed,ExitCode
  ```

---

## Understanding Outputs

After training completes, you'll find:

- **Logs:** `~/DL_course/ViT/logs/fmnist-vit-<JOBID>.out`
- **Model checkpoints:** `~/DL_course/ViT/saved_models/*.pth`
- **Training metrics:** `~/DL_course/ViT/metrics/*.json`
- **Plots:** `~/DL_course/ViT/plots/*.png` (if `--plot-results` was used)

---

## Command-Line Arguments

You can customize training by passing arguments to the Python script:

```bash
sbatch ViT/run_training.sh \
  --epochs 50 \
  --batch-size 128 \
  --learning-rate 0.001 \
  --plot-results \
  --phases 1 2 3
```

Common options:
- `--epochs N`: Number of training epochs (default: 30)
- `--batch-size N`: Batch size (default: 64)
- `--learning-rate X`: Learning rate (default: 0.0001)
- `--plot-results`: Generate plots after training
- `--phases 1 2 3`: Run specific training phases only

---

## Rerunning or Resuming Training

The training script automatically detects existing checkpoints and skips completed phases.

**To retrain a specific model:**
```bash
rm ~/DL_course/ViT/saved_models/DeiT_Distilled_*.pth
sbatch ViT/run_training.sh --plot-results --epochs 30
```

**To start completely fresh:**
```bash
rm -rf ~/DL_course/ViT/saved_models/*
rm -rf ~/DL_course/ViT/metrics/*
sbatch ViT/run_training.sh --plot-results --epochs 30
```

---

## Troubleshooting

### Job fails with "Virtual environment not found"
Make sure you created the venv and the path matches:
```bash
ls -la ~/venvs/dl-course/bin/activate
```

### Job fails with "module not found"
Check available CUDA versions and update the script:
```bash
module avail cuda
```

### Out of memory errors
Reduce batch size or request more memory in the SLURM script:
```bash
#SBATCH --mem=64G
```

### Job pending forever
Check your allocation and available resources:
```bash
squeue -u $USER
sacctmgr show associations user=$USER
```

---

## Cleanup

**Cancel a running job:**
```bash
scancel <JOBID>
```

**Remove virtual environment** (only if you're completely done):
```bash
rm -rf ~/venvs/dl-course
```

**Clean up old logs:**
```bash
rm ~/DL_course/ViT/logs/*.out
```

---

## Quick Reference

| Task | Command |
|------|---------|
| Submit job | `sbatch ViT/run_training.sh --plot-results --epochs 30` |
| Check status | `squeue -u $USER` |
| View output | `tail -f ~/DL_course/ViT/logs/fmnist-vit-<JOBID>.out` |
| Cancel job | `scancel <JOBID>` |
| Retrain model | `rm saved_models/<model>.pth && sbatch ...` |

---

**Questions?** Contact your instructor or consult the [Rivanna Documentation](https://www.rc.virginia.edu/userinfo/rivanna/overview/).

In [None]:
#!/usr/bin/env python
"""
Train ResNet18, Vision Transformer, and DeiT models on FashionMNIST with heavy
augmentation options. Intended for use on UVA Rivanna or similar HPC clusters.

Example:
    python train_fashionmnist.py --batch-size 128 --epochs 30 --data-dir ./data
"""

from __future__ import annotations

import argparse
import random
import time
from pathlib import Path
from typing import Any, Dict, Tuple
import json

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import datasets, transforms


# ==================== Utility Helpers ====================

def set_seed(seed: int) -> None:
    """Set random seeds for reproducibility."""
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)


# ==================== Data Augmentation ====================

class RandAugment:
    """RandAugment for FashionMNIST."""

    def __init__(self, n: int = 2, m: int = 9) -> None:
        self.n = n  # Number of augmentations
        self.m = m  # Magnitude (unused placeholder to match API)

    def __call__(self, img):
        ops = [
            transforms.RandomRotation(30),
            transforms.RandomAffine(0, translate=(0.1, 0.1)),
            transforms.RandomAffine(0, shear=15),
        ]

        for _ in range(self.n):
            op = np.random.choice(ops)
            img = op(img)

        return img


class RandomErasing:
    """Random erasing augmentation."""

    def __init__(self, p: float = 0.5, scale: Tuple[float, float] = (0.02, 0.33)) -> None:
        self.p = p
        self.scale = scale

    def __call__(self, img):
        if np.random.rand() > self.p:
            return img

        h, w = img.shape[-2:]
        area = h * w
        target_area = np.random.uniform(*self.scale) * area
        aspect_ratio = np.random.uniform(0.3, 3.3)

        h_erase = int(round(np.sqrt(target_area * aspect_ratio)))
        w_erase = int(round(np.sqrt(target_area / aspect_ratio)))

        if h_erase < h and w_erase < w:
            i = np.random.randint(0, h - h_erase)
            j = np.random.randint(0, w - w_erase)
            img[:, i : i + h_erase, j : j + w_erase] = 0

        return img


def mixup_data(x, y, alpha: float = 0.2):
    """Mixup augmentation."""
    lam = np.random.beta(alpha, alpha) if alpha > 0 else 1.0

    batch_size = x.size(0)
    index = torch.randperm(batch_size, device=x.device)

    mixed_x = lam * x + (1 - lam) * x[index]
    y_a, y_b = y, y[index]

    return mixed_x, y_a, y_b, lam


def cutmix_data(x, y, alpha: float = 1.0):
    """CutMix augmentation."""
    lam = np.random.beta(alpha, alpha) if alpha > 0 else 1.0

    batch_size = x.size(0)
    index = torch.randperm(batch_size, device=x.device)

    W = x.size(2)
    H = x.size(3)
    cut_rat = np.sqrt(1.0 - lam)
    cut_w = int(W * cut_rat)
    cut_h = int(H * cut_rat)

    cx = np.random.randint(W)
    cy = np.random.randint(H)

    bbx1 = np.clip(cx - cut_w // 2, 0, W)
    bby1 = np.clip(cy - cut_h // 2, 0, H)
    bbx2 = np.clip(cx + cut_w // 2, 0, W)
    bby2 = np.clip(cy + cut_h // 2, 0, H)

    x[:, :, bbx1:bbx2, bby1:bby2] = x[index, :, bbx1:bbx2, bby1:bby2]

    lam = 1 - ((bbx2 - bbx1) * (bby2 - bby1) / (W * H))
    y_a, y_b = y, y[index]

    return x, y_a, y_b, lam


# ==================== Vision Transformer Implementation ====================


class PatchEmbedding(nn.Module):
    """Split image into patches and embed them."""

    def __init__(self, img_size: int = 32, patch_size: int = 4, in_channels: int = 1, embed_dim: int = 384) -> None:
        super().__init__()
        self.img_size = img_size
        self.patch_size = patch_size
        self.n_patches = (img_size // patch_size) ** 2

        self.proj = nn.Conv2d(in_channels, embed_dim, kernel_size=patch_size, stride=patch_size)

    def forward(self, x):
        x = self.proj(x)
        x = x.flatten(2)
        x = x.transpose(1, 2)
        return x


class MultiHeadAttention(nn.Module):
    """Multi-head self-attention mechanism."""

    def __init__(self, embed_dim: int, num_heads: int, dropout: float = 0.1) -> None:
        super().__init__()
        self.embed_dim = embed_dim
        self.num_heads = num_heads
        self.head_dim = embed_dim // num_heads

        if embed_dim % num_heads != 0:
            raise ValueError("Embedding dimension must be divisible by number of heads.")

        self.qkv = nn.Linear(embed_dim, embed_dim * 3)
        self.proj = nn.Linear(embed_dim, embed_dim)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        B, N, C = x.shape

        qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim)
        qkv = qkv.permute(2, 0, 3, 1, 4)
        q, k, v = qkv[0], qkv[1], qkv[2]

        scale = self.head_dim ** -0.5
        attn = (q @ k.transpose(-2, -1)) * scale
        attn = attn.softmax(dim=-1)
        attn = self.dropout(attn)

        x = (attn @ v).transpose(1, 2).reshape(B, N, C)
        x = self.proj(x)
        x = self.dropout(x)

        return x


class MLP(nn.Module):
    """MLP block with GELU activation."""

    def __init__(self, in_features: int, hidden_features: int, dropout: float = 0.1) -> None:
        super().__init__()
        self.fc1 = nn.Linear(in_features, hidden_features)
        self.act = nn.GELU()
        self.fc2 = nn.Linear(hidden_features, in_features)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        x = self.fc1(x)
        x = self.act(x)
        x = self.dropout(x)
        x = self.fc2(x)
        x = self.dropout(x)
        return x


class TransformerBlock(nn.Module):
    """Transformer encoder block with pre-normalization."""

    def __init__(self, embed_dim: int, num_heads: int, mlp_ratio: float = 4, dropout: float = 0.1) -> None:
        super().__init__()
        self.norm1 = nn.LayerNorm(embed_dim)
        self.attn = MultiHeadAttention(embed_dim, num_heads, dropout)
        self.norm2 = nn.LayerNorm(embed_dim)
        self.mlp = MLP(embed_dim, int(embed_dim * mlp_ratio), dropout)

    def forward(self, x):
        x = x + self.attn(self.norm1(x))
        x = x + self.mlp(self.norm2(x))
        return x


class VisionTransformer(nn.Module):
    """Vision Transformer for image classification."""

    def __init__(
        self,
        img_size: int = 32,
        patch_size: int = 4,
        in_channels: int = 1,
        num_classes: int = 10,
        embed_dim: int = 384,
        num_heads: int = 6,
        num_blocks: int = 6,
        mlp_ratio: float = 4,
        dropout: float = 0.1,
    ) -> None:
        super().__init__()

        self.patch_embed = PatchEmbedding(img_size, patch_size, in_channels, embed_dim)
        n_patches = self.patch_embed.n_patches

        self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
        self.pos_embed = nn.Parameter(torch.zeros(1, n_patches + 1, embed_dim))
        self.dropout = nn.Dropout(dropout)

        self.blocks = nn.ModuleList(
            [TransformerBlock(embed_dim, num_heads, mlp_ratio, dropout) for _ in range(num_blocks)]
        )

        self.norm = nn.LayerNorm(embed_dim)
        self.head = nn.Linear(embed_dim, num_classes)

        nn.init.trunc_normal_(self.pos_embed, std=0.02)
        nn.init.trunc_normal_(self.cls_token, std=0.02)
        self.apply(self._init_weights)

    def _init_weights(self, m):
        if isinstance(m, nn.Linear):
            nn.init.trunc_normal_(m.weight, std=0.02)
            if m.bias is not None:
                nn.init.constant_(m.bias, 0)
        elif isinstance(m, nn.LayerNorm):
            nn.init.constant_(m.bias, 0)
            nn.init.constant_(m.weight, 1.0)

    def forward(self, x):
        B = x.shape[0]

        x = self.patch_embed(x)
        cls_tokens = self.cls_token.expand(B, -1, -1)
        x = torch.cat([cls_tokens, x], dim=1)
        x = x + self.pos_embed
        x = self.dropout(x)

        for block in self.blocks:
            x = block(x)

        x = self.norm(x[:, 0])
        x = self.head(x)

        return x


# ==================== DeiT Implementation ====================


class DeiT(nn.Module):
    """Data-efficient Image Transformer with distillation token."""

    def __init__(
        self,
        img_size: int = 32,
        patch_size: int = 4,
        in_channels: int = 1,
        num_classes: int = 10,
        embed_dim: int = 384,
        num_heads: int = 6,
        num_blocks: int = 6,
        mlp_ratio: float = 4,
        dropout: float = 0.1,
    ) -> None:
        super().__init__()

        self.patch_embed = PatchEmbedding(img_size, patch_size, in_channels, embed_dim)
        n_patches = self.patch_embed.n_patches

        self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
        self.dist_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
        self.pos_embed = nn.Parameter(torch.zeros(1, n_patches + 2, embed_dim))
        self.dropout = nn.Dropout(dropout)

        self.blocks = nn.ModuleList(
            [TransformerBlock(embed_dim, num_heads, mlp_ratio, dropout) for _ in range(num_blocks)]
        )

        self.norm = nn.LayerNorm(embed_dim)

        self.head = nn.Linear(embed_dim, num_classes)
        self.head_dist = nn.Linear(embed_dim, num_classes)

        nn.init.trunc_normal_(self.pos_embed, std=0.02)
        nn.init.trunc_normal_(self.cls_token, std=0.02)
        nn.init.trunc_normal_(self.dist_token, std=0.02)
        self.apply(self._init_weights)

    def _init_weights(self, m):
        if isinstance(m, nn.Linear):
            nn.init.trunc_normal_(m.weight, std=0.02)
            if m.bias is not None:
                nn.init.constant_(m.bias, 0)
        elif isinstance(m, nn.LayerNorm):
            nn.init.constant_(m.bias, 0)
            nn.init.constant_(m.weight, 1.0)

    def forward(self, x):
        B = x.shape[0]

        x = self.patch_embed(x)

        cls_tokens = self.cls_token.expand(B, -1, -1)
        dist_tokens = self.dist_token.expand(B, -1, -1)
        x = torch.cat([cls_tokens, dist_tokens, x], dim=1)

        x = x + self.pos_embed
        x = self.dropout(x)

        for block in self.blocks:
            x = block(x)

        x = self.norm(x)

        x_cls = self.head(x[:, 0])
        x_dist = self.head_dist(x[:, 1])

        return x_cls, x_dist


# ==================== ResNet18 Implementation ====================


class ResidualBlock(nn.Module):
    """Basic residual block for ResNet."""

    def __init__(self, in_channels: int, out_channels: int, stride: int = 1) -> None:
        super().__init__()
        self.conv1 = nn.Conv2d(
            in_channels, out_channels, kernel_size=3, stride=stride, padding=1, bias=False
        )
        self.bn1 = nn.BatchNorm2d(out_channels)
        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(out_channels)

        self.shortcut = nn.Sequential()
        if stride != 1 or in_channels != out_channels:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(out_channels),
            )

        self.relu = nn.ReLU(inplace=True)

    def forward(self, x):
        identity = self.shortcut(x)

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)

        out += identity
        out = self.relu(out)

        return out


class ResNet18(nn.Module):
    """ResNet18 architecture."""

    def __init__(self, in_channels: int = 1, num_classes: int = 10) -> None:
        super().__init__()

        self.conv1 = nn.Conv2d(in_channels, 64, kernel_size=7, stride=2, padding=3, bias=False)
        self.bn1 = nn.BatchNorm2d(64)
        self.relu = nn.ReLU(inplace=True)
        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)

        self.layer1 = self._make_layer(64, 64, 2, stride=1)
        self.layer2 = self._make_layer(64, 128, 2, stride=2)
        self.layer3 = self._make_layer(128, 256, 2, stride=2)
        self.layer4 = self._make_layer(256, 512, 2, stride=2)

        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
        self.fc = nn.Linear(512, num_classes)

        self.apply(self._init_weights)

    def _make_layer(self, in_channels: int, out_channels: int, num_blocks: int, stride: int):
        layers = [ResidualBlock(in_channels, out_channels, stride)]
        for _ in range(1, num_blocks):
            layers.append(ResidualBlock(out_channels, out_channels, stride=1))
        return nn.Sequential(*layers)

    def _init_weights(self, m):
        if isinstance(m, nn.Conv2d):
            nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
        elif isinstance(m, nn.BatchNorm2d):
            nn.init.constant_(m.weight, 1)
            nn.init.constant_(m.bias, 0)
        elif isinstance(m, nn.Linear):
            nn.init.normal_(m.weight, 0, 0.01)
            nn.init.constant_(m.bias, 0)

    def forward(self, x):
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.maxpool(x)

        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)

        x = self.avgpool(x)
        x = torch.flatten(x, 1)
        x = self.fc(x)

        return x


# ==================== Data Loading ====================


def get_dataloaders(
    batch_size: int = 128,
    data_dir: Path | str = "./data",
    use_augmentation: bool = False,
    num_workers: int = 2,
) -> Tuple[DataLoader, DataLoader, DataLoader, DataLoader]:
    """Create data loaders with optional heavy augmentation."""

    if use_augmentation:
        transform_train = transforms.Compose(
            [
                transforms.Resize(32),
                transforms.RandomHorizontalFlip(),
                RandAugment(n=2, m=9),
                transforms.ToTensor(),
                RandomErasing(p=0.25),
                transforms.Normalize((0.5,), (0.5,)),
            ]
        )
    else:
        transform_train = transforms.Compose(
            [
                transforms.Resize(32),
                transforms.ToTensor(),
                transforms.Normalize((0.5,), (0.5,)),
            ]
        )

    transform_test = transforms.Compose(
        [
            transforms.Resize(32),
            transforms.ToTensor(),
            transforms.Normalize((0.5,), (0.5,)),
        ]
    )

    transform_resnet_train = transforms.Compose(
        [
            transforms.Resize(96),
            transforms.ToTensor(),
            transforms.Normalize((0.5,), (0.5,)),
        ]
    )

    transform_resnet_test = transforms.Compose(
        [
            transforms.Resize(96),
            transforms.ToTensor(),
            transforms.Normalize((0.5,), (0.5,)),
        ]
    )

    data_dir = Path(data_dir)
    data_dir.mkdir(parents=True, exist_ok=True)

    train_dataset = datasets.FashionMNIST(root=data_dir, train=True, download=True, transform=transform_train)
    test_dataset = datasets.FashionMNIST(root=data_dir, train=False, download=True, transform=transform_test)

    train_dataset_resnet = datasets.FashionMNIST(
        root=data_dir, train=True, download=True, transform=transform_resnet_train
    )
    test_dataset_resnet = datasets.FashionMNIST(
        root=data_dir, train=False, download=True, transform=transform_resnet_test
    )

    train_loader = DataLoader(
        train_dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers, pin_memory=True
    )
    test_loader = DataLoader(
        test_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers, pin_memory=True
    )

    train_loader_resnet = DataLoader(
        train_dataset_resnet, batch_size=batch_size, shuffle=True, num_workers=num_workers, pin_memory=True
    )
    test_loader_resnet = DataLoader(
        test_dataset_resnet, batch_size=batch_size, shuffle=False, num_workers=num_workers, pin_memory=True
    )

    return train_loader, test_loader, train_loader_resnet, test_loader_resnet


# ==================== Training Functions ====================


def count_parameters(model: nn.Module) -> int:
    """Count the number of trainable parameters."""
    return sum(p.numel() for p in model.parameters() if p.requires_grad)


def save_model(model: nn.Module, path: Path, epoch: int, optimizer: optim.Optimizer, acc: float) -> None:
    """Save model checkpoint."""
    path.parent.mkdir(parents=True, exist_ok=True)
    torch.save(
        {
            "epoch": epoch,
            "model_state_dict": model.state_dict(),
            "optimizer_state_dict": optimizer.state_dict(),
            "accuracy": acc,
        },
        path,
    )
    print(f"Model saved to {path}")


def load_model_if_exists(model: nn.Module, model_name: str, device: torch.device, save_dir: Path):
    """Load model if checkpoint exists."""
    final_path = save_dir / f"{model_name}_final.pth"
    best_path = save_dir / f"{model_name}_best.pth"

    if final_path.exists():
        print(f"Loading existing model from {final_path}")
        checkpoint = torch.load(final_path, map_location=device, weights_only=False)
        model.load_state_dict(checkpoint["model_state_dict"])
        model = model.to(device)
        print(
            f"Loaded model - Epoch: {checkpoint['epoch']}, Accuracy: {checkpoint['accuracy']:.2f}%"
        )
        return model, checkpoint["accuracy"], True

    if best_path.exists():
        print(f"Best checkpoint found at {best_path} (no final checkpoint)")

    return model, 0.0, False


def save_metrics(metrics_dir: Path, model_key: str, history: Dict[str, list], best_acc: float) -> None:
    """Persist training history and metadata for later analysis."""
    metrics_dir.mkdir(parents=True, exist_ok=True)
    payload = {"history": history, "best_acc": best_acc}
    target = metrics_dir / f"{model_key}.json"
    with target.open("w", encoding="utf-8") as fp:
        json.dump(payload, fp, indent=2)
    print(f"Metrics saved to {target}")


def load_metrics(metrics_dir: Path, model_key: str) -> Dict[str, Any] | None:
    """Load previously saved metrics if available."""
    target = metrics_dir / f"{model_key}.json"
    if target.exists():
        with target.open("r", encoding="utf-8") as fp:
            data = json.load(fp)
        return data
    return None


def train_epoch(
    model: nn.Module,
    loader: DataLoader,
    criterion: nn.Module,
    optimizer: optim.Optimizer,
    device: torch.device,
    use_augment: bool = False,
) -> Tuple[float, float]:
    """Train for one epoch."""
    model.train()
    running_loss = 0.0
    correct = 0
    total = 0

    for inputs, targets in loader:
        inputs, targets = inputs.to(device, non_blocking=True), targets.to(device, non_blocking=True)

        if use_augment and np.random.rand() < 0.5:
            if np.random.rand() < 0.5:
                inputs, targets_a, targets_b, lam = mixup_data(inputs, targets)
                optimizer.zero_grad(set_to_none=True)
                outputs = model(inputs)
                loss = lam * criterion(outputs, targets_a) + (1 - lam) * criterion(outputs, targets_b)
            else:
                inputs, targets_a, targets_b, lam = cutmix_data(inputs, targets)
                optimizer.zero_grad(set_to_none=True)
                outputs = model(inputs)
                loss = lam * criterion(outputs, targets_a) + (1 - lam) * criterion(outputs, targets_b)
        else:
            optimizer.zero_grad(set_to_none=True)
            outputs = model(inputs)
            loss = criterion(outputs, targets)

        loss.backward()
        optimizer.step()

        running_loss += loss.item() * inputs.size(0)
        _, predicted = outputs.max(1)
        total += targets.size(0)
        correct += predicted.eq(targets).sum().item()

    epoch_loss = running_loss / total
    epoch_acc = 100.0 * correct / total

    return epoch_loss, epoch_acc


def train_epoch_deit(
    model: DeiT,
    loader: DataLoader,
    teacher: nn.Module,
    criterion: nn.Module,
    optimizer: optim.Optimizer,
    device: torch.device,
    alpha: float = 0.5,
    beta: float = 0.5,
    temperature: float = 3.0,
    use_augment: bool = True,
) -> Tuple[float, float]:
    """Train DeiT for one epoch with distillation."""
    model.train()
    teacher.eval()

    running_loss = 0.0
    correct = 0
    total = 0

    for inputs, targets in loader:
        inputs, targets = inputs.to(device, non_blocking=True), targets.to(device, non_blocking=True)

        if use_augment and np.random.rand() < 0.5:
            if np.random.rand() < 0.5:
                inputs, targets_a, targets_b, lam = mixup_data(inputs, targets)
            else:
                inputs, targets_a, targets_b, lam = cutmix_data(inputs, targets)

            inputs_teacher = F.interpolate(inputs, size=(96, 96), mode="bilinear")

            with torch.no_grad():
                teacher_outputs = teacher(inputs_teacher)

            optimizer.zero_grad(set_to_none=True)
            outputs_cls, outputs_dist = model(inputs)

            loss_cls = lam * criterion(outputs_cls, targets_a) + (1 - lam) * criterion(outputs_cls, targets_b)
            loss_dist_hard = (
                lam * criterion(outputs_dist, targets_a) + (1 - lam) * criterion(outputs_dist, targets_b)
            )

            loss_kl = (
                F.kl_div(
                    F.log_softmax(outputs_dist / temperature, dim=1),
                    F.softmax(teacher_outputs / temperature, dim=1),
                    reduction="batchmean",
                )
                * (temperature**2)
            )

            loss = (1 - alpha) * loss_cls + alpha * loss_dist_hard + beta * loss_kl
        else:
            inputs_teacher = F.interpolate(inputs, size=(96, 96), mode="bilinear")

            with torch.no_grad():
                teacher_outputs = teacher(inputs_teacher)

            optimizer.zero_grad(set_to_none=True)
            outputs_cls, outputs_dist = model(inputs)

            loss_cls = criterion(outputs_cls, targets)
            loss_dist_hard = criterion(outputs_dist, targets)

            loss_kl = (
                F.kl_div(
                    F.log_softmax(outputs_dist / temperature, dim=1),
                    F.softmax(teacher_outputs / temperature, dim=1),
                    reduction="batchmean",
                )
                * (temperature**2)
            )

            loss = (1 - alpha) * loss_cls + alpha * loss_dist_hard + beta * loss_kl

        loss.backward()
        optimizer.step()

        running_loss += loss.item() * inputs.size(0)
        _, predicted = outputs_cls.max(1)
        total += targets.size(0)
        correct += predicted.eq(targets).sum().item()

    epoch_loss = running_loss / total
    epoch_acc = 100.0 * correct / total

    return epoch_loss, epoch_acc


def evaluate(model: nn.Module, loader: DataLoader, criterion: nn.Module, device: torch.device, is_deit: bool = False):
    """Evaluate the model."""
    model.eval()
    running_loss = 0.0
    correct = 0
    total = 0

    with torch.no_grad():
        for inputs, targets in loader:
            inputs, targets = inputs.to(device, non_blocking=True), targets.to(device, non_blocking=True)

            if is_deit:
                outputs_cls, outputs_dist = model(inputs)
                outputs = (outputs_cls + outputs_dist) / 2
            else:
                outputs = model(inputs)

            loss = criterion(outputs, targets)

            running_loss += loss.item() * inputs.size(0)
            _, predicted = outputs.max(1)
            total += targets.size(0)
            correct += predicted.eq(targets).sum().item()

    epoch_loss = running_loss / total
    epoch_acc = 100.0 * correct / total

    return epoch_loss, epoch_acc


def train_model(
    model: nn.Module,
    train_loader: DataLoader,
    test_loader: DataLoader,
    num_epochs: int,
    lr: float,
    device: torch.device,
    model_name: str,
    save_dir: Path,
    use_augment: bool = False,
    teacher: nn.Module | None = None,
    is_deit: bool = False,
) -> Tuple[Dict[str, list], float]:
    """Train and evaluate a model."""
    separator = "=" * 60
    print(f"\n{separator}")
    print(f"Training {model_name}")
    print(separator)
    print(f"Parameters: {count_parameters(model):,}")

    model = model.to(device)
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.AdamW(model.parameters(), lr=lr, weight_decay=0.01)

    def lr_lambda(epoch: int) -> float:
        if epoch < 5:
            return (epoch + 1) / 5
        return 0.5 * (1 + np.cos(np.pi * (epoch - 5) / (max(num_epochs - 5, 1))))

    scheduler = optim.lr_scheduler.LambdaLR(optimizer, lr_lambda)

    best_acc = 0.0
    history: Dict[str, list] = {"train_loss": [], "train_acc": [], "test_loss": [], "test_acc": []}

    start_time = time.time()

    for epoch in range(num_epochs):
        epoch_start = time.time()

        if is_deit and teacher is not None:
            train_loss, train_acc = train_epoch_deit(
                model,
                train_loader,
                teacher,
                criterion,
                optimizer,
                device,
                use_augment=use_augment,
            )
        else:
            train_loss, train_acc = train_epoch(
                model, train_loader, criterion, optimizer, device, use_augment=use_augment
            )

        test_loss, test_acc = evaluate(model, test_loader, criterion, device, is_deit=is_deit)

        scheduler.step()

        history["train_loss"].append(train_loss)
        history["train_acc"].append(train_acc)
        history["test_loss"].append(test_loss)
        history["test_acc"].append(test_acc)

        epoch_time = time.time() - epoch_start

        print(
            f"Epoch [{epoch + 1}/{num_epochs}] "
            f"Train Loss: {train_loss:.4f} Train Acc: {train_acc:.2f}% "
            f"Test Loss: {test_loss:.4f} Test Acc: {test_acc:.2f}% "
            f"Time: {epoch_time:.2f}s LR: {scheduler.get_last_lr()[0]:.6f}"
        )

        if test_acc > best_acc:
            best_acc = test_acc
            save_path = save_dir / f"{model_name.replace(' ', '_')}_best.pth"
            save_model(model, save_path, epoch, optimizer, best_acc)

    total_time = time.time() - start_time

    save_path = save_dir / f"{model_name.replace(' ', '_')}_final.pth"
    save_model(model, save_path, num_epochs, optimizer, history["test_acc"][-1])

    print(f"\nTraining completed in {total_time:.2f}s")
    print(f"Best test accuracy: {best_acc:.2f}%")
    print(f"Final test accuracy: {history['test_acc'][-1]:.2f}%")

    return history, best_acc


def generate_plots(plots_dir: Path, metrics: Dict[str, Dict[str, Any]]) -> None:
    """Generate comparison plots for available metrics."""
    try:
        import matplotlib
        matplotlib.use("Agg", force=True)
        import matplotlib.pyplot as plt
    except ImportError:
        print("Matplotlib not installed; skipping plot generation.")
        return

    plots_dir.mkdir(parents=True, exist_ok=True)

    def _plot_series(key: str, ylabel: str, filename: str) -> None:
        plt.figure()
        has_data = False
        for model_name, payload in metrics.items():
            history = payload.get("history", {})
            series = history.get(key)
            if not series:
                continue
            epochs = range(1, len(series) + 1)
            label = model_name.replace("_", " ")
            plt.plot(epochs, series, label=label)
            has_data = True
        if has_data:
            plt.xlabel("Epoch")
            plt.ylabel(ylabel)
            plt.title(f"{ylabel} by Epoch")
            plt.grid(True, linestyle="--", alpha=0.3)
            plt.legend()
            plt.tight_layout()
            outfile = plots_dir / filename
            plt.savefig(outfile)
            print(f"Saved plot: {outfile}")
        plt.close()

    _plot_series("train_loss", "Train Loss", "train_loss.png")
    _plot_series("test_loss", "Test Loss", "test_loss.png")
    _plot_series("train_acc", "Train Accuracy (%)", "train_accuracy.png")
    _plot_series("test_acc", "Test Accuracy (%)", "test_accuracy.png")

    # Bar chart summarizing best test accuracy
    best_entries = [
        (model_name.replace("_", " "), payload.get("best_acc"))
        for model_name, payload in metrics.items()
        if payload.get("best_acc") is not None
    ]
    if best_entries:
        labels, values = zip(*best_entries)
        plt.figure()
        plt.bar(labels, values)
        plt.ylabel("Best Test Accuracy (%)")
        plt.title("Best Test Accuracy Comparison")
        plt.xticks(rotation=20, ha="right")
        plt.tight_layout()
        outfile = plots_dir / "best_accuracy.png"
        plt.savefig(outfile)
        plt.close()
        print(f"Saved plot: {outfile}")


# ==================== Main Execution ====================


def parse_args() -> argparse.Namespace:
    parser = argparse.ArgumentParser(description="Train ViT/DeiT models on FashionMNIST.")
    parser.add_argument("--batch-size", type=int, default=128, help="Mini-batch size for training.")
    parser.add_argument("--epochs", type=int, default=30, help="Number of training epochs.")
    parser.add_argument("--lr-vit", type=float, default=1e-3, help="Learning rate for Vision Transformer.")
    parser.add_argument("--lr-resnet", type=float, default=1e-2, help="Learning rate for ResNet18 teacher.")
    parser.add_argument("--lr-deit", type=float, default=1e-3, help="Learning rate for DeiT student.")
    parser.add_argument("--seed", type=int, default=42, help="Random seed.")
    parser.add_argument("--data-dir", type=Path, default=Path("./data"), help="Dataset download root.")
    parser.add_argument("--save-dir", type=Path, default=Path("./saved_models"), help="Directory for checkpoints.")
    parser.add_argument(
        "--num-workers",
        type=int,
        default=4,
        help="Number of worker processes for dataloaders.",
    )
    parser.add_argument(
        "--skip-resnet",
        action="store_true",
        help="Skip training the ResNet teacher if a checkpoint is already available.",
    )
    parser.add_argument(
        "--skip-vit",
        action="store_true",
        help="Skip training the ViT baseline if a checkpoint is already available.",
    )
    parser.add_argument(
        "--skip-deit",
        action="store_true",
        help="Skip training the DeiT student if a checkpoint is already available.",
    )
    parser.add_argument(
        "--metrics-dir",
        type=Path,
        default=Path("./metrics"),
        help="Directory for saving/loading training metrics.",
    )
    parser.add_argument(
        "--plots-dir",
        type=Path,
        default=Path("./plots"),
        help="Directory for saving generated plots.",
    )
    parser.add_argument(
        "--no-save-metrics",
        action="store_true",
        help="Disable saving metrics to disk.",
    )
    parser.add_argument(
        "--plot-results",
        action="store_true",
        help="Generate plots comparing training curves (requires matplotlib).",
    )
    return parser.parse_args()


def main() -> None:
    args = parse_args()

    set_seed(args.seed)

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Using device: {device}")

    args.save_dir.mkdir(parents=True, exist_ok=True)
    args.metrics_dir = args.metrics_dir.resolve()
    args.plots_dir = args.plots_dir.resolve()
    args.save_metrics = not args.no_save_metrics

    if args.plot_results and not args.metrics_dir.exists():
        print(f"Metrics directory {args.metrics_dir} does not exist yet; plots may be incomplete.")

    print("Loading FashionMNIST dataset...")
    (
        train_loader_vit,
        test_loader_vit,
        train_loader_resnet,
        test_loader_resnet,
    ) = get_dataloaders(
        batch_size=args.batch_size,
        data_dir=args.data_dir,
        use_augmentation=False,
        num_workers=args.num_workers,
    )

    train_loader_deit, test_loader_deit, _, _ = get_dataloaders(
        batch_size=args.batch_size,
        data_dir=args.data_dir,
        use_augmentation=True,
        num_workers=args.num_workers,
    )

    metrics_data: Dict[str, Dict[str, Any]] = {}

    # ========== Train ResNet18 (Teacher) ==========
    print("\n" + "=" * 60)
    print("PHASE 1: Training ResNet18 (Teacher)")
    print("=" * 60)

    resnet_model = ResNet18(in_channels=1, num_classes=10)
    resnet_model, resnet_best_acc, resnet_loaded = load_model_if_exists(
        resnet_model, "ResNet18_Teacher", device, args.save_dir
    )

    metrics_key_resnet = "ResNet18_Teacher"
    resnet_metrics = load_metrics(args.metrics_dir, metrics_key_resnet) if args.metrics_dir.exists() else None

    if not resnet_loaded and not args.skip_resnet:
        resnet_history, resnet_best_acc = train_model(
            resnet_model,
            train_loader_resnet,
            test_loader_resnet,
            args.epochs,
            args.lr_resnet,
            device,
            "ResNet18_Teacher",
            args.save_dir,
        )
        resnet_metrics = {"history": resnet_history, "best_acc": resnet_best_acc}
    else:
        print("Skipping training - using loaded model")
        if resnet_metrics is None:
            resnet_metrics = {"history": {"test_acc": [resnet_best_acc]}, "best_acc": resnet_best_acc}

    if args.save_metrics:
        save_metrics(args.metrics_dir, metrics_key_resnet, resnet_metrics["history"], resnet_metrics["best_acc"])

    resnet_history = resnet_metrics["history"]
    resnet_best_acc = resnet_metrics.get("best_acc", resnet_best_acc)
    metrics_data[metrics_key_resnet] = resnet_metrics

    # ========== Train ViT (Baseline) ==========
    print("\n" + "=" * 60)
    print("PHASE 2: Training Vision Transformer (Baseline)")
    print("=" * 60)

    vit_model = VisionTransformer(
        img_size=32,
        patch_size=4,
        in_channels=1,
        num_classes=10,
        embed_dim=384,
        num_heads=6,
        num_blocks=6,
        mlp_ratio=4,
        dropout=0.1,
    )

    vit_model, vit_best_acc, vit_loaded = load_model_if_exists(
        vit_model, "Vision_Transformer", device, args.save_dir
    )

    metrics_key_vit = "Vision_Transformer"
    vit_metrics = load_metrics(args.metrics_dir, metrics_key_vit) if args.metrics_dir.exists() else None

    if not vit_loaded and not args.skip_vit:
        vit_history, vit_best_acc = train_model(
            vit_model,
            train_loader_vit,
            test_loader_vit,
            args.epochs,
            args.lr_vit,
            device,
            "Vision_Transformer",
            args.save_dir,
        )
        vit_metrics = {"history": vit_history, "best_acc": vit_best_acc}
    else:
        print("Skipping training - using loaded model")
        if vit_metrics is None:
            vit_metrics = {"history": {"test_acc": [vit_best_acc]}, "best_acc": vit_best_acc}

    if args.save_metrics:
        save_metrics(args.metrics_dir, metrics_key_vit, vit_metrics["history"], vit_metrics["best_acc"])

    vit_history = vit_metrics["history"]
    vit_best_acc = vit_metrics.get("best_acc", vit_best_acc)
    metrics_data[metrics_key_vit] = vit_metrics

    # ========== Train DeiT with Distillation ==========
    print("\n" + "=" * 60)
    print("PHASE 3: Training DeiT with Knowledge Distillation")
    print("=" * 60)

    deit_model = DeiT(
        img_size=32,
        patch_size=4,
        in_channels=1,
        num_classes=10,
        embed_dim=384,
        num_heads=6,
        num_blocks=6,
        mlp_ratio=4,
        dropout=0.1,
    )

    deit_model, deit_best_acc, deit_loaded = load_model_if_exists(
        deit_model, "DeiT_Distilled", device, args.save_dir
    )

    metrics_key_deit = "DeiT_Distilled"
    deit_metrics = load_metrics(args.metrics_dir, metrics_key_deit) if args.metrics_dir.exists() else None

    if not deit_loaded and not args.skip_deit:
        if not resnet_loaded:
            teacher_ckpt = args.save_dir / "ResNet18_Teacher_best.pth"
            if teacher_ckpt.exists():
                checkpoint = torch.load(teacher_ckpt, map_location=device, weights_only=False)
                resnet_model.load_state_dict(checkpoint["model_state_dict"])
                print(f"Loaded teacher weights from {teacher_ckpt}")
            else:
                raise FileNotFoundError(
                    f"Teacher checkpoint {teacher_ckpt} not found. "
                    "Please train the ResNet teacher first or provide a checkpoint."
                )

        resnet_model = resnet_model.to(device)
        resnet_model.eval()

        deit_history, deit_best_acc = train_model(
            deit_model,
            train_loader_deit,
            test_loader_deit,
            args.epochs,
            args.lr_deit,
            device,
            "DeiT_Distilled",
            args.save_dir,
            use_augment=True,
            teacher=resnet_model,
            is_deit=True,
        )
        deit_metrics = {"history": deit_history, "best_acc": deit_best_acc}
    else:
        print("Skipping training - using loaded model")
        if deit_metrics is None:
            deit_metrics = {"history": {"test_acc": [deit_best_acc]}, "best_acc": deit_best_acc}

    if args.save_metrics:
        save_metrics(args.metrics_dir, metrics_key_deit, deit_metrics["history"], deit_metrics["best_acc"])

    deit_history = deit_metrics["history"]
    deit_best_acc = deit_metrics.get("best_acc", deit_best_acc)
    metrics_data[metrics_key_deit] = deit_metrics

    # ========== Final Comparison ==========
    print(f"\n{'=' * 60}")
    print("FINAL COMPARISON")
    print(f"{'=' * 60}")

    print(f"\nResNet18 (Teacher):")
    print(f"  Parameters: {count_parameters(resnet_model):,}")
    print(f"  Best Test Acc: {resnet_best_acc:.2f}%")
    if not (resnet_loaded or args.skip_resnet):
        print(f"  Final Test Acc: {resnet_history['test_acc'][-1]:.2f}%")

    print(f"\nVision Transformer (Baseline):")
    print(f"  Parameters: {count_parameters(vit_model):,}")
    print(f"  Best Test Acc: {vit_best_acc:.2f}%")
    if not (vit_loaded or args.skip_vit):
        print(f"  Final Test Acc: {vit_history['test_acc'][-1]:.2f}%")

    print(f"\nDeiT (Student with Distillation):")
    print(f"  Parameters: {count_parameters(deit_model):,}")
    print(f"  Best Test Acc: {deit_best_acc:.2f}%")
    if not (deit_loaded or args.skip_deit):
        print(f"  Final Test Acc: {deit_history['test_acc'][-1]:.2f}%")

    print(f"  Improvement over ViT: {deit_best_acc - vit_best_acc:+.2f}%")

    print(f"\n{'=' * 60}")
    print(f"All models saved in '{args.save_dir}' directory")
    print(f"{'=' * 60}")

    if args.plot_results:
        generate_plots(args.plots_dir, metrics_data)
        print(f"Plots saved under '{args.plots_dir}'.")


if __name__ == "__main__":
    main()


In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
import time
import numpy as np
import os

# Set random seeds for reproducibility
torch.manual_seed(42)
np.random.seed(42)

# Device configuration
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'Using device: {device}')

# Create directory for saved models
os.makedirs('saved_models', exist_ok=True)

# ==================== Data Augmentation ====================

class RandAugment:
    """RandAugment for FashionMNIST."""
    def __init__(self, n=2, m=9):
        self.n = n  # Number of augmentations
        self.m = m  # Magnitude

    def __call__(self, img):
        ops = [
            transforms.RandomRotation(30),
            transforms.RandomAffine(0, translate=(0.1, 0.1)),
            transforms.RandomAffine(0, shear=15),
        ]

        for _ in range(self.n):
            op = np.random.choice(ops)
            img = op(img)

        return img

class RandomErasing:
    """Random erasing augmentation."""
    def __init__(self, p=0.5, scale=(0.02, 0.33)):
        self.p = p
        self.scale = scale

    def __call__(self, img):
        if np.random.rand() > self.p:
            return img

        h, w = img.shape[-2:]
        area = h * w
        target_area = np.random.uniform(*self.scale) * area
        aspect_ratio = np.random.uniform(0.3, 3.3)

        h_erase = int(round(np.sqrt(target_area * aspect_ratio)))
        w_erase = int(round(np.sqrt(target_area / aspect_ratio)))

        if h_erase < h and w_erase < w:
            i = np.random.randint(0, h - h_erase)
            j = np.random.randint(0, w - w_erase)
            img[:, i:i+h_erase, j:j+w_erase] = 0

        return img

def mixup_data(x, y, alpha=0.2):
    """Mixup augmentation."""
    if alpha > 0:
        lam = np.random.beta(alpha, alpha)
    else:
        lam = 1

    batch_size = x.size()[0]
    index = torch.randperm(batch_size).to(x.device)

    mixed_x = lam * x + (1 - lam) * x[index]
    y_a, y_b = y, y[index]

    return mixed_x, y_a, y_b, lam

def cutmix_data(x, y, alpha=1.0):
    """CutMix augmentation."""
    if alpha > 0:
        lam = np.random.beta(alpha, alpha)
    else:
        lam = 1

    batch_size = x.size()[0]
    index = torch.randperm(batch_size).to(x.device)

    # Get random box
    W = x.size()[2]
    H = x.size()[3]
    cut_rat = np.sqrt(1. - lam)
    cut_w = int(W * cut_rat)
    cut_h = int(H * cut_rat)

    cx = np.random.randint(W)
    cy = np.random.randint(H)

    bbx1 = np.clip(cx - cut_w // 2, 0, W)
    bby1 = np.clip(cy - cut_h // 2, 0, H)
    bbx2 = np.clip(cx + cut_w // 2, 0, W)
    bby2 = np.clip(cy + cut_h // 2, 0, H)

    x[:, :, bbx1:bbx2, bby1:bby2] = x[index, :, bbx1:bbx2, bby1:bby2]

    # Adjust lambda
    lam = 1 - ((bbx2 - bbx1) * (bby2 - bby1) / (W * H))
    y_a, y_b = y, y[index]

    return x, y_a, y_b, lam

# ==================== Vision Transformer Implementation ====================

class PatchEmbedding(nn.Module):
    """Split image into patches and embed them."""
    def __init__(self, img_size=32, patch_size=4, in_channels=1, embed_dim=384):
        super().__init__()
        self.img_size = img_size
        self.patch_size = patch_size
        self.n_patches = (img_size // patch_size) ** 2

        self.proj = nn.Conv2d(in_channels, embed_dim,
                             kernel_size=patch_size, stride=patch_size)

    def forward(self, x):
        x = self.proj(x)
        x = x.flatten(2)
        x = x.transpose(1, 2)
        return x

class MultiHeadAttention(nn.Module):
    """Multi-head self-attention mechanism."""
    def __init__(self, embed_dim, num_heads, dropout=0.1):
        super().__init__()
        self.embed_dim = embed_dim
        self.num_heads = num_heads
        self.head_dim = embed_dim // num_heads

        assert embed_dim % num_heads == 0

        self.qkv = nn.Linear(embed_dim, embed_dim * 3)
        self.proj = nn.Linear(embed_dim, embed_dim)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        B, N, C = x.shape

        qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim)
        qkv = qkv.permute(2, 0, 3, 1, 4)
        q, k, v = qkv[0], qkv[1], qkv[2]

        scale = self.head_dim ** -0.5
        attn = (q @ k.transpose(-2, -1)) * scale
        attn = attn.softmax(dim=-1)
        attn = self.dropout(attn)

        x = (attn @ v).transpose(1, 2).reshape(B, N, C)
        x = self.proj(x)
        x = self.dropout(x)

        return x

class MLP(nn.Module):
    """MLP block with GELU activation."""
    def __init__(self, in_features, hidden_features, dropout=0.1):
        super().__init__()
        self.fc1 = nn.Linear(in_features, hidden_features)
        self.act = nn.GELU()
        self.fc2 = nn.Linear(hidden_features, in_features)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        x = self.fc1(x)
        x = self.act(x)
        x = self.dropout(x)
        x = self.fc2(x)
        x = self.dropout(x)
        return x

class TransformerBlock(nn.Module):
    """Transformer encoder block with pre-normalization."""
    def __init__(self, embed_dim, num_heads, mlp_ratio=4, dropout=0.1):
        super().__init__()
        self.norm1 = nn.LayerNorm(embed_dim)
        self.attn = MultiHeadAttention(embed_dim, num_heads, dropout)
        self.norm2 = nn.LayerNorm(embed_dim)
        self.mlp = MLP(embed_dim, int(embed_dim * mlp_ratio), dropout)

    def forward(self, x):
        x = x + self.attn(self.norm1(x))
        x = x + self.mlp(self.norm2(x))
        return x

class VisionTransformer(nn.Module):
    """Vision Transformer for image classification."""
    def __init__(self, img_size=32, patch_size=4, in_channels=1,
                 num_classes=10, embed_dim=384, num_heads=6,
                 num_blocks=6, mlp_ratio=4, dropout=0.1):
        super().__init__()

        self.patch_embed = PatchEmbedding(img_size, patch_size,
                                         in_channels, embed_dim)
        n_patches = self.patch_embed.n_patches

        self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
        self.pos_embed = nn.Parameter(torch.zeros(1, n_patches + 1, embed_dim))
        self.dropout = nn.Dropout(dropout)

        self.blocks = nn.ModuleList([
            TransformerBlock(embed_dim, num_heads, mlp_ratio, dropout)
            for _ in range(num_blocks)
        ])

        self.norm = nn.LayerNorm(embed_dim)
        self.head = nn.Linear(embed_dim, num_classes)

        nn.init.trunc_normal_(self.pos_embed, std=0.02)
        nn.init.trunc_normal_(self.cls_token, std=0.02)
        self.apply(self._init_weights)

    def _init_weights(self, m):
        if isinstance(m, nn.Linear):
            nn.init.trunc_normal_(m.weight, std=0.02)
            if m.bias is not None:
                nn.init.constant_(m.bias, 0)
        elif isinstance(m, nn.LayerNorm):
            nn.init.constant_(m.bias, 0)
            nn.init.constant_(m.weight, 1.0)

    def forward(self, x):
        B = x.shape[0]

        x = self.patch_embed(x)
        cls_tokens = self.cls_token.expand(B, -1, -1)
        x = torch.cat([cls_tokens, x], dim=1)
        x = x + self.pos_embed
        x = self.dropout(x)

        for block in self.blocks:
            x = block(x)

        x = self.norm(x[:, 0])
        x = self.head(x)

        return x

# ==================== DeiT Implementation ====================

class DeiT(nn.Module):
    """Data-efficient Image Transformer with distillation token."""
    def __init__(self, img_size=32, patch_size=4, in_channels=1,
                 num_classes=10, embed_dim=384, num_heads=6,
                 num_blocks=6, mlp_ratio=4, dropout=0.1):
        super().__init__()

        self.patch_embed = PatchEmbedding(img_size, patch_size,
                                         in_channels, embed_dim)
        n_patches = self.patch_embed.n_patches

        # CLS and distillation tokens
        self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
        self.dist_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
        self.pos_embed = nn.Parameter(torch.zeros(1, n_patches + 2, embed_dim))
        self.dropout = nn.Dropout(dropout)

        self.blocks = nn.ModuleList([
            TransformerBlock(embed_dim, num_heads, mlp_ratio, dropout)
            for _ in range(num_blocks)
        ])

        self.norm = nn.LayerNorm(embed_dim)

        # Two heads: one for cls token, one for dist token
        self.head = nn.Linear(embed_dim, num_classes)
        self.head_dist = nn.Linear(embed_dim, num_classes)

        nn.init.trunc_normal_(self.pos_embed, std=0.02)
        nn.init.trunc_normal_(self.cls_token, std=0.02)
        nn.init.trunc_normal_(self.dist_token, std=0.02)
        self.apply(self._init_weights)

    def _init_weights(self, m):
        if isinstance(m, nn.Linear):
            nn.init.trunc_normal_(m.weight, std=0.02)
            if m.bias is not None:
                nn.init.constant_(m.bias, 0)
        elif isinstance(m, nn.LayerNorm):
            nn.init.constant_(m.bias, 0)
            nn.init.constant_(m.weight, 1.0)

    def forward(self, x):
        B = x.shape[0]

        x = self.patch_embed(x)

        # Add both cls and dist tokens
        cls_tokens = self.cls_token.expand(B, -1, -1)
        dist_tokens = self.dist_token.expand(B, -1, -1)
        x = torch.cat([cls_tokens, dist_tokens, x], dim=1)

        x = x + self.pos_embed
        x = self.dropout(x)

        for block in self.blocks:
            x = block(x)

        x = self.norm(x)

        # Separate outputs from cls and dist tokens
        x_cls = self.head(x[:, 0])
        x_dist = self.head_dist(x[:, 1])

        return x_cls, x_dist

# ==================== ResNet18 Implementation ====================

class ResidualBlock(nn.Module):
    """Basic residual block for ResNet."""
    def __init__(self, in_channels, out_channels, stride=1):
        super().__init__()
        self.conv1 = nn.Conv2d(in_channels, out_channels,
                              kernel_size=3, stride=stride, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(out_channels)
        self.conv2 = nn.Conv2d(out_channels, out_channels,
                              kernel_size=3, stride=1, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(out_channels)

        self.shortcut = nn.Sequential()
        if stride != 1 or in_channels != out_channels:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_channels, out_channels,
                         kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(out_channels)
            )

        self.relu = nn.ReLU(inplace=True)

    def forward(self, x):
        identity = self.shortcut(x)

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)

        out += identity
        out = self.relu(out)

        return out

class ResNet18(nn.Module):
    """ResNet18 architecture."""
    def __init__(self, in_channels=1, num_classes=10):
        super().__init__()

        self.conv1 = nn.Conv2d(in_channels, 64,
                              kernel_size=7, stride=2, padding=3, bias=False)
        self.bn1 = nn.BatchNorm2d(64)
        self.relu = nn.ReLU(inplace=True)
        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)

        self.layer1 = self._make_layer(64, 64, 2, stride=1)
        self.layer2 = self._make_layer(64, 128, 2, stride=2)
        self.layer3 = self._make_layer(128, 256, 2, stride=2)
        self.layer4 = self._make_layer(256, 512, 2, stride=2)

        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
        self.fc = nn.Linear(512, num_classes)

        self.apply(self._init_weights)

    def _make_layer(self, in_channels, out_channels, num_blocks, stride):
        layers = []
        layers.append(ResidualBlock(in_channels, out_channels, stride))
        for _ in range(1, num_blocks):
            layers.append(ResidualBlock(out_channels, out_channels, stride=1))
        return nn.Sequential(*layers)

    def _init_weights(self, m):
        if isinstance(m, nn.Conv2d):
            nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
        elif isinstance(m, nn.BatchNorm2d):
            nn.init.constant_(m.weight, 1)
            nn.init.constant_(m.bias, 0)
        elif isinstance(m, nn.Linear):
            nn.init.normal_(m.weight, 0, 0.01)
            nn.init.constant_(m.bias, 0)

    def forward(self, x):
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.maxpool(x)

        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)

        x = self.avgpool(x)
        x = torch.flatten(x, 1)
        x = self.fc(x)

        return x

# ==================== Data Loading ====================

def get_dataloaders(batch_size=128, use_augmentation=False):
    """Create data loaders with optional heavy augmentation."""

    if use_augmentation:
        # Heavy augmentation for DeiT
        transform_train = transforms.Compose([
            transforms.Resize(32),
            transforms.RandomHorizontalFlip(),
            RandAugment(n=2, m=9),
            transforms.ToTensor(),
            RandomErasing(p=0.25),
            transforms.Normalize((0.5,), (0.5,))
        ])
    else:
        # Standard transform for ViT
        transform_train = transforms.Compose([
            transforms.Resize(32),
            transforms.ToTensor(),
            transforms.Normalize((0.5,), (0.5,))
        ])

    transform_test = transforms.Compose([
        transforms.Resize(32),
        transforms.ToTensor(),
        transforms.Normalize((0.5,), (0.5,))
    ])

    # Transform for ResNet (96x96)
    transform_resnet_train = transforms.Compose([
        transforms.Resize(96),
        transforms.ToTensor(),
        transforms.Normalize((0.5,), (0.5,))
    ])

    transform_resnet_test = transforms.Compose([
        transforms.Resize(96),
        transforms.ToTensor(),
        transforms.Normalize((0.5,), (0.5,))
    ])

    train_dataset = datasets.FashionMNIST(
        root='./data', train=True, download=True, transform=transform_train)
    test_dataset = datasets.FashionMNIST(
        root='./data', train=False, download=True, transform=transform_test)

    train_dataset_resnet = datasets.FashionMNIST(
        root='./data', train=True, download=True, transform=transform_resnet_train)
    test_dataset_resnet = datasets.FashionMNIST(
        root='./data', train=False, download=True, transform=transform_resnet_test)

    train_loader = DataLoader(train_dataset, batch_size=batch_size,
                             shuffle=True, num_workers=2, pin_memory=True)
    test_loader = DataLoader(test_dataset, batch_size=batch_size,
                            shuffle=False, num_workers=2, pin_memory=True)

    train_loader_resnet = DataLoader(train_dataset_resnet, batch_size=batch_size,
                                    shuffle=True, num_workers=2, pin_memory=True)
    test_loader_resnet = DataLoader(test_dataset_resnet, batch_size=batch_size,
                                   shuffle=False, num_workers=2, pin_memory=True)

    return train_loader, test_loader, train_loader_resnet, test_loader_resnet

# ==================== Training Functions ====================

def count_parameters(model):
    """Count the number of trainable parameters."""
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

def save_model(model, path, epoch, optimizer, acc):
    """Save model checkpoint."""
    torch.save({
        'epoch': epoch,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'accuracy': acc,
    }, path)
    print(f"Model saved to {path}")

def load_model_if_exists(model, model_name, device):
    """Load model if checkpoint exists."""
    final_path = f'saved_models/{model_name}_final.pth'
    best_path = f'saved_models/{model_name}_best.pth'

    if os.path.exists(final_path):
        print(f"Loading existing model from {final_path}")
        checkpoint = torch.load(final_path, map_location=device, weights_only=False)
        model.load_state_dict(checkpoint['model_state_dict'])
        model = model.to(device)
        print(f"Loaded model - Epoch: {checkpoint['epoch']}, Accuracy: {checkpoint['accuracy']:.2f}%")
        return model, checkpoint['accuracy'], True

    return model, 0.0, False

def train_epoch(model, loader, criterion, optimizer, device, use_augment=False):
    """Train for one epoch."""
    model.train()
    running_loss = 0.0
    correct = 0
    total = 0

    for inputs, targets in loader:
        inputs, targets = inputs.to(device), targets.to(device)

        # Apply mixup or cutmix
        if use_augment and np.random.rand() < 0.5:
            if np.random.rand() < 0.5:
                inputs, targets_a, targets_b, lam = mixup_data(inputs, targets)
                optimizer.zero_grad()
                outputs = model(inputs)
                loss = lam * criterion(outputs, targets_a) + (1 - lam) * criterion(outputs, targets_b)
            else:
                inputs, targets_a, targets_b, lam = cutmix_data(inputs, targets)
                optimizer.zero_grad()
                outputs = model(inputs)
                loss = lam * criterion(outputs, targets_a) + (1 - lam) * criterion(outputs, targets_b)
        else:
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, targets)

        loss.backward()
        optimizer.step()

        running_loss += loss.item() * inputs.size(0)
        _, predicted = outputs.max(1)
        total += targets.size(0)
        correct += predicted.eq(targets).sum().item()

    epoch_loss = running_loss / total
    epoch_acc = 100. * correct / total

    return epoch_loss, epoch_acc

def train_epoch_deit(model, loader, teacher, criterion, optimizer, device,
                     alpha=0.5, beta=0.5, temperature=3.0, use_augment=True):
    """Train DeiT for one epoch with distillation."""
    model.train()
    teacher.eval()

    running_loss = 0.0
    correct = 0
    total = 0

    for inputs, targets in loader:
        inputs, targets = inputs.to(device), targets.to(device)

        # Apply mixup or cutmix
        if use_augment and np.random.rand() < 0.5:
            if np.random.rand() < 0.5:
                inputs, targets_a, targets_b, lam = mixup_data(inputs, targets)
            else:
                inputs, targets_a, targets_b, lam = cutmix_data(inputs, targets)

            # Resize for teacher if needed
            inputs_teacher = F.interpolate(inputs, size=(96, 96), mode='bilinear')

            with torch.no_grad():
                teacher_outputs = teacher(inputs_teacher)

            optimizer.zero_grad()
            outputs_cls, outputs_dist = model(inputs)

            # Hard label loss with mixup
            loss_cls = lam * criterion(outputs_cls, targets_a) + (1 - lam) * criterion(outputs_cls, targets_b)
            loss_dist_hard = lam * criterion(outputs_dist, targets_a) + (1 - lam) * criterion(outputs_dist, targets_b)

            # Soft label loss (distillation)
            loss_kl = F.kl_div(
                F.log_softmax(outputs_dist / temperature, dim=1),
                F.softmax(teacher_outputs / temperature, dim=1),
                reduction='batchmean'
            ) * (temperature ** 2)

            loss = (1 - alpha) * loss_cls + alpha * loss_dist_hard + beta * loss_kl
        else:
            # Resize for teacher if needed
            inputs_teacher = F.interpolate(inputs, size=(96, 96), mode='bilinear')

            with torch.no_grad():
                teacher_outputs = teacher(inputs_teacher)

            optimizer.zero_grad()
            outputs_cls, outputs_dist = model(inputs)

            # Hard label losses
            loss_cls = criterion(outputs_cls, targets)
            loss_dist_hard = criterion(outputs_dist, targets)

            # Soft label loss (distillation)
            loss_kl = F.kl_div(
                F.log_softmax(outputs_dist / temperature, dim=1),
                F.softmax(teacher_outputs / temperature, dim=1),
                reduction='batchmean'
            ) * (temperature ** 2)

            loss = (1 - alpha) * loss_cls + alpha * loss_dist_hard + beta * loss_kl

        loss.backward()
        optimizer.step()

        running_loss += loss.item() * inputs.size(0)
        _, predicted = outputs_cls.max(1)
        total += targets.size(0)
        correct += predicted.eq(targets).sum().item()

    epoch_loss = running_loss / total
    epoch_acc = 100. * correct / total

    return epoch_loss, epoch_acc

def evaluate(model, loader, criterion, device, is_deit=False):
    """Evaluate the model."""
    model.eval()
    running_loss = 0.0
    correct = 0
    total = 0

    with torch.no_grad():
        for inputs, targets in loader:
            inputs, targets = inputs.to(device), targets.to(device)

            if is_deit:
                outputs_cls, outputs_dist = model(inputs)
                # Average predictions from both tokens
                outputs = (outputs_cls + outputs_dist) / 2
            else:
                outputs = model(inputs)

            loss = criterion(outputs, targets)

            running_loss += loss.item() * inputs.size(0)
            _, predicted = outputs.max(1)
            total += targets.size(0)
            correct += predicted.eq(targets).sum().item()

    epoch_loss = running_loss / total
    epoch_acc = 100. * correct / total

    return epoch_loss, epoch_acc

def train_model(model, train_loader, test_loader, num_epochs,
                lr, device, model_name, use_augment=False,
                teacher=None, is_deit=False):
    """Train and evaluate a model."""
    print(f"\n{'='*60}")
    print(f"Training {model_name}")
    print(f"{'='*60}")
    print(f"Parameters: {count_parameters(model):,}")

    model = model.to(device)
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.AdamW(model.parameters(), lr=lr, weight_decay=0.01)

    # Learning rate scheduler with warmup
    def lr_lambda(epoch):
        if epoch < 5:
            return (epoch + 1) / 5
        else:
            return 0.5 * (1 + np.cos(np.pi * (epoch - 5) / (num_epochs - 5)))

    scheduler = optim.lr_scheduler.LambdaLR(optimizer, lr_lambda)

    best_acc = 0.0
    history = {'train_loss': [], 'train_acc': [],
               'test_loss': [], 'test_acc': []}

    start_time = time.time()

    for epoch in range(num_epochs):
        epoch_start = time.time()

        if is_deit and teacher is not None:
            train_loss, train_acc = train_epoch_deit(
                model, train_loader, teacher, criterion,
                optimizer, device, use_augment=use_augment
            )
        else:
            train_loss, train_acc = train_epoch(
                model, train_loader, criterion, optimizer,
                device, use_augment=use_augment
            )

        test_loss, test_acc = evaluate(model, test_loader, criterion,
                                      device, is_deit=is_deit)

        scheduler.step()

        history['train_loss'].append(train_loss)
        history['train_acc'].append(train_acc)
        history['test_loss'].append(test_loss)
        history['test_acc'].append(test_acc)

        epoch_time = time.time() - epoch_start

        print(f"Epoch [{epoch+1}/{num_epochs}] "
              f"Train Loss: {train_loss:.4f} Train Acc: {train_acc:.2f}% "
              f"Test Loss: {test_loss:.4f} Test Acc: {test_acc:.2f}% "
              f"Time: {epoch_time:.2f}s LR: {scheduler.get_last_lr()[0]:.6f}")

        if test_acc > best_acc:
            best_acc = test_acc
            # Save best model
            save_path = f'saved_models/{model_name.replace(" ", "_")}_best.pth'
            save_model(model, save_path, epoch, optimizer, best_acc)

    total_time = time.time() - start_time

    # Save final model
    save_path = f'saved_models/{model_name.replace(" ", "_")}_final.pth'
    save_model(model, save_path, num_epochs, optimizer, test_acc)

    print(f"\nTraining completed in {total_time:.2f}s")
    print(f"Best test accuracy: {best_acc:.2f}%")
    print(f"Final test accuracy: {history['test_acc'][-1]:.2f}%")

    return history, best_acc

# ==================== Main Execution ====================

if __name__ == '__main__':
    # Hyperparameters
    batch_size = 128
    num_epochs = 30
    lr_vit = 0.001
    lr_resnet = 0.01
    lr_deit = 0.001

    # Get data loaders
    print("Loading FashionMNIST dataset...")
    train_loader_vit, test_loader_vit, train_loader_resnet, test_loader_resnet = \
        get_dataloaders(batch_size, use_augmentation=False)

    # Get augmented loaders for DeiT
    train_loader_deit, test_loader_deit, _, _ = \
        get_dataloaders(batch_size, use_augmentation=True)

    # ========== Train ResNet18 (Teacher) ==========
    print("\n" + "="*60)
    print("PHASE 1: Training ResNet18 (Teacher)")
    print("="*60)

    resnet_model = ResNet18(in_channels=1, num_classes=10)
    resnet_model, resnet_best_acc, resnet_loaded = load_model_if_exists(
        resnet_model, "ResNet18_Teacher", device
    )

    if not resnet_loaded:
        resnet_history, resnet_best_acc = train_model(
            resnet_model, train_loader_resnet, test_loader_resnet,
            num_epochs, lr_resnet, device, "ResNet18_Teacher"
        )
    else:
        print("Skipping training - using loaded model")
        resnet_history = {'test_acc': [resnet_best_acc]}

    # ========== Train ViT (Baseline) ==========
    print("\n" + "="*60)
    print("PHASE 2: Training Vision Transformer (Baseline)")
    print("="*60)

    vit_model = VisionTransformer(
        img_size=32,
        patch_size=4,
        in_channels=1,
        num_classes=10,
        embed_dim=384,
        num_heads=6,
        num_blocks=6,
        mlp_ratio=4,
        dropout=0.1
    )

    vit_model, vit_best_acc, vit_loaded = load_model_if_exists(
        vit_model, "Vision_Transformer", device
    )

    if not vit_loaded:
        vit_history, vit_best_acc = train_model(
            vit_model, train_loader_vit, test_loader_vit,
            num_epochs, lr_vit, device, "Vision_Transformer"
        )
    else:
        print("Skipping training - using loaded model")
        vit_history = {'test_acc': [vit_best_acc]}

    # ========== Train DeiT with Distillation ==========
    print("\n" + "="*60)
    print("PHASE 3: Training DeiT with Knowledge Distillation")
    print("="*60)

    deit_model = DeiT(
        img_size=32,
        patch_size=4,
        in_channels=1,
        num_classes=10,
        embed_dim=384,
        num_heads=6,
        num_blocks=6,
        mlp_ratio=4,
        dropout=0.1
    )

    deit_model, deit_best_acc, deit_loaded = load_model_if_exists(
        deit_model, "DeiT_Distilled", device
    )

    if not deit_loaded:
        # Load best ResNet as teacher
        if not resnet_loaded:
            checkpoint = torch.load('saved_models/ResNet18_Teacher_best.pth',
                                   map_location=device, weights_only=False)
            resnet_model.load_state_dict(checkpoint['model_state_dict'])

        resnet_model = resnet_model.to(device)
        resnet_model.eval()

        deit_history, deit_best_acc = train_model(
            deit_model, train_loader_deit, test_loader_deit,
            num_epochs, lr_deit, device, "DeiT_Distilled",
            use_augment=True, teacher=resnet_model, is_deit=True
        )
    else:
        print("Skipping training - using loaded model")
        deit_history = {'test_acc': [deit_best_acc]}

    # ========== Final Comparison ==========
    print(f"\n{'='*60}")
    print("FINAL COMPARISON")
    print(f"{'='*60}")
    print(f"\nResNet18 (Teacher):")
    print(f"  Parameters: {count_parameters(resnet_model):,}")
    print(f"  Best Test Acc: {resnet_best_acc:.2f}%")
    if not resnet_loaded:
        print(f"  Final Test Acc: {resnet_history['test_acc'][-1]:.2f}%")

    print(f"\nVision Transformer (Baseline):")
    print(f"  Parameters: {count_parameters(vit_model):,}")
    print(f"  Best Test Acc: {vit_best_acc:.2f}%")
    if not vit_loaded:
        print(f"  Final Test Acc: {vit_history['test_acc'][-1]:.2f}%")

    print(f"\nDeiT (Student with Distillation):")
    print(f"  Parameters: {count_parameters(deit_model):,}")
    print(f"  Best Test Acc: {deit_best_acc:.2f}%")
    if not deit_loaded:
        print(f"  Final Test Acc: {deit_history['test_acc'][-1]:.2f}%")
    print(f"  Improvement over ViT: {deit_best_acc - vit_best_acc:+.2f}%")

    print(f"\n{'='*60}")
    print("All models saved in 'saved_models/' directory")
    print(f"{'='*60}")

KeyboardInterrupt: 