
# VIT-TINY_Training_VIA_ViT_SMALL_KD+CL

*Configuration*  
- Input resolution: *12-16-20-24-28-32*  
- Patch size: *2*

---

### Models & Parameters

- *Student (ViT‑Tiny)*  
  - Parameters: 5M  
  - Curriculum: 10 epochs per resolution stage  
  - 🎯 *Accuracy:* 79.92%

In [2]:
import os
import math
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import matplotlib.pyplot as plt
from torch.utils.data import DataLoader, Subset
from torchvision import datasets, transforms
from torchvision.transforms import AutoAugment, AutoAugmentPolicy, RandomErasing
from transformers import (
    DeiTConfig,
    DeiTForImageClassification,
    ViTConfig,
    ViTPreTrainedModel,
    ViTModel
)
from tqdm import tqdm


use_dp = False  

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
torch.manual_seed(42)
np.random.seed(42)

mean = (0.4914, 0.4822, 0.4465)
std  = (0.2023, 0.1994, 0.2010)

alpha, temp = 0.5, 4.0
stages = [(r, 10) for r in [12, 16, 20, 24, 28, 32]]


class ViTWithDistillation(ViTPreTrainedModel):
    def __init__(self, config: ViTConfig):
        super().__init__(config)
        self.vit = ViTModel(config)
        self.distill_token = nn.Parameter(torch.zeros(1, 1, config.hidden_size))
        self.classifier = nn.Linear(config.hidden_size, config.num_labels)
        self.distiller = nn.Linear(config.hidden_size, config.num_labels)
        self.init_weights()

    def forward(self, pixel_values, labels=None, teacher_logits=None,
                alpha=0.5, temperature=1.0):
        B = pixel_values.size(0)
        embeds = self.vit.embeddings(pixel_values, interpolate_pos_encoding=True)
        cls_emb, patch_emb = embeds[:, :1, :], embeds[:, 1:, :]
        dist_tok = self.distill_token.expand(B, -1, -1)
        x = torch.cat([cls_emb, dist_tok, patch_emb], dim=1)
        x = self.vit.encoder(x)[0]
        cls_out, dist_out = x[:, 0], x[:, 1]
        logits = self.classifier(cls_out)
        dist_logits = self.distiller(dist_out)
        output = {"logits": logits, "distill_logits": dist_logits}

        if labels is not None and teacher_logits is not None:
            loss_ce = F.cross_entropy(logits, labels)
            kd = F.kl_div(
                F.log_softmax(dist_logits / temperature, dim=1),
                F.softmax(teacher_logits / temperature, dim=1),
                reduction='batchmean'
            ) * (temperature ** 2)
            output["loss"] = (1 - alpha) * loss_ce + alpha * kd

        return output


teacher_config = DeiTConfig(
    image_size=32,
    patch_size=2,
    num_labels=10,
    hidden_size=384,
    num_hidden_layers=12,
    num_attention_heads=6,
    intermediate_size=1536,
    hidden_dropout_prob=0.1,
    attention_probs_dropout_prob=0.1,
    stochastic_depth_prob=0.1
)
teacher = DeiTForImageClassification(teacher_config).to(device)
if use_dp:
    teacher = nn.DataParallel(teacher)

ckpt = torch.load(
    "/kaggle/input/best-teacher/pytorch/default/1/best_teacher.pth",
    map_location=device
)
state = {k.replace('module.', ''): v for k, v in ckpt.items()}
teacher.load_state_dict(state, strict=True)
teacher.eval()
for p in teacher.parameters():
    p.requires_grad_(False)


tr_tf = transforms.Compose([
    transforms.RandomCrop(32, padding=4),
    transforms.RandomHorizontalFlip(),
    AutoAugment(policy=AutoAugmentPolicy.CIFAR10),
    transforms.ToTensor(),
    transforms.Normalize(mean, std),
    RandomErasing(p=0.2, scale=(0.02,0.2), ratio=(0.3,3.3))
])
val_tf = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean, std)
])

