# Knowledge Distillation on CIFAR-10

Mục tiêu: Cài đặt nhiều phương pháp Knowledge Distillation (KD) khác nhau cho bài toán phân loại CIFAR-10.

Yêu cầu chính:
- Dataset: CIFAR-10 (train/test chuẩn của torchvision)
- Model teacher: pretrained trên CIFAR-10; model student nhỏ hơn, chưa train.
- Một biến chung `KD_EPOCHS` xác định số epoch train cho TẤT CẢ phương pháp KD.
- Trước mỗi phương pháp có một cell markdown ghi tên phương pháp.
- Mỗi phương pháp in ra: tổng thời gian train + accuracy trên train và test.

Các phần dưới đây cung cấp phần setup dùng chung (dataloader, model, util, teacher), sau đó là từng phương pháp KD.

In [1]:
# %% Shared Setup: dependencies, config, data, models, utils
import os, time, math, random, warnings
warnings.filterwarnings("ignore")
from dataclasses import dataclass
from typing import Tuple, Dict, Optional

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader, Subset

import torchvision
import torchvision.transforms as T

from torch.cuda.amp import GradScaler, autocast
import copy

# Reproducibility
SEED = int(os.environ.get("SEED", 42))
random.seed(SEED)
torch.manual_seed(SEED)

# Device
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Device:", DEVICE)

# Global epochs shared by all KD methods
KD_EPOCHS = int(os.environ.get("KD_EPOCHS", 10))  # chỉnh tại đây nếu muốn
BATCH_SIZE = 64
NUM_WORKERS = 4
NUM_CLASSES = 10
VAL_RATIO = float(os.environ.get("VAL_RATIO", 0.1))  # 10% train -> val split

# Teacher training epochs if training from scratch
TEACHER_EPOCHS = int(os.environ.get("TEACHER_EPOCHS", 100))

# Ensure checkpoints directory exists
CKPT_DIR = os.environ.get("CKPT_DIR", "./checkpoints")
os.makedirs(CKPT_DIR, exist_ok=True)

# Data: CIFAR-10
mean = (0.4914, 0.4822, 0.4465)
std = (0.2470, 0.2435, 0.2616)

train_tf = T.Compose([
    T.RandomCrop(32, padding=4),
    T.RandomHorizontalFlip(),
    T.ToTensor(),
    T.Normalize(mean, std),
])

test_tf = T.Compose([
    T.ToTensor(),
    T.Normalize(mean, std),
])

# Build full train set twice to allow different transforms for train vs validation
train_full_aug = torchvision.datasets.CIFAR10(root="./data", train=True, download=True, transform=train_tf)
train_full_plain = torchvision.datasets.CIFAR10(root="./data", train=True, download=False, transform=test_tf)

# Create reproducible train/val split
N = len(train_full_aug)
val_size = max(1, int(VAL_RATIO * N))
train_size = N - val_size
gen = torch.Generator().manual_seed(SEED)
perm = torch.randperm(N, generator=gen)
val_idx = perm[:val_size].tolist()
train_idx = perm[val_size:].tolist()

train_set = Subset(train_full_aug, train_idx)
val_set = Subset(train_full_plain, val_idx)

test_set = torchvision.datasets.CIFAR10(root="./data", train=False, download=True, transform=test_tf)

train_loader = DataLoader(train_set, batch_size=BATCH_SIZE, shuffle=True, num_workers=NUM_WORKERS, pin_memory=True)
val_loader = DataLoader(val_set, batch_size=BATCH_SIZE, shuffle=False, num_workers=NUM_WORKERS, pin_memory=True)
test_loader = DataLoader(test_set, batch_size=BATCH_SIZE, shuffle=False, num_workers=NUM_WORKERS, pin_memory=True)

# Student model (smaller than ResNet18)
class SmallNet(nn.Module):
    def __init__(self, num_classes=10):
        super().__init__()
        self.features = nn.Sequential(
            nn.Conv2d(3, 32, 3, padding=1), nn.BatchNorm2d(32), nn.ReLU(inplace=True),
            nn.Conv2d(32, 32, 3, padding=1), nn.BatchNorm2d(32), nn.ReLU(inplace=True),
            nn.MaxPool2d(2),
            nn.Conv2d(32, 64, 3, padding=1), nn.BatchNorm2d(64), nn.ReLU(inplace=True),
            nn.Conv2d(64, 64, 3, padding=1), nn.BatchNorm2d(64), nn.ReLU(inplace=True),
            nn.MaxPool2d(2),
            nn.Conv2d(64, 128, 3, padding=1), nn.BatchNorm2d(128), nn.ReLU(inplace=True),
            nn.MaxPool2d(2)
        )
        self.classifier = nn.Sequential(
            nn.AdaptiveAvgPool2d(1),
            nn.Flatten(),
            nn.Linear(128, num_classes)
        )
    def forward(self, x):
        x = self.features(x)
        x = self.classifier(x)
        return x

# Teacher model: minimal ResNet18 for CIFAR-10 (3x3 stem, no initial maxpool)
class BasicBlock(nn.Module):
    expansion = 1
    def __init__(self, in_planes, planes, stride=1):
        super().__init__()
        self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(planes)
        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(planes)
        self.shortcut = nn.Sequential()
        if stride != 1 or in_planes != planes:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(planes)
            )
    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.bn2(self.conv2(out))
        out += self.shortcut(x)
        out = F.relu(out)
        return out

class ResNet18(nn.Module):
    def __init__(self, num_classes=10):
        super().__init__()
        self.in_planes = 64
        self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(64)
        # layers: 2,2,2,2
        self.layer1 = self._make_layer(64, 2, stride=1)
        self.layer2 = self._make_layer(128, 2, stride=2)
        self.layer3 = self._make_layer(256, 2, stride=2)
        self.layer4 = self._make_layer(512, 2, stride=2)
        self.avgpool = nn.AdaptiveAvgPool2d((1,1))
        self.fc = nn.Linear(512, num_classes)
    def _make_layer(self, planes, num_blocks, stride):
        strides = [stride] + [1]*(num_blocks-1)
        layers = []
        for s in strides:
            layers.append(BasicBlock(self.in_planes, planes, s))
            self.in_planes = planes * BasicBlock.expansion
        return nn.Sequential(*layers)
    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.layer1(out)
        out = self.layer2(out)
        out = self.layer3(out)
        out = self.layer4(out)
        out = self.avgpool(out)
        out = torch.flatten(out, 1)
        out = self.fc(out)
        return out


def build_teacher(num_classes=10):
    # Build a CIFAR-10 style ResNet18 from scratch
    return ResNet18(num_classes)

# Point teacher checkpoint to checkpoints/kd_teacher.pth
TEACHER_CKPT = os.path.join(CKPT_DIR, "kd_teacher.pth")

# Train/eval utilities
@torch.no_grad()
def evaluate(model: nn.Module, loader: DataLoader, device: torch.device) -> Tuple[float, float]:
    model.eval()
    total, correct, loss_sum = 0, 0, 0.0
    criterion = nn.CrossEntropyLoss()
    for x, y in loader:
        x, y = x.to(device, non_blocking=True), y.to(device, non_blocking=True)
        out = model(x)
        logits = out[0] if isinstance(out, (tuple, list)) else out
        loss = criterion(logits, y)
        loss_sum += loss.item() * y.size(0)
        pred = logits.argmax(dim=1)
        correct += (pred == y).sum().item()
        total += y.size(0)
    acc = correct / total
    avg_loss = loss_sum / total
    return acc, avg_loss

@dataclass
class TrainResult:
    train_acc: float
    test_acc: float
    train_time_sec: float


def top1_acc(logits: torch.Tensor, y: torch.Tensor) -> float:
    return (logits.argmax(dim=1) == y).float().mean().item()


def train_ce(model: nn.Module, loader: DataLoader, optimizer, device: torch.device, scaler: Optional[GradScaler] = None):
    model.train()
    criterion = nn.CrossEntropyLoss()
    total_correct, total_samples = 0, 0
    for x, y in loader:
        x, y = x.to(device, non_blocking=True), y.to(device, non_blocking=True)
        optimizer.zero_grad(set_to_none=True)
        if scaler is not None:
            with autocast():
                out = model(x)
                logits = out[0] if isinstance(out, (tuple, list)) else out
                loss = criterion(logits, y)
            scaler.scale(loss).backward()
            scaler.step(optimizer)
            scaler.update()
        else:
            out = model(x)
            logits = out[0] if isinstance(out, (tuple, list)) else out
            loss = criterion(logits, y)
            loss.backward()
            optimizer.step()
        total_correct += (logits.argmax(1) == y).sum().item()
        total_samples += y.size(0)
    return total_correct / max(1, total_samples)

def flatten_features(x):
    return torch.flatten(x, 1)

# Hooks to grab intermediate features (for feature distillation methods)
class FeatureHook:
    def __init__(self, module: nn.Module):
        self.feat = None
        module.register_forward_hook(self.hook)
    def hook(self, module, input, output):
        self.feat = output

# Teacher build/load
teacher = build_teacher(NUM_CLASSES).to(DEVICE)

TEACHER_OPT_LR = float(os.environ.get("TEACHER_OPT_LR", 0.1))
TEACHER_WD = float(os.environ.get("TEACHER_WD", 5e-4))

if os.path.isfile(TEACHER_CKPT):
    teacher.load_state_dict(torch.load(TEACHER_CKPT, map_location=DEVICE))
    print(f"Loaded teacher weights from {TEACHER_CKPT}")
else:
    print("Teacher checkpoint not found. Training ResNet18 teacher from scratch on CIFAR-10 (select best by val acc).")
    optimizer_t = optim.SGD(teacher.parameters(), lr=TEACHER_OPT_LR, momentum=0.9, weight_decay=TEACHER_WD)
    scheduler_t = optim.lr_scheduler.CosineAnnealingLR(optimizer_t, T_max=TEACHER_EPOCHS)
    scaler_t = GradScaler(enabled=torch.cuda.is_available())
    best_val, best_state = 0.0, None
    start_t = time.time()
    for e in range(TEACHER_EPOCHS):
        acc_train = train_ce(teacher, train_loader, optimizer_t, DEVICE, scaler=scaler_t)
        acc_val, _ = evaluate(teacher, val_loader, DEVICE)
        if acc_val > best_val:
            best_val = acc_val
            best_state = copy.deepcopy(teacher.state_dict())
            torch.save(best_state, TEACHER_CKPT)
            tag = " (saved best)"
        else:
            tag = ""
        scheduler_t.step()
        print(f"Teacher epoch {e+1}/{TEACHER_EPOCHS} - train_acc: {acc_train:.4f} - val_acc: {acc_val:.4f}{tag}")
    elapsed_t = time.time() - start_t
    if best_state is not None:
        teacher.load_state_dict(best_state)
    print(f"Saved best teacher to {TEACHER_CKPT}. Best val_acc: {best_val:.4f}. Training time: {elapsed_t:.1f}s")

for p in teacher.parameters():
    p.requires_grad_(False)
teacher.eval()

acc_test, _ = evaluate(teacher, test_loader, DEVICE)
print(f"Teacher test acc: {acc_test:.4f}")

print("KD_EPOCHS =", KD_EPOCHS)

Device: cuda


100%|██████████| 170M/170M [00:02<00:00, 66.4MB/s]


