In [None]:
#Resnet + ViT model
import os
import math
import random
import numpy as np
from tqdm import tqdm

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import datasets, transforms, models


# Hybrid ResNet-18 + ViT Model

class HybridResNetViT(nn.Module):
    def __init__(self, num_classes=10, embed_dim=256, depth=4, num_heads=4):
        super().__init__()

        resnet = models.resnet18(pretrained=True)
        self.backbone = nn.Sequential(*list(resnet.children())[:-2])

        self.conv_proj = nn.Conv2d(512, embed_dim, kernel_size=1)

        self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
        self.pos_embed = None

        encoder_layer = nn.TransformerEncoderLayer(
            d_model=embed_dim, nhead=num_heads, dim_feedforward=embed_dim*4,
            dropout=0.1, activation="gelu", batch_first=True, norm_first=True
        )
        self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=depth)

        self.norm = nn.LayerNorm(embed_dim)
        self.fc = nn.Linear(embed_dim, num_classes)

    def forward(self, x):
        B = x.size(0)
        feats = self.backbone(x)
        feats = self.conv_proj(feats)
        tokens = feats.flatten(2).transpose(1, 2)
        cls_tokens = self.cls_token.expand(B, -1, -1)
        tokens = torch.cat((cls_tokens, tokens), dim=1)

        if self.pos_embed is None or self.pos_embed.size(1) != tokens.size(1):
            self.pos_embed = nn.Parameter(torch.zeros(1, tokens.size(1), tokens.size(2), device=x.device))
            nn.init.trunc_normal_(self.pos_embed, std=0.02)
        tokens = tokens + self.pos_embed

        out = self.transformer(tokens)
        cls_out = self.norm(out[:, 0])
        return self.fc(cls_out)


# Dataloaders with Augmentations

def get_dataloaders(batch_size=64, img_size=32, num_workers=2):
    transform_train = transforms.Compose([
        transforms.RandomCrop(img_size, padding=4),
        transforms.RandomHorizontalFlip(),
        transforms.ColorJitter(0.2, 0.2, 0.2, 0.1),
        transforms.ToTensor(),
        transforms.Normalize((0.4914,0.4822,0.4465),(0.2470,0.2435,0.2616)),
        transforms.RandomErasing(p=0.25)
    ])
    transform_test = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.4914,0.4822,0.4465),(0.2470,0.2435,0.2616))
    ])

    trainset = datasets.CIFAR10(root='./data', train=True, download=True, transform=transform_train)
    testset = datasets.CIFAR10(root='./data', train=False, download=True, transform=transform_test)

    trainloader = DataLoader(trainset, batch_size=batch_size, shuffle=True, num_workers=num_workers, pin_memory=True)
    testloader  = DataLoader(testset, batch_size=batch_size, shuffle=False, num_workers=num_workers, pin_memory=True)
    return trainloader, testloader


# Training + Evaluation Helpers