full_train = datasets.CIFAR10('./data', train=True, download=True, transform=tr_tf)
num_val = int(0.1 * len(full_train))
num_train = len(full_train) - num_val
train_ds, val_idx_ds = torch.utils.data.random_split(
    full_train, [num_train, num_val], generator=torch.Generator().manual_seed(42)
)
full_for_val = datasets.CIFAR10('./data', train=True, download=False, transform=val_tf)
val_ds = Subset(full_for_val, val_idx_ds.indices)

stu_cfg = ViTConfig(
    image_size=32,
    patch_size=2,
    num_labels=10,
    hidden_size=192,
    num_hidden_layers=12,
    num_attention_heads=3,
    intermediate_size=768,
    hidden_dropout_prob=0.1,
    attention_probs_dropout_prob=0.1,
    stochastic_depth_prob=0.1
)
student = ViTWithDistillation(config=stu_cfg).to(device)
opt = optim.AdamW(student.parameters(), lr=3e-4, weight_decay=1e-4)


for res, epochs in stages:
    print(f"\n--- Curriculum Stage: Crop {res}px for {epochs} epochs ---")
    train_stage_tf = transforms.Compose([
        transforms.RandomCrop(res, padding=4),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize(mean, std),
    ])
    val_stage_tf = transforms.Compose([
        transforms.CenterCrop(res),
        transforms.ToTensor(),
        transforms.Normalize(mean, std),
    ])

    base_train = datasets.CIFAR10('./data', train=True, download=False, transform=train_stage_tf)
    base_val = datasets.CIFAR10('./data', train=True, download=False, transform=val_stage_tf)
    tr_loader = DataLoader(Subset(base_train, train_ds.indices), batch_size=128, shuffle=True, num_workers=4)
    vl_loader = DataLoader(Subset(base_val, val_idx_ds.indices), batch_size=128, shuffle=False, num_workers=4)

    for ep in range(1, epochs + 1):
        # Training with progress bar
        student.train()
        for x, y in tqdm(tr_loader, desc=f"Stage {res}px Ep{ep}/{epochs}"):
            x, y = x.to(device), y.to(device)
            with torch.no_grad():
                tlog = teacher(pixel_values=x, interpolate_pos_encoding=True).logits
            out = student(pixel_values=x, labels=y, teacher_logits=tlog, alpha=alpha, temperature=temp)
            loss = out['loss'].mean()
            opt.zero_grad()
            loss.backward()
            opt.step()

        # Validation accuracy
        student.eval()
        correct, total = 0, 0
        for xb, yb in tqdm(vl_loader, desc=f"Val Stage {res}px Ep{ep}/{epochs}"):
            xb, yb = xb.to(device), yb.to(device)
            with torch.no_grad():
                preds = student(pixel_values=xb)['logits'].argmax(1)
            correct += (preds == yb).sum().item()
            total += yb.size(0)
        val_acc = 100 * correct / total
        print(f"Stage {res}px Ep{ep} val acc: {val_acc:.2f}%")


student.eval()
test_ds = datasets.CIFAR10('./data', train=False, transform=val_tf)
tloader = DataLoader(test_ds, batch_size=128, shuffle=False, num_workers=4)
correct = 0
for xb, yb in tqdm(tloader, desc="Final Test"):
    xb, yb = xb.to(device), yb.to(device)
    with torch.no_grad():
        preds = student(pixel_values=xb)['logits'].argmax(1)
    correct += (preds == yb).sum().item()