Teacher checkpoint not found. Training ResNet18 teacher from scratch on CIFAR-10 (select best by val acc).
Teacher epoch 1/100 - train_acc: 0.2973 - val_acc: 0.4008 (saved best)
Teacher epoch 2/100 - train_acc: 0.4454 - val_acc: 0.4654 (saved best)
Teacher epoch 3/100 - train_acc: 0.5633 - val_acc: 0.5748 (saved best)
Teacher epoch 4/100 - train_acc: 0.6345 - val_acc: 0.5894 (saved best)
Teacher epoch 5/100 - train_acc: 0.6826 - val_acc: 0.6918 (saved best)
Teacher epoch 6/100 - train_acc: 0.7226 - val_acc: 0.7206 (saved best)
Teacher epoch 7/100 - train_acc: 0.7502 - val_acc: 0.7362 (saved best)
Teacher epoch 8/100 - train_acc: 0.7654 - val_acc: 0.6730
Teacher epoch 9/100 - train_acc: 0.7767 - val_acc: 0.7410 (saved best)
Teacher epoch 10/100 - train_acc: 0.7874 - val_acc: 0.7772 (saved best)
Teacher epoch 11/100 - train_acc: 0.7947 - val_acc: 0.7330
Teacher epoch 12/100 - train_acc: 0.7983 - val_acc: 0.7568
Teacher epoch 13/100 - train_acc: 0.8023 - val_acc: 0.7812 (saved best)
Teach

## So sánh mô hình teacher và student

In [2]:
  # %% So sánh teacher vs student: tham số, layer, FLOPs/MACs, Latency
from collections import defaultdict

# Choose the canonical student used by most methods
student_ref = SmallNet(NUM_CLASSES).to(DEVICE)

def count_params(m: nn.Module):
    total = sum(p.numel() for p in m.parameters())
    trainable = sum(p.numel() for p in m.parameters() if p.requires_grad)
    return total, trainable


def count_layers(m: nn.Module):
    # count conv + linear layers as "layers"
    return sum(isinstance(mod, (nn.Conv2d, nn.Linear)) for mod in m.modules())


# Lightweight FLOPs/MACs estimation using hooks (counts MACs ~ multiply-adds)
def estimate_macs(model: nn.Module, input_size=(1, 3, 32, 32)):
    macs = 0

    def conv_hook(self, inp, out):
        nonlocal macs
        # inp[0]: N,Cin,Hin,Win ; out: N,Cout,Hout,Wout
        x = inp[0]
        N, Cin, Hin, Win = x.shape
        Cout, Hout, Wout = out.shape[1:]
        kH, kW = self.kernel_size if isinstance(self.kernel_size, tuple) else (self.kernel_size, self.kernel_size)
        # MACs per output element: Cin * kH * kW
        macs += N * Cout * Hout * Wout * Cin * kH * kW

    def linear_hook(self, inp, out):
        nonlocal macs
        # inp[0]: N, in_features ; out: N, out_features
        N, in_f = inp[0].shape
        out_f = out.shape[1]
        macs += N * in_f * out_f

    hooks = []
    for mod in model.modules():
        if isinstance(mod, nn.Conv2d):
            hooks.append(mod.register_forward_hook(conv_hook))
        elif isinstance(mod, nn.Linear):
            hooks.append(mod.register_forward_hook(linear_hook))

    model.eval()
    with torch.no_grad():
        dummy = torch.randn(*input_size, device=DEVICE)
        _ = model(dummy)

    for h in hooks:
        h.remove()

    # FLOPs ~ 2 * MACs if counting MUL+ADD as two ops. Report both.
    flops = 2 * macs
    return macs, flops


def pretty(n):
    # format large numbers with units
    for unit in ["", "K", "M", "B", "T"]:
        if abs(n) < 1000:
            return f"{n:.2f}{unit}"
        n /= 1000
    return f"{n:.2f}P"


def measure_latency_ms_per_image(model: nn.Module, batch_size: int = 128, repeats: int = 30, warmup: int = 10):
    model.eval()
    x = torch.randn(batch_size, 3, 32, 32, device=DEVICE)
    # Warmup
    with torch.no_grad():
        for _ in range(warmup):
            _ = model(x)
    if torch.cuda.is_available():
        torch.cuda.synchronize()
    # Timed runs
    times = []
    with torch.no_grad():
        for _ in range(repeats):
            t0 = time.perf_counter()
            _ = model(x)
            if torch.cuda.is_available():
                torch.cuda.synchronize()
            t1 = time.perf_counter()
            times.append((t1 - t0))
    avg_s = sum(times) / len(times)
    ms_per_image = (avg_s / batch_size) * 1000.0
    return ms_per_image


def report_model_stats(name: str, m: nn.Module):
    total, trainable = count_params(m)
    layers = count_layers(m)
    macs, flops = estimate_macs(m)
    latency = measure_latency_ms_per_image(m)
    print(
        f"[{name}]\n"
        f"  Params: {total} (trainable: {trainable}) -> {pretty(total)} params\n"
        f"  Layers (Conv+FC): {layers}\n"
        f"  MACs (32x32): {pretty(macs)}  |  FLOPs~ {pretty(flops)}\n"
        f"  Latency: ~{latency:.3f} ms / image (batch=128, avg over repeats)\n"
    )

print("So sánh mô hình teacher và student (đầu vào 32x32):\n")
report_model_stats("Teacher (ResNet18)", teacher)
report_model_stats("Student (SmallNet)", student_ref)

# Cleanup
del student_ref
torch.cuda.empty_cache() if torch.cuda.is_available() else None

So sánh mô hình teacher và student (đầu vào 32x32):

[Teacher (ResNet18)]
  Params: 11173962 (trainable: 0) -> 11.17M params
  Layers (Conv+FC): 21
  MACs (32x32): 555.42M  |  FLOPs~ 1.11B
  Latency: ~0.162 ms / image (batch=128, avg over repeats)

[Student (SmallNet)]
  Params: 141354 (trainable: 141354) -> 141.35K params
  Layers (Conv+FC): 6
  MACs (32x32): 29.20M  |  FLOPs~ 58.40M
  Latency: ~0.016 ms / image (batch=128, avg over repeats)



## Baseline: Train-from-scratch cho student

In [3]:
# %% Train student WITHOUT teacher (baseline CE) and save kd_student_v0.pth
# Hyperparams
LR = 0.1
WEIGHT_DECAY = 5e-4

# Build student
student = SmallNet(NUM_CLASSES).to(DEVICE)
optimizer = optim.SGD(student.parameters(), lr=LR, momentum=0.9, weight_decay=WEIGHT_DECAY)
scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=KD_EPOCHS)
scaler = GradScaler(enabled=torch.cuda.is_available())

best_val, best_state = 0.0, None
start = time.time()
for epoch in range(KD_EPOCHS):
    # one epoch CE-only
    train_acc_epoch = train_ce(student, train_loader, optimizer, DEVICE, scaler=scaler)
    val_acc_epoch, _ = evaluate(student, val_loader, DEVICE)
    if val_acc_epoch > best_val:
        best_val = val_acc_epoch
        best_state = copy.deepcopy(student.state_dict())
        tag = " (best)"
    else:
        tag = ""
    scheduler.step()
    print(f"Epoch {epoch+1}/{KD_EPOCHS} - train_acc: {train_acc_epoch:.4f} - val_acc: {val_acc_epoch:.4f}{tag}")

elapsed = time.time() - start
if best_state is not None:
    student.load_state_dict(best_state)

# Save checkpoint
os.makedirs(CKPT_DIR, exist_ok=True)
ckpt_path = os.path.join(CKPT_DIR, "kd_student_v0.pth")
torch.save(student.state_dict(), ckpt_path)
print(f"Saved student checkpoint (v0) to {ckpt_path}")

# Final report
train_acc_final, _ = evaluate(student, train_loader, DEVICE)
val_acc_final, _ = evaluate(student, val_loader, DEVICE)
test_acc_final, _ = evaluate(student, test_loader, DEVICE)
print({
    "method": "Baseline CE (no teacher)",
    "train_time_sec": round(elapsed, 2),
    "train_acc": round(train_acc_final, 4),
    "val_acc": round(val_acc_final, 4),
    "test_acc": round(test_acc_final, 4)
})

Epoch 1/10 - train_acc: 0.4287 - val_acc: 0.4850 (best)
Epoch 2/10 - train_acc: 0.6060 - val_acc: 0.5872 (best)
Epoch 3/10 - train_acc: 0.6714 - val_acc: 0.5096
Epoch 4/10 - train_acc: 0.7094 - val_acc: 0.6482 (best)
Epoch 5/10 - train_acc: 0.7416 - val_acc: 0.6676 (best)
Epoch 6/10 - train_acc: 0.7661 - val_acc: 0.7648 (best)
Epoch 7/10 - train_acc: 0.7890 - val_acc: 0.7742 (best)
Epoch 8/10 - train_acc: 0.8105 - val_acc: 0.7892 (best)
Epoch 9/10 - train_acc: 0.8344 - val_acc: 0.8280 (best)
Epoch 10/10 - train_acc: 0.8469 - val_acc: 0.8378 (best)
Saved student checkpoint (v0) to ./checkpoints/kd_student_v0.pth
{'method': 'Baseline CE (no teacher)', 'train_time_sec': 109.38, 'train_acc': 0.8603, 'val_acc': 0.8378, 'test_acc': 0.8277}


## Phương pháp 1: Vanilla KD

**Ý tưởng:** Để student học theo "soft targets" của  thay vì chỉ dựa vào nhãn cứng. Soft targets giữ lại thông tin về độ tự tin giữa các lớp.

Ký hiệu cho một mẫu $(x, y)$:
- $z_t$ = logits của teacher, $z_s$ = logits của student.
- Nhiệt độ (temperature) $\tau>0$ làm mềm phân phối:  
  $$
  p_t^{(\tau)} = \mathrm{softmax}\!\left(\frac{z_t}{\tau}\right), \quad
  p_s^{(\tau)} = \mathrm{softmax}\!\left(\frac{z_s}{\tau}\right).
  $$

Hàm mất mát KD dùng KL-divergence giữa phân phối mềm của teacher và student, có hệ số hiệu chỉnh $\tau^2$:
$$
\mathcal{L}_{\mathrm{KD}} = \tau^2 \, \mathrm{KL}\big( p_t^{(\tau)} \,\Vert\, p_s^{(\tau)} \big).
$$

Kết hợp với cross-entropy (CE) chuẩn theo nhãn thật $y$:
$$
\mathcal{L} = \alpha\,\mathcal{L}_{\mathrm{KD}} + (1-\alpha)\, \mathrm{CE}(z_s, y).
$$
- $\alpha\in[0,1]$ điều chỉnh tỷ trọng giữa “học theo giáo viên” và “học theo nhãn thật”.
- $\tau$ lớn làm phân phối mềm hơn (giảm cực đoan), giúp student học được cấu trúc liên lớp.

In [4]:
# %% Train: Vanilla KD
# Hyperparams for KD
T = 4.0
ALPHA = 0.7
LR = 0.1

# KD loss function for Vanilla KD
def kd_loss_vanilla(logits_s, logits_t, y, T=4.0, alpha=0.5):
    ce = F.cross_entropy(logits_s, y)
    p_s = F.log_softmax(logits_s / T, dim=1)
    p_t = F.softmax(logits_t / T, dim=1)
    kd = F.kl_div(p_s, p_t, reduction='batchmean') * (T * T)
    return alpha * kd + (1 - alpha) * ce

student = SmallNet(NUM_CLASSES).to(DEVICE)
optimizer = optim.SGD(student.parameters(), lr=LR, momentum=0.9, weight_decay=5e-4)
scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=KD_EPOCHS)
scaler = GradScaler(enabled=torch.cuda.is_available())

