In [None]:
# ---------- SETUP ----------
!pip install datasets timm > /dev/null
import torch, torch.nn as nn, torch.nn.functional as F
from torchvision import models, transforms
from torch.utils.data import DataLoader
from datasets import load_dataset
import numpy as np, random, shutil, os
from tqdm import tqdm

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Device:", device)

# ---------- CONFIG ----------
BATCH_SIZE = 128
EPOCHS = 45
LR = 1e-3
MOMENTUM = 0.9
WEIGHT_DECAY = 5e-4
TAU = 0.07
GAMMA, BETA = 1.0, 0.8
NUM_NEGS = 2048

# ---------- LOAD DATA ----------
dataset = load_dataset("Maysee/tiny-imagenet")

def subsample_split(split, keep=10000):
    n = len(dataset[split])
    idx = random.sample(range(n), min(keep, n))
    return dataset[split].select(idx)

train_hf = subsample_split("train", keep=9000)
val_hf   = subsample_split("valid", keep=1000)
print(f"Using {len(train_hf)} train / {len(val_hf)} val samples")

# ---------- TRANSFORMS ----------
transform_train = transforms.Compose([
    transforms.Lambda(lambda img: img.convert("RGB")),
    transforms.RandomResizedCrop(224),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize((0.485,0.456,0.406), (0.229,0.224,0.225))
])
transform_val = transforms.Compose([
    transforms.Lambda(lambda img: img.convert("RGB")),
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize((0.485,0.456,0.406), (0.229,0.224,0.225))
])

def transform_example(example, transform):
    example["pixel_values"] = transform(example["image"])
    return example

train_hf = train_hf.map(lambda e: transform_example(e, transform_train))
val_hf   = val_hf.map(lambda e: transform_example(e, transform_val))

class HFDataset(torch.utils.data.Dataset):
    def __init__(self, hf_dataset):
        self.ds = hf_dataset
        self.normalize = transforms.Normalize(
            mean=[0.485, 0.456, 0.406],
            std=[0.229, 0.224, 0.225]
        )

    def __len__(self):
        return len(self.ds)

    def __getitem__(self, idx):
        item = self.ds[idx]
        y = int(item["label"])

        # --- 1. Handle multiple possible formats ---
        if "pixel_values" in item:
            x = item["pixel_values"]
        elif "image" in item:
            x = item["image"]
        else:
            raise KeyError("Expected 'image' or 'pixel_values' in dataset example.")

        # --- 2. Convert to torch tensor ---
        if isinstance(x, torch.Tensor):
            x = x.float()
        elif isinstance(x, np.ndarray):
            x = torch.from_numpy(x).float()
        elif isinstance(x, list):
            x = torch.tensor(np.array(x), dtype=torch.float32)
        elif hasattr(x, "convert"):  # PIL Image
            x = transforms.functional.pil_to_tensor(x).float()
        else:
            raise TypeError(f"Unsupported image type: {type(x)}")

        # --- 3. Fix dimensions ---
        if x.ndim == 3 and x.shape[-1] == 3:      # (H, W, 3)
            x = x.permute(2, 0, 1)
        elif x.ndim == 2:                         # (H, W)
            x = x.unsqueeze(0).repeat(3, 1, 1)
        elif x.ndim == 3 and x.shape[0] == 1:     # (1, H, W)
            x = x.repeat(3, 1, 1)
        elif x.ndim != 3 or x.shape[0] != 3:
            raise ValueError(f"Unexpected tensor shape {x.shape}")

        # --- 4. Normalize safely ---
        if x.max() > 1.0:
            x = x / 255.0
        x = self.normalize(x)

        return x.contiguous(), y
train_loader = DataLoader(HFDataset(train_hf), batch_size=BATCH_SIZE,
                          shuffle=True, num_workers=2, pin_memory=True)
val_loader   = DataLoader(HFDataset(val_hf), batch_size=BATCH_SIZE,
                          shuffle=False, num_workers=2, pin_memory=True)
print(f"Train={len(train_hf)}, Val={len(val_hf)} loaded")

# ---------- MODELS ----------
num_classes = len(set(train_hf["label"]))
teacher = models.vgg16(weights='IMAGENET1K_V1')
teacher.classifier[6] = nn.Linear(4096, num_classes)
teacher = teacher.to(device)
teacher.eval()

student = models.vgg11()
student.classifier[6] = nn.Linear(4096, num_classes)
student = student.to(device)
print("Teacher & Student ready:", num_classes)


Device: cuda
Using 9000 train / 1000 val samples


Map:   0%|          | 0/9000 [00:00<?, ? examples/s]

Map:   0%|          | 0/1000 [00:00<?, ? examples/s]

Train=9000, Val=1000 loaded
Teacher & Student ready: 200