student_acc = 100 * correct / len(test_ds)
print(f"Student Acc: {student_acc:.2f}%")


  ckpt = torch.load(


Files already downloaded and verified

--- Curriculum Stage: Crop 12px for 10 epochs ---


Stage 12px Ep1/10: 100%|██████████| 352/352 [00:33<00:00, 10.58it/s]
Val Stage 12px Ep1/10: 100%|██████████| 40/40 [00:00<00:00, 42.29it/s]


Stage 12px Ep1 val acc: 26.86%


Stage 12px Ep2/10: 100%|██████████| 352/352 [00:33<00:00, 10.60it/s]
Val Stage 12px Ep2/10: 100%|██████████| 40/40 [00:01<00:00, 38.48it/s]


Stage 12px Ep2 val acc: 32.48%


Stage 12px Ep3/10: 100%|██████████| 352/352 [00:33<00:00, 10.60it/s]
Val Stage 12px Ep3/10: 100%|██████████| 40/40 [00:00<00:00, 42.35it/s]


Stage 12px Ep3 val acc: 33.96%


Stage 12px Ep4/10: 100%|██████████| 352/352 [00:33<00:00, 10.60it/s]
Val Stage 12px Ep4/10: 100%|██████████| 40/40 [00:00<00:00, 42.29it/s]


Stage 12px Ep4 val acc: 34.38%


Stage 12px Ep5/10: 100%|██████████| 352/352 [00:33<00:00, 10.60it/s]
Val Stage 12px Ep5/10: 100%|██████████| 40/40 [00:00<00:00, 41.95it/s]


Stage 12px Ep5 val acc: 37.22%


Stage 12px Ep6/10: 100%|██████████| 352/352 [00:33<00:00, 10.59it/s]
Val Stage 12px Ep6/10: 100%|██████████| 40/40 [00:00<00:00, 41.94it/s]


Stage 12px Ep6 val acc: 36.66%


Stage 12px Ep7/10: 100%|██████████| 352/352 [00:33<00:00, 10.59it/s]
Val Stage 12px Ep7/10: 100%|██████████| 40/40 [00:00<00:00, 41.42it/s]


Stage 12px Ep7 val acc: 37.74%


Stage 12px Ep8/10: 100%|██████████| 352/352 [00:33<00:00, 10.58it/s]
Val Stage 12px Ep8/10: 100%|██████████| 40/40 [00:00<00:00, 42.56it/s]


Stage 12px Ep8 val acc: 38.26%


Stage 12px Ep9/10: 100%|██████████| 352/352 [00:33<00:00, 10.59it/s]
Val Stage 12px Ep9/10: 100%|██████████| 40/40 [00:00<00:00, 41.18it/s]


Stage 12px Ep9 val acc: 38.72%


Stage 12px Ep10/10: 100%|██████████| 352/352 [00:33<00:00, 10.58it/s]
Val Stage 12px Ep10/10: 100%|██████████| 40/40 [00:01<00:00, 34.83it/s]


Stage 12px Ep10 val acc: 40.74%

--- Curriculum Stage: Crop 16px for 10 epochs ---


Stage 16px Ep1/10: 100%|██████████| 352/352 [01:00<00:00,  5.80it/s]
Val Stage 16px Ep1/10: 100%|██████████| 40/40 [00:01<00:00, 29.00it/s]


Stage 16px Ep1 val acc: 48.12%


Stage 16px Ep2/10: 100%|██████████| 352/352 [01:00<00:00,  5.81it/s]
Val Stage 16px Ep2/10: 100%|██████████| 40/40 [00:01<00:00, 29.73it/s]


Stage 16px Ep2 val acc: 51.64%


Stage 16px Ep3/10: 100%|██████████| 352/352 [01:00<00:00,  5.81it/s]
Val Stage 16px Ep3/10: 100%|██████████| 40/40 [00:01<00:00, 29.69it/s]


Stage 16px Ep3 val acc: 51.46%


Stage 16px Ep4/10: 100%|██████████| 352/352 [01:00<00:00,  5.81it/s]
Val Stage 16px Ep4/10: 100%|██████████| 40/40 [00:01<00:00, 29.49it/s]


Stage 16px Ep4 val acc: 53.58%


Stage 16px Ep5/10: 100%|██████████| 352/352 [01:00<00:00,  5.81it/s]
Val Stage 16px Ep5/10: 100%|██████████| 40/40 [00:01<00:00, 29.81it/s]


Stage 16px Ep5 val acc: 53.62%


Stage 16px Ep6/10: 100%|██████████| 352/352 [01:00<00:00,  5.81it/s]
Val Stage 16px Ep6/10: 100%|██████████| 40/40 [00:01<00:00, 29.51it/s]


Stage 16px Ep6 val acc: 55.32%


Stage 16px Ep7/10: 100%|██████████| 352/352 [01:00<00:00,  5.81it/s]
Val Stage 16px Ep7/10: 100%|██████████| 40/40 [00:01<00:00, 29.52it/s]


Stage 16px Ep7 val acc: 54.40%


Stage 16px Ep8/10: 100%|██████████| 352/352 [01:00<00:00,  5.81it/s]
Val Stage 16px Ep8/10: 100%|██████████| 40/40 [00:01<00:00, 29.55it/s]


Stage 16px Ep8 val acc: 55.06%


Stage 16px Ep9/10: 100%|██████████| 352/352 [01:00<00:00,  5.81it/s]
Val Stage 16px Ep9/10: 100%|██████████| 40/40 [00:01<00:00, 29.67it/s]


Stage 16px Ep9 val acc: 56.98%


Stage 16px Ep10/10: 100%|██████████| 352/352 [01:00<00:00,  5.81it/s]
Val Stage 16px Ep10/10: 100%|██████████| 40/40 [00:01<00:00, 29.45it/s]


Stage 16px Ep10 val acc: 56.96%

--- Curriculum Stage: Crop 20px for 10 epochs ---


Stage 20px Ep1/10: 100%|██████████| 352/352 [01:31<00:00,  3.84it/s]
Val Stage 20px Ep1/10: 100%|██████████| 40/40 [00:01<00:00, 22.21it/s]


Stage 20px Ep1 val acc: 63.94%


Stage 20px Ep2/10: 100%|██████████| 352/352 [01:31<00:00,  3.84it/s]
Val Stage 20px Ep2/10: 100%|██████████| 40/40 [00:01<00:00, 22.31it/s]


Stage 20px Ep2 val acc: 65.42%


Stage 20px Ep3/10: 100%|██████████| 352/352 [01:31<00:00,  3.85it/s]
Val Stage 20px Ep3/10: 100%|██████████| 40/40 [00:01<00:00, 22.40it/s]


Stage 20px Ep3 val acc: 65.44%


Stage 20px Ep4/10: 100%|██████████| 352/352 [01:31<00:00,  3.84it/s]
Val Stage 20px Ep4/10: 100%|██████████| 40/40 [00:01<00:00, 22.53it/s]


Stage 20px Ep4 val acc: 65.78%


Stage 20px Ep6/10: 100%|██████████| 352/352 [01:31<00:00,  3.84it/s]
Val Stage 20px Ep6/10: 100%|██████████| 40/40 [00:01<00:00, 22.21it/s]


Stage 20px Ep6 val acc: 66.00%


Stage 20px Ep7/10: 100%|██████████| 352/352 [01:31<00:00,  3.84it/s]
Val Stage 20px Ep7/10: 100%|██████████| 40/40 [00:01<00:00, 22.45it/s]


Stage 20px Ep7 val acc: 67.74%


Stage 20px Ep8/10: 100%|██████████| 352/352 [01:31<00:00,  3.84it/s]
Val Stage 20px Ep8/10: 100%|██████████| 40/40 [00:01<00:00, 22.26it/s]


Stage 20px Ep8 val acc: 67.40%


Stage 20px Ep9/10: 100%|██████████| 352/352 [01:31<00:00,  3.84it/s]
Val Stage 20px Ep9/10: 100%|██████████| 40/40 [00:01<00:00, 22.22it/s]


Stage 20px Ep9 val acc: 69.96%


Stage 20px Ep10/10: 100%|██████████| 352/352 [01:31<00:00,  3.84it/s]
Val Stage 20px Ep10/10: 100%|██████████| 40/40 [00:01<00:00, 22.47it/s]


Stage 20px Ep10 val acc: 69.22%

--- Curriculum Stage: Crop 24px for 10 epochs ---


Stage 24px Ep1/10: 100%|██████████| 352/352 [02:16<00:00,  2.58it/s]
Val Stage 24px Ep1/10: 100%|██████████| 40/40 [00:02<00:00, 16.16it/s]


Stage 24px Ep1 val acc: 73.98%


Stage 24px Ep2/10: 100%|██████████| 352/352 [02:16<00:00,  2.58it/s]
Val Stage 24px Ep2/10: 100%|██████████| 40/40 [00:02<00:00, 16.13it/s]


Stage 24px Ep2 val acc: 71.88%


Stage 24px Ep3/10: 100%|██████████| 352/352 [02:16<00:00,  2.58it/s]
Val Stage 24px Ep3/10: 100%|██████████| 40/40 [00:02<00:00, 16.19it/s]


Stage 24px Ep3 val acc: 73.26%


Stage 24px Ep4/10: 100%|██████████| 352/352 [02:16<00:00,  2.58it/s]
Val Stage 24px Ep4/10: 100%|██████████| 40/40 [00:02<00:00, 15.94it/s]


Stage 24px Ep4 val acc: 76.14%


Stage 24px Ep5/10: 100%|██████████| 352/352 [02:16<00:00,  2.58it/s]
Val Stage 24px Ep5/10: 100%|██████████| 40/40 [00:02<00:00, 15.91it/s]


Stage 24px Ep5 val acc: 75.46%


Stage 24px Ep6/10: 100%|██████████| 352/352 [02:16<00:00,  2.58it/s]
Val Stage 24px Ep6/10: 100%|██████████| 40/40 [00:02<00:00, 15.78it/s]


Stage 24px Ep6 val acc: 75.10%


Stage 24px Ep7/10: 100%|██████████| 352/352 [02:16<00:00,  2.58it/s]
Val Stage 24px Ep7/10: 100%|██████████| 40/40 [00:02<00:00, 15.77it/s]


Stage 24px Ep7 val acc: 75.62%


Stage 24px Ep8/10: 100%|██████████| 352/352 [02:16<00:00,  2.58it/s]
Val Stage 24px Ep8/10: 100%|██████████| 40/40 [00:02<00:00, 16.17it/s]


Stage 24px Ep8 val acc: 75.94%


Stage 24px Ep9/10: 100%|██████████| 352/352 [02:16<00:00,  2.58it/s]
Val Stage 24px Ep9/10: 100%|██████████| 40/40 [00:02<00:00, 16.03it/s]


Stage 24px Ep9 val acc: 76.54%


Stage 24px Ep10/10: 100%|██████████| 352/352 [02:16<00:00,  2.58it/s]
Val Stage 24px Ep10/10: 100%|██████████| 40/40 [00:02<00:00, 16.13it/s]


Stage 24px Ep10 val acc: 76.38%

--- Curriculum Stage: Crop 28px for 10 epochs ---


Stage 28px Ep1/10: 100%|██████████| 352/352 [03:22<00:00,  1.74it/s]
Val Stage 28px Ep1/10: 100%|██████████| 40/40 [00:03<00:00, 11.74it/s]


Stage 28px Ep1 val acc: 77.86%


Stage 28px Ep2/10: 100%|██████████| 352/352 [03:22<00:00,  1.74it/s]
Val Stage 28px Ep2/10: 100%|██████████| 40/40 [00:03<00:00, 11.74it/s]


Stage 28px Ep2 val acc: 78.74%


Stage 28px Ep3/10: 100%|██████████| 352/352 [03:22<00:00,  1.74it/s]
Val Stage 28px Ep3/10: 100%|██████████| 40/40 [00:03<00:00, 11.77it/s]


Stage 28px Ep3 val acc: 79.68%


Stage 28px Ep4/10: 100%|██████████| 352/352 [03:22<00:00,  1.74it/s]
Val Stage 28px Ep4/10: 100%|██████████| 40/40 [00:03<00:00, 11.73it/s]


Stage 28px Ep4 val acc: 79.56%


Stage 28px Ep5/10: 100%|██████████| 352/352 [03:22<00:00,  1.74it/s]
Val Stage 28px Ep5/10: 100%|██████████| 40/40 [00:03<00:00, 11.74it/s]


Stage 28px Ep5 val acc: 80.22%


Stage 28px Ep6/10: 100%|██████████| 352/352 [03:22<00:00,  1.74it/s]
Val Stage 28px Ep6/10: 100%|██████████| 40/40 [00:03<00:00, 11.75it/s]


Stage 28px Ep6 val acc: 80.30%


Stage 28px Ep7/10: 100%|██████████| 352/352 [03:22<00:00,  1.74it/s]
Val Stage 28px Ep7/10: 100%|██████████| 40/40 [00:03<00:00, 11.75it/s]


Stage 28px Ep7 val acc: 79.32%


Stage 28px Ep8/10: 100%|██████████| 352/352 [03:22<00:00,  1.74it/s]
Val Stage 28px Ep8/10: 100%|██████████| 40/40 [00:03<00:00, 11.68it/s]


Stage 28px Ep8 val acc: 80.56%


Stage 28px Ep9/10: 100%|██████████| 352/352 [03:22<00:00,  1.74it/s]
Val Stage 28px Ep9/10: 100%|██████████| 40/40 [00:03<00:00, 11.74it/s]


Stage 28px Ep9 val acc: 79.62%


Stage 28px Ep10/10: 100%|██████████| 352/352 [03:22<00:00,  1.74it/s]
Val Stage 28px Ep10/10: 100%|██████████| 40/40 [00:03<00:00, 11.65it/s]


Stage 28px Ep10 val acc: 79.60%

--- Curriculum Stage: Crop 32px for 10 epochs ---


Stage 32px Ep1/10: 100%|██████████| 352/352 [04:42<00:00,  1.25it/s]
Val Stage 32px Ep1/10: 100%|██████████| 40/40 [00:04<00:00,  8.92it/s]


Stage 32px Ep1 val acc: 80.98%


Stage 32px Ep2/10: 100%|██████████| 352/352 [04:42<00:00,  1.25it/s]
Val Stage 32px Ep2/10: 100%|██████████| 40/40 [00:04<00:00,  8.90it/s]


Stage 32px Ep2 val acc: 80.84%


Stage 32px Ep3/10: 100%|██████████| 352/352 [04:42<00:00,  1.25it/s]
Val Stage 32px Ep3/10: 100%|██████████| 40/40 [00:04<00:00,  8.87it/s]


Stage 32px Ep3 val acc: 80.42%


Stage 32px Ep4/10: 100%|██████████| 352/352 [04:42<00:00,  1.25it/s]
Val Stage 32px Ep4/10: 100%|██████████| 40/40 [00:04<00:00,  8.87it/s]


Stage 32px Ep4 val acc: 80.32%


Stage 32px Ep5/10: 100%|██████████| 352/352 [04:42<00:00,  1.25it/s]
Val Stage 32px Ep5/10: 100%|██████████| 40/40 [00:04<00:00,  8.88it/s]


Stage 32px Ep5 val acc: 81.56%


Stage 32px Ep6/10: 100%|██████████| 352/352 [04:42<00:00,  1.25it/s]
Val Stage 32px Ep6/10: 100%|██████████| 40/40 [00:04<00:00,  8.86it/s]


Stage 32px Ep6 val acc: 79.14%


Stage 32px Ep7/10: 100%|██████████| 352/352 [04:42<00:00,  1.25it/s]
Val Stage 32px Ep7/10: 100%|██████████| 40/40 [00:04<00:00,  8.87it/s]


Stage 32px Ep7 val acc: 80.16%


Stage 32px Ep8/10:   1%|          | 3/352 [00:02<05:06,  1.14it/s]


KeyboardInterrupt: 

Early stopping done.

In [3]:
student.eval()
test_ds = datasets.CIFAR10('./data', train=False, transform=val_tf)
tloader = DataLoader(test_ds, batch_size=128, shuffle=False, num_workers=4)
correct = 0
for xb, yb in tqdm(tloader, desc="Final Test"):
    xb, yb = xb.to(device), yb.to(device)
    with torch.no_grad():
        preds = student(pixel_values=xb)['logits'].argmax(1)
    correct += (preds == yb).sum().item()
student_acc = 100 * correct / len(test_ds)
print(f"Student Acc: {student_acc:.2f}%")


Final Test: 100%|██████████| 79/79 [00:08<00:00,  9.27it/s]

Student Acc: 79.92%