best_val, best_state = 0.0, None
start = time.time()
for epoch in range(KD_EPOCHS):
    student.train()
    total_correct, total_samples = 0, 0
    for x, y in train_loader:
        x, y = x.to(DEVICE, non_blocking=True), y.to(DEVICE, non_blocking=True)
        with torch.no_grad():
            logits_t = teacher(x)
        optimizer.zero_grad(set_to_none=True)
        with autocast(enabled=torch.cuda.is_available()):
            logits_s = student(x)
            loss = kd_loss_vanilla(logits_s, logits_t, y, T=T, alpha=ALPHA)
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()
        total_correct += (logits_s.argmax(1) == y).sum().item()
        total_samples += y.size(0)
    scheduler.step()
    train_acc_epoch = total_correct / max(1, total_samples)
    val_acc_epoch, _ = evaluate(student, val_loader, DEVICE)
    if val_acc_epoch > best_val:
        best_val = val_acc_epoch
        best_state = copy.deepcopy(student.state_dict())
        tag = " (best)"
    else:
        tag = ""
    print(f"Epoch {epoch+1}/{KD_EPOCHS} - train_acc: {train_acc_epoch:.4f} - val_acc: {val_acc_epoch:.4f}{tag}")

elapsed = time.time() - start
if best_state is not None:
    student.load_state_dict(best_state)

# Save student checkpoint for method 1
os.makedirs(CKPT_DIR, exist_ok=True)
ckpt_path = os.path.join(CKPT_DIR, "kd_student_v1.pth")
torch.save(student.state_dict(), ckpt_path)
print(f"Saved student checkpoint (v1) to {ckpt_path}")

train_acc_final, _ = evaluate(student, train_loader, DEVICE)
val_acc_final, _ = evaluate(student, val_loader, DEVICE)
test_acc_final, _ = evaluate(student, test_loader, DEVICE)
print({
    "method": "Vanilla KD",
    "train_time_sec": round(elapsed, 2),
    "train_acc": round(train_acc_final, 4),
    "val_acc": round(val_acc_final, 4),
    "test_acc": round(test_acc_final, 4)
})

Epoch 1/10 - train_acc: 0.4361 - val_acc: 0.5826 (best)
Epoch 2/10 - train_acc: 0.6066 - val_acc: 0.6080 (best)
Epoch 3/10 - train_acc: 0.6764 - val_acc: 0.6750 (best)
Epoch 4/10 - train_acc: 0.7204 - val_acc: 0.7090 (best)
Epoch 5/10 - train_acc: 0.7541 - val_acc: 0.6478
Epoch 6/10 - train_acc: 0.7768 - val_acc: 0.7130 (best)
Epoch 7/10 - train_acc: 0.7994 - val_acc: 0.7596 (best)
Epoch 8/10 - train_acc: 0.8232 - val_acc: 0.8162 (best)
Epoch 9/10 - train_acc: 0.8387 - val_acc: 0.8324 (best)
Epoch 10/10 - train_acc: 0.8541 - val_acc: 0.8432 (best)
Saved student checkpoint (v1) to ./checkpoints/kd_student_v1.pth
{'method': 'Vanilla KD', 'train_time_sec': 137.3, 'train_acc': 0.8638, 'val_acc': 0.8432, 'test_acc': 0.8413}


## Phương pháp 2: Hard-label Distillation

**Ý tưởng:** Biến dự đoán của teacher thành ''nhãn cứng'' giả (pseudo-label) rồi kết hợp với nhãn thật.

Với một mẫu $(x, y)$:
- $z_t$ = logits của teacher, nhãn giả của teacher là:  
  $$
  \tilde{y} = \arg\max\limits_{c}\; z_{t,c}.
  $$
- $z_s$ = logits của student.

Hàm mất mát kết hợp hai cross-entropy:
$$
\mathcal{L}_{\mathrm{CEE}} = \beta\,\mathrm{CE}(z_s, \tilde{y}) + (1-\beta)\, \mathrm{CE}(z_s, y).
$$
- Thành phần $\mathrm{CE}(z_s, y)$ giúp bám sát nhãn thật.
- Thành phần $\mathrm{CE}(z_s, \tilde{y})$ ép student bắt chước dự đoán mạnh nhất của teacher (hard target).
- $\beta$ điều chỉnh mức tin cậy vào teacher.

Khác với Vanilla KD, CEE không dùng nhiệt độ hay phân phối mềm; nó chỉ dựa vào lớp có xác suất cao nhất của teacher.

In [5]:
# %% Train: Hard-label Distillation (CEE)
BETA = 0.7
LR = 0.1

# CEE loss function
def cee_loss(logits_s, logits_t, y, beta=0.7):
    # CEE: combine CE(student, y) and CE(student, teacher_label)
    ce_y = F.cross_entropy(logits_s, y)
    pseudo = logits_t.argmax(dim=1)
    ce_t = F.cross_entropy(logits_s, pseudo)
    return beta * ce_t + (1 - beta) * ce_y

student = SmallNet(NUM_CLASSES).to(DEVICE)
optimizer = optim.SGD(student.parameters(), lr=LR, momentum=0.9, weight_decay=5e-4)
scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=KD_EPOCHS)
scaler = GradScaler(enabled=torch.cuda.is_available())

best_val, best_state = 0.0, None
start = time.time()
for epoch in range(KD_EPOCHS):
    student.train()
    total_correct, total_samples = 0, 0
    for x, y in train_loader:
        x, y = x.to(DEVICE, non_blocking=True), y.to(DEVICE, non_blocking=True)
        with torch.no_grad():
            logits_t = teacher(x)
        optimizer.zero_grad(set_to_none=True)
        with autocast(enabled=torch.cuda.is_available()):
            logits_s = student(x)
            loss = cee_loss(logits_s, logits_t, y, beta=BETA)
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()
        total_correct += (logits_s.argmax(1) == y).sum().item()
        total_samples += y.size(0)
    scheduler.step()
    train_acc_epoch = total_correct / max(1, total_samples)
    val_acc_epoch, _ = evaluate(student, val_loader, DEVICE)
    if val_acc_epoch > best_val:
        best_val = val_acc_epoch
        best_state = copy.deepcopy(student.state_dict())
        tag = " (best)"
    else:
        tag = ""
    print(f"Epoch {epoch+1}/{KD_EPOCHS} - train_acc: {train_acc_epoch:.4f} - val_acc: {val_acc_epoch:.4f}{tag}")

elapsed = time.time() - start
if best_state is not None:
    student.load_state_dict(best_state)

# Save student checkpoint for method 2
os.makedirs(CKPT_DIR, exist_ok=True)
ckpt_path = os.path.join(CKPT_DIR, "kd_student_v2.pth")
torch.save(student.state_dict(), ckpt_path)
print(f"Saved student checkpoint (v2) to {ckpt_path}")

train_acc_final, _ = evaluate(student, train_loader, DEVICE)
val_acc_final, _ = evaluate(student, val_loader, DEVICE)
test_acc_final, _ = evaluate(student, test_loader, DEVICE)
print({
    "method": "Hard-label Distillation (CEE)",
    "train_time_sec": round(elapsed, 2),
    "train_acc": round(train_acc_final, 4),
    "val_acc": round(val_acc_final, 4),
    "test_acc": round(test_acc_final, 4)
})

Epoch 1/10 - train_acc: 0.4412 - val_acc: 0.5120 (best)
Epoch 2/10 - train_acc: 0.6141 - val_acc: 0.5824 (best)
Epoch 3/10 - train_acc: 0.6813 - val_acc: 0.6470 (best)
Epoch 4/10 - train_acc: 0.7173 - val_acc: 0.6382
Epoch 5/10 - train_acc: 0.7433 - val_acc: 0.6976 (best)
Epoch 6/10 - train_acc: 0.7716 - val_acc: 0.6708
Epoch 7/10 - train_acc: 0.7925 - val_acc: 0.7250 (best)
Epoch 8/10 - train_acc: 0.8150 - val_acc: 0.7822 (best)
Epoch 9/10 - train_acc: 0.8374 - val_acc: 0.8288 (best)
Epoch 10/10 - train_acc: 0.8497 - val_acc: 0.8496 (best)
Saved student checkpoint (v2) to ./checkpoints/kd_student_v2.pth
{'method': 'Hard-label Distillation (CEE)', 'train_time_sec': 136.03, 'train_acc': 0.8627, 'val_acc': 0.8496, 'test_acc': 0.8363}


## Phương pháp 3: Feature Distillation

**Ý tưởng:** Khớp đặc trưng trung gian (feature maps) giữa teacher và student để student học biểu diễn gần giống teacher.

Ký hiệu theo minibatch:
- $F_t \in \mathbb{R}^{N\times C_t\times H_t\times W_t}$: đặc trưng của teacher ở một tầng (ví dụ layer cuối conv).
- $F_s \in \mathbb{R}^{N\times C_s\times H_s\times W_s}$: đặc trưng của student ở tầng tương ứng.
- Do số kênh/không gian khác nhau, ta dùng một “đầu chiếu” $g_s(\cdot)$ để đưa $F_s$ về không gian của $F_t$ và/hoặc nội suy không gian về cùng kích thước.
- Chuẩn hoá theo kênh để giảm lệch về biên độ:  
  $$
  \widehat{F} = \frac{F}{\sqrt{\sum\limits_{c} F_c^2}+\varepsilon}.
  $$

Hàm mất mát tổng hợp:
$$
\mathcal{L} = \mathrm{CE}(z_s, y) + \lambda_{\mathrm{feat}}\, \big\|\, \widehat{g_s(F_s)} - \widehat{F_t} \,\big\|_2^2.
$$
- CE đảm bảo mục tiêu phân loại; 
- MSE giữa đặc trưng đã chuẩn hoá giúp student học cấu trúc biểu diễn của teacher.
- Trong thực nghiệm thường tăng dần hệ số $\lambda_{\mathrm{feat}}$ (ramp-up) để ổn định huấn luyện.

In [6]:
# %% Train: Feature Distillation (penultimate features)
# We'll tap student at its final conv output (before average pool) and teacher at layer4 output.
LR = 0.1
W_FEAT = 50.0  # lower weight; we will ramp it up during training

# Keep student architecture unchanged; just expose features
class StudentExposeFeat(SmallNet):
    def forward(self, x):
        f = self.features(x)  # N,128,4,4
        logits = self.classifier(f)
        return logits, f

student = StudentExposeFeat(NUM_CLASSES).to(DEVICE)

# Teacher hook at last conv block (layer4)
if not hasattr(teacher, 'layer4'):
    raise RuntimeError("Teacher doesn't have layer4; choose a ResNet-like teacher.")

hook_t = FeatureHook(teacher.layer4)

# Separate projection head (FitNet-style regressor) outside the student
proj_s = nn.Sequential(
    nn.Conv2d(128, 512, kernel_size=1, bias=False),
    nn.BatchNorm2d(512),
    nn.ReLU(inplace=True),
).to(DEVICE)

optimizer = optim.SGD(list(student.parameters()) + list(proj_s.parameters()), lr=LR, momentum=0.9, weight_decay=5e-4)
scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=KD_EPOCHS)
scaler = GradScaler(enabled=torch.cuda.is_available())

# channel-wise L2 normalization helper
def norm_channel(x: torch.Tensor, eps: float = 1e-6) -> torch.Tensor:
    return x / (x.pow(2).sum(dim=1, keepdim=True).sqrt().clamp_min(eps))