def rand_bbox(W, H, lam):
    cut_rat = math.sqrt(1. - lam)
    cut_w = int(W * cut_rat)
    cut_h = int(H * cut_rat)
    cx = random.randint(0, W - 1)
    cy = random.randint(0, H - 1)
    bbx1 = max(0, cx - cut_w // 2)
    bby1 = max(0, cy - cut_h // 2)
    bbx2 = min(W, cx + cut_w // 2)
    bby2 = min(H, cy + cut_h // 2)
    return bbx1, bby1, bbx2, bby2

def evaluate(model, testloader, device):
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for images, labels in testloader:
            images = images.to(device)
            labels = labels.to(device)
            with torch.cuda.amp.autocast(enabled=(device.type=='cuda')):
                outputs = model(images)
            _, pred = outputs.max(1)
            total += labels.size(0)
            correct += pred.eq(labels).sum().item()
    return 100.0 * correct / total


# Main Training Loop

def train(model, trainloader, testloader, device,
          epochs=60, lr=3e-4, weight_decay=0.02,
          accumulation_steps=2, warmup_epochs=3,
          mixup_alpha=0.8, cutmix_prob=0.5, label_smoothing=0.05):

    criterion = nn.CrossEntropyLoss(label_smoothing=label_smoothing)
    optimizer = optim.AdamW(model.parameters(), lr=lr, weight_decay=weight_decay)
    scaler = torch.cuda.amp.GradScaler(enabled=(device.type=='cuda'))

    total_steps = len(trainloader) * epochs
    warmup_steps = len(trainloader) * warmup_epochs

    def lr_lambda(current_step):
        if current_step < warmup_steps:
            return float(current_step) / float(max(1, warmup_steps))
        progress = float(current_step - warmup_steps) / float(max(1, total_steps - warmup_steps))
        return 0.5 * (1.0 + math.cos(math.pi * progress))

    scheduler = optim.lr_scheduler.LambdaLR(optimizer, lr_lambda)

    best_acc = 0.0
    os.makedirs("checkpoints", exist_ok=True)
    global_step = 0

    for epoch in range(epochs):
        model.train()
        running_loss = 0.0
        optimizer.zero_grad()
        for i, (images, labels) in enumerate(tqdm(trainloader, desc=f"Epoch {epoch+1}/{epochs}")):
            images = images.to(device)
            labels = labels.to(device)

            # Mixup/CutMix
            if mixup_alpha > 0:
                lam = np.random.beta(mixup_alpha, mixup_alpha)
                index = torch.randperm(images.size(0)).to(device)
                if random.random() < cutmix_prob:
                    bbx1, bby1, bbx2, bby2 = rand_bbox(images.size(3), images.size(2), lam)
                    images[:, :, bby1:bby2, bbx1:bbx2] = images[index, :, bby1:bby2, bbx1:bbx2]
                    lam = 1 - ((bbx2 - bbx1) * (bby2 - bby1) / (images.size(2) * images.size(3)))
                    target_a, target_b = labels, labels[index]
                else:
                    images = images * lam + images[index] * (1 - lam)
                    target_a, target_b = labels, labels[index]
            else:
                lam = 1.0
                target_a, target_b = labels, None

            with torch.cuda.amp.autocast(enabled=(device.type=='cuda')):
                outputs = model(images)
                if target_b is not None:
                    loss = lam * nn.functional.cross_entropy(outputs, target_a, label_smoothing=0.0) + \
                           (1 - lam) * nn.functional.cross_entropy(outputs, target_b, label_smoothing=0.0)
                else:
                    loss = criterion(outputs, labels)
                loss = loss / accumulation_steps

            scaler.scale(loss).backward()

            if (i + 1) % accumulation_steps == 0 or (i + 1) == len(trainloader):
                scaler.step(optimizer)
                scaler.update()
                optimizer.zero_grad()
                scheduler.step()
                global_step += 1

            running_loss += loss.item() * accumulation_steps

        avg_loss = running_loss / len(trainloader)
        acc = evaluate(model, testloader, device)
        print(f"[Epoch {epoch+1:02d}/{epochs}] Loss: {avg_loss:.4f}  Test Acc: {acc:.2f}%")

        if acc > best_acc:
            best_acc = acc
            torch.save(model.state_dict(), "checkpoints/hybrid_resnet_vit_best.pth")

    torch.save(model.state_dict(), "checkpoints/hybrid_resnet_vit_final.pth")
    print(f"Training done. Best Acc: {best_acc:.2f}%")


# Entry Point

if __name__ == "__main__":
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print("Using device:", device)

    trainloader, testloader = get_dataloaders(batch_size=64, img_size=32, num_workers=2)

    model = HybridResNetViT(num_classes=10, embed_dim=256, depth=4, num_heads=4).to(device)

    train(model, trainloader, testloader, device,
          epochs=60,
          lr=3e-4,
          weight_decay=0.02,
          accumulation_steps=2,
          warmup_epochs=5,
          mixup_alpha=0.8,
          cutmix_prob=0.5,
          label_smoothing=0.05)


Using device: cuda


100%|██████████| 170M/170M [00:03<00:00, 43.6MB/s]


Downloading: "https://download.pytorch.org/models/resnet18-f37072fd.pth" to /root/.cache/torch/hub/checkpoints/resnet18-f37072fd.pth


100%|██████████| 44.7M/44.7M [00:00<00:00, 179MB/s]
  scaler = torch.cuda.amp.GradScaler(enabled=(device.type=='cuda'))
  with torch.cuda.amp.autocast(enabled=(device.type=='cuda')):
Epoch 1/60: 100%|██████████| 782/782 [01:05<00:00, 12.01it/s]
  with torch.cuda.amp.autocast(enabled=(device.type=='cuda')):


[Epoch 01/60] Loss: 2.1996  Test Acc: 44.26%


Epoch 2/60: 100%|██████████| 782/782 [01:00<00:00, 12.91it/s]


[Epoch 02/60] Loss: 1.8674  Test Acc: 60.83%


Epoch 3/60: 100%|██████████| 782/782 [01:01<00:00, 12.74it/s]


[Epoch 03/60] Loss: 1.7153  Test Acc: 68.56%


Epoch 4/60: 100%|██████████| 782/782 [00:59<00:00, 13.06it/s]


[Epoch 04/60] Loss: 1.6391  Test Acc: 72.00%


Epoch 5/60: 100%|██████████| 782/782 [00:58<00:00, 13.30it/s]


[Epoch 05/60] Loss: 1.5864  Test Acc: 75.53%


Epoch 6/60: 100%|██████████| 782/782 [00:59<00:00, 13.13it/s]


[Epoch 06/60] Loss: 1.5466  Test Acc: 76.50%


Epoch 7/60: 100%|██████████| 782/782 [01:00<00:00, 13.00it/s]


[Epoch 07/60] Loss: 1.5252  Test Acc: 77.31%


Epoch 8/60: 100%|██████████| 782/782 [00:58<00:00, 13.27it/s]


[Epoch 08/60] Loss: 1.5173  Test Acc: 77.76%


Epoch 9/60: 100%|██████████| 782/782 [00:59<00:00, 13.18it/s]


[Epoch 09/60] Loss: 1.4949  Test Acc: 76.40%


Epoch 10/60: 100%|██████████| 782/782 [01:01<00:00, 12.74it/s]


[Epoch 10/60] Loss: 1.4968  Test Acc: 78.24%


Epoch 11/60: 100%|██████████| 782/782 [01:00<00:00, 12.97it/s]


[Epoch 11/60] Loss: 1.4733  Test Acc: 78.61%


Epoch 12/60: 100%|██████████| 782/782 [01:00<00:00, 12.84it/s]


[Epoch 12/60] Loss: 1.4544  Test Acc: 79.76%


Epoch 13/60: 100%|██████████| 782/782 [01:00<00:00, 12.85it/s]


[Epoch 13/60] Loss: 1.4407  Test Acc: 80.30%


Epoch 14/60: 100%|██████████| 782/782 [01:02<00:00, 12.60it/s]


[Epoch 14/60] Loss: 1.4189  Test Acc: 80.63%


Epoch 15/60: 100%|██████████| 782/782 [01:00<00:00, 12.83it/s]


[Epoch 15/60] Loss: 1.4221  Test Acc: 81.40%


Epoch 16/60: 100%|██████████| 782/782 [01:00<00:00, 12.93it/s]


[Epoch 16/60] Loss: 1.4212  Test Acc: 81.32%


Epoch 17/60: 100%|██████████| 782/782 [01:01<00:00, 12.76it/s]


[Epoch 17/60] Loss: 1.4207  Test Acc: 82.04%


Epoch 18/60: 100%|██████████| 782/782 [00:59<00:00, 13.14it/s]


[Epoch 18/60] Loss: 1.3878  Test Acc: 81.74%


Epoch 19/60: 100%|██████████| 782/782 [01:00<00:00, 13.00it/s]


[Epoch 19/60] Loss: 1.3620  Test Acc: 83.25%


Epoch 20/60: 100%|██████████| 782/782 [00:59<00:00, 13.10it/s]


[Epoch 20/60] Loss: 1.3561  Test Acc: 82.95%


Epoch 21/60: 100%|██████████| 782/782 [01:01<00:00, 12.71it/s]


[Epoch 21/60] Loss: 1.3750  Test Acc: 82.91%


Epoch 22/60: 100%|██████████| 782/782 [01:00<00:00, 12.86it/s]


[Epoch 22/60] Loss: 1.3541  Test Acc: 83.37%


Epoch 23/60: 100%|██████████| 782/782 [01:00<00:00, 12.95it/s]


[Epoch 23/60] Loss: 1.3558  Test Acc: 83.96%


Epoch 24/60: 100%|██████████| 782/782 [00:59<00:00, 13.20it/s]


[Epoch 24/60] Loss: 1.3542  Test Acc: 83.18%


Epoch 25/60: 100%|██████████| 782/782 [01:00<00:00, 13.02it/s]


[Epoch 25/60] Loss: 1.3500  Test Acc: 82.96%


Epoch 26/60: 100%|██████████| 782/782 [00:57<00:00, 13.51it/s]


[Epoch 26/60] Loss: 1.3519  Test Acc: 83.65%


Epoch 27/60: 100%|██████████| 782/782 [00:58<00:00, 13.27it/s]


[Epoch 27/60] Loss: 1.3375  Test Acc: 83.08%


Epoch 28/60: 100%|██████████| 782/782 [01:01<00:00, 12.77it/s]


[Epoch 28/60] Loss: 1.3585  Test Acc: 83.99%


Epoch 29/60: 100%|██████████| 782/782 [01:00<00:00, 13.01it/s]


[Epoch 29/60] Loss: 1.3317  Test Acc: 83.19%


Epoch 30/60: 100%|██████████| 782/782 [00:59<00:00, 13.04it/s]


[Epoch 30/60] Loss: 1.3341  Test Acc: 83.39%


Epoch 31/60: 100%|██████████| 782/782 [00:59<00:00, 13.05it/s]


[Epoch 31/60] Loss: 1.3095  Test Acc: 81.95%


Epoch 32/60: 100%|██████████| 782/782 [01:00<00:00, 13.02it/s]


[Epoch 32/60] Loss: 1.3173  Test Acc: 84.18%


Epoch 33/60: 100%|██████████| 782/782 [00:58<00:00, 13.27it/s]


[Epoch 33/60] Loss: 1.2949  Test Acc: 84.85%


Epoch 34/60: 100%|██████████| 782/782 [00:58<00:00, 13.29it/s]


[Epoch 34/60] Loss: 1.2712  Test Acc: 85.24%


Epoch 35/60: 100%|██████████| 782/782 [00:59<00:00, 13.22it/s]


[Epoch 35/60] Loss: 1.2925  Test Acc: 85.01%


Epoch 36/60: 100%|██████████| 782/782 [01:00<00:00, 12.90it/s]


[Epoch 36/60] Loss: 1.2944  Test Acc: 85.48%


Epoch 37/60: 100%|██████████| 782/782 [00:58<00:00, 13.40it/s]


[Epoch 37/60] Loss: 1.2880  Test Acc: 85.08%


Epoch 38/60: 100%|██████████| 782/782 [00:58<00:00, 13.44it/s]


[Epoch 38/60] Loss: 1.2801  Test Acc: 84.48%


Epoch 39/60: 100%|██████████| 782/782 [00:57<00:00, 13.72it/s]


[Epoch 39/60] Loss: 1.3098  Test Acc: 85.39%


Epoch 40/60: 100%|██████████| 782/782 [00:58<00:00, 13.37it/s]


[Epoch 40/60] Loss: 1.2729  Test Acc: 85.13%


Epoch 41/60: 100%|██████████| 782/782 [00:57<00:00, 13.53it/s]


[Epoch 41/60] Loss: 1.2748  Test Acc: 85.34%


Epoch 42/60: 100%|██████████| 782/782 [00:59<00:00, 13.12it/s]


[Epoch 42/60] Loss: 1.2527  Test Acc: 86.10%


Epoch 43/60: 100%|██████████| 782/782 [00:58<00:00, 13.37it/s]


[Epoch 43/60] Loss: 1.2622  Test Acc: 85.76%


Epoch 44/60: 100%|██████████| 782/782 [00:58<00:00, 13.36it/s]


[Epoch 44/60] Loss: 1.2848  Test Acc: 86.17%


Epoch 45/60: 100%|██████████| 782/782 [00:58<00:00, 13.30it/s]


[Epoch 45/60] Loss: 1.2792  Test Acc: 85.08%


Epoch 46/60: 100%|██████████| 782/782 [00:58<00:00, 13.31it/s]


[Epoch 46/60] Loss: 1.2628  Test Acc: 86.26%


Epoch 47/60: 100%|██████████| 782/782 [01:00<00:00, 13.03it/s]


[Epoch 47/60] Loss: 1.2638  Test Acc: 85.98%


Epoch 48/60: 100%|██████████| 782/782 [00:59<00:00, 13.16it/s]


[Epoch 48/60] Loss: 1.2557  Test Acc: 85.63%


Epoch 49/60: 100%|██████████| 782/782 [00:59<00:00, 13.12it/s]


[Epoch 49/60] Loss: 1.2382  Test Acc: 86.32%


Epoch 50/60: 100%|██████████| 782/782 [00:59<00:00, 13.14it/s]


[Epoch 50/60] Loss: 1.2310  Test Acc: 86.71%


Epoch 51/60: 100%|██████████| 782/782 [01:00<00:00, 12.86it/s]


[Epoch 51/60] Loss: 1.2668  Test Acc: 86.46%


Epoch 52/60: 100%|██████████| 782/782 [00:59<00:00, 13.14it/s]


[Epoch 52/60] Loss: 1.2556  Test Acc: 86.80%


Epoch 53/60: 100%|██████████| 782/782 [00:59<00:00, 13.20it/s]


[Epoch 53/60] Loss: 1.2371  Test Acc: 86.65%


Epoch 54/60: 100%|██████████| 782/782 [00:58<00:00, 13.39it/s]


[Epoch 54/60] Loss: 1.2178  Test Acc: 86.55%


Epoch 55/60: 100%|██████████| 782/782 [01:00<00:00, 12.97it/s]


[Epoch 55/60] Loss: 1.2069  Test Acc: 86.41%


Epoch 56/60: 100%|██████████| 782/782 [00:59<00:00, 13.12it/s]


[Epoch 56/60] Loss: 1.2209  Test Acc: 87.04%


Epoch 57/60: 100%|██████████| 782/782 [00:59<00:00, 13.06it/s]


[Epoch 57/60] Loss: 1.1948  Test Acc: 86.68%


Epoch 58/60: 100%|██████████| 782/782 [00:59<00:00, 13.20it/s]


[Epoch 58/60] Loss: 1.2120  Test Acc: 86.76%


Epoch 59/60: 100%|██████████| 782/782 [01:00<00:00, 13.02it/s]


[Epoch 59/60] Loss: 1.2137  Test Acc: 86.51%


Epoch 60/60: 100%|██████████| 782/782 [00:59<00:00, 13.24it/s]


[Epoch 60/60] Loss: 1.2217  Test Acc: 86.99%
Training done. Best Acc: 87.04%


In [None]:
# ViT + Knowledge Distillation CIFAR-10( attained 86% accuracy, but couldn't save model)

import os
import math
import random
from tqdm import tqdm

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import datasets, transforms, models

import numpy as np


# Utilities

def trunc_normal_(tensor, mean=0., std=1.):
    with torch.no_grad():
        size = tensor.shape
        tmp = tensor.new_empty(size + (4,)).normal_()
        valid = (tmp < 2) & (tmp > -2)
        ind = valid.max(-1, keepdim=True)[1]
        tensor.data.copy_(tmp.gather(-1, ind).squeeze(-1))
        tensor.data.mul_(std).add_(mean)
        return tensor

def rand_bbox(W, H, lam):
    cut_rat = math.sqrt(1. - lam)
    cut_w = int(W * cut_rat)
    cut_h = int(H * cut_rat)
    cx = random.randint(0, W - 1)
    cy = random.randint(0, H - 1)
    bbx1 = max(0, cx - cut_w // 2)
    bby1 = max(0, cy - cut_h // 2)
    bbx2 = min(W, cx + cut_w // 2)
    bby2 = min(H, cy + cut_h // 2)
    return bbx1, bby1, bbx2, bby2


# DropPath

class DropPath(nn.Module):
    def __init__(self, drop_prob: float = 0.):
        super().__init__()
        self.drop_prob = drop_prob

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        if self.drop_prob == 0. or not self.training:
            return x
        keep_prob = 1 - self.drop_prob
        shape = (x.shape[0],) + (1,) * (x.ndim - 1)
        rand = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device)
        rand.floor_()
        return x.div(keep_prob) * rand


# Patch Embedding

class PatchEmbed(nn.Module):
    def __init__(self, img_size=32, patch_size=4, in_channels=3, embed_dim=224):
        super().__init__()
        assert img_size % patch_size == 0
        self.img_size = img_size
        self.patch_size = patch_size
        self.n_patches = (img_size // patch_size) ** 2
        k = patch_size + 2
        self.proj = nn.Conv2d(in_channels, embed_dim, kernel_size=k, stride=patch_size, padding=1)

    def forward(self, x):
        x = self.proj(x).flatten(2).transpose(1,2)
        return x


# Attention, MLP, TransformerBlock

class Attention(nn.Module):
    def __init__(self, dim, num_heads=7, qkv_bias=True, dropout=0.):
        super().__init__()
        self.num_heads = num_heads
        head_dim = dim // num_heads
        assert dim % num_heads == 0
        self.scale = head_dim ** -0.5
        self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
        self.attn_drop = nn.Dropout(dropout)
        self.proj = nn.Linear(dim, dim)
        self.proj_drop = nn.Dropout(dropout)

    def forward(self, x):
        B, N, D = x.shape
        qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, D // self.num_heads).permute(2,0,3,1,4)
        q, k, v = qkv[0], qkv[1], qkv[2]
        attn = (q @ k.transpose(-2,-1)) * self.scale
        attn = attn.softmax(dim=-1)
        attn = self.attn_drop(attn)
        x = (attn @ v).transpose(1,2).reshape(B,N,D)
        x = self.proj(x)
        x = self.proj_drop(x)
        return x

class MLP(nn.Module):
    def __init__(self, in_features, hidden_features, out_features, dropout=0.):
        super().__init__()
        self.fc1 = nn.Linear(in_features, hidden_features)
        self.act = nn.GELU()
        self.fc2 = nn.Linear(hidden_features, out_features)
        self.drop = nn.Dropout(dropout)
    def forward(self,x):
        x = self.fc1(x)
        x = self.act(x)
        x = self.drop(x)
        x = self.fc2(x)
        x = self.drop(x)
        return x

class TransformerEncoderBlock(nn.Module):
    def __init__(self, dim, num_heads, mlp_ratio=3.0, qkv_bias=True, dropout=0., drop_path=0.):
        super().__init__()
        self.norm1 = nn.LayerNorm(dim, eps=1e-6)
        self.attn = Attention(dim, num_heads=num_heads, qkv_bias=qkv_bias, dropout=dropout)
        self.drop_path1 = DropPath(drop_path) if drop_path > 0 else nn.Identity()
        self.norm2 = nn.LayerNorm(dim, eps=1e-6)
        mlp_hidden = int(dim * mlp_ratio)
        self.mlp = MLP(dim, mlp_hidden, dim, dropout=dropout)
        self.drop_path2 = DropPath(drop_path) if drop_path > 0 else nn.Identity()
        self.gamma1 = nn.Parameter(1e-4 * torch.ones(dim))
        self.gamma2 = nn.Parameter(1e-4 * torch.ones(dim))
    def forward(self,x):
        x = x + self.drop_path1(self.gamma1 * self.attn(self.norm1(x)))
        x = x + self.drop_path2(self.gamma2 * self.mlp(self.norm2(x)))
        return x

# Vision Transformer

class VisionTransformer(nn.Module):
    def __init__(self,
                 img_size=32,
                 patch_size=4,
                 in_channels=3,
                 num_classes=10,
                 embed_dim=224,
                 depth=7,
                 num_heads=7,
                 mlp_ratio=3.0,
                 dropout=0.0,
                 drop_path_rate=0.05):
        super().__init__()
        self.patch_embed = PatchEmbed(img_size=img_size, patch_size=patch_size,
                                      in_channels=in_channels, embed_dim=embed_dim)
        num_patches = self.patch_embed.n_patches

        self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
        self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim))
        self.pos_drop = nn.Dropout(dropout)

        dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)]
        self.blocks = nn.ModuleList([
            TransformerEncoderBlock(embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio,
                                    qkv_bias=True, dropout=dropout, drop_path=dpr[i])
            for i in range(depth)
        ])

        self.norm = nn.LayerNorm(embed_dim, eps=1e-6)
        self.pre_logits = nn.Sequential(nn.LayerNorm(embed_dim), nn.Linear(embed_dim, embed_dim))
        self.head = nn.Linear(embed_dim, num_classes)

        trunc_normal_(self.pos_embed, std=0.02)
        trunc_normal_(self.cls_token, std=0.02)
        self.apply(self._init_weights)

    def _init_weights(self, m):
        if isinstance(m, nn.Linear):
            nn.init.xavier_uniform_(m.weight)
            if m.bias is not None: nn.init.zeros_(m.bias)
        elif isinstance(m, nn.Conv2d):
            nn.init.kaiming_normal_(m.weight, mode='fan_out')
            if m.bias is not None: nn.init.zeros_(m.bias)
        elif isinstance(m, nn.LayerNorm):
            nn.init.ones_(m.weight)
            nn.init.zeros_(m.bias)

    def forward(self, x):
        B = x.shape[0]
        x = self.patch_embed(x)
        cls_tokens = self.cls_token.expand(B, -1, -1)
        x = torch.cat((cls_tokens, x), dim=1)
        x = x + self.pos_embed
        x = self.pos_drop(x)
        for blk in self.blocks: x = blk(x)
        x = self.norm(x)
        cls_out = x[:, 0]
        x = self.pre_logits(cls_out)
        x = self.head(x)
        return x


# DataLoaders + Augmentations

def get_dataloaders(batch_size=16, img_size=32, num_workers=2):
    transform_train = transforms.Compose([
        transforms.RandomCrop(img_size, padding=4),
        transforms.RandomHorizontalFlip(),
        transforms.ColorJitter(0.2,0.2,0.2,0.1),
        transforms.ToTensor(),
        transforms.Normalize((0.4914,0.4822,0.4465),(0.2470,0.2435,0.2616)),
        transforms.RandomErasing(p=0.25)
    ])
    transform_test = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.4914,0.4822,0.4465),(0.2470,0.2435,0.2616))
    ])

    trainset = datasets.CIFAR10(root='./data', train=True, download=True, transform=transform_train)
    testset  = datasets.CIFAR10(root='./data', train=False, download=True, transform=transform_test)
    trainloader = DataLoader(trainset, batch_size=batch_size, shuffle=True, num_workers=num_workers, pin_memory=True)
    testloader  = DataLoader(testset, batch_size=batch_size, shuffle=False, num_workers=num_workers, pin_memory=True)
    return trainloader, testloader