In [None]:
# ------------------ GPU SAFETY PRESET ------------------
BATCH_SIZE = 32          # <—— biggest Colab-safe size for two VGGs
NUM_NEGS = 512           # 2048 is too large for 15 GB VRAM
torch.cuda.empty_cache()
torch.backends.cudnn.benchmark = True
torch.backends.cuda.matmul.allow_tf32 = True

# ------------------ MEMORY QUEUE ------------------
class MemoryQueue:
    def __init__(self, dim, size=NUM_NEGS):
        self.queue = F.normalize(torch.randn(size, dim), dim=1)
        self.ptr = 0
        self.size = size

    @torch.no_grad()
    def enqueue(self, feats):
        feats = feats.detach().cpu()
        n = feats.size(0)
        if n >= self.size:
            # keep only the most recent items
            self.queue[:] = feats[-self.size:]
            self.ptr = 0
            return
        end = self.ptr + n
        if end <= self.size:
            self.queue[self.ptr:end] = feats
        else:
            first = self.size - self.ptr
            self.queue[self.ptr:] = feats[:first]
            self.queue[:end - self.size] = feats[first:]
        self.ptr = (self.ptr + n) % self.size

    def get_negatives(self):
        return self.queue.to(device, non_blocking=True)


# ------------------ CRD LOSS ------------------
def crd_loss(f_s, f_t, mem_queue, tau=TAU):
    f_s = F.normalize(f_s, dim=1)
    f_t = F.normalize(f_t, dim=1)
    negs = mem_queue.get_negatives()          # [N, d]
    pos = torch.exp(torch.sum(f_t * f_s, dim=1) / tau)
    neg = torch.exp((f_t @ negs.T) / tau).sum(dim=1)
    return -torch.log(pos / (pos + neg)).mean()

# ------------------ TRAIN ONE EPOCH ------------------
def train_crd(student, teacher, loader, optimizer, mem_queue, scaler):
    student.train(); teacher.eval()
    ce_total = crd_total = correct = total = 0

    for x, y in tqdm(loader, leave=False):
        x, y = x.to(device, non_blocking=True), y.to(device, non_blocking=True)

        with torch.no_grad(), torch.cuda.amp.autocast(dtype=torch.float16):
            ft = teacher.features(x)
            ft = F.adaptive_avg_pool2d(ft, 1).flatten(1)   # [B,512]

        optimizer.zero_grad(set_to_none=True)

        with torch.cuda.amp.autocast(dtype=torch.float16):
            fs_full = student.features(x)                  # [B,512,7,7]
            fs_flat = fs_full.flatten(1)                   # [B,25088]
            s_logits = student.classifier(fs_flat)
            fs = F.adaptive_avg_pool2d(fs_full, 1).flatten(1)  # [B,512]

            ce = F.cross_entropy(s_logits, y)
            l_crd = crd_loss(fs, ft, mem_queue)
            loss = ce + GAMMA * BETA * l_crd

        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()
        mem_queue.enqueue(ft)

        with torch.no_grad():
            _, pred = s_logits.max(1)
            correct += pred.eq(y).sum().item()
            total += y.size(0)
            ce_total += ce.item() * y.size(0)
            crd_total += l_crd.item() * y.size(0)

        # ---- free GPU each batch ----
        del x, y, ft, fs_full, fs_flat, fs, s_logits, loss
        torch.cuda.empty_cache()

    return ce_total/total, crd_total/total, 100*correct/total

# ------------------ MAIN LOOP ------------------
def train_loop(student, teacher):
    opt = torch.optim.SGD(student.parameters(), lr=LR,
                          momentum=MOMENTUM, weight_decay=WEIGHT_DECAY)
    scaler = torch.amp.GradScaler('cuda')

    mem_queue = MemoryQueue(dim=512)

    for ep in range(1, EPOCHS+1):
        ce, crd, acc = train_crd(student, teacher, train_loader, opt, mem_queue, scaler)
        print(f"[Epoch {ep:02d}] CE={ce:.4f} | CRD={crd:.4f} | Acc={acc:.2f}%")
        if ep % 5 == 0:
            torch.save(student.state_dict(), f"student_crd_ep{ep}.pth")
            torch.cuda.empty_cache()

    return student


In [None]:
student_crd = train_loop(student, teacher)


  with torch.no_grad(), torch.cuda.amp.autocast(dtype=torch.float16):
  with torch.cuda.amp.autocast(dtype=torch.float16):


[Epoch 01] CE=5.3082 | CRD=inf | Acc=0.54%




[Epoch 02] CE=5.3083 | CRD=inf | Acc=0.60%




[Epoch 03] CE=5.3093 | CRD=inf | Acc=0.57%




[Epoch 04] CE=5.3097 | CRD=inf | Acc=0.53%




[Epoch 05] CE=nan | CRD=nan | Acc=0.44%




[Epoch 06] CE=nan | CRD=nan | Acc=0.46%




[Epoch 07] CE=nan | CRD=nan | Acc=0.46%