best_val, best_state = 0.0, None
start = time.time()
for epoch in range(KD_EPOCHS):
    student.train(); proj_s.train()
    total_correct, total_samples = 0, 0
    # linear ramp-up for feature loss (first half epochs)
    ramp = min(1.0, (epoch + 1) / max(1, KD_EPOCHS // 2))
    feat_w = W_FEAT * ramp
    for x, y in train_loader:
        x, y = x.to(DEVICE, non_blocking=True), y.to(DEVICE, non_blocking=True)
        with torch.no_grad():
            _ = teacher(x)  # populate hook_t.feat
            f_t = hook_t.feat
        optimizer.zero_grad(set_to_none=True)
        with autocast(enabled=torch.cuda.is_available()):
            logits_s, f_s = student(x)
            f_s_proj = proj_s(f_s)  # project to 512 channels
            # Align spatial dims if needed using adaptive avgpool to teacher spatial size
            if f_t is None:
                raise RuntimeError("Teacher feature hook not captured.")
            if f_s_proj.shape[-2:] != f_t.shape[-2:]:
                f_s_resized = F.adaptive_avg_pool2d(f_s_proj, f_t.shape[-2:])
            else:
                f_s_resized = f_s_proj
            # Normalize features along channel dimension to reduce scale mismatch
            nf_s = norm_channel(f_s_resized)
            nf_t = norm_channel(f_t)
            loss_ce = F.cross_entropy(logits_s, y)
            loss_feat = F.mse_loss(nf_s, nf_t.detach())
            loss = loss_ce + feat_w * loss_feat
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()
        total_correct += (logits_s.argmax(1) == y).sum().item()
        total_samples += y.size(0)
    scheduler.step()
    train_acc_epoch = total_correct / max(1, total_samples)
    student.eval(); proj_s.eval()
    val_acc_epoch, _ = evaluate(student, val_loader, DEVICE)
    if val_acc_epoch > best_val:
        best_val = val_acc_epoch
        best_state = {
            'student': copy.deepcopy(student.state_dict()),
            'proj_s': copy.deepcopy(proj_s.state_dict()),
        }
        tag = " (best)"
    else:
        tag = ""
    print(f"Epoch {epoch+1}/{KD_EPOCHS} - train_acc: {train_acc_epoch:.4f} - val_acc: {val_acc_epoch:.4f} - feat_w: {feat_w:.1f}{tag}")

elapsed = time.time() - start
if best_state is not None:
    student.load_state_dict(best_state['student'])
    proj_s.load_state_dict(best_state['proj_s'])

# Save student checkpoint for method 3
os.makedirs(CKPT_DIR, exist_ok=True)
ckpt_path = os.path.join(CKPT_DIR, "kd_student_v3.pth")
torch.save(student.state_dict(), ckpt_path)
print(f"Saved student checkpoint (v3) to {ckpt_path}")

train_acc_final, _ = evaluate(student, train_loader, DEVICE)
val_acc_final, _ = evaluate(student, val_loader, DEVICE)
test_acc_final, _ = evaluate(student, test_loader, DEVICE)
print({
    "method": "Feature Distillation (MSE)",
    "train_time_sec": round(elapsed, 2),
    "train_acc": round(train_acc_final, 4),
    "val_acc": round(val_acc_final, 4),
    "test_acc": round(test_acc_final, 4)
})

Epoch 1/10 - train_acc: 0.4300 - val_acc: 0.5252 - feat_w: 10.0 (best)
Epoch 2/10 - train_acc: 0.6136 - val_acc: 0.5818 - feat_w: 20.0 (best)
Epoch 3/10 - train_acc: 0.6775 - val_acc: 0.5326 - feat_w: 30.0
Epoch 4/10 - train_acc: 0.7166 - val_acc: 0.6638 - feat_w: 40.0 (best)
Epoch 5/10 - train_acc: 0.7431 - val_acc: 0.7366 - feat_w: 50.0 (best)
Epoch 6/10 - train_acc: 0.7678 - val_acc: 0.7694 - feat_w: 50.0 (best)
Epoch 7/10 - train_acc: 0.7936 - val_acc: 0.7200 - feat_w: 50.0
Epoch 8/10 - train_acc: 0.8144 - val_acc: 0.8056 - feat_w: 50.0 (best)
Epoch 9/10 - train_acc: 0.8353 - val_acc: 0.8348 - feat_w: 50.0 (best)
Epoch 10/10 - train_acc: 0.8498 - val_acc: 0.8494 - feat_w: 50.0 (best)
Saved student checkpoint (v3) to ./checkpoints/kd_student_v3.pth
{'method': 'Feature Distillation (MSE)', 'train_time_sec': 143.98, 'train_acc': 0.8634, 'val_acc': 0.8494, 'test_acc': 0.8349}


## Phương pháp 4: Attention Transfer

**Ý tưởng:** Khớp toàn bộ feature, AT khớp "bản đồ chú ý" (attention map) – độ mạnh tổng hợp theo kênh ở từng vị trí không gian.

Cho feature $F \in \mathbb{R}^{N\times C\times H\times W}$, bản đồ chú ý $A \in \mathbb{R}^{N\times H\times W}$ được định nghĩa (một biến thể hay dùng):
$$
A = \frac{\tfrac{1}{C}\sum\limits_{c=1}^C F_c^2}{\left\|\, \tfrac{1}{C}\sum\limits_{c=1}^C F_c^2 \right\|_2 + \varepsilon}.
$$

Với $A_s, A_t$ lần lượt của student và teacher (đã chuẩn hoá), hàm mất mát AT:
$$
\mathcal{L}_{\mathrm{AT}} = \lambda_{\mathrm{AT}}\, \big\| A_s - A_t \big\|_2^2.
$$
Tổng mất mát:
$$
\mathcal{L} = \mathrm{CE}(z_s, y) + \mathcal{L}_{\mathrm{AT}}.
$$

AT truyền "nơi nào quan trọng" trong ảnh theo teacher. Student học tập trung vào vùng hữu ích thay vì khớp mọi chi tiết của feature.

In [7]:
# %% Train: Attention Transfer
LR = 0.05
W_AT = 250.0

# Attention Transfer loss function
def attention_transfer_loss(f_s, f_t, w=1.0, eps=1e-6):
    # Attention Transfer (Zagoruyko & Komodakis): match normalized spatial attention maps
    def att_map(f):
        # f: N, C, H, W -> N, H, W
        am = f.pow(2).mean(dim=1)
        am = am / (am.flatten(1).norm(p=2, dim=1, keepdim=True).clamp_min(eps).view(-1,1,1))
        return am
    a_s, a_t = att_map(f_s), att_map(f_t)
    return w * F.mse_loss(a_s, a_t.detach())

class StudentWithFeatAT(SmallNet):
    def forward(self, x):
        f = self.features(x)
        logits = self.classifier(f)
        return logits, f

student = StudentWithFeatAT(NUM_CLASSES).to(DEVICE)

hook_t = FeatureHook(teacher.layer4)
optimizer = optim.SGD(student.parameters(), lr=LR, momentum=0.9, weight_decay=5e-4)
scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=KD_EPOCHS)
scaler = GradScaler(enabled=torch.cuda.is_available())

best_val, best_state = 0.0, None
start = time.time()
for epoch in range(KD_EPOCHS):
    student.train()
    total_correct, total_samples = 0, 0
    for x, y in train_loader:
        x, y = x.to(DEVICE, non_blocking=True), y.to(DEVICE, non_blocking=True)
        with torch.no_grad():
            _ = teacher(x)
            f_t = hook_t.feat
        optimizer.zero_grad(set_to_none=True)
        with autocast(enabled=torch.cuda.is_available()):
            logits_s, f_s = student(x)
            if f_t is None:
                raise RuntimeError("Teacher feature hook not captured.")
            if f_s.shape[-2:] != f_t.shape[-2:]:
                f_s_resized = F.adaptive_avg_pool2d(f_s, f_t.shape[-2:])
            else:
                f_s_resized = f_s
            loss = F.cross_entropy(logits_s, y) + attention_transfer_loss(f_s_resized, f_t, w=W_AT)
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()
        total_correct += (logits_s.argmax(1) == y).sum().item()
        total_samples += y.size(0)
    scheduler.step()
    train_acc_epoch = total_correct / max(1, total_samples)
    val_acc_epoch, _ = evaluate(student, val_loader, DEVICE)
    if val_acc_epoch > best_val:
        best_val = val_acc_epoch
        best_state = copy.deepcopy(student.state_dict())
        tag = " (best)"
    else:
        tag = ""
    print(f"Epoch {epoch+1}/{KD_EPOCHS} - train_acc: {train_acc_epoch:.4f} - val_acc: {val_acc_epoch:.4f}{tag}")

elapsed = time.time() - start
if best_state is not None:
    student.load_state_dict(best_state)

# Save student checkpoint for method 4
os.makedirs(CKPT_DIR, exist_ok=True)
ckpt_path = os.path.join(CKPT_DIR, "kd_student_v4.pth")
torch.save(student.state_dict(), ckpt_path)
print(f"Saved student checkpoint (v4) to {ckpt_path}")

train_acc_final, _ = evaluate(student, train_loader, DEVICE)
val_acc_final, _ = evaluate(student, val_loader, DEVICE)
test_acc_final, _ = evaluate(student, test_loader, DEVICE)
print({
    "method": "Attention Transfer",
    "train_time_sec": round(elapsed, 2),
    "train_acc": round(train_acc_final, 4),
    "val_acc": round(val_acc_final, 4),
    "test_acc": round(test_acc_final, 4)
})

Epoch 1/10 - train_acc: 0.3885 - val_acc: 0.4950 (best)
Epoch 2/10 - train_acc: 0.5659 - val_acc: 0.5994 (best)
Epoch 3/10 - train_acc: 0.6412 - val_acc: 0.5738
Epoch 4/10 - train_acc: 0.6855 - val_acc: 0.6546 (best)
Epoch 5/10 - train_acc: 0.7167 - val_acc: 0.6630 (best)
Epoch 6/10 - train_acc: 0.7469 - val_acc: 0.7638 (best)
Epoch 7/10 - train_acc: 0.7724 - val_acc: 0.7658 (best)
Epoch 8/10 - train_acc: 0.7917 - val_acc: 0.7810 (best)
Epoch 9/10 - train_acc: 0.8083 - val_acc: 0.8066 (best)
Epoch 10/10 - train_acc: 0.8202 - val_acc: 0.8186 (best)
Saved student checkpoint (v4) to ./checkpoints/kd_student_v4.pth
{'method': 'Attention Transfer', 'train_time_sec': 142.78, 'train_acc': 0.833, 'val_acc': 0.8186, 'test_acc': 0.8073}


## Phương pháp 5: Logit Matching

**Ý tưởng:** Ép vector logit của student gần với logit của teacher bằng tổn thất $L_2$, đồng thời giữ CE theo nhãn thật.

Với một mẫu $(x, y)$:
- $z_t, z_s \in \mathbb{R}^C$ là logits (trước softmax) của teacher và student.
- Tổn thất logit matching:
  $$
  \mathcal{L}_{\mathrm{logit}} = \big\| z_s - z_t \big\|_2^2.
  $$
- Tổng mất mát:
  $$
  \mathcal{L} = \mathrm{CE}(z_s, y) + \lambda\, \mathcal{L}_{\mathrm{logit}}.
  $$

Khác Vanilla KD (dùng KL trên phân phối mềm), cách này làm việc trực tiếp ở không gian logit, thường đơn giản và ổn định nhưng có thể kém nhạy với cấu trúc phân phối so với KD dùng temperature.

In [8]:
# %% Train: Logit Matching (L2) + CE
LR = 0.1
W_LOGIT = 1.0

student = SmallNet(NUM_CLASSES).to(DEVICE)
optimizer = optim.SGD(student.parameters(), lr=LR, momentum=0.9, weight_decay=5e-4)
scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=KD_EPOCHS)
scaler = GradScaler(enabled=torch.cuda.is_available())

best_val, best_state = 0.0, None
start = time.time()
for epoch in range(KD_EPOCHS):
    student.train()
    total_correct, total_samples = 0, 0
    for x, y in train_loader:
        x, y = x.to(DEVICE, non_blocking=True), y.to(DEVICE, non_blocking=True)
        with torch.no_grad():
            logits_t = teacher(x)
        optimizer.zero_grad(set_to_none=True)
        with autocast(enabled=torch.cuda.is_available()):
            logits_s = student(x)
            loss = F.cross_entropy(logits_s, y) + W_LOGIT * F.mse_loss(logits_s, logits_t)
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()
        total_correct += (logits_s.argmax(1) == y).sum().item()
        total_samples += y.size(0)
    scheduler.step()
    train_acc_epoch = total_correct / max(1, total_samples)
    val_acc_epoch, _ = evaluate(student, val_loader, DEVICE)
    if val_acc_epoch > best_val:
        best_val = val_acc_epoch
        best_state = copy.deepcopy(student.state_dict())
        tag = " (best)"
    else:
        tag = ""
    print(f"Epoch {epoch+1}/{KD_EPOCHS} - train_acc: {train_acc_epoch:.4f} - val_acc: {val_acc_epoch:.4f}{tag}")