# Knowledge Distillation Loss

class DistillationLoss(nn.Module):
    def __init__(self, alpha=0.5, temperature=4.0):
        super().__init__()
        self.alpha = alpha
        self.temperature = temperature
        self.ce = nn.CrossEntropyLoss()
        self.kl = nn.KLDivLoss(reduction='batchmean')

    def forward(self, student_logits, labels, teacher_logits):
        # CE with ground truth
        loss_ce = self.ce(student_logits, labels)
        # KL with teacher
        T = self.temperature
        student_soft = nn.functional.log_softmax(student_logits / T, dim=1)
        teacher_soft = nn.functional.softmax(teacher_logits / T, dim=1)
        loss_kl = self.kl(student_soft, teacher_soft) * (T * T)
        # combined
        return (1 - self.alpha) * loss_ce + self.alpha * loss_kl


# Evaluation

def evaluate(model, testloader, device):
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for images, labels in testloader:
            images = images.to(device)
            labels = labels.to(device)
            outputs = model(images)
            _, pred = outputs.max(1)
            total += labels.size(0)
            correct += pred.eq(labels).sum().item()
    return 100.0 * correct / total


# Training Loop with Distillation
def train_distill(student, teacher, trainloader, testloader, device,
                   epochs=90, lr=3e-4, weight_decay=0.02, accumulation_steps=2,
                   warmup_epochs=5, mixup_alpha=0.8, cutmix_prob=0.5):

    criterion = DistillationLoss(alpha=0.5, temperature=4.0)
    optimizer = optim.AdamW(student.parameters(), lr=lr, weight_decay=weight_decay)
    scaler = torch.cuda.amp.GradScaler(enabled=(device.type=='cuda'))
    total_steps = len(trainloader) * epochs
    warmup_steps = len(trainloader) * warmup_epochs
    def lr_lambda(current_step):
        if current_step < warmup_steps:
            return float(current_step) / max(1, warmup_steps)
        progress = float(current_step - warmup_steps) / float(max(1, total_steps - warmup_steps))
        return 0.5 * (1.0 + math.cos(math.pi * progress))
    scheduler = optim.lr_scheduler.LambdaLR(optimizer, lr_lambda)
    best_acc = 0.0
    os.makedirs("checkpoints", exist_ok=True)
    global_step = 0

    teacher.eval()  # teacher is frozen
    for epoch in range(epochs):
        student.train()
        running_loss = 0.0
        optimizer.zero_grad()
        for i, (images, labels) in enumerate(tqdm(trainloader, desc=f"Epoch {epoch+1}/{epochs}")):
            images = images.to(device)
            labels = labels.to(device)

            # Mixup / CutMix
            apply_aug = (mixup_alpha > 0)
            if apply_aug:
                if random.random() < cutmix_prob:
                    lam = np.random.beta(mixup_alpha, mixup_alpha)
                    index = torch.randperm(images.size(0)).to(images.device)
                    bbx1, bby1, bbx2, bby2 = rand_bbox(images.size(3), images.size(2), lam)
                    images[:, :, bby1:bby2, bbx1:bbx2] = images[index, :, bby1:bby2, bbx1:bbx2]
                    lam_adj = 1 - ((bbx2 - bbx1) * (bby2 - bby1) / (images.size(2) * images.size(3)))
                    target_a, target_b = labels, labels[index]
                else:
                    lam = np.random.beta(mixup_alpha, mixup_alpha)
                    index = torch.randperm(images.size(0)).to(images.device)
                    images = images * lam + images[index] * (1 - lam)
                    target_a, target_b = labels, labels[index]
                    lam_adj = lam
            else:
                target_a, target_b = labels, None
                lam_adj = 1.0

            with torch.cuda.amp.autocast(enabled=(device.type=='cuda')):
                student_logits = student(images)
                with torch.no_grad():
                    teacher_logits = teacher(images)
                if target_b is not None:
                    # weighted CE for mixup
                    loss = lam_adj * criterion(student_logits, target_a, teacher_logits) + \
                           (1 - lam_adj) * criterion(student_logits, target_b, teacher_logits)
                else:
                    loss = criterion(student_logits, labels, teacher_logits)
                loss = loss / accumulation_steps

            scaler.scale(loss).backward()

            if (i + 1) % accumulation_steps == 0 or (i + 1) == len(trainloader):
                scaler.step(optimizer)
                scaler.update()
                optimizer.zero_grad()
                scheduler.step()
                global_step += 1

            running_loss += loss.item() * accumulation_steps

        avg_loss = running_loss / len(trainloader)
        acc = evaluate(student, testloader, device)
        print(f"[Epoch {epoch+1}/{epochs}] Loss: {avg_loss:.4f}  Test Acc: {acc:.2f}%")

        if acc > best_acc:
            best_acc = acc
            torch.save(student.state_dict(), "checkpoints/vit_student_best.pth")

    torch.save(student.state_dict(), "checkpoints/vit_student_final.pth")
    print(f"Training done. Best Acc: {best_acc:.2f}%. Final model saved.")


# Main

if __name__ == "__main__":
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print("Using device:", device)

    BATCH_SIZE = 16
    ACCUM_STEPS = 2
    EPOCHS = 90
    WARMUP = 5

    trainloader, testloader = get_dataloaders(batch_size=BATCH_SIZE, img_size=32, num_workers=2)

    # Teacher: ResNet-18 pretrained on CIFAR-10 (or train quickly first)
    teacher = models.resnet18(num_classes=10)
    teacher = teacher.to(device)
    # You can load a pretrained checkpoint here if available
    # teacher.load_state_dict(torch.load("resnet18_cifar10.pth"))

    # Student: ViT
    student = VisionTransformer(
        img_size=32,
        patch_size=4,
        embed_dim=224,
        depth=7,
        num_heads=7,
        mlp_ratio=3.0,
        dropout=0.1,
        drop_path_rate=0.05
    ).to(device)

    train_distill(student, teacher, trainloader, testloader, device,
                   epochs=EPOCHS,
                   lr=3e-4,
                   weight_decay=0.02,
                   accumulation_steps=ACCUM_STEPS,
                   warmup_epochs=WARMUP,
                   mixup_alpha=0.8,
                   cutmix_prob=0.5)
