In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F

# ----------------------------
# Basic 3x3 Convolution
# ----------------------------
def conv3x3(in_planes, out_planes, stride=1):
    return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False)

# ----------------------------
# Basic Residual Block
# ----------------------------
class BasicBlock(nn.Module):
    expansion = 1
    def __init__(self, in_planes, planes, stride=1, downsample=None, is_last=False):
        super().__init__()
        self.is_last = is_last
        self.conv1 = conv3x3(in_planes, planes, stride)
        self.bn1 = nn.BatchNorm2d(planes)
        self.relu = nn.ReLU(inplace=True)
        self.conv2 = conv3x3(planes, planes)
        self.bn2 = nn.BatchNorm2d(planes)
        self.downsample = downsample

    def forward(self, x):
        residual = x
        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)
        out = self.conv2(out)
        out = self.bn2(out)

        if self.downsample is not None:
            residual = self.downsample(x)
        out += residual
        preact = out
        out = F.relu(out)

        if self.is_last:
            return out, preact
        else:
            return out

# ----------------------------
# ResNet CIFAR Modular
# ----------------------------
class ResNet(nn.Module):
    def __init__(self, depth, num_filters=[16,16,32,64], block_name='basicblock', num_classes=100):
        super().__init__()
        assert block_name.lower() == 'basicblock', "Currently only BasicBlock supported"
        assert (depth - 2) % 6 == 0, "Depth must be 6n+2 for BasicBlock"
        n = (depth - 2) // 6

        self.in_planes = num_filters[0]
        self.conv1 = conv3x3(3, num_filters[0])
        self.bn1 = nn.BatchNorm2d(num_filters[0])
        self.relu = nn.ReLU(inplace=True)

        # Residual layers
        self.layer1 = self._make_layer(BasicBlock, num_filters[1], n)
        self.layer2 = self._make_layer(BasicBlock, num_filters[2], n, stride=2)
        self.layer3 = self._make_layer(BasicBlock, num_filters[3], n, stride=2)

        self.avgpool = nn.AvgPool2d(8)
        self.fc = nn.Linear(num_filters[3]*BasicBlock.expansion, num_classes)

        # weight init
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
            elif isinstance(m, nn.BatchNorm2d):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)

    def _make_layer(self, block, planes, blocks, stride=1):
        downsample = None
        if stride != 1 or self.in_planes != planes*block.expansion:
            downsample = nn.Sequential(
                nn.Conv2d(self.in_planes, planes*block.expansion, kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(planes*block.expansion)
            )

        layers = [block(self.in_planes, planes, stride, downsample, is_last=(blocks==1))]
        self.in_planes = planes*block.expansion
        for i in range(1, blocks):
            layers.append(block(self.in_planes, planes, is_last=(i==blocks-1)))
        return nn.Sequential(*layers)

    def forward(self, x, is_feat=False, preact=False):
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        f0 = x

        x, f1_pre = self.layer1(x)
        f1 = x
        x, f2_pre = self.layer2(x)
        f2 = x
        x, f3_pre = self.layer3(x)
        f3 = x

        x = self.avgpool(x)
        x = x.view(x.size(0), -1)
        f4 = x
        x = self.fc(x)

        if is_feat:
            return ([f0,f1,f2,f3,f4], x) if not preact else ([f0,f1_pre,f2_pre,f3_pre,f4], x)
        else:
            return x

# ----------------------------
# Architecture builders
# ----------------------------
def resnet20(num_classes=100): return ResNet(20, num_classes=num_classes)
def resnet32(num_classes=100):
    """ResNet32 matching the standard CIFAR ResNet paper structure."""
    # Use num_filters=[32, 32, 64, 128] to match the checkpoint you want to load
    return ResNet(depth=32, num_filters=[32, 64, 128, 256], block_name='basicblock', num_classes=num_classes)

def resnet56(num_classes=100): return ResNet(56, num_classes=num_classes)
def resnet110(num_classes=100): return ResNet(110, num_classes=num_classes)
def resnet32_basic(num_classes=100): return ResNet(32, num_classes=num_classes)

In [2]:
import torch
from torch.utils.data import DataLoader
from torchvision import datasets, transforms

# ---------------------------
# Device
# ---------------------------
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# ---------------------------
# CIFAR-100 test dataset
# ---------------------------
mean = (0.5071, 0.4867, 0.4408)
std  = (0.2675, 0.2565, 0.2761)

test_tf = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean, std)
])

test_ds = datasets.CIFAR100("./data", train=False, transform=test_tf, download=True)
test_loader = DataLoader(test_ds, batch_size=128, shuffle=False, num_workers=2, pin_memory=True)

# ---------------------------
# Model constructors
# ---------------------------
model_dict = {
    # "ResNet20": resnet20,
    # "ResNet32": resnet32(), # this is the resnet-32x4
    "ResNet56": resnet56(),
    "ResNet110": resnet110()
}

# Path mapping for pretrained weights (adjust paths if needed)
weights_dict = {
    # "ResNet20": "/path/to/resnet20.pth",
    # "ResNet32": "/kaggle/input/resnet32/pytorch/default/1/ckpt_epoch_240.pth",
    "ResNet56":  "/kaggle/input/resnet-56/ckpt_epoch_240.pth",
    "ResNet110": "/kaggle/input/resnet-110/ckpt_epoch_240.pth"
}

# ---------------------------
# Evaluation loop
# ---------------------------
for name, constructor in model_dict.items():
    print(f"Evaluating {name}...")
    model = constructor.to(device)
    
    # Load pretrained weights if available
    weight_path = weights_dict.get(name)
    if weight_path:
        checkpoint = torch.load(weight_path, map_location=device,weights_only=False)
        if 'model' in checkpoint:
            state_dict = checkpoint['model']
        else:
            state_dict = checkpoint
        model.load_state_dict(state_dict)
    
    model.eval()
    correct = total = 0
    with torch.no_grad():
        for imgs, labels in test_loader:
            imgs, labels = imgs.to(device), labels.to(device)
            logits = model(imgs)
            preds = logits.argmax(1)
            correct += (preds == labels).sum().item()
            total += labels.size(0)

    test_acc = 100 * correct / total
    print(f"🎯 {name} Test Accuracy on CIFAR-100: {test_acc:.2f}%\n")


Evaluating ResNet56...
🎯 ResNet56 Test Accuracy on CIFAR-100: 72.41%

Evaluating ResNet110...
🎯 ResNet110 Test Accuracy on CIFAR-100: 74.31%



# Training the student models BASIC Resnet 20 / 32

import os
import math
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader, Subset
from torchvision import datasets, transforms
from torchvision.transforms import AutoAugment, AutoAugmentPolicy, RandomErasing
from tqdm.auto import tqdm
#from torchvision.models import resnet18, resnet34



device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
use_dp = torch.cuda.device_count() > 1
torch.manual_seed(42)
np.random.seed(42)

mean = (0.5071, 0.4867, 0.4408)  # CIFAR-100 mean
std  = (0.2675, 0.2565, 0.2761)  # CIFAR-100 std

batch_size = 64

# Precompute DataLoaders for each resolution
stages = [(r) for r in [(32, 240)]]
dataloader_dict = {}

'''
Copied from the paper as it is.

they are not using any validation sets. Training it on the entire train set!
    train_transform = transforms.Compose([
        transforms.RandomCrop(32, padding=4),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize(mean=mean, std=stdv),
    ])
    test_transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize(mean=mean, std=stdv),
    ])
    
    '''

for resolution, _ in stages:
    train_tf = transforms.Compose([
        transforms.RandomCrop(resolution,padding=4),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize(mean, std),
    ])
    
    train_set = datasets.CIFAR100('./data', train=True, download=False, transform=train_tf)
    
    train_loader = DataLoader(train_set, batch_size=batch_size,
                              shuffle=True, num_workers=0, pin_memory=True)
    dataloader_dict[resolution] = {
        'train': train_loader
    }

# Test loader (fixed resolution)
test_tf = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean, std),
])

test_ds = datasets.CIFAR100('./data', train=False, transform=test_tf)
test_loader = DataLoader(test_ds, batch_size=batch_size, shuffle=False, num_workers=0, pin_memory=True)


# learning_rate is divided by 10 

# ---------------------------
# Test loop
# ---------------------------

def test(Test_model):
    
    Test_model.eval()
    correct_val = total_val = 0
    
    with torch.no_grad():
        for imgs, labels in test_loader:
            imgs, labels = imgs.to(device), labels.to(device)
            logits = Test_model(imgs)
            preds = logits.argmax(1)
            correct_val += (preds == labels).sum().item()
            total_val += labels.size(0)
    test_acc = 100 * correct_val / total_val

    # print(f"Test Acc = {val_acc:.2f}%")
    return test_acc

# ---------------------------
# Training loop
# ---------------------------

def train(model, model_type):
    
    # Loss + optimizer
    lr = 0.05 # as per the paper
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.SGD(model.parameters(), lr=lr, momentum=0.9, weight_decay=5e-4)

    # Step LR schedule: decay at 150, 180, 210 epochs
    scheduler = optim.lr_scheduler.MultiStepLR(optimizer, milestones=[150, 180, 210], gamma=0.1)

    best_val_acc = 0.0
    num_epochs = 240
    
    for res, epochs in stages:
        
        print(f"\n=== Training at resolution {res}px ===")
        tr_loader = dataloader_dict[res]['train']
        
        for e in range(1, epochs+1):
            
            model.train()
            total_loss = 0
            correct = total = 0
        
            for imgs, labels in tqdm(tr_loader, desc=f"Epoch {e}/{num_epochs}"):
                
                imgs, labels = imgs.to(device), labels.to(device)
                optimizer.zero_grad()
                logits = model(imgs)
                loss = criterion(logits, labels)
                loss.backward()
                optimizer.step()
        
                total_loss += loss.item()
                preds = logits.argmax(1)
                correct += (preds == labels).sum().item()
                total += labels.size(0)
        
            train_acc = 100 * correct / total
            scheduler.step()

            test_acc = test(model)
            
            print(f"Epoch {e}: Train Acc = {train_acc:.2f}% Test Accuracy = {test_acc:.2f}%") 
            
            # Save best model
            if test_acc > best_val_acc:
                best_val_acc = test_acc
                torch.save(model.state_dict(), f"resnet{model_type}_student.pth")
                print(f"→ Saved best model at epoch {e} with Test Acc = {train_acc:.2f}%")
    print("✅ Training completed!")