elapsed = time.time() - start
if best_state is not None:
    student.load_state_dict(best_state)

# Save student checkpoint for method 5
os.makedirs(CKPT_DIR, exist_ok=True)
ckpt_path = os.path.join(CKPT_DIR, "kd_student_v5.pth")
torch.save(student.state_dict(), ckpt_path)
print(f"Saved student checkpoint (v5) to {ckpt_path}")

train_acc_final, _ = evaluate(student, train_loader, DEVICE)
val_acc_final, _ = evaluate(student, val_loader, DEVICE)
test_acc_final, _ = evaluate(student, test_loader, DEVICE)
print({
    "method": "Logit Matching (L2) + CE",
    "train_time_sec": round(elapsed, 2),
    "train_acc": round(train_acc_final, 4),
    "val_acc": round(val_acc_final, 4),
    "test_acc": round(test_acc_final, 4)
})

Epoch 1/10 - train_acc: 0.3546 - val_acc: 0.4810 (best)
Epoch 2/10 - train_acc: 0.5628 - val_acc: 0.6238 (best)
Epoch 3/10 - train_acc: 0.6456 - val_acc: 0.6496 (best)
Epoch 4/10 - train_acc: 0.7034 - val_acc: 0.5976
Epoch 5/10 - train_acc: 0.7427 - val_acc: 0.7186 (best)
Epoch 6/10 - train_acc: 0.7744 - val_acc: 0.7298 (best)
Epoch 7/10 - train_acc: 0.7980 - val_acc: 0.7794 (best)
Epoch 8/10 - train_acc: 0.8202 - val_acc: 0.7968 (best)
Epoch 9/10 - train_acc: 0.8380 - val_acc: 0.8304 (best)
Epoch 10/10 - train_acc: 0.8542 - val_acc: 0.8510 (best)
Saved student checkpoint (v5) to ./checkpoints/kd_student_v5.pth
{'method': 'Logit Matching (L2) + CE', 'train_time_sec': 135.97, 'train_acc': 0.8618, 'val_acc': 0.851, 'test_acc': 0.8385}


## Phương pháp 6: Focal Knowledge Distillation

**Ý tưởng:** Nhấn mạnh các mẫu/nhãn mà student còn khó (sai hoặc xác suất thấp ở lớp đúng) khi khớp với phân phối của teacher. Thay vì đối xử đồng đều như KL thông thường, Focal KD tái trọng số từng mẫu/lớp theo độ khó, giúp giảm nhiễu ở các trường hợp teacher quá tự tin và tăng cường học ở các trường hợp student còn yếu.

Ký hiệu cho một mẫu $(x, y)$:
- $z_t, z_s \in \mathbb{R}^C$: logits của teacher và student.
- Nhiệt độ $\tau>0$ làm mềm phân phối:
  $$
  p_t^{(\tau)} = \mathrm{softmax}\!\left(\frac{z_t}{\tau}\right),\quad
  p_s^{(\tau)} = \mathrm{softmax}\!\left(\frac{z_s}{\tau}\right).
  $$

Trọng số tiêu điểm (focal) cho từng lớp $c$ phụ thuộc vào độ khó theo student:
- Độ khó ở lớp $c$: $d_c = 1 - p_s^{(\tau)}(c)$.
- Trọng số:
  $$
  w_c = d_c^{\gamma} = \big(1 - p_s^{(\tau)}(c)\big)^{\gamma},\qquad \gamma \ge 0.
  $$
- Khi $\gamma$ lớn, mẫu/lớp khó (student dự đoán thấp) sẽ được nhấn mạnh mạnh hơn.

Hàm mất mát Focal KD dùng KL có trọng số:
$$
\mathcal{L}_{\mathrm{FKD}} = \tau^2\, \sum_{c=1}^C w_c\, p_t^{(\tau)}(c)\, \log\frac{p_t^{(\tau)}(c)}{p_s^{(\tau)}(c)}
\;=\; \tau^2\, \mathrm{KL}\big( p_t^{(\tau)} \,\Vert\, p_s^{(\tau)};\, w\big),
$$
trong đó ký hiệu $\mathrm{KL}(\cdot\,\Vert\,\cdot; w)$ là KL được tính với trọng số theo thành phần.

Tổng mất mát huấn luyện:
$$
\mathcal{L} = \mathrm{CE}(z_s, y) + \lambda_{\mathrm{FKD}}\, \mathcal{L}_{\mathrm{FKD}}.
$$
- $\tau$ điều chỉnh độ mềm; nhân $\tau^2$ để cân bằng thang độ như KD cổ điển.
- $\gamma$ điều chỉnh mức độ “tiêu điểm”; $\lambda_{\mathrm{FKD}}$ cân bằng giữa CE và FKD.
- Ưu điểm: tập trung vào những phần student còn yếu, giảm tác động của các thành phần đã dễ; thường cải thiện độ chính xác khi Vanilla KD/DKD kém hiệu quả do mất cân bằng độ khó.

In [9]:
# %% Train: Focal Knowledge Distillation (FKD)
T = 4.0
GAMMA_FKD = 2.0      # focal gamma
LAMBDA_FKD = 1.0     # weight for FKD loss
LR = 0.1


def fkd_loss(logits_s, logits_t, T=4.0, gamma=2.0, eps=1e-12):
    # Probabilities with temperature
    ps = F.softmax(logits_s / T, dim=1)   # N,C
    pt = F.softmax(logits_t / T, dim=1)   # N,C

    # Focal weights: higher weight where student is uncertain (harder)
    w = (1.0 - ps).clamp(min=0.0, max=1.0) ** gamma   # N,C

    # Weighted KL(pt || ps)
    # Compute element-wise: pt * (log pt - log ps) * w, average over batch
    log_ps = (ps + eps).log()
    log_pt = (pt + eps).log()
    kl_elem = w * pt * (log_pt - log_ps)
    loss = kl_elem.sum(dim=1).mean()  # batch mean
    return (T * T) * loss


student = SmallNet(NUM_CLASSES).to(DEVICE)
optimizer = optim.SGD(student.parameters(), lr=LR, momentum=0.9, weight_decay=5e-4)
scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=KD_EPOCHS)
scaler = GradScaler(enabled=torch.cuda.is_available())

best_val, best_state = 0.0, None
start = time.time()
for epoch in range(KD_EPOCHS):
    student.train()
    total_correct, total_samples = 0, 0
    for x, y in train_loader:
        x, y = x.to(DEVICE, non_blocking=True), y.to(DEVICE, non_blocking=True)
        with torch.no_grad():
            logits_t = teacher(x)

        optimizer.zero_grad(set_to_none=True)
        with autocast(enabled=torch.cuda.is_available()):
            logits_s = student(x)
            loss_ce  = F.cross_entropy(logits_s, y)
            loss_kd  = fkd_loss(logits_s, logits_t, T=T, gamma=GAMMA_FKD)
            loss     = loss_ce + LAMBDA_FKD * loss_kd
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()

        total_correct += (logits_s.argmax(1) == y).sum().item()
        total_samples += y.size(0)

    scheduler.step()
    train_acc_epoch = total_correct / max(1, total_samples)
    val_acc_epoch, _ = evaluate(student, val_loader, DEVICE)
    if val_acc_epoch > best_val:
        best_val = val_acc_epoch
        best_state = copy.deepcopy(student.state_dict())
        tag = " (best)"
    else:
        tag = ""
    print(f"Epoch {epoch+1}/{KD_EPOCHS} - train_acc: {train_acc_epoch:.4f} - val_acc: {val_acc_epoch:.4f}{tag}")

elapsed = time.time() - start
if best_state is not None:
    student.load_state_dict(best_state)

# Save student checkpoint for method 6 (FKD) - giữ nguyên tên file để phần đánh giá không cần sửa
os.makedirs(CKPT_DIR, exist_ok=True)
ckpt_path = os.path.join(CKPT_DIR, "kd_student_v6.pth")
torch.save(student.state_dict(), ckpt_path)
print(f"Saved student checkpoint (v6 FKD) to {ckpt_path}")

train_acc_final, _ = evaluate(student, train_loader, DEVICE)
val_acc_final, _ = evaluate(student, val_loader, DEVICE)
test_acc_final, _ = evaluate(student, test_loader, DEVICE)
print({
    "method": "Focal Knowledge Distillation (FKD)",
    "train_time_sec": round(elapsed, 2),
    "train_acc": round(train_acc_final, 4),
    "val_acc": round(val_acc_final, 4),
    "test_acc": round(test_acc_final, 4)
})

Epoch 1/10 - train_acc: 0.3813 - val_acc: 0.4538 (best)
Epoch 2/10 - train_acc: 0.5643 - val_acc: 0.5450 (best)
Epoch 3/10 - train_acc: 0.6486 - val_acc: 0.5306
Epoch 4/10 - train_acc: 0.6991 - val_acc: 0.6566 (best)
Epoch 5/10 - train_acc: 0.7353 - val_acc: 0.6620 (best)
Epoch 6/10 - train_acc: 0.7606 - val_acc: 0.7622 (best)
Epoch 7/10 - train_acc: 0.7887 - val_acc: 0.7702 (best)
Epoch 8/10 - train_acc: 0.8140 - val_acc: 0.7650
Epoch 9/10 - train_acc: 0.8316 - val_acc: 0.8098 (best)
Epoch 10/10 - train_acc: 0.8501 - val_acc: 0.8370 (best)
Saved student checkpoint (v6 FKD) to ./checkpoints/kd_student_v6.pth
{'method': 'Focal Knowledge Distillation (FKD)', 'train_time_sec': 140.84, 'train_acc': 0.8564, 'val_acc': 0.837, 'test_acc': 0.8275}


## Phương pháp 7: Relational Knowledge Distillation

**Ý tưởng:** RKD không khớp trực tiếp đặc trưng/logit mà khớp các “quan hệ” giữa các mẫu trong batch.

Biểu diễn vector hoá (sau pool) và chuẩn hoá $L_2$:  
$$
z_s = \mathrm{norm}(\mathrm{GAP}(F_s)), \quad z_t = \mathrm{norm}(\mathrm{GAP}(F_t)).
$$
với $\mathrm{GAP}$ là global average pooling, $\mathrm{norm}$ là chuẩn hoá $L_2$.

1) Khoảng cách cặp (pairwise distance):
- Ma trận khoảng cách Euclid:  
  $$
  D(z)_{ij} = \| z_i - z_j \|_2.
  $$
- Chuẩn hoá theo trung bình phần tử dương để bất biến tỉ lệ:  
  $$
  \tilde{D}_t = \frac{D(z_t)}{\mathrm{mean}\big( D(z_t)_{ij} : D(z_t)_{ij}>0 \big)}, \quad
  \tilde{D}_s = \frac{D(z_s)}{\mathrm{mean}\big( D(z_s)_{ij} : D(z_s)_{ij}>0 \big)}.
  $$
- Mất mát khoảng cách:  
  $$
  \mathcal{L}_{\mathrm{dist}} = \mathrm{SmoothL1}(\tilde{D}_s, \tilde{D}_t).
  $$

2) Góc bộ ba (triplet angle):
- Với mọi $i, j, k$:  
  $$
  v_{ij} = z_j - z_i,\; v_{ik} = z_k - z_i,\; \cos\angle(jik) = \frac{v_{ij}^\top v_{ik}}{\|v_{ij}\|\,\|v_{ik}\|}.
  $$