[Epoch 08] CE=nan | CRD=nan | Acc=0.46%




[Epoch 09] CE=nan | CRD=nan | Acc=0.46%




[Epoch 10] CE=nan | CRD=nan | Acc=0.46%




[Epoch 11] CE=nan | CRD=nan | Acc=0.46%




[Epoch 12] CE=nan | CRD=nan | Acc=0.46%




[Epoch 13] CE=nan | CRD=nan | Acc=0.46%




[Epoch 14] CE=nan | CRD=nan | Acc=0.46%




[Epoch 15] CE=nan | CRD=nan | Acc=0.46%


 20%|█▉        | 14/71 [02:18<07:46,  8.18s/it]

In [None]:
from scipy.spatial.distance import cosine, wasserstein_distance
from scipy.stats import entropy

def topk_acc(output, target, topk=(1,5)):
    maxk = max(topk)
    _, pred = output.topk(maxk, dim=1)
    pred = pred.t()
    correct = pred.eq(target.view(1,-1).expand_as(pred))
    res = []
    for k in topk:
        correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True)
        res.append(correct_k.mul_(100.0 / target.size(0)).item())
    return res

def evaluate(model, loader):
    model.eval()
    top1_sum=top5_sum=0; total=0
    with torch.no_grad():
        for x,y in loader:
            x,y = x.to(device), y.to(device)
            out = model(x)
            t1,t5 = topk_acc(out,y)
            top1_sum+=t1*y.size(0); top5_sum+=t5*y.size(0); total+=y.size(0)
    return top1_sum/total, top5_sum/total

def compare_distributions(teacher, student, loader, n_batches=20):
    KLs, JSDs, WDs, CSs = [],[],[],[]
    with torch.no_grad():
        for i,(x,_) in enumerate(loader):
            if i>=n_batches: break
            x = x.to(device)
            pt = F.softmax(teacher(x),dim=1).cpu().numpy()
            ps = F.softmax(student(x),dim=1).cpu().numpy()
            for a,b in zip(pt,ps):
                KLs.append(entropy(a,b))
                M = 0.5*(a+b)
                JSDs.append(0.5*entropy(a,M)+0.5*entropy(b,M))
                WDs.append(wasserstein_distance(a,b))
                CSs.append(1 - cosine(a,b))
    return np.mean(KLs), np.mean(JSDs), np.mean(WDs), np.mean(CSs)

# ---------- RUN EVAL ----------
top1, top5 = evaluate(student_crd, val_loader)
print(f"Val Top1={top1:.2f}% | Top5={top5:.2f}%")

kl, jsd, wd, cs = compare_distributions(teacher, student_crd, val_loader)
print(f"KL={kl:.4f}, JSD={jsd:.4f}, W={wd:.4f}, CosSim={cs:.4f}")


In [None]:
from torchvision.models.feature_extraction import create_feature_extractor
import matplotlib.pyplot as plt

# ---------- Grad-CAM ----------
def gradcam(model, x, target_layer='features.28'):
    model.eval()
    extractor = create_feature_extractor(model, {target_layer: 'feat'})
    x = x.unsqueeze(0).to(device).requires_grad_(True)
    out = extractor(x)
    logits = model.classifier(out['feat'].mean([2,3]))
    pred = logits.argmax(1)
    loss = logits[0, pred]
    loss.backward()
    grad = x.grad[0].mean(0).cpu()
    return grad.abs() / grad.abs().max()

def compare_cam(teacher, student, loader, n=50):
    sims=[]
    for i,(x,_) in enumerate(loader):
        if i>=n: break
        x=x[0]
        cam_t = gradcam(teacher, x)
        cam_s = gradcam(student, x)
        sims.append(F.cosine_similarity(cam_t.flatten(), cam_s.flatten(), dim=0).item())
    return np.mean(sims)

# ---------- Color-Invariance ----------
color_jitter = transforms.ColorJitter(0.4,0.4,0.4,0.2)
def color_eval(model, loader):
    total, corr_norm, corr_jit = 0,0,0
    model.eval()
    with torch.no_grad():
        for x,y in loader:
            x,y = x.to(device), y.to(device)
            x_j = torch.stack([color_jitter(xx.cpu()) for xx in x.cpu()]).to(device)
            out_n, out_j = model(x), model(x_j)
            pred_n, pred_j = out_n.argmax(1), out_j.argmax(1)
            corr_norm += pred_n.eq(y).sum().item()
            corr_jit += pred_j.eq(y).sum().item()
            total += y.size(0)
    acc_n, acc_j = corr_norm/total*100, corr_jit/total*100
    return acc_n, acc_j, abs(acc_n-acc_j)

acc_n, acc_j, gap = color_eval(student_crd, val_loader)
sim_cam = compare_cam(teacher, student_crd, val_loader)
print(f"Color Invariance Gap Δinv={gap:.2f}% | GradCAM agreement={sim_cam:.3f}")