# KD Training Vanilla


In [3]:
import os
import math
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader, Subset
from torchvision import datasets, transforms
from torchvision.transforms import AutoAugment, AutoAugmentPolicy, RandomErasing
from tqdm.auto import tqdm
#from torchvision.models import resnet18, resnet34
from torchvision import datasets, transforms
from torch.utils.data import DataLoader


mean = (0.5071, 0.4867, 0.4408)
std  = (0.2675, 0.2565, 0.2761)
batch_size = 128
num_workers = 0
pin_mem = True

@torch.no_grad()
def evaluate(model, loader, device):
    model.eval()
    correct = total = 0
    for x, y in loader:
        x, y = x.to(device), y.to(device)
        logits = model(x)
        pred = logits.argmax(1)
        correct += (pred == y).sum().item()
        total += y.size(0)
    return 100.0 * correct / total


torch.manual_seed(27)
np.random.seed(27)

torch.cuda.manual_seed_all(27)
torch.backends.cudnn.deterministic = False
torch.backends.cudnn.benchmark = True



In [4]:
stages = [(r) for r in [(32, 240)]]
dataloader_dict = {}

for resolution, _ in stages:
    train_tf = transforms.Compose([
        transforms.RandomCrop(resolution, padding=4),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize(mean, std),
    ])

    train_set = datasets.CIFAR100('./data', train=True, download=True, transform=train_tf)
    train_loader = DataLoader(train_set, batch_size=batch_size, shuffle=True,
                              num_workers=num_workers, pin_memory=pin_mem)

    dataloader_dict[resolution] = {'train': train_loader}


test_tf = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean, std),
])
test_set = datasets.CIFAR100('./data', train=False, download=True, transform=test_tf)
test_loader = DataLoader(test_set, batch_size=batch_size, shuffle=False,
                         num_workers=num_workers, pin_memory=pin_mem)


resolution, _ = stages[0]          # (32, 240)
train_loader = dataloader_dict[resolution]['train']


In [5]:
# learning_rate is divided by 10 
# ---------------------------
# Training loop
# ---------------------------

def kd_loss(student_logits, teacher_logits, labels, T=4.0, alpha=0.9): 
    """
    alpha = 0.1 as per the github repo
    Compute KD loss = α * KD + (1-α) * CE
    T = temperature
    α = weight for soft distillation loss
    """
    # Hard-label loss
    ce = F.cross_entropy(student_logits, labels)
    kd = F.kl_div(
        F.log_softmax(student_logits / T, dim=1),
        F.softmax(teacher_logits / T, dim=1),
        reduction="batchmean") * (T * T)
    
    return (1 - alpha) * ce + alpha * kd

In [6]:
def train_via_KD(t_model, s_model, train_loader, test_loader, device,
                 epochs=240, base_lr=0.05, wd=5e-4, milestones=(150,180,210),
                 T=4.0, alpha=0.9, save_path="student_kd.pth"):

    # Freeze teacher
    t_model.to(device).eval()
    for p in t_model.parameters():
        p.requires_grad = False

    s_model = s_model.to(device)

    optimizer = optim.SGD(s_model.parameters(), lr=base_lr, momentum=0.9, weight_decay=wd)
    scheduler = optim.lr_scheduler.MultiStepLR(optimizer, milestones=list(milestones), gamma=0.1)

    best_test = -1.0
    for e in range(1, epochs + 1):
        s_model.train()
        running_loss, correct, total = 0.0, 0, 0

        for imgs, labels in tqdm(train_loader, desc=f"Epoch {e}/{epochs}"):
            imgs, labels = imgs.to(device), labels.to(device)

            with torch.no_grad():
                t_logits = t_model(imgs)

            s_logits = s_model(imgs)
            loss = kd_loss(s_logits, t_logits, labels, T=T, alpha=alpha)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            running_loss += loss.item() * imgs.size(0)
            preds = s_logits.argmax(1)
            correct += (preds == labels).sum().item()
            total += labels.size(0)

        scheduler.step()
        train_loss = running_loss / total
        train_acc  = 100.0 * correct / total
        test_acc   = evaluate(s_model, test_loader, device)

        print(f"Epoch {e:3d}/{epochs} | loss {train_loss:.3f} | train {train_acc:5.2f}% | test {test_acc:5.2f}%")

        if test_acc > best_test:
            best_test = test_acc
            to_save = s_model.module.state_dict() if isinstance(s_model, torch.nn.DataParallel) else s_model.state_dict()
            torch.save(to_save, save_path)
            print(f"  ↳ Saved best @ epoch {e} (test {best_test:.2f}%) → {save_path}")

    print(f"✅ KD training finished. Best Test Acc: {best_test:.2f}%")


device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
teacher = resnet56(num_classes = 100)
t_ckpt  = torch.load("/kaggle/input/resnet-56/ckpt_epoch_240.pth", map_location=device, weights_only=False)
teacher.load_state_dict(t_ckpt['model'] if 'model' in t_ckpt else t_ckpt)

student = resnet20(num_classes=100)

train_via_KD(teacher, student, train_loader, test_loader, device, save_path= "56_t-20_s.pth")


teacher = resnet110(num_classes = 100)
t_ckpt  = torch.load("/kaggle/input/resnet-110/ckpt_epoch_240.pth", map_location=device, weights_only=False)
teacher.load_state_dict(t_ckpt['model'] if 'model' in t_ckpt else t_ckpt)

student = resnet20(num_classes=100)

train_via_KD(teacher, student, train_loader, test_loader, device, save_path= "110_t-20_s.pth")

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
teacher = resnet110(num_classes = 100)
t_ckpt  = torch.load("/kaggle/input/resnet-110/ckpt_epoch_240.pth", map_location=device, weights_only=False)
teacher.load_state_dict(t_ckpt['model'] if 'model' in t_ckpt else t_ckpt)

student = resnet32_basic(num_classes=100)

train_via_KD(teacher, student, train_loader, test_loader, device, save_path= "110_t-32_s.pth")

# DIFF KD training - latent diffusion for training


In [7]:
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from tqdm.auto import tqdm


def sinusoidal_embedding(t: torch.Tensor, dim: int) -> torch.Tensor:
    """[B] -> [B, dim]"""
    device = t.device
    half = dim // 2
    freqs = torch.exp(-math.log(10000) * torch.arange(half, device=device).float() / max(half,1))
    args = t.float().unsqueeze(1) * freqs.unsqueeze(0)
    emb = torch.cat([torch.sin(args), torch.cos(args)], dim=1)
    if dim % 2: emb = F.pad(emb, (0,1))
    return emb


@torch.no_grad()
def ddim_update(x_t, eps_pred, abar_t, abar_prev):
    # x0 = (x_t - sqrt(1-abar_t)*eps) / sqrt(abar_t)
    x0 = (x_t - torch.sqrt(1.0 - abar_t) * eps_pred) / torch.sqrt(abar_t)
    x_prev = torch.sqrt(abar_prev) * x0 + torch.sqrt(1.0 - abar_prev) * eps_pred
    return x_prev

class DiffSchedule:
    def __init__(self, T=1000, device='cuda'):
        # linear beta works well for features/logits here
        betas = torch.linspace(1e-4, 2e-2, T, device=device)
        alphas = 1.0 - betas
        self.abar = torch.cumprod(alphas, dim=0)              # [T]
        self.T = T
        self.device = device

    def abar_at(self, t):  # t: Long[B] or scalar int
        return self.abar[t]