- Tập hợp vào tensor $A(z)$ với $A_{i,j,k} = \cos(\angle jik)$.  
  $$
  \mathcal{L}_{\mathrm{angle}} = \mathrm{SmoothL1}\big( A(z_s), A(z_t) \big).
  $$

Tổng mất mát:
$$
\mathcal{L} = \mathrm{CE}(z_s^{\text{logit}}, y) + \lambda_d\, \mathcal{L}_{\mathrm{dist}} + \lambda_a\, \mathcal{L}_{\mathrm{angle}}.
$$
Trong đó $z_s^{\text{logit}}$ là logits cho CE; $z_s, z_t$ cho phần RKD được lấy từ feature đã pool và chuẩn hoá. Cách này truyền cấu trúc hình học của không gian biểu diễn từ teacher sang student.

In [10]:
# %% Train: Relational Knowledge Distillation (RKD)
# Reference: Park et al., CVPR 2019. We implement RKD with both distance and angle losses.
# Idea: match relational structures between samples (pairwise distances and triplet angles) in feature space.

LR = 0.05
W_RKD_DIST = 25.0
W_RKD_ANGLE = 50.0

# Student wrapper to expose features before global pooling
class StudentWithFeatRKD(SmallNet):
    def forward(self, x):
        f = self.features(x)         # N, 128, H, W
        logits = self.classifier(f)  # N, C
        return logits, f

student = StudentWithFeatRKD(NUM_CLASSES).to(DEVICE)

# Teacher feature hook (layer4 output)
if not hasattr(teacher, 'layer4'):
    raise RuntimeError("Teacher doesn't have layer4; choose a ResNet-like teacher.")

hook_t = FeatureHook(teacher.layer4)

# Helper: global-average pool to vectors and L2-normalize
def to_vec_norm(fm: torch.Tensor, eps: float = 1e-6) -> torch.Tensor:
    # fm: N, C, H, W -> N, C
    v = F.adaptive_avg_pool2d(fm, 1).flatten(1)
    v = F.normalize(v, dim=1, eps=eps)
    return v

# RKD: Distance loss (pairwise)
def rkd_distance(z_s: torch.Tensor, z_t: torch.Tensor, eps: float = 1e-6) -> torch.Tensor:
    # pairwise Euclidean distances
    with torch.no_grad():
        # teacher pairwise distances normalized by mean
        d_t = torch.cdist(z_t, z_t, p=2)
        mean_t = d_t[d_t>0].mean().clamp_min(eps)
        d_t = d_t / mean_t
    d_s = torch.cdist(z_s, z_s, p=2)
    mean_s = d_s[d_s>0].mean().clamp_min(eps)
    d_s = d_s / mean_s
    return F.smooth_l1_loss(d_s, d_t)

# RKD: Angle loss (triplet angles)
def rkd_angle(z_s: torch.Tensor, z_t: torch.Tensor, eps: float = 1e-6) -> torch.Tensor:
    # For each anchor i, compute vectors to j and k: v_ij, v_ik, then their angle via cosine
    def angle_matrix(z: torch.Tensor) -> torch.Tensor:
        # z: N, D
        n = z.size(0)
        # compute pairwise differences: v_ij = z_j - z_i -> shape (N,N,D)
        diff = z.unsqueeze(1) - z.unsqueeze(0)
        # normalize along D
        diff = F.normalize(diff, dim=2, eps=eps)
        # cosine between v_ij and v_ik for all (j,k): cos = v_ij · v_ik
        # angle tensor A where A[i,j,k] = cos(angle_jik)
        A = torch.einsum('ijd,ikd->ijk', diff, diff)
        return A
    with torch.no_grad():
        A_t = angle_matrix(z_t)
    A_s = angle_matrix(z_s)
    return F.smooth_l1_loss(A_s, A_t)

optimizer = optim.SGD(student.parameters(), lr=LR, momentum=0.9, weight_decay=5e-4)
scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=KD_EPOCHS)
scaler = GradScaler(enabled=torch.cuda.is_available())

best_val, best_state = 0.0, None
start = time.time()
for epoch in range(KD_EPOCHS):
    student.train()
    total_correct, total_samples = 0, 0
    for x, y in train_loader:
        x, y = x.to(DEVICE, non_blocking=True), y.to(DEVICE, non_blocking=True)
        with torch.no_grad():
            _ = teacher(x)  # populate hook
            f_t = hook_t.feat
        optimizer.zero_grad(set_to_none=True)
        with autocast(enabled=torch.cuda.is_available()):
            logits_s, f_s = student(x)
            # CE term for classification
            loss_ce = F.cross_entropy(logits_s, y)
            # RKD on pooled features
            z_s = to_vec_norm(f_s)
            z_t = to_vec_norm(f_t)
            loss_dist = rkd_distance(z_s, z_t)
            loss_ang = rkd_angle(z_s, z_t)
            loss = loss_ce + W_RKD_DIST * loss_dist + W_RKD_ANGLE * loss_ang
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()
        total_correct += (logits_s.argmax(1) == y).sum().item()
        total_samples += y.size(0)
    scheduler.step()
    train_acc_epoch = total_correct / max(1, total_samples)
    val_acc_epoch, _ = evaluate(student, val_loader, DEVICE)
    if val_acc_epoch > best_val:
        best_val = val_acc_epoch
        best_state = copy.deepcopy(student.state_dict())
        tag = " (best)"
    else:
        tag = ""
    print(f"Epoch {epoch+1}/{KD_EPOCHS} - train_acc: {train_acc_epoch:.4f} - val_acc: {val_acc_epoch:.4f}{tag}")

elapsed = time.time() - start
if best_state is not None:
    student.load_state_dict(best_state)

# Save student checkpoint for method 7 (RKD)
os.makedirs(CKPT_DIR, exist_ok=True)
ckpt_path = os.path.join(CKPT_DIR, "kd_student_v7.pth")
torch.save(student.state_dict(), ckpt_path)
print(f"Saved student checkpoint (v7 RKD) to {ckpt_path}")

train_acc_final, _ = evaluate(student, train_loader, DEVICE)
val_acc_final, _ = evaluate(student, val_loader, DEVICE)
test_acc_final, _ = evaluate(student, test_loader, DEVICE)
print({
    "method": "Relational Knowledge Distillation (RKD)",
    "train_time_sec": round(elapsed, 2),
    "train_acc": round(train_acc_final, 4),
    "val_acc": round(val_acc_final, 4),
    "test_acc": round(test_acc_final, 4)
})

Epoch 1/10 - train_acc: 0.4824 - val_acc: 0.5318 (best)
Epoch 2/10 - train_acc: 0.6400 - val_acc: 0.6478 (best)
Epoch 3/10 - train_acc: 0.6944 - val_acc: 0.6522 (best)
Epoch 4/10 - train_acc: 0.7362 - val_acc: 0.7058 (best)
Epoch 5/10 - train_acc: 0.7631 - val_acc: 0.7214 (best)
Epoch 6/10 - train_acc: 0.7845 - val_acc: 0.7882 (best)
Epoch 7/10 - train_acc: 0.8052 - val_acc: 0.7894 (best)
Epoch 8/10 - train_acc: 0.8208 - val_acc: 0.7898 (best)
Epoch 9/10 - train_acc: 0.8358 - val_acc: 0.8302 (best)
Epoch 10/10 - train_acc: 0.8471 - val_acc: 0.8410 (best)
Saved student checkpoint (v7 RKD) to ./checkpoints/kd_student_v7.pth
{'method': 'Relational Knowledge Distillation (RKD)', 'train_time_sec': 172.05, 'train_acc': 0.8577, 'val_acc': 0.841, 'test_acc': 0.8392}


## Phương pháp 8: Contrastive Representation Distillation

**Ý tưởng:** CRD khớp biểu diễn thông qua mục tiêu tương phản (contrastive) kiểu InfoNCE giữa nhúng (embedding) của student và teacher.

Giả sử có $N$ mẫu trong batch. Sau khi chiếu về cùng không gian bởi hai đầu chiếu học được $h_s(\cdot), h_t(\cdot)$:
- $e_s = h_s(F_s) \in \mathbb{R}^{N\times d}$, $e_t = h_t(F_t) \in \mathbb{R}^{N\times d}$.
- Chuẩn hoá $L_2$: $z_s = e_s/\|e_s\|$, $z_t = e_t/\|e_t\|$ (theo từng mẫu).

Với mỗi $i$, ta coi $z_s[i]$ là query và $z_t[i]$ là positive; các $z_t[j]$ ($j\neq i$) là negatives. Logits tương phản:
$$
\ell_{i,j} = \frac{ z_s[i]^\top z_t[j] }{\tau}, \quad j=1,\dots,N.
$$
Tổn thất InfoNCE trong-batch:
$$
\mathcal{L}_{\mathrm{CRD}} = \frac{1}{N} \sum_{i=1}^N \mathrm{CE}\big( \ell_{i,:}, \, j^*=i \big).
$$

Tổng mất mát:
$$
\mathcal{L} = \mathrm{CE}(z_s^{\text{logit}}, y) + \lambda_{\mathrm{CRD}}\, \mathcal{L}_{\mathrm{CRD}}.
$$
- $\tau$ là nhiệt độ điều chỉnh độ sắc của phân bố tương phản.
- Mục tiêu: kéo cặp (student, teacher) của cùng mẫu lại gần, đẩy xa cặp khác mẫu, giúp student học không gian biểu diễn tương đồng teacher.

In [11]:
# %% Train: Contrastive Representation Distillation (CRD, in-batch)
# Ref: Tian et al., ICLR 2020. We implement a lightweight in-batch CRD:
#   - Take penultimate conv features from student and teacher
#   - Project to a shared embedding space with small MLP heads
#   - Use InfoNCE with in-batch negatives (z_s vs z_t of all samples)
#   - Optimize CE + W_CRD * CRD

LR = 0.1
W_CRD = 1.0
TAU = 0.07   # temperature for contrastive logits
EMB_DIM = 128

# Student wrapper to expose features
class StudentWithFeatCRD(SmallNet):
    def forward(self, x):
        f = self.features(x)         # N, 128, H, W
        logits = self.classifier(f)  # N, C
        return logits, f

student = StudentWithFeatCRD(NUM_CLASSES).to(DEVICE)

# Teacher feature hook (layer4)
if not hasattr(teacher, 'layer4'):
    raise RuntimeError("Teacher doesn't have layer4; choose a ResNet-like teacher.")

hook_t = FeatureHook(teacher.layer4)

# Projection heads (trainable). Teacher backbone is frozen, but this head is trainable.
proj_s = nn.Sequential(
    nn.AdaptiveAvgPool2d(1),
    nn.Flatten(),
    nn.Linear(128, EMB_DIM),
    nn.ReLU(inplace=True),
    nn.Linear(EMB_DIM, EMB_DIM),
).to(DEVICE)

proj_t = nn.Sequential(
    nn.AdaptiveAvgPool2d(1),
    nn.Flatten(),
    nn.Linear(512, EMB_DIM),
    nn.ReLU(inplace=True),
    nn.Linear(EMB_DIM, EMB_DIM),
).to(DEVICE)

# Optimizer includes student + projection heads
optimizer = optim.SGD(list(student.parameters()) + list(proj_s.parameters()) + list(proj_t.parameters()),
                      lr=LR, momentum=0.9, weight_decay=5e-4)
scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=KD_EPOCHS)
scaler = GradScaler(enabled=torch.cuda.is_available())

# Contrastive loss (InfoNCE) using in-batch negatives
def crd_loss(emb_s: torch.Tensor, emb_t: torch.Tensor, tau: float = 0.07) -> torch.Tensor:
    # Normalize
    zs = F.normalize(emb_s, dim=1)
    zt = F.normalize(emb_t, dim=1)
    # Similarity logits: N x N
    logits = (zs @ zt.t()) / tau
    targets = torch.arange(logits.size(0), device=logits.device)
    return F.cross_entropy(logits, targets)

best_val, best_state = 0.0, None
start = time.time()
for epoch in range(KD_EPOCHS):
    student.train(); proj_s.train(); proj_t.train()
    total_correct, total_samples = 0, 0
    for x, y in train_loader:
        x, y = x.to(DEVICE, non_blocking=True), y.to(DEVICE, non_blocking=True)
        with torch.no_grad():
            _ = teacher(x)  # populate hook
            f_t = hook_t.feat  # N, 512, Ht, Wt
        optimizer.zero_grad(set_to_none=True)
        with autocast(enabled=torch.cuda.is_available()):
            logits_s, f_s = student(x)  # f_s: N, 128, Hs, Ws
            # Project to embeddings (vector)
            z_s = proj_s(f_s)
            z_t = proj_t(f_t.detach())
            loss_ce = F.cross_entropy(logits_s, y)
            loss_con = crd_loss(z_s, z_t, tau=TAU)
            loss = loss_ce + W_CRD * loss_con
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()
        total_correct += (logits_s.argmax(1) == y).sum().item()
        total_samples += y.size(0)
    scheduler.step()
    # Eval on validation using classifier head only
    student.eval(); proj_s.eval(); proj_t.eval()
    train_acc_epoch = total_correct / max(1, total_samples)
    val_acc_epoch, _ = evaluate(student, val_loader, DEVICE)
    if val_acc_epoch > best_val:
        best_val = val_acc_epoch
        best_state = {
            'student': copy.deepcopy(student.state_dict()),
            'proj_s': copy.deepcopy(proj_s.state_dict()),
            'proj_t': copy.deepcopy(proj_t.state_dict()),
        }
        tag = " (best)"
    else:
        tag = ""
    print(f"Epoch {epoch+1}/{KD_EPOCHS} - train_acc: {train_acc_epoch:.4f} - val_acc: {val_acc_epoch:.4f}{tag}")

elapsed = time.time() - start
if best_state is not None:
    student.load_state_dict(best_state['student'])
    proj_s.load_state_dict(best_state['proj_s'])
    proj_t.load_state_dict(best_state['proj_t'])

# Save student checkpoint for method 8
os.makedirs(CKPT_DIR, exist_ok=True)
ckpt_path = os.path.join(CKPT_DIR, "kd_student_v8.pth")
torch.save(student.state_dict(), ckpt_path)
print(f"Saved student checkpoint (v8) to {ckpt_path}")

train_acc_final, _ = evaluate(student, train_loader, DEVICE)
val_acc_final, _ = evaluate(student, val_loader, DEVICE)
test_acc_final, _ = evaluate(student, test_loader, DEVICE)
print({
    "method": "Contrastive Representation Distillation (CRD)",
    "train_time_sec": round(elapsed, 2),
    "train_acc": round(train_acc_final, 4),
    "val_acc": round(val_acc_final, 4),
    "test_acc": round(test_acc_final, 4)
})

Epoch 1/10 - train_acc: 0.4197 - val_acc: 0.4888 (best)
Epoch 2/10 - train_acc: 0.6045 - val_acc: 0.5758 (best)
Epoch 3/10 - train_acc: 0.6753 - val_acc: 0.6390 (best)
Epoch 4/10 - train_acc: 0.7229 - val_acc: 0.6842 (best)
Epoch 5/10 - train_acc: 0.7567 - val_acc: 0.7490 (best)
Epoch 6/10 - train_acc: 0.7822 - val_acc: 0.7540 (best)
Epoch 7/10 - train_acc: 0.8062 - val_acc: 0.8048 (best)
Epoch 8/10 - train_acc: 0.8258 - val_acc: 0.8236 (best)
Epoch 9/10 - train_acc: 0.8440 - val_acc: 0.8430 (best)
Epoch 10/10 - train_acc: 0.8518 - val_acc: 0.8522 (best)
Saved student checkpoint (v8) to ./checkpoints/kd_student_v8.pth
{'method': 'Contrastive Representation Distillation (CRD)', 'train_time_sec': 151.58, 'train_acc': 0.8645, 'val_acc': 0.8522, 'test_acc': 0.843}


## Phương pháp 9: Probabilistic Knowledge Transfer (PKT)

**Ý tưởng:** PKT không khớp trực tiếp giá trị đặc trưng/logit mà khớp “quan hệ xác suất” giữa các mẫu trong cùng batch. Cụ thể, với biểu diễn (embedding) của mỗi mẫu, ta xây dựng phân phối xác suất tương tự (similarity distribution) từ mỗi mẫu đến các mẫu còn lại và buộc student mô phỏng phân phối của teacher. Cách này bảo tồn cấu trúc cục bộ của không gian biểu diễn và ít nhạy cảm với biến đổi biên độ/chuẩn hoá.

Ký hiệu theo minibatch gồm $N$ mẫu, với embedding đã chuẩn hoá $L_2$:
- $z_t = \mathrm{norm}(e_t) \in \mathbb{R}^{N\times d}$, $z_s = \mathrm{norm}(e_s) \in \mathbb{R}^{N\times d}$.
- Hàm tương tự (similarity) dùng cosine: $\mathrm{sim}(u, v) = u^\top v$.
- Với mỗi mẫu $i$, định nghĩa phân phối xác suất trên các mẫu khác bằng softmax nhiệt độ $\sigma$ (khác với nhiệt độ KD trên logits):
  $$
  p_t(j\,|\,i) \,=\, \frac{\exp\big(\mathrm{sim}(z_t[i], z_t[j]) / \sigma\big)}{\sum\limits_{k\ne i} \exp\big(\mathrm{sim}(z_t[i], z_t[k]) / \sigma\big)},\quad j\ne i,
  $$
  và tương tự $p_s(j\,|\,i)$ từ $z_s$.

Mất mát PKT là KL trung bình qua tất cả điều kiện $i$ (bỏ qua phần tử tự so sánh $j=i$):
$$
\mathcal{L}_{\mathrm{PKT}} \,=\, \frac{1}{N}\sum_{i=1}^N \mathrm{KL}\big( p_t(\cdot\,|\,i) \,\Vert\, p_s(\cdot\,|\,i) \big).
$$

Tổng mất mát huấn luyện:
$$
\mathcal{L} = \mathrm{CE}(z_s^{\text{logit}}, y) + \lambda_{\mathrm{PKT}}\, \mathcal{L}_{\mathrm{PKT}}.
$$
- $\sigma$ điều chỉnh độ sắc của phân phối tương tự; $\lambda_{\mathrm{PKT}}$ cân bằng với CE.
- Khác RKD (dựa khoảng cách/góc), PKT dùng xác suất tương tự có chuẩn hoá theo từng gốc $i$, nên bền vững hơn với thay đổi tỉ lệ, và nhấn mạnh quan hệ cục bộ trong batch.

In [12]:
# %% Train: Probabilistic Knowledge Transfer (PKT)
LR = 0.1
SIGMA_PKT = 0.1      # temperature for similarity softmax
LAMBDA_PKT = 2.0     # weight for PKT loss

# We need embeddings for teacher and student. We'll reuse the penultimate conv feature then GAP + linear head to produce logits.
# Wrap SmallNet to expose an embedding vector per sample (after GAP, before final FC) and logits.
class SmallNetWithEmbed(SmallNet):
    def forward(self, x):
        f = self.features(x)              # N,C,H,W (e.g., 128x4x4)
        gap = F.adaptive_avg_pool2d(f, 1).flatten(1)  # N,C
        logits = self.classifier(f)       # original classifier may use conv+pool inside
        return logits, gap                # logits (for CE), gap as embedding e_s

# Teacher embed: build a small head to map teacher features to a vector; fallback to logits space if needed.
# Try to tap a common block (layer4) if exists; otherwise use logits as embedding.

def build_teacher_embedder(teacher_model: nn.Module, out_dim: int = None):
    # If teacher has layer4, hook it; else return a function that uses logits as embedding
    layer = getattr(teacher_model, 'layer4', None)
    if layer is None:
        # Fallback: use logits as embedding
        def teacher_embed(x):
            logits = teacher_model(x)
            e = logits if out_dim is None else F.linear(logits, torch.eye(logits.shape[1], device=logits.device))
            return logits, e
        return teacher_embed

    class Hook:
        def __init__(self, module):
            self.feat = None
            module.register_forward_hook(self._hook)
        def _hook(self, m, inp, out):
            self.feat = out
    h = Hook(layer)

    # simple projector after hook: GAP only to get vector
    def teacher_embed(x):
        logits = teacher_model(x)  # populates hook
        if h.feat is None:
            raise RuntimeError('Teacher feature hook not captured')
        e = F.adaptive_avg_pool2d(h.feat, 1).flatten(1)
        return logits, e

    return teacher_embed


def pkt_loss(e_s: torch.Tensor, e_t: torch.Tensor, sigma: float = 0.1, eps: float = 1e-12) -> torch.Tensor:
    # L2-normalize embeddings
    z_s = F.normalize(e_s, p=2, dim=1)  # N,d
    z_t = F.normalize(e_t, p=2, dim=1)  # N,d

    # Cosine similarity matrices (exclude self with mask later)
    S_s = z_s @ z_s.t()   # N,N
    S_t = z_t @ z_t.t()   # N,N

    # Mask out diagonal (self-similarity) by setting to -inf before softmax
    N = S_s.size(0)
    mask = torch.eye(N, device=S_s.device).bool()
    S_s = S_s.masked_fill(mask, float('-inf'))
    S_t = S_t.masked_fill(mask, float('-inf'))

    # Row-wise softmax with temperature sigma
    P_s = F.softmax(S_s / sigma, dim=1).clamp_min(eps)  # N,N
    P_t = F.softmax(S_t / sigma, dim=1).clamp_min(eps)  # N,N

    # Row-wise KL: average over rows
    loss = F.kl_div(P_s.log(), P_t, reduction='batchmean')
    return loss


student = SmallNetWithEmbed(NUM_CLASSES).to(DEVICE)
teacher_embed = build_teacher_embedder(teacher)
optimizer = optim.SGD(student.parameters(), lr=LR, momentum=0.9, weight_decay=5e-4)
scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=KD_EPOCHS)
scaler = GradScaler(enabled=torch.cuda.is_available())

best_val, best_state = 0.0, None
start = time.time()
for epoch in range(KD_EPOCHS):
    student.train()
    total_correct, total_samples = 0, 0
    for x, y in train_loader:
        x, y = x.to(DEVICE, non_blocking=True), y.to(DEVICE, non_blocking=True)
        with torch.no_grad():
            logits_t, e_t = teacher_embed(x)

        optimizer.zero_grad(set_to_none=True)
        with autocast(enabled=torch.cuda.is_available()):
            logits_s, e_s = student(x)
            loss_ce  = F.cross_entropy(logits_s, y)
            loss_pkt = pkt_loss(e_s, e_t, sigma=SIGMA_PKT)
            loss = loss_ce + LAMBDA_PKT * loss_pkt
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()

        total_correct += (logits_s.argmax(1) == y).sum().item()
        total_samples += y.size(0)

    scheduler.step()
    train_acc_epoch = total_correct / max(1, total_samples)
    val_acc_epoch, _ = evaluate(student, val_loader, DEVICE)
    if val_acc_epoch > best_val:
        best_val = val_acc_epoch
        best_state = copy.deepcopy(student.state_dict())
        tag = " (best)"
    else:
        tag = ""
    print(f"Epoch {epoch+1}/{KD_EPOCHS} - train_acc: {train_acc_epoch:.4f} - val_acc: {val_acc_epoch:.4f}{tag}")

elapsed = time.time() - start
if best_state is not None:
    student.load_state_dict(best_state)

# Save student checkpoint for method 9 (PKT)
os.makedirs(CKPT_DIR, exist_ok=True)
ckpt_path = os.path.join(CKPT_DIR, "kd_student_v9.pth")
torch.save(student.state_dict(), ckpt_path)
print(f"Saved student checkpoint (v9 PKT) to {ckpt_path}")