# ---------- Tiny denoisers ----------
class Bottleneck2D(nn.Module):
    def __init__(self, c, hidden=None):
        super().__init__()
        h = hidden or max(32, c // 4)
        self.net = nn.Sequential(
            nn.Conv2d(c, h, 1, bias=False), nn.BatchNorm2d(h), nn.ReLU(inplace=True),
            nn.Conv2d(h, h, 3, padding=1, bias=False), nn.BatchNorm2d(h), nn.ReLU(inplace=True),
            nn.Conv2d(h, c, 1, bias=False), nn.BatchNorm2d(c)
        )
        self.act = nn.ReLU(inplace=True)
    def forward(self, x):  # residual
        return self.act(self.net(x) + x)

class DiffusionDenoiser2D(nn.Module):
    """Φ_theta for feature maps [B,C,H,W]. Two bottlenecks + time conditioning."""
    def __init__(self, c, t_dim=128):
        super().__init__()
        self.t_mlp = nn.Sequential(nn.Linear(t_dim, 2*c), nn.SiLU(), nn.Linear(2*c, 2*c))
        self.bn = nn.BatchNorm2d(c)
        self.block1 = Bottleneck2D(c)
        self.block2 = Bottleneck2D(c)
        self.proj = nn.Conv2d(c, c, 1)
        self.t_dim = t_dim
        nn.init.zeros_(self.proj.weight)

    def forward(self, zt, t):
        # time -> FiLM
        temb = self.t_mlp(sinusoidal_embedding(t, self.t_dim))  # [B, 2C]
        B, C, H, W = zt.shape
        gamma, beta = temb[:, :C].view(B,C,1,1), temb[:, C:].view(B,C,1,1)
        h = self.bn(zt) * (1 + gamma) + beta
        h = self.block1(h)
        h = self.block2(h)
        eps = self.proj(h)
        return eps

class DiffusionDenoiser1D(nn.Module):
    """MLP denoiser for logits [B,num_classes]."""
    def __init__(self, d, t_dim=64, hidden=512):
        super().__init__()
        self.t_mlp = nn.Sequential(nn.Linear(t_dim, hidden), nn.SiLU(), nn.Linear(hidden, hidden))
        self.fc1 = nn.Linear(d, hidden)
        self.fc2 = nn.Linear(hidden, hidden)
        self.out = nn.Linear(hidden, d)
        nn.init.zeros_(self.out.weight)

        self.t_dim = t_dim

    def forward(self, zt, t):
        temb = self.t_mlp(sinusoidal_embedding(t, self.t_dim))  # [B,H]
        h = self.fc1(zt) + temb
        h = F.silu(h)
        h = F.silu(self.fc2(h))
        eps = self.out(h)
        return eps

# ---------- Noise adapters (γ) ----------
class NoiseAdapter2D(nn.Module):
    def __init__(self, c):
        super().__init__()
        self.conv = nn.Conv2d(c, max(8, c//8), 1)
        self.head = nn.Sequential(
            nn.AdaptiveAvgPool2d(1),
            nn.Flatten(),
            nn.Linear(max(8, c//8), 1)
        )
    def forward(self, z):    # [B,C,H,W] -> [B,1,1,1] in (0,1)
        g = torch.sigmoid(self.head(F.silu(self.conv(z))))
        return g.view(-1,1,1,1)

class NoiseAdapter1D(nn.Module):
    def __init__(self, d):
        super().__init__()
        self.net = nn.Sequential(nn.Linear(d, max(8, d//8)), nn.SiLU(), nn.Linear(max(8, d//8), 1))
    def forward(self, z):    # [B,D] -> [B,1] in (0,1)
        return torch.sigmoid(self.net(z))

# ---------- DiffKD trainer ----------
class DiffKDTrainer(nn.Module):
    def __init__(self, teacher, student, num_classes=100, device='cuda',
                 T=1000, t_start=500, nfe=5, lambda_diff=1.0, lambda_kd=1.0, lambda_kd_logits=1.0):
        super().__init__()
        self.teacher = teacher.eval()                # frozen
        for p in self.teacher.parameters(): p.requires_grad = False

        self.student = student

        self.device = device
        self.sch = DiffSchedule(T=T, device=device)
        self.t_start = t_start
        self.nfe = nfe

        # Feature channels (before avgpool) for your ResNet: last stage is 64 on CIFAR
        self.feat_c = 64

        # Denoisers
        self.phi_feat = DiffusionDenoiser2D(self.feat_c).to(device)
        self.phi_logit = DiffusionDenoiser1D(num_classes).to(device)

        # Noise adapters (γ)
        self.adapt_feat = NoiseAdapter2D(self.feat_c).to(device)
        self.adapt_logit = NoiseAdapter1D(num_classes).to(device)

        # in case teacher/student channels mismatch, add a 1×1 proj for student feature
        self.stu_proj = nn.Conv2d(self.feat_c, self.feat_c, 1).to(device)

        self.lambda_diff = lambda_diff
        self.lambda_kd = lambda_kd
        self.lambda_kd_logits = lambda_kd_logits

    def get_feats_logits(self, model, x):
        feats, logits = model(x, is_feat=True)   # feats: [f0,f1,f2,f3,f4]; we need f3
        f3 = feats[3]                            # [B, C=64, H=8, W=8] on CIFAR
        return f3, logits

    def train_step(self, batch, optimizer):
        x, y = batch
        x = x.to(self.device); y = y.to(self.device)

        # ----- teacher inference (no grad) -----
        with torch.no_grad():
            t_feat, t_logit = self.get_feats_logits(self.teacher, x)
            # losses use these as "clean" targets
            t_feat_detached = t_feat.detach()
            t_logit_detached = t_logit.detach()

        # ----- student forward -----
        s_feat, s_logit = self.get_feats_logits(self.student, x)
        s_feat = self.stu_proj(s_feat)

        # ----- (A) Train denoisers: noise-pred on teacher outputs -----
        B = x.size(0)
        # sample random diffusion times for noise-pred training
        t_rand = torch.randint(1, self.sch.T, (B,), device=self.device)
        abar_t = self.sch.abar_at(t_rand).view(B,1,1,1)
        eps_feat = torch.randn_like(t_feat_detached)
        zt_feat = torch.sqrt(abar_t) * t_feat_detached + torch.sqrt(1.0 - abar_t) * eps_feat
        eps_hat_feat = self.phi_feat(zt_feat, t_rand)
        loss_diff_feat = F.mse_loss(eps_hat_feat, eps_feat)

        # logits branch
        abar_tl = self.sch.abar_at(t_rand).view(B,1)
        eps_logit = torch.randn_like(t_logit_detached)
        zt_logit = torch.sqrt(abar_tl) * t_logit_detached + torch.sqrt(1.0 - abar_tl) * eps_logit
        eps_hat_logit = self.phi_logit(zt_logit, t_rand)
        loss_diff_logit = F.mse_loss(eps_hat_logit, eps_logit)

        loss_diff = loss_diff_feat + loss_diff_logit

        # ----- (B) Denoise student -> KD against teacher -----
        # start from t_start and do nfe jumps down to 0
        with torch.no_grad():
            # precompute the time grid
            steps = torch.linspace(self.t_start, 0, self.nfe+1, device=self.device).long()
        # Feature branch
        gamma_feat = self.adapt_feat(s_feat)                  # [B,1,1,1]
        z = gamma_feat * s_feat + (1.0 - gamma_feat) * torch.randn_like(s_feat)
        for i in range(self.nfe):
            t_cur = steps[i].expand(B)
            t_prev = steps[i+1].expand(B)
            abar_cur = self.sch.abar_at(t_cur).view(B,1,1,1)
            abar_prev = self.sch.abar_at(t_prev).view(B,1,1,1).clamp(min=1e-6)
            eps = self.phi_feat(z, t_cur)
            z = ddim_update(z, eps, abar_cur, abar_prev)
        s_feat_denoised = z
        loss_kd_feat = F.mse_loss(s_feat_denoised, t_feat_detached)

        # Logits branch
        gamma_log = self.adapt_logit(s_logit)                 # [B,1]
        z = gamma_log * s_logit + (1.0 - gamma_log) * torch.randn_like(s_logit)
        for i in range(self.nfe):
            t_cur = steps[i].expand(B)
            t_prev = steps[i+1].expand(B)
            abar_cur = self.sch.abar_at(t_cur).view(B,1)
            abar_prev = self.sch.abar_at(t_prev).view(B,1).clamp(min=1e-6)
            eps = self.phi_logit(z, t_cur)
            # 1D version of update
            x0 = (z - torch.sqrt(1.0 - abar_cur) * eps) / torch.sqrt(abar_cur)
            z  = torch.sqrt(abar_prev) * x0 + torch.sqrt(1.0 - abar_prev) * eps
        s_logit_denoised = z
        # logits KD with KL (temperature=1 per paper baseline on logits)
        loss_kd_log = F.kl_div(
            F.log_softmax(s_logit_denoised, dim=1),
            F.softmax(t_logit_detached, dim=1),
            reduction="batchmean"
        )

        # ----- (C) Task loss (CE) -----
        loss_task = F.cross_entropy(s_logit, y)

        # ----- total -----
        loss = loss_task + self.lambda_diff * loss_diff + self.lambda_kd * loss_kd_feat + self.lambda_kd_logits * loss_kd_log

        optimizer.zero_grad(set_to_none=True)
        loss.backward()
        optimizer.step()

        with torch.no_grad():
            pred = s_logit.argmax(1)
            acc = (pred == y).float().mean().item() * 100.0

        return {
            "loss_total": float(loss),
            "loss_task": float(loss_task),
            "loss_diff": float(loss_diff),
            "loss_kd_feat": float(loss_kd_feat),
            "loss_kd_log": float(loss_kd_log),
            "acc": acc,
        }

# -------------------------
# Training loop for DiffKD
# -------------------------
def train_via_DiffKD(teacher, student, train_loader, test_loader, device,
                     epochs=240, base_lr=0.05, wd=5e-4,
                     t_start=500, nfe=5, save_path="student_diffkd.pth"):
    torch.cuda.empty_cache()
    trainer = DiffKDTrainer(
        teacher.to(device),
        student.to(device),
        num_classes=100,
        device=device,
        T=1000,
        t_start=t_start,
        nfe=nfe,
        lambda_diff=1.0,         # λ1 (diffusion noise-pred losses)
        lambda_kd=1.0,           # λ3 on feature KD (MSE)
        lambda_kd_logits=1.0     # logits KD (KL)
    )

    # optimize student + denoisers + adapters + proj (teacher frozen)
    params = list(trainer.student.parameters()) + \
             list(trainer.phi_feat.parameters()) + list(trainer.phi_logit.parameters()) + \
             list(trainer.adapt_feat.parameters()) + list(trainer.adapt_logit.parameters()) + \
             list(trainer.stu_proj.parameters())

    optimizer = torch.optim.SGD(params, lr=base_lr, momentum=0.9, weight_decay=wd)
    scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[150,180,210], gamma=0.1)

    @torch.no_grad()
    def evaluate(model, loader):
        model.eval()
        correct = total = 0
        for x, y in loader:
            x, y = x.to(device), y.to(device)
            logits = model(x)
            pred = logits.argmax(1)
            correct += (pred == y).sum().item()
            total += y.size(0)
        return 100.0 * correct / total

    best = -1.0
    for e in range(1, epochs+1):
        trainer.train()
        running = {"loss_total":0,"acc":0}
        n = 0
        for batch in tqdm(train_loader, desc=f"[DiffKD] Epoch {e}/{epochs}"):
            stats = trainer.train_step(batch, optimizer)
            bs = batch[0].size(0)
            n += bs
            running["loss_total"] += stats["loss_total"] * bs
            running["acc"] += stats["acc"] * bs

        scheduler.step()
        tr_loss = running["loss_total"]/n
        tr_acc  = running["acc"]/n
        te_acc  = evaluate(trainer.student, test_loader)

        print(f"Epoch {e:3d}/{epochs} | loss {tr_loss:.3f} | train {tr_acc:5.2f}% | test {te_acc:5.2f}%")

        if te_acc > best:
            best = te_acc
            state = trainer.student.module.state_dict() if isinstance(trainer.student, nn.DataParallel) else trainer.student.state_dict()
            torch.save(state, save_path)
            print(f"  ↳ Saved best @ epoch {e} (test {best:.2f}%) → {save_path}")

    print(f"✅ DiffKD finished. Best Test Acc: {best:.2f}%")
    return best


In [8]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# 1) Teacher
teacher = resnet56(num_classes=100)
t_ckpt  = torch.load("/kaggle/input/resnet-56/ckpt_epoch_240.pth", map_location=device, weights_only=False)
teacher.load_state_dict(t_ckpt['model'] if 'model' in t_ckpt else t_ckpt)

# 2) Student
student = resnet20(num_classes=100)

# 3) Train with DiffKD (feature+logits, NFEs=5, t_start=500)
train_via_DiffKD(
    teacher, student,
    train_loader, test_loader, device,
    epochs=240, base_lr=0.05, wd=5e-4,
    t_start=500, nfe=5,
    save_path="rn20_from_rn56_DiffKD.pth"
)


[DiffKD] Epoch 1/240:   0%|          | 0/391 [00:00<?, ?it/s]

Epoch   1/240 | loss 18.687 | train  7.53% | test 12.48%
  ↳ Saved best @ epoch 1 (test 12.48%) → rn20_from_rn56_DiffKD.pth


[DiffKD] Epoch 2/240:   0%|          | 0/391 [00:00<?, ?it/s]

Epoch   2/240 | loss 16.571 | train 16.80% | test 19.78%
  ↳ Saved best @ epoch 2 (test 19.78%) → rn20_from_rn56_DiffKD.pth


[DiffKD] Epoch 3/240:   0%|          | 0/391 [00:00<?, ?it/s]

Epoch   3/240 | loss 15.853 | train 26.22% | test 26.87%
  ↳ Saved best @ epoch 3 (test 26.87%) → rn20_from_rn56_DiffKD.pth


[DiffKD] Epoch 4/240:   0%|          | 0/391 [00:00<?, ?it/s]

Epoch   4/240 | loss 15.328 | train 33.84% | test 28.85%
  ↳ Saved best @ epoch 4 (test 28.85%) → rn20_from_rn56_DiffKD.pth


[DiffKD] Epoch 5/240:   0%|          | 0/391 [00:00<?, ?it/s]

Epoch   5/240 | loss 14.951 | train 39.18% | test 38.87%
  ↳ Saved best @ epoch 5 (test 38.87%) → rn20_from_rn56_DiffKD.pth


[DiffKD] Epoch 6/240:   0%|          | 0/391 [00:00<?, ?it/s]

Epoch   6/240 | loss 14.636 | train 43.13% | test 35.69%


[DiffKD] Epoch 7/240:   0%|          | 0/391 [00:00<?, ?it/s]

Epoch   7/240 | loss 14.401 | train 45.85% | test 41.83%
  ↳ Saved best @ epoch 7 (test 41.83%) → rn20_from_rn56_DiffKD.pth


[DiffKD] Epoch 8/240:   0%|          | 0/391 [00:00<?, ?it/s]

Epoch   8/240 | loss 14.209 | train 48.34% | test 43.55%
  ↳ Saved best @ epoch 8 (test 43.55%) → rn20_from_rn56_DiffKD.pth


[DiffKD] Epoch 9/240:   0%|          | 0/391 [00:00<?, ?it/s]

Epoch   9/240 | loss 14.057 | train 50.17% | test 45.37%
  ↳ Saved best @ epoch 9 (test 45.37%) → rn20_from_rn56_DiffKD.pth


[DiffKD] Epoch 10/240:   0%|          | 0/391 [00:00<?, ?it/s]

Epoch  10/240 | loss 13.949 | train 51.65% | test 47.16%
  ↳ Saved best @ epoch 10 (test 47.16%) → rn20_from_rn56_DiffKD.pth


[DiffKD] Epoch 11/240:   0%|          | 0/391 [00:00<?, ?it/s]

Epoch  11/240 | loss 13.882 | train 53.29% | test 46.48%


[DiffKD] Epoch 12/240:   0%|          | 0/391 [00:00<?, ?it/s]

Epoch  12/240 | loss 13.857 | train 54.12% | test 50.99%
  ↳ Saved best @ epoch 12 (test 50.99%) → rn20_from_rn56_DiffKD.pth


[DiffKD] Epoch 13/240:   0%|          | 0/391 [00:00<?, ?it/s]

Epoch  13/240 | loss 13.861 | train 55.11% | test 49.27%


[DiffKD] Epoch 14/240:   0%|          | 0/391 [00:00<?, ?it/s]

Epoch  14/240 | loss 13.820 | train 56.17% | test 48.84%


[DiffKD] Epoch 15/240:   0%|          | 0/391 [00:00<?, ?it/s]

Epoch  15/240 | loss 13.807 | train 56.72% | test 49.81%


[DiffKD] Epoch 16/240:   0%|          | 0/391 [00:00<?, ?it/s]

Epoch  16/240 | loss 13.811 | train 57.40% | test 53.33%
  ↳ Saved best @ epoch 16 (test 53.33%) → rn20_from_rn56_DiffKD.pth


[DiffKD] Epoch 17/240:   0%|          | 0/391 [00:00<?, ?it/s]

Epoch  17/240 | loss 13.764 | train 57.81% | test 51.00%


[DiffKD] Epoch 18/240:   0%|          | 0/391 [00:00<?, ?it/s]

Epoch  18/240 | loss 13.717 | train 58.32% | test 49.07%


[DiffKD] Epoch 19/240:   0%|          | 0/391 [00:00<?, ?it/s]

Epoch  19/240 | loss 13.637 | train 59.00% | test 53.39%
  ↳ Saved best @ epoch 19 (test 53.39%) → rn20_from_rn56_DiffKD.pth


[DiffKD] Epoch 20/240:   0%|          | 0/391 [00:00<?, ?it/s]

Epoch  20/240 | loss 13.601 | train 59.26% | test 52.23%


[DiffKD] Epoch 21/240:   0%|          | 0/391 [00:00<?, ?it/s]

Epoch  21/240 | loss 13.577 | train 59.30% | test 49.38%


[DiffKD] Epoch 22/240:   0%|          | 0/391 [00:00<?, ?it/s]

Epoch  22/240 | loss 13.500 | train 59.86% | test 51.09%


[DiffKD] Epoch 23/240:   0%|          | 0/391 [00:00<?, ?it/s]

Epoch  23/240 | loss 13.451 | train 60.14% | test 52.03%


[DiffKD] Epoch 24/240:   0%|          | 0/391 [00:00<?, ?it/s]

Epoch  24/240 | loss 13.448 | train 60.52% | test 51.60%


[DiffKD] Epoch 25/240:   0%|          | 0/391 [00:00<?, ?it/s]

Epoch  25/240 | loss 13.371 | train 60.77% | test 52.75%


[DiffKD] Epoch 26/240:   0%|          | 0/391 [00:00<?, ?it/s]

Epoch  26/240 | loss 13.337 | train 60.73% | test 53.45%
  ↳ Saved best @ epoch 26 (test 53.45%) → rn20_from_rn56_DiffKD.pth


[DiffKD] Epoch 27/240:   0%|          | 0/391 [00:00<?, ?it/s]

Epoch  27/240 | loss 13.319 | train 61.46% | test 54.56%
  ↳ Saved best @ epoch 27 (test 54.56%) → rn20_from_rn56_DiffKD.pth


[DiffKD] Epoch 28/240:   0%|          | 0/391 [00:00<?, ?it/s]

Epoch  28/240 | loss 13.286 | train 61.66% | test 52.37%


[DiffKD] Epoch 29/240:   0%|          | 0/391 [00:00<?, ?it/s]

Epoch  29/240 | loss 13.231 | train 61.41% | test 52.54%


[DiffKD] Epoch 30/240:   0%|          | 0/391 [00:00<?, ?it/s]

Epoch  30/240 | loss 13.195 | train 61.98% | test 52.04%


[DiffKD] Epoch 31/240:   0%|          | 0/391 [00:00<?, ?it/s]

Epoch  31/240 | loss 13.193 | train 62.07% | test 55.43%
  ↳ Saved best @ epoch 31 (test 55.43%) → rn20_from_rn56_DiffKD.pth


[DiffKD] Epoch 32/240:   0%|          | 0/391 [00:00<?, ?it/s]

Epoch  32/240 | loss 13.162 | train 62.14% | test 48.01%


[DiffKD] Epoch 33/240:   0%|          | 0/391 [00:00<?, ?it/s]

Epoch  33/240 | loss 13.177 | train 62.51% | test 56.48%
  ↳ Saved best @ epoch 33 (test 56.48%) → rn20_from_rn56_DiffKD.pth


[DiffKD] Epoch 34/240:   0%|          | 0/391 [00:00<?, ?it/s]

Epoch  34/240 | loss 13.175 | train 62.40% | test 55.83%


[DiffKD] Epoch 35/240:   0%|          | 0/391 [00:00<?, ?it/s]

Epoch  35/240 | loss 13.165 | train 62.36% | test 54.76%


[DiffKD] Epoch 36/240:   0%|          | 0/391 [00:00<?, ?it/s]

Epoch  36/240 | loss 13.137 | train 62.88% | test 54.22%


[DiffKD] Epoch 37/240:   0%|          | 0/391 [00:00<?, ?it/s]

Epoch  37/240 | loss 13.142 | train 62.99% | test 50.91%


[DiffKD] Epoch 38/240:   0%|          | 0/391 [00:00<?, ?it/s]

Epoch  38/240 | loss 13.116 | train 62.84% | test 54.23%


[DiffKD] Epoch 39/240:   0%|          | 0/391 [00:00<?, ?it/s]

Epoch  39/240 | loss 13.146 | train 63.09% | test 55.43%


[DiffKD] Epoch 40/240:   0%|          | 0/391 [00:00<?, ?it/s]

Epoch  40/240 | loss 13.140 | train 63.49% | test 50.45%


[DiffKD] Epoch 41/240:   0%|          | 0/391 [00:00<?, ?it/s]

Epoch  41/240 | loss 13.083 | train 63.05% | test 52.65%


[DiffKD] Epoch 42/240:   0%|          | 0/391 [00:00<?, ?it/s]

Epoch  42/240 | loss 13.103 | train 63.20% | test 52.57%


[DiffKD] Epoch 43/240:   0%|          | 0/391 [00:00<?, ?it/s]

Epoch  43/240 | loss 13.115 | train 63.41% | test 56.37%


[DiffKD] Epoch 44/240:   0%|          | 0/391 [00:00<?, ?it/s]

Epoch  44/240 | loss 13.067 | train 63.71% | test 55.21%


[DiffKD] Epoch 45/240:   0%|          | 0/391 [00:00<?, ?it/s]

Epoch  45/240 | loss 13.077 | train 63.66% | test 55.07%


[DiffKD] Epoch 46/240:   0%|          | 0/391 [00:00<?, ?it/s]

Epoch  46/240 | loss 13.064 | train 63.84% | test 56.50%
  ↳ Saved best @ epoch 46 (test 56.50%) → rn20_from_rn56_DiffKD.pth


[DiffKD] Epoch 47/240:   0%|          | 0/391 [00:00<?, ?it/s]

Epoch  47/240 | loss 13.092 | train 63.93% | test 52.73%


[DiffKD] Epoch 48/240:   0%|          | 0/391 [00:00<?, ?it/s]

Epoch  48/240 | loss 13.078 | train 64.00% | test 57.58%
  ↳ Saved best @ epoch 48 (test 57.58%) → rn20_from_rn56_DiffKD.pth


[DiffKD] Epoch 49/240:   0%|          | 0/391 [00:00<?, ?it/s]

Epoch  49/240 | loss 13.067 | train 64.03% | test 54.62%


[DiffKD] Epoch 50/240:   0%|          | 0/391 [00:00<?, ?it/s]

Epoch  50/240 | loss 13.104 | train 63.73% | test 51.50%


[DiffKD] Epoch 51/240:   0%|          | 0/391 [00:00<?, ?it/s]

Epoch  51/240 | loss 13.079 | train 64.34% | test 55.47%


[DiffKD] Epoch 52/240:   0%|          | 0/391 [00:00<?, ?it/s]

Epoch  52/240 | loss 13.094 | train 64.19% | test 55.93%


[DiffKD] Epoch 53/240:   0%|          | 0/391 [00:00<?, ?it/s]

Epoch  53/240 | loss 13.051 | train 64.27% | test 54.76%


[DiffKD] Epoch 54/240:   0%|          | 0/391 [00:00<?, ?it/s]

Epoch  54/240 | loss 13.058 | train 64.36% | test 56.72%


[DiffKD] Epoch 55/240:   0%|          | 0/391 [00:00<?, ?it/s]

Epoch  55/240 | loss 13.046 | train 64.23% | test 56.71%


[DiffKD] Epoch 56/240:   0%|          | 0/391 [00:00<?, ?it/s]

Epoch  56/240 | loss 13.057 | train 64.27% | test 56.40%


[DiffKD] Epoch 57/240:   0%|          | 0/391 [00:00<?, ?it/s]

Epoch  57/240 | loss 13.063 | train 64.40% | test 56.36%


[DiffKD] Epoch 58/240:   0%|          | 0/391 [00:00<?, ?it/s]

Epoch  58/240 | loss 13.031 | train 64.64% | test 57.60%
  ↳ Saved best @ epoch 58 (test 57.60%) → rn20_from_rn56_DiffKD.pth


[DiffKD] Epoch 59/240:   0%|          | 0/391 [00:00<?, ?it/s]

Epoch  59/240 | loss 13.016 | train 64.86% | test 53.99%


[DiffKD] Epoch 60/240:   0%|          | 0/391 [00:00<?, ?it/s]

Epoch  60/240 | loss 13.013 | train 64.44% | test 56.70%


[DiffKD] Epoch 61/240:   0%|          | 0/391 [00:00<?, ?it/s]

Epoch  61/240 | loss 13.033 | train 64.27% | test 56.98%


[DiffKD] Epoch 62/240:   0%|          | 0/391 [00:00<?, ?it/s]

Epoch  62/240 | loss 13.028 | train 64.60% | test 54.29%


[DiffKD] Epoch 63/240:   0%|          | 0/391 [00:00<?, ?it/s]

Epoch  63/240 | loss 13.024 | train 64.82% | test 55.97%


[DiffKD] Epoch 64/240:   0%|          | 0/391 [00:00<?, ?it/s]

Epoch  64/240 | loss 13.007 | train 65.10% | test 54.83%


[DiffKD] Epoch 65/240:   0%|          | 0/391 [00:00<?, ?it/s]

Epoch  65/240 | loss 13.034 | train 64.96% | test 53.57%


[DiffKD] Epoch 66/240:   0%|          | 0/391 [00:00<?, ?it/s]

Epoch  66/240 | loss 13.010 | train 64.71% | test 54.84%


[DiffKD] Epoch 67/240:   0%|          | 0/391 [00:00<?, ?it/s]

Epoch  67/240 | loss 13.027 | train 64.81% | test 57.28%


[DiffKD] Epoch 68/240:   0%|          | 0/391 [00:00<?, ?it/s]

Epoch  68/240 | loss 13.004 | train 65.19% | test 52.97%


[DiffKD] Epoch 69/240:   0%|          | 0/391 [00:00<?, ?it/s]

Epoch  69/240 | loss 12.984 | train 65.06% | test 54.06%


[DiffKD] Epoch 70/240:   0%|          | 0/391 [00:00<?, ?it/s]

Epoch  70/240 | loss 13.003 | train 65.27% | test 54.97%


[DiffKD] Epoch 71/240:   0%|          | 0/391 [00:00<?, ?it/s]

Epoch  71/240 | loss 12.959 | train 64.92% | test 53.49%


[DiffKD] Epoch 72/240:   0%|          | 0/391 [00:00<?, ?it/s]

Epoch  72/240 | loss 12.998 | train 65.34% | test 56.79%


[DiffKD] Epoch 73/240:   0%|          | 0/391 [00:00<?, ?it/s]

Epoch  73/240 | loss 12.987 | train 65.14% | test 57.52%


[DiffKD] Epoch 74/240:   0%|          | 0/391 [00:00<?, ?it/s]

Epoch  74/240 | loss 13.007 | train 64.98% | test 56.40%


[DiffKD] Epoch 75/240:   0%|          | 0/391 [00:00<?, ?it/s]

Epoch  75/240 | loss 12.962 | train 65.22% | test 56.34%


[DiffKD] Epoch 76/240:   0%|          | 0/391 [00:00<?, ?it/s]

Epoch  76/240 | loss 12.978 | train 65.29% | test 54.38%


[DiffKD] Epoch 77/240:   0%|          | 0/391 [00:00<?, ?it/s]

Epoch  77/240 | loss 12.956 | train 65.27% | test 54.94%


[DiffKD] Epoch 78/240:   0%|          | 0/391 [00:00<?, ?it/s]

Epoch  78/240 | loss 12.970 | train 65.24% | test 54.66%


[DiffKD] Epoch 79/240:   0%|          | 0/391 [00:00<?, ?it/s]

Epoch  79/240 | loss 12.938 | train 65.37% | test 53.65%


[DiffKD] Epoch 80/240:   0%|          | 0/391 [00:00<?, ?it/s]

Epoch  80/240 | loss 12.949 | train 65.23% | test 57.21%


[DiffKD] Epoch 81/240:   0%|          | 0/391 [00:00<?, ?it/s]

Epoch  81/240 | loss 12.953 | train 65.44% | test 56.19%


[DiffKD] Epoch 82/240:   0%|          | 0/391 [00:00<?, ?it/s]

Epoch  82/240 | loss 12.971 | train 65.45% | test 53.96%


[DiffKD] Epoch 83/240:   0%|          | 0/391 [00:00<?, ?it/s]

Epoch  83/240 | loss 12.938 | train 65.30% | test 52.35%


[DiffKD] Epoch 84/240:   0%|          | 0/391 [00:00<?, ?it/s]

Epoch  84/240 | loss 12.948 | train 65.46% | test 56.68%


[DiffKD] Epoch 85/240:   0%|          | 0/391 [00:00<?, ?it/s]

Epoch  85/240 | loss 12.918 | train 65.49% | test 58.30%
  ↳ Saved best @ epoch 85 (test 58.30%) → rn20_from_rn56_DiffKD.pth


[DiffKD] Epoch 86/240:   0%|          | 0/391 [00:00<?, ?it/s]

Epoch  86/240 | loss 12.980 | train 65.24% | test 56.40%


[DiffKD] Epoch 87/240:   0%|          | 0/391 [00:00<?, ?it/s]

Epoch  87/240 | loss 12.970 | train 65.82% | test 52.95%


[DiffKD] Epoch 88/240:   0%|          | 0/391 [00:00<?, ?it/s]

Epoch  88/240 | loss 12.983 | train 65.50% | test 56.82%


[DiffKD] Epoch 89/240:   0%|          | 0/391 [00:00<?, ?it/s]

Epoch  89/240 | loss 12.950 | train 65.54% | test 55.82%


[DiffKD] Epoch 90/240:   0%|          | 0/391 [00:00<?, ?it/s]

Epoch  90/240 | loss 12.931 | train 65.68% | test 55.67%


[DiffKD] Epoch 91/240:   0%|          | 0/391 [00:00<?, ?it/s]

Epoch  91/240 | loss 12.951 | train 65.89% | test 57.19%


[DiffKD] Epoch 92/240:   0%|          | 0/391 [00:00<?, ?it/s]

Epoch  92/240 | loss 12.954 | train 65.75% | test 56.01%


[DiffKD] Epoch 93/240:   0%|          | 0/391 [00:00<?, ?it/s]

Epoch  93/240 | loss 12.949 | train 65.32% | test 55.19%


[DiffKD] Epoch 94/240:   0%|          | 0/391 [00:00<?, ?it/s]

Epoch  94/240 | loss 12.920 | train 65.58% | test 57.50%


[DiffKD] Epoch 95/240:   0%|          | 0/391 [00:00<?, ?it/s]

Epoch  95/240 | loss 12.936 | train 65.72% | test 56.33%


[DiffKD] Epoch 96/240:   0%|          | 0/391 [00:00<?, ?it/s]

Epoch  96/240 | loss 12.951 | train 65.61% | test 56.56%


[DiffKD] Epoch 97/240:   0%|          | 0/391 [00:00<?, ?it/s]

Epoch  97/240 | loss 12.950 | train 65.69% | test 57.41%


[DiffKD] Epoch 98/240:   0%|          | 0/391 [00:00<?, ?it/s]

Epoch  98/240 | loss 12.930 | train 65.61% | test 57.23%


[DiffKD] Epoch 99/240:   0%|          | 0/391 [00:00<?, ?it/s]

Epoch  99/240 | loss 12.950 | train 65.69% | test 56.18%


[DiffKD] Epoch 100/240:   0%|          | 0/391 [00:00<?, ?it/s]

Epoch 100/240 | loss 12.915 | train 65.99% | test 57.39%


[DiffKD] Epoch 101/240:   0%|          | 0/391 [00:00<?, ?it/s]

Epoch 101/240 | loss 12.919 | train 65.87% | test 56.24%


[DiffKD] Epoch 102/240:   0%|          | 0/391 [00:00<?, ?it/s]

Epoch 102/240 | loss 12.935 | train 65.84% | test 56.09%


[DiffKD] Epoch 103/240:   0%|          | 0/391 [00:00<?, ?it/s]

Epoch 103/240 | loss 12.898 | train 65.96% | test 52.42%


[DiffKD] Epoch 104/240:   0%|          | 0/391 [00:00<?, ?it/s]

Epoch 104/240 | loss 12.917 | train 65.64% | test 56.26%


[DiffKD] Epoch 105/240:   0%|          | 0/391 [00:00<?, ?it/s]

Epoch 105/240 | loss 12.887 | train 65.92% | test 57.18%


[DiffKD] Epoch 106/240:   0%|          | 0/391 [00:00<?, ?it/s]

Epoch 106/240 | loss 12.931 | train 65.98% | test 56.13%


[DiffKD] Epoch 107/240:   0%|          | 0/391 [00:00<?, ?it/s]

Epoch 107/240 | loss 12.915 | train 65.90% | test 56.70%


[DiffKD] Epoch 108/240:   0%|          | 0/391 [00:00<?, ?it/s]

Epoch 108/240 | loss 12.917 | train 65.85% | test 56.64%


[DiffKD] Epoch 109/240:   0%|          | 0/391 [00:00<?, ?it/s]

Epoch 109/240 | loss 12.945 | train 65.68% | test 53.68%


[DiffKD] Epoch 110/240:   0%|          | 0/391 [00:00<?, ?it/s]

Epoch 110/240 | loss 12.896 | train 65.97% | test 58.04%


[DiffKD] Epoch 111/240:   0%|          | 0/391 [00:00<?, ?it/s]

Epoch 111/240 | loss 12.900 | train 66.03% | test 55.32%


[DiffKD] Epoch 112/240:   0%|          | 0/391 [00:00<?, ?it/s]

Epoch 112/240 | loss 12.922 | train 66.06% | test 52.77%


[DiffKD] Epoch 113/240:   0%|          | 0/391 [00:00<?, ?it/s]

Epoch 113/240 | loss 12.893 | train 65.98% | test 56.70%


[DiffKD] Epoch 114/240:   0%|          | 0/391 [00:00<?, ?it/s]

Epoch 114/240 | loss 12.915 | train 66.16% | test 53.69%


[DiffKD] Epoch 115/240:   0%|          | 0/391 [00:00<?, ?it/s]

Epoch 115/240 | loss 12.923 | train 66.00% | test 56.25%


[DiffKD] Epoch 116/240:   0%|          | 0/391 [00:00<?, ?it/s]

Epoch 116/240 | loss 12.916 | train 66.30% | test 55.45%


[DiffKD] Epoch 117/240:   0%|          | 0/391 [00:00<?, ?it/s]

Epoch 117/240 | loss 12.926 | train 65.98% | test 57.39%


[DiffKD] Epoch 118/240:   0%|          | 0/391 [00:00<?, ?it/s]

Epoch 118/240 | loss 12.913 | train 66.06% | test 57.57%


[DiffKD] Epoch 119/240:   0%|          | 0/391 [00:00<?, ?it/s]

Epoch 119/240 | loss 12.918 | train 66.32% | test 58.67%
  ↳ Saved best @ epoch 119 (test 58.67%) → rn20_from_rn56_DiffKD.pth


[DiffKD] Epoch 120/240:   0%|          | 0/391 [00:00<?, ?it/s]

Epoch 120/240 | loss 12.915 | train 66.28% | test 58.31%


[DiffKD] Epoch 121/240:   0%|          | 0/391 [00:00<?, ?it/s]

Epoch 121/240 | loss 12.919 | train 66.04% | test 55.80%


[DiffKD] Epoch 122/240:   0%|          | 0/391 [00:00<?, ?it/s]

Epoch 122/240 | loss 12.903 | train 66.10% | test 56.28%


[DiffKD] Epoch 123/240:   0%|          | 0/391 [00:00<?, ?it/s]

Epoch 123/240 | loss 12.922 | train 65.83% | test 55.54%


[DiffKD] Epoch 124/240:   0%|          | 0/391 [00:00<?, ?it/s]

Epoch 124/240 | loss 12.904 | train 66.28% | test 59.08%
  ↳ Saved best @ epoch 124 (test 59.08%) → rn20_from_rn56_DiffKD.pth


[DiffKD] Epoch 125/240:   0%|          | 0/391 [00:00<?, ?it/s]

Epoch 125/240 | loss 12.871 | train 66.19% | test 57.36%


[DiffKD] Epoch 126/240:   0%|          | 0/391 [00:00<?, ?it/s]

Epoch 126/240 | loss 12.889 | train 66.23% | test 57.43%


[DiffKD] Epoch 127/240:   0%|          | 0/391 [00:00<?, ?it/s]

Epoch 127/240 | loss 12.889 | train 66.17% | test 54.51%


[DiffKD] Epoch 128/240:   0%|          | 0/391 [00:00<?, ?it/s]

Epoch 128/240 | loss 12.918 | train 66.17% | test 56.14%


[DiffKD] Epoch 129/240:   0%|          | 0/391 [00:00<?, ?it/s]

Epoch 129/240 | loss 12.887 | train 66.36% | test 54.58%


[DiffKD] Epoch 130/240:   0%|          | 0/391 [00:00<?, ?it/s]

Epoch 130/240 | loss 12.926 | train 66.34% | test 55.45%


[DiffKD] Epoch 131/240:   0%|          | 0/391 [00:00<?, ?it/s]

Epoch 131/240 | loss 12.958 | train 66.39% | test 54.41%


[DiffKD] Epoch 132/240:   0%|          | 0/391 [00:00<?, ?it/s]

Epoch 132/240 | loss 12.907 | train 66.41% | test 57.35%


[DiffKD] Epoch 133/240:   0%|          | 0/391 [00:00<?, ?it/s]

Epoch 133/240 | loss 12.889 | train 66.35% | test 56.25%


[DiffKD] Epoch 134/240:   0%|          | 0/391 [00:00<?, ?it/s]

Epoch 134/240 | loss 12.920 | train 66.24% | test 53.96%


[DiffKD] Epoch 135/240:   0%|          | 0/391 [00:00<?, ?it/s]

Epoch 135/240 | loss 12.875 | train 66.38% | test 55.42%


[DiffKD] Epoch 136/240:   0%|          | 0/391 [00:00<?, ?it/s]

Epoch 136/240 | loss 12.937 | train 66.65% | test 57.91%


[DiffKD] Epoch 137/240:   0%|          | 0/391 [00:00<?, ?it/s]

Epoch 137/240 | loss 12.918 | train 66.34% | test 57.26%


[DiffKD] Epoch 138/240:   0%|          | 0/391 [00:00<?, ?it/s]

Epoch 138/240 | loss 12.929 | train 66.22% | test 57.02%


[DiffKD] Epoch 139/240:   0%|          | 0/391 [00:00<?, ?it/s]

Epoch 139/240 | loss 12.899 | train 66.07% | test 55.52%


[DiffKD] Epoch 140/240:   0%|          | 0/391 [00:00<?, ?it/s]

Epoch 140/240 | loss 12.886 | train 66.53% | test 53.23%


[DiffKD] Epoch 141/240:   0%|          | 0/391 [00:00<?, ?it/s]

Epoch 141/240 | loss 12.929 | train 66.43% | test 58.50%


[DiffKD] Epoch 142/240:   0%|          | 0/391 [00:00<?, ?it/s]

Epoch 142/240 | loss 12.856 | train 66.45% | test 53.31%


[DiffKD] Epoch 143/240:   0%|          | 0/391 [00:00<?, ?it/s]

Epoch 143/240 | loss 12.922 | train 66.21% | test 58.27%


[DiffKD] Epoch 144/240:   0%|          | 0/391 [00:00<?, ?it/s]

Epoch 144/240 | loss 12.868 | train 66.41% | test 55.14%


[DiffKD] Epoch 145/240:   0%|          | 0/391 [00:00<?, ?it/s]

Epoch 145/240 | loss 12.854 | train 66.50% | test 53.32%


[DiffKD] Epoch 146/240:   0%|          | 0/391 [00:00<?, ?it/s]

Epoch 146/240 | loss 12.868 | train 66.65% | test 57.04%


[DiffKD] Epoch 147/240:   0%|          | 0/391 [00:00<?, ?it/s]

Epoch 147/240 | loss 12.902 | train 66.53% | test 57.26%


[DiffKD] Epoch 148/240:   0%|          | 0/391 [00:00<?, ?it/s]

Epoch 148/240 | loss 12.900 | train 66.30% | test 55.43%


[DiffKD] Epoch 149/240:   0%|          | 0/391 [00:00<?, ?it/s]

Epoch 149/240 | loss 12.900 | train 66.44% | test 50.25%


[DiffKD] Epoch 150/240:   0%|          | 0/391 [00:00<?, ?it/s]

Epoch 150/240 | loss 12.880 | train 66.41% | test 56.60%


[DiffKD] Epoch 151/240:   0%|          | 0/391 [00:00<?, ?it/s]

Epoch 151/240 | loss 12.230 | train 74.21% | test 68.10%
  ↳ Saved best @ epoch 151 (test 68.10%) → rn20_from_rn56_DiffKD.pth


[DiffKD] Epoch 152/240:   0%|          | 0/391 [00:00<?, ?it/s]

Epoch 152/240 | loss 12.092 | train 76.65% | test 68.19%
  ↳ Saved best @ epoch 152 (test 68.19%) → rn20_from_rn56_DiffKD.pth


[DiffKD] Epoch 153/240:   0%|          | 0/391 [00:00<?, ?it/s]

Epoch 153/240 | loss 12.097 | train 77.45% | test 68.67%
  ↳ Saved best @ epoch 153 (test 68.67%) → rn20_from_rn56_DiffKD.pth


[DiffKD] Epoch 154/240:   0%|          | 0/391 [00:00<?, ?it/s]

Epoch 154/240 | loss 12.125 | train 77.81% | test 68.66%


[DiffKD] Epoch 155/240:   0%|          | 0/391 [00:00<?, ?it/s]

Epoch 155/240 | loss 12.137 | train 78.60% | test 69.06%
  ↳ Saved best @ epoch 155 (test 69.06%) → rn20_from_rn56_DiffKD.pth


[DiffKD] Epoch 156/240:   0%|          | 0/391 [00:00<?, ?it/s]

Epoch 156/240 | loss 12.153 | train 78.89% | test 69.03%


[DiffKD] Epoch 157/240:   0%|          | 0/391 [00:00<?, ?it/s]

Epoch 157/240 | loss 12.207 | train 79.04% | test 69.22%
  ↳ Saved best @ epoch 157 (test 69.22%) → rn20_from_rn56_DiffKD.pth


[DiffKD] Epoch 158/240:   0%|          | 0/391 [00:00<?, ?it/s]

Epoch 158/240 | loss 12.225 | train 79.43% | test 69.26%
  ↳ Saved best @ epoch 158 (test 69.26%) → rn20_from_rn56_DiffKD.pth


[DiffKD] Epoch 159/240:   0%|          | 0/391 [00:00<?, ?it/s]

Epoch 159/240 | loss 12.265 | train 79.65% | test 69.41%
  ↳ Saved best @ epoch 159 (test 69.41%) → rn20_from_rn56_DiffKD.pth


[DiffKD] Epoch 160/240:   0%|          | 0/391 [00:00<?, ?it/s]

Epoch 160/240 | loss 12.287 | train 80.06% | test 69.33%


[DiffKD] Epoch 161/240:   0%|          | 0/391 [00:00<?, ?it/s]

Epoch 161/240 | loss 12.320 | train 80.07% | test 69.55%
  ↳ Saved best @ epoch 161 (test 69.55%) → rn20_from_rn56_DiffKD.pth


[DiffKD] Epoch 162/240:   0%|          | 0/391 [00:00<?, ?it/s]

Epoch 162/240 | loss 12.357 | train 80.25% | test 69.27%


[DiffKD] Epoch 163/240:   0%|          | 0/391 [00:00<?, ?it/s]

Epoch 163/240 | loss 12.381 | train 80.54% | test 69.89%
  ↳ Saved best @ epoch 163 (test 69.89%) → rn20_from_rn56_DiffKD.pth


[DiffKD] Epoch 164/240:   0%|          | 0/391 [00:00<?, ?it/s]

Epoch 164/240 | loss 12.414 | train 80.72% | test 69.26%


[DiffKD] Epoch 165/240:   0%|          | 0/391 [00:00<?, ?it/s]

Epoch 165/240 | loss 12.433 | train 81.08% | test 69.75%


[DiffKD] Epoch 166/240:   0%|          | 0/391 [00:00<?, ?it/s]

Epoch 166/240 | loss 12.472 | train 80.87% | test 69.35%


[DiffKD] Epoch 167/240:   0%|          | 0/391 [00:00<?, ?it/s]

Epoch 167/240 | loss 12.491 | train 81.15% | test 68.87%


[DiffKD] Epoch 168/240:   0%|          | 0/391 [00:00<?, ?it/s]

Epoch 168/240 | loss 12.534 | train 81.31% | test 69.10%


[DiffKD] Epoch 169/240:   0%|          | 0/391 [00:00<?, ?it/s]

Epoch 169/240 | loss 12.557 | train 81.33% | test 68.87%


[DiffKD] Epoch 170/240:   0%|          | 0/391 [00:00<?, ?it/s]

Epoch 170/240 | loss 12.577 | train 81.45% | test 69.67%


[DiffKD] Epoch 171/240:   0%|          | 0/391 [00:00<?, ?it/s]

Epoch 171/240 | loss 12.610 | train 81.55% | test 69.13%


[DiffKD] Epoch 172/240:   0%|          | 0/391 [00:00<?, ?it/s]

Epoch 172/240 | loss 12.639 | train 82.02% | test 69.14%


[DiffKD] Epoch 173/240:   0%|          | 0/391 [00:00<?, ?it/s]

Epoch 173/240 | loss 12.657 | train 81.86% | test 68.80%


[DiffKD] Epoch 174/240:   0%|          | 0/391 [00:00<?, ?it/s]

Epoch 174/240 | loss 12.668 | train 82.04% | test 69.05%


[DiffKD] Epoch 175/240:   0%|          | 0/391 [00:00<?, ?it/s]

Epoch 175/240 | loss 12.705 | train 82.10% | test 68.78%


[DiffKD] Epoch 176/240:   0%|          | 0/391 [00:00<?, ?it/s]

Epoch 176/240 | loss 12.739 | train 82.21% | test 68.87%


[DiffKD] Epoch 177/240:   0%|          | 0/391 [00:00<?, ?it/s]

Epoch 177/240 | loss 12.756 | train 82.16% | test 69.01%


[DiffKD] Epoch 178/240:   0%|          | 0/391 [00:00<?, ?it/s]

Epoch 178/240 | loss 12.774 | train 82.39% | test 69.72%


[DiffKD] Epoch 179/240:   0%|          | 0/391 [00:00<?, ?it/s]

Epoch 179/240 | loss 12.786 | train 82.31% | test 69.00%


[DiffKD] Epoch 180/240:   0%|          | 0/391 [00:00<?, ?it/s]

Epoch 180/240 | loss 12.825 | train 82.55% | test 68.68%


[DiffKD] Epoch 181/240:   0%|          | 0/391 [00:00<?, ?it/s]

Epoch 181/240 | loss 12.716 | train 84.07% | test 69.81%


[DiffKD] Epoch 182/240:   0%|          | 0/391 [00:00<?, ?it/s]

Epoch 182/240 | loss 12.691 | train 84.61% | test 69.82%


[DiffKD] Epoch 183/240:   0%|          | 0/391 [00:00<?, ?it/s]

Epoch 183/240 | loss 12.693 | train 84.69% | test 69.77%


[DiffKD] Epoch 184/240:   0%|          | 0/391 [00:00<?, ?it/s]

Epoch 184/240 | loss 12.678 | train 84.71% | test 70.00%
  ↳ Saved best @ epoch 184 (test 70.00%) → rn20_from_rn56_DiffKD.pth


[DiffKD] Epoch 185/240:   0%|          | 0/391 [00:00<?, ?it/s]

Epoch 185/240 | loss 12.688 | train 84.75% | test 69.98%


[DiffKD] Epoch 186/240:   0%|          | 0/391 [00:00<?, ?it/s]

Epoch 186/240 | loss 12.684 | train 84.91% | test 70.12%
  ↳ Saved best @ epoch 186 (test 70.12%) → rn20_from_rn56_DiffKD.pth


[DiffKD] Epoch 187/240:   0%|          | 0/391 [00:00<?, ?it/s]

Epoch 187/240 | loss 12.700 | train 84.97% | test 70.14%
  ↳ Saved best @ epoch 187 (test 70.14%) → rn20_from_rn56_DiffKD.pth


[DiffKD] Epoch 188/240:   0%|          | 0/391 [00:00<?, ?it/s]

Epoch 188/240 | loss 12.698 | train 84.93% | test 70.25%
  ↳ Saved best @ epoch 188 (test 70.25%) → rn20_from_rn56_DiffKD.pth


[DiffKD] Epoch 189/240:   0%|          | 0/391 [00:00<?, ?it/s]

Epoch 189/240 | loss 12.704 | train 84.96% | test 70.06%


[DiffKD] Epoch 190/240:   0%|          | 0/391 [00:00<?, ?it/s]

Epoch 190/240 | loss 12.698 | train 85.07% | test 70.25%


[DiffKD] Epoch 191/240:   0%|          | 0/391 [00:00<?, ?it/s]

Epoch 191/240 | loss 12.703 | train 85.13% | test 70.17%


[DiffKD] Epoch 192/240:   0%|          | 0/391 [00:00<?, ?it/s]

Epoch 192/240 | loss 12.702 | train 85.42% | test 70.24%


[DiffKD] Epoch 193/240:   0%|          | 0/391 [00:00<?, ?it/s]

Epoch 193/240 | loss 12.711 | train 85.04% | test 69.98%


[DiffKD] Epoch 194/240:   0%|          | 0/391 [00:00<?, ?it/s]

Epoch 194/240 | loss 12.714 | train 85.01% | test 69.95%


[DiffKD] Epoch 195/240:   0%|          | 0/391 [00:00<?, ?it/s]

Epoch 195/240 | loss 12.717 | train 85.39% | test 70.12%


[DiffKD] Epoch 196/240:   0%|          | 0/391 [00:00<?, ?it/s]

Epoch 196/240 | loss 12.727 | train 85.20% | test 70.06%


[DiffKD] Epoch 197/240:   0%|          | 0/391 [00:00<?, ?it/s]

Epoch 197/240 | loss 12.734 | train 85.29% | test 70.08%


[DiffKD] Epoch 198/240:   0%|          | 0/391 [00:00<?, ?it/s]

Epoch 198/240 | loss 12.728 | train 85.25% | test 70.05%


[DiffKD] Epoch 199/240:   0%|          | 0/391 [00:00<?, ?it/s]

Epoch 199/240 | loss 12.723 | train 85.35% | test 70.09%


[DiffKD] Epoch 200/240:   0%|          | 0/391 [00:00<?, ?it/s]

Epoch 200/240 | loss 12.720 | train 85.59% | test 70.50%
  ↳ Saved best @ epoch 200 (test 70.50%) → rn20_from_rn56_DiffKD.pth


[DiffKD] Epoch 201/240:   0%|          | 0/391 [00:00<?, ?it/s]

Epoch 201/240 | loss 12.734 | train 85.38% | test 70.12%


[DiffKD] Epoch 202/240:   0%|          | 0/391 [00:00<?, ?it/s]

Epoch 202/240 | loss 12.739 | train 85.40% | test 70.12%


[DiffKD] Epoch 203/240:   0%|          | 0/391 [00:00<?, ?it/s]

Epoch 203/240 | loss 12.742 | train 85.47% | test 70.30%


[DiffKD] Epoch 204/240:   0%|          | 0/391 [00:00<?, ?it/s]

Epoch 204/240 | loss 12.733 | train 85.39% | test 70.29%


[DiffKD] Epoch 205/240:   0%|          | 0/391 [00:00<?, ?it/s]

Epoch 205/240 | loss 12.751 | train 85.48% | test 70.38%


[DiffKD] Epoch 206/240:   0%|          | 0/391 [00:00<?, ?it/s]

Epoch 206/240 | loss 12.747 | train 85.50% | test 70.11%


[DiffKD] Epoch 207/240:   0%|          | 0/391 [00:00<?, ?it/s]

Epoch 207/240 | loss 12.751 | train 85.76% | test 70.17%


[DiffKD] Epoch 208/240:   0%|          | 0/391 [00:00<?, ?it/s]

Epoch 208/240 | loss 12.755 | train 85.69% | test 69.99%


[DiffKD] Epoch 209/240:   0%|          | 0/391 [00:00<?, ?it/s]

Epoch 209/240 | loss 12.759 | train 85.50% | test 70.11%


[DiffKD] Epoch 210/240:   0%|          | 0/391 [00:00<?, ?it/s]

Epoch 210/240 | loss 12.759 | train 85.63% | test 69.93%


[DiffKD] Epoch 211/240:   0%|          | 0/391 [00:00<?, ?it/s]

Epoch 211/240 | loss 12.743 | train 85.69% | test 70.10%


[DiffKD] Epoch 212/240:   0%|          | 0/391 [00:00<?, ?it/s]

Epoch 212/240 | loss 12.747 | train 85.72% | test 70.04%


[DiffKD] Epoch 213/240:   0%|          | 0/391 [00:00<?, ?it/s]

Epoch 213/240 | loss 12.745 | train 86.00% | test 69.89%


[DiffKD] Epoch 214/240:   0%|          | 0/391 [00:00<?, ?it/s]

Epoch 214/240 | loss 12.752 | train 85.75% | test 70.05%


[DiffKD] Epoch 215/240:   0%|          | 0/391 [00:00<?, ?it/s]

Epoch 215/240 | loss 12.753 | train 85.76% | test 70.10%


[DiffKD] Epoch 216/240:   0%|          | 0/391 [00:00<?, ?it/s]

Epoch 216/240 | loss 12.753 | train 85.80% | test 70.11%


[DiffKD] Epoch 217/240:   0%|          | 0/391 [00:00<?, ?it/s]

Epoch 217/240 | loss 12.753 | train 85.70% | test 70.16%


[DiffKD] Epoch 218/240:   0%|          | 0/391 [00:00<?, ?it/s]

Epoch 218/240 | loss 12.745 | train 85.87% | test 70.09%


[DiffKD] Epoch 219/240:   0%|          | 0/391 [00:00<?, ?it/s]

Epoch 219/240 | loss 12.753 | train 85.72% | test 70.22%


[DiffKD] Epoch 220/240:   0%|          | 0/391 [00:00<?, ?it/s]

Epoch 220/240 | loss 12.751 | train 85.93% | test 70.33%


[DiffKD] Epoch 221/240:   0%|          | 0/391 [00:00<?, ?it/s]

Epoch 221/240 | loss 12.747 | train 85.88% | test 70.03%


[DiffKD] Epoch 222/240:   0%|          | 0/391 [00:00<?, ?it/s]

Epoch 222/240 | loss 12.754 | train 85.80% | test 70.11%


[DiffKD] Epoch 223/240:   0%|          | 0/391 [00:00<?, ?it/s]

Epoch 223/240 | loss 12.747 | train 85.81% | test 70.11%


[DiffKD] Epoch 224/240:   0%|          | 0/391 [00:00<?, ?it/s]

Epoch 224/240 | loss 12.754 | train 85.80% | test 70.07%


[DiffKD] Epoch 225/240:   0%|          | 0/391 [00:00<?, ?it/s]

Epoch 225/240 | loss 12.746 | train 86.02% | test 70.12%


[DiffKD] Epoch 226/240:   0%|          | 0/391 [00:00<?, ?it/s]

Epoch 226/240 | loss 12.749 | train 85.87% | test 70.09%


[DiffKD] Epoch 227/240:   0%|          | 0/391 [00:00<?, ?it/s]

Epoch 227/240 | loss 12.752 | train 85.87% | test 70.03%


[DiffKD] Epoch 228/240:   0%|          | 0/391 [00:00<?, ?it/s]

Epoch 228/240 | loss 12.757 | train 85.74% | test 70.28%


[DiffKD] Epoch 229/240:   0%|          | 0/391 [00:00<?, ?it/s]

Epoch 229/240 | loss 12.757 | train 85.93% | test 70.21%


[DiffKD] Epoch 230/240:   0%|          | 0/391 [00:00<?, ?it/s]

Epoch 230/240 | loss 12.752 | train 85.81% | test 70.08%


[DiffKD] Epoch 231/240:   0%|          | 0/391 [00:00<?, ?it/s]

Epoch 231/240 | loss 12.751 | train 85.69% | test 70.29%


[DiffKD] Epoch 232/240:   0%|          | 0/391 [00:00<?, ?it/s]

Epoch 232/240 | loss 12.749 | train 85.80% | test 70.21%


[DiffKD] Epoch 233/240:   0%|          | 0/391 [00:00<?, ?it/s]

Epoch 233/240 | loss 12.756 | train 85.69% | test 70.26%


[DiffKD] Epoch 234/240:   0%|          | 0/391 [00:00<?, ?it/s]

Epoch 234/240 | loss 12.762 | train 85.88% | test 69.96%


[DiffKD] Epoch 235/240:   0%|          | 0/391 [00:00<?, ?it/s]

Epoch 235/240 | loss 12.760 | train 85.83% | test 70.12%


[DiffKD] Epoch 236/240:   0%|          | 0/391 [00:00<?, ?it/s]

Epoch 236/240 | loss 12.756 | train 86.07% | test 69.98%


[DiffKD] Epoch 237/240:   0%|          | 0/391 [00:00<?, ?it/s]

Epoch 237/240 | loss 12.754 | train 85.97% | test 70.01%


[DiffKD] Epoch 238/240:   0%|          | 0/391 [00:00<?, ?it/s]

Epoch 238/240 | loss 12.757 | train 85.98% | test 69.96%


[DiffKD] Epoch 239/240:   0%|          | 0/391 [00:00<?, ?it/s]

Epoch 239/240 | loss 12.765 | train 85.81% | test 70.10%


[DiffKD] Epoch 240/240:   0%|          | 0/391 [00:00<?, ?it/s]

Epoch 240/240 | loss 12.756 | train 85.84% | test 70.17%
✅ DiffKD finished. Best Test Acc: 70.50%


70.5