train_acc_final, _ = evaluate(student, train_loader, DEVICE)
val_acc_final, _ = evaluate(student, val_loader, DEVICE)
test_acc_final, _ = evaluate(student, test_loader, DEVICE)
print({
    "method": "Probabilistic Knowledge Transfer (PKT)",
    "train_time_sec": round(elapsed, 2),
    "train_acc": round(train_acc_final, 4),
    "val_acc": round(val_acc_final, 4),
    "test_acc": round(test_acc_final, 4)
})

Epoch 1/10 - train_acc: 0.4301 - val_acc: 0.5292 (best)
Epoch 2/10 - train_acc: 0.6020 - val_acc: 0.5908 (best)
Epoch 3/10 - train_acc: 0.6672 - val_acc: 0.6258 (best)
Epoch 4/10 - train_acc: 0.7121 - val_acc: 0.7066 (best)
Epoch 5/10 - train_acc: 0.7436 - val_acc: 0.6634
Epoch 6/10 - train_acc: 0.7680 - val_acc: 0.6818
Epoch 7/10 - train_acc: 0.7897 - val_acc: 0.7814 (best)
Epoch 8/10 - train_acc: 0.8155 - val_acc: 0.8016 (best)
Epoch 9/10 - train_acc: 0.8326 - val_acc: 0.8162 (best)
Epoch 10/10 - train_acc: 0.8486 - val_acc: 0.8538 (best)
Saved student checkpoint (v9 PKT) to ./checkpoints/kd_student_v9.pth
{'method': 'Probabilistic Knowledge Transfer (PKT)', 'train_time_sec': 146.62, 'train_acc': 0.8611, 'val_acc': 0.8538, 'test_acc': 0.8376}


## Đánh giá toàn diện

In [13]:
# %% Comprehensive Evaluation: Load checkpoints and compare models
import os
import numpy as np
import pandas as pd
import torch
import torch.nn.functional as F
from collections import OrderedDict

# Expect checkpoints in CKPT_DIR
ckpt_dir = CKPT_DIR if 'CKPT_DIR' in globals() else './checkpoints'
os.makedirs(ckpt_dir, exist_ok=True)

# Build teacher & student architectures (same as training)
teacher_eval = build_teacher(NUM_CLASSES).to(DEVICE)
student_eval = SmallNet(NUM_CLASSES).to(DEVICE)

# Utility: evaluate metrics over a dataloader
@torch.no_grad()
def collect_metrics(model, loader, device):
    model.eval()
    total, correct, loss_sum = 0, 0, 0.0
    logits_list, labels_list = [], []
    criterion = torch.nn.CrossEntropyLoss(reduction='sum')
    for x, y in loader:
        x, y = x.to(device, non_blocking=True), y.to(device, non_blocking=True)
        out = model(x)
        logits = out[0] if isinstance(out, (tuple, list)) else out
        loss_sum += criterion(logits, y).item()
        correct += (logits.argmax(1) == y).sum().item()
        total += y.size(0)
        logits_list.append(logits.detach().cpu())
        labels_list.append(y.detach().cpu())
    logits_all = torch.cat(logits_list, dim=0)
    labels_all = torch.cat(labels_list, dim=0)
    acc = correct / max(1, total)
    avg_loss = loss_sum / max(1, total)
    return {
        'acc': acc,
        'loss': avg_loss,
        'logits': logits_all,
        'labels': labels_all,
    }

# Expected checkpoints
models = OrderedDict([
    ("teacher", os.path.join(ckpt_dir, "kd_teacher.pth")),
    ("v1_vanilla", os.path.join(ckpt_dir, "kd_student_v1.pth")),
    ("v2_cee", os.path.join(ckpt_dir, "kd_student_v2.pth")),
    ("v3_feat", os.path.join(ckpt_dir, "kd_student_v3.pth")),
    ("v4_at", os.path.join(ckpt_dir, "kd_student_v4.pth")),
    ("v5_logit", os.path.join(ckpt_dir, "kd_student_v5.pth")),
    ("v6_fkd", os.path.join(ckpt_dir, "kd_student_v6.pth")),
    ("v7_rkd", os.path.join(ckpt_dir, "kd_student_v7.pth")),
    ("v8_crd", os.path.join(ckpt_dir, "kd_student_v8.pth")),
    ("v9_pkt", os.path.join(ckpt_dir, "kd_student_v9.pth")),
])

# Load teacher
loaded = {}
if os.path.isfile(models['teacher']):
    teacher_eval.load_state_dict(torch.load(models['teacher'], map_location=DEVICE))
    loaded['teacher'] = teacher_eval
else:
    print(f"[WARN] Teacher checkpoint not found: {models['teacher']}")

# Load students into dict (same arch SmallNet)
for name, path in models.items():
    if name == 'teacher':
        continue
    if os.path.isfile(path):
        m = SmallNet(NUM_CLASSES).to(DEVICE)
        m.load_state_dict(torch.load(path, map_location=DEVICE))
        loaded[name] = m
    else:
        print(f"[WARN] Student checkpoint not found: {path}")

# Metrics to compute
# - acc_test: accuracy on test set
# - loss_test: CE loss on test set
# - ece: Expected Calibration Error (10-bin)
# - agree_t: agreement rate between student and teacher predictions
# - kl_to_t: KL(student || teacher) on test logits (softmax distributions)
# - cos_logits: cosine similarity between student and teacher logits
# - entropy: average predictive entropy (uncertainty)


def expected_calibration_error(probs: torch.Tensor, labels: torch.Tensor, n_bins: int = 10) -> float:
    # probs: N,C ; labels: N
    confidences, predictions = probs.max(dim=1)
    accuracies = predictions.eq(labels)
    bins = torch.linspace(0, 1, steps=n_bins + 1)
    ece = torch.zeros(1)
    for i in range(n_bins):
        in_bin = (confidences > bins[i]) & (confidences <= bins[i + 1]) if i < n_bins - 1 else (confidences > bins[i]) & (confidences <= bins[i + 1])
        prop = in_bin.float().mean()
        if prop.item() > 0:
            acc_bin = accuracies[in_bin].float().mean()
            conf_bin = confidences[in_bin].float().mean()
            ece += torch.abs(conf_bin - acc_bin) * prop
    return ece.item()


def kl_divergence(p_logits: torch.Tensor, q_logits: torch.Tensor, T: float = 1.0) -> float:
    # KL(P||Q) with temperature T
    p = F.log_softmax(p_logits / T, dim=1)
    q = F.softmax(q_logits / T, dim=1)
    return F.kl_div(p, q, reduction='batchmean').item()


def cosine_similarity_logits(a: torch.Tensor, b: torch.Tensor) -> float:
    a_flat = a.flatten(1)
    b_flat = b.flatten(1)
    a_n = F.normalize(a_flat, p=2, dim=1)
    b_n = F.normalize(b_flat, p=2, dim=1)
    return (a_n * b_n).sum(dim=1).mean().item()


def avg_entropy_from_logits(logits: torch.Tensor) -> float:
    probs = F.softmax(logits, dim=1)
    entropy = -(probs * (probs.clamp_min(1e-12).log())).sum(dim=1)
    return entropy.mean().item()

# Evaluate all models
results = []
metrics = [
    'acc', 'loss', 'ece', 'agree_t', 'kl_to_t', 'cos_logits', 'entropy'
]

# First collect teacher outputs
teacher_out = None
if 'teacher' in loaded:
    teacher_out = collect_metrics(loaded['teacher'], test_loader, DEVICE)

for name, model in loaded.items():
    out = collect_metrics(model, test_loader, DEVICE)
    probs = F.softmax(out['logits'], dim=1)
    ece = expected_calibration_error(probs, out['labels'])

    # Comparisons to teacher (only if teacher available and current is not teacher)
    if teacher_out is not None and name != 'teacher':
        agree = (out['logits'].argmax(1) == teacher_out['logits'].argmax(1)).float().mean().item()
        kl = kl_divergence(out['logits'], teacher_out['logits'])
        cos = cosine_similarity_logits(out['logits'], teacher_out['logits'])
    else:
        agree, kl, cos = np.nan, np.nan, np.nan

    ent = avg_entropy_from_logits(out['logits'])

    results.append({
        'model': name,
        'acc': round(out['acc'], 4),
        'loss': round(out['loss'], 4),
        'ece': round(ece, 4),
        'agree_t': round(agree, 4) if not np.isnan(agree) else np.nan,
        'kl_to_t': round(kl, 4) if not np.isnan(kl) else np.nan,
        'cos_logits': round(cos, 4) if not np.isnan(cos) else np.nan,
        'entropy': round(ent, 4),
    })

# Build results DataFrame (sorted by model order above)
df = pd.DataFrame(results)
# Optional: reorder rows to keep teacher first
order = [k for k in models.keys() if k in df['model'].values]
df['order_idx'] = df['model'].apply(lambda m: order.index(m) if m in order else 999)
df = df.sort_values('order_idx').drop(columns=['order_idx']).reset_index(drop=True)

print("\n===== Bảng kết quả (metrics) =====")
print(df.to_string(index=False))

# Ranking table: for each metric, compute rank (best rank = 1).
# For metrics where lower is better (loss, ece, kl, entropy), rank ascending. For higher-better (acc, agree_t, cos), rank descending.
rank_prefs = {
    'acc': 'desc',
    'loss': 'asc',
    'ece': 'asc',
    'agree_t': 'desc',
    'kl_to_t': 'asc',
    'cos_logits': 'desc',
    'entropy': 'asc',
}

rank_df = df.copy()
for col, pref in rank_prefs.items():
    series = rank_df[col]
    if series.isna().all():
        rank_df[col + '_rank'] = np.nan
        continue
    # For NaN values (e.g., teacher comparisons), assign worst rank
    fill_val = series.max() + 1 if pref == 'asc' else series.min() - 1
    series_filled = series.fillna(fill_val)
    ascending = (pref == 'asc')
    rank_df[col + '_rank'] = series_filled.rank(method='min', ascending=ascending)

# Keep only rank columns and model name
rank_cols = ['model'] + [c + '_rank' for c in rank_prefs.keys()]
rank_table = rank_df[rank_cols]
rank_table['avg_rank'] = rank_table[[c for c in rank_cols if c != 'model']].mean(axis=1)

print("\n===== Bảng xếp hạng (rank) =====")
print(rank_table.to_string(index=False))


===== Bảng kết quả (metrics) =====
     model    acc   loss    ece  agree_t  kl_to_t  cos_logits  entropy
   teacher 0.9444 0.2141 0.0318      NaN      NaN         NaN   0.0697
v1_vanilla 0.8413 0.5419 0.0651   0.8533   0.4258      0.8436   0.2650
    v2_cee 0.8363 0.4892 0.0148   0.8482   0.3921      0.7720   0.4351
   v3_feat 0.8349 0.4764 0.0090   0.8458   0.3872      0.7722   0.4488
     v4_at 0.8073 0.5509 0.0097   0.8195   0.4582      0.7437   0.5392
  v5_logit 0.8385 0.5159 0.0572   0.8520   0.4016      0.8748   0.3075
    v6_fkd 0.8275 0.5258 0.0435   0.8398   0.4238      0.8220   0.3720
    v7_rkd 0.8392 0.4861 0.0165   0.8491   0.3986      0.8354   0.4830
    v8_crd 0.8430 0.4591 0.0063   0.8541   0.3654      0.7805   0.4537
    v9_pkt 0.8376 0.4699 0.0071   0.8449   0.3875      0.8224   0.4987

===== Bảng xếp hạng (rank) =====
     model  acc_rank  loss_rank  ece_rank  agree_t_rank  kl_to_t_rank  cos_logits_rank  entropy_rank  avg_rank
   teacher       1.0        1.0       