In [1]:
#@title Imports
import torch, torch.nn as nn, torch.nn.functional as F, torchvision
from torch.utils.data import Subset, DataLoader
from torchvision import transforms, datasets, models
import numpy as np, matplotlib.pyplot as plt

In [2]:
#@title Device
device = 'cuda' if torch.cuda.is_available() else 'cpu'

In [3]:
# @title Data (CIFAR-10)

from torchvision import datasets, transforms
from torch.utils.data import DataLoader, Subset
import numpy as np

tfm = transforms.Compose([
    transforms.Resize(224),                     # for ResNet50
    transforms.ToTensor(),
    transforms.Normalize(
        mean=[0.485, 0.456, 0.406],              # ImageNet normalization
        std=[0.229, 0.224, 0.225]
    )
])

train_full = datasets.CIFAR10(
    root="./data",
    train=True,
    download=True,
    transform=tfm
)

valset = datasets.CIFAR10(
    root="./data",
    train=False,
    download=True,
    transform=tfm
)

# Use only 10% of training data
n = int(0.5 * len(train_full))
subset_idx = np.random.permutation(len(train_full))[:n]
trainset = Subset(train_full, subset_idx)



100%|██████████| 170M/170M [00:05<00:00, 29.3MB/s]


In [5]:
import torch.nn as nn
from torchvision import models

teacher = models.resnet50(
    weights=models.ResNet50_Weights.IMAGENET1K_V2
)

teacher.fc = nn.Linear(teacher.fc.in_features, 10)
teacher.to(device)


ResNet(
  (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): Bottleneck(
      (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (downsample): Sequential(
        (0): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 

In [6]:
for param in teacher.parameters():
    param.requires_grad = False

# Unfreeze layer4 + fc
for param in teacher.layer4.parameters():
    param.requires_grad = True



In [7]:
from torchvision import transforms
from torch.utils.data import DataLoader

train_loader = DataLoader(
    trainset, batch_size=64, shuffle=True
)

val_loader = DataLoader(
    valset, batch_size=64, shuffle=False
)



In [8]:
criterion = nn.CrossEntropyLoss()

optimizer = torch.optim.Adam(
    teacher.fc.parameters(),
    lr=1e-3
)

In [None]:
from tqdm import tqdm

teacher.train()

running_loss = 0.0
correct = 0
total = 0

pbar = tqdm(train_loader, desc="Training", leave=False)

for images, labels in pbar:
    images, labels = images.to(device), labels.to(device)

    outputs = teacher(images)
    loss = criterion(outputs, labels)

    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    # stats
    running_loss += loss.item() * labels.size(0)
    _, preds = outputs.max(1)
    correct += (preds == labels).sum().item()
    total += labels.size(0)

    # update tqdm bar
    pbar.set_postfix({
        "loss": running_loss / total,
        "acc": correct / total
    })

epoch_loss = running_loss / total
epoch_acc = correct / total

print(f"Train Loss: {epoch_loss:.4f} | Train Accuracy: {epoch_acc:.4f}")


Training:   0%|          | 0/391 [00:00<?, ?it/s]

In [30]:
teacher.eval()
correct, total = 0, 0

with torch.no_grad():
    for images, labels in val_loader:
        images, labels = images.to(device), labels.to(device)
        outputs = teacher(images)
        _, preds = outputs.max(1)
        correct += (preds == labels).sum().item()
        total += labels.size(0)

val_acc = correct / total
print("Val Accuracy:", val_acc)


Val Accuracy: 0.8702


## Student


In [52]:
#@title Variables
BATCH_SIZE = 64
ATTENTION_HEADS = 8
TRANSFORMER_LAYERS = 4
EMBED_DIM = 64
IMG_SIZE = 32
PATCH_SIZE = 4
CLASSES = 10
EPOCHS_STUDENT = 20
LR_STUDENT = 1e-4
TEMPERATURE = 2
ALPHA = 0.7
CHANNELS = 3

In [49]:
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(
        mean=(0.4914, 0.4822, 0.4465),
        std=(0.2023, 0.1994, 0.2010)
    )
])

trainset = torchvision.datasets.CIFAR10(
    root="./data",
    train=True,
    download=True,
    transform=transform
)

test1 = torchvision.datasets.CIFAR10(
    root="./data",
    train=False,
    download=True,
    transform=transform
)

n = int(1 * len(trainset))
subset_idx = np.random.permutation(len(trainset))[:n]
train = Subset(trainset, subset_idx)

train_dl = DataLoader(train, batch_size=BATCH_SIZE, shuffle=True)
test_dl = DataLoader(test1, batch_size=BATCH_SIZE)


In [54]:
#@title Student ViT
class PatchEmbed(nn.Module):
    def __init__(self, img_size=IMG_SIZE, patch=PATCH_SIZE, dim=EMBED_DIM, channels = CHANNELS):
        super().__init__()
        self.proj=nn.Conv2d(CHANNELS,dim,patch,patch)
        self.n=(img_size//patch)**2
    def forward(self,x):
        x=self.proj(x).flatten(2).transpose(1,2)
        return x


class ViT(nn.Module):
    def __init__(self,num_classes=CLASSES,dim=EMBED_DIM,depth=TRANSFORMER_LAYERS,heads=ATTENTION_HEADS):
        super().__init__()
        self.patch=PatchEmbed()
        n=self.patch.n
        self.cls=nn.Parameter(torch.zeros(1,1,dim))
        self.distill=nn.Parameter(torch.zeros(1,1,dim))
        self.pos=nn.Parameter(torch.zeros(1,n+2,dim))
        self.blocks = nn.Sequential(*[
            nn.TransformerEncoderLayer(dim, heads, dim*4, batch_first=True)
            for _ in range(depth)
        ])
        self.norm=nn.LayerNorm(dim)
        self.head_cls=nn.Linear(dim,num_classes)
        self.head_dist=nn.Linear(dim,num_classes)

    def forward(self,x):
        B=x.size(0)
        x=self.patch(x)
        cls=self.cls.expand(B,-1,-1)
        dist=self.distill.expand(B,-1,-1)
        x=torch.cat([cls,x,dist],1)+self.pos
        x = self.blocks(x)
        x=self.norm(x)
        cls_out=x[:,0]; dist_out=x[:,-1]
        cls_logits=self.head_cls(cls_out)
        dist_logits=self.head_dist(dist_out)
        return cls_logits, dist_logits

student=ViT().to(device)
opt_s=torch.optim.AdamW(student.parameters(),lr=LR_STUDENT)

In [55]:
from tqdm import tqdm
import torch.nn.functional as F


def accuracy_from_logits(logits, targets):
    preds = logits.argmax(dim=1)
    return (preds == targets).float().mean().item()



# @title Train Student with Hard Distillation
def kd_loss(cls_logits, dist_logits, t_logits, y, T=TEMPERATURE, alpha=ALPHA):
    # CE loss: class token vs ground truth
    ce = F.cross_entropy(cls_logits, y)

    # KL loss: distillation token vs teacher
    kd = F.kl_div(
        F.log_softmax(dist_logits / T, dim=1),
        F.softmax(t_logits / T, dim=1),
        reduction='batchmean'
    ) * (T * T)

    return alpha * ce + (1 - alpha) * kd



print("Training student...")

for e in range(EPOCHS_STUDENT):
    # ---------- TRAIN ----------
    student.train()
    correct_train = 0
    total_train = 0

    for x, y in tqdm(train_dl, desc=f"Epoch {e+1} [Train]", leave=False):
        x, y = x.to(device), y.to(device)

        with torch.no_grad():
            t_logits = teacher(x)

        cls_logits, dist_logits = student(x)

        loss = kd_loss(
            cls_logits=cls_logits,
            dist_logits=dist_logits,
            t_logits=t_logits,
            y=y
        )

        opt_s.zero_grad()
        loss.backward()
        opt_s.step()

        # Train accuracy (class token only)
        preds = cls_logits.argmax(dim=1)
        correct_train += (preds == y).sum().item()
        total_train += y.size(0)

    train_acc = correct_train / total_train

    # ---------- VALIDATION ----------
    student.eval()
    correct_val = 0
    total_val = 0

    with torch.no_grad():
        for x, y in tqdm(test_dl, desc=f"Epoch {e+1} [Val]", leave=False):
            x, y = x.to(device), y.to(device)

            cls_logits, _ = student(x)

            preds = cls_logits.argmax(dim=1)
            correct_val += (preds == y).sum().item()
            total_val += y.size(0)

    val_acc = correct_val / total_val

    # ---------- PRINT ----------
    print(
        f"Epoch [{e+1}/{EPOCHS_STUDENT}] | "
        f"Train Acc: {train_acc:.4f} | Val Acc: {val_acc:.4f}"
    )

Training student...




Epoch [1/20] | Train Acc: 0.2752 | Val Acc: 0.3535




Epoch [2/20] | Train Acc: 0.3782 | Val Acc: 0.4059




Epoch [3/20] | Train Acc: 0.4390 | Val Acc: 0.4738




Epoch [4/20] | Train Acc: 0.4874 | Val Acc: 0.5101




Epoch [5/20] | Train Acc: 0.5166 | Val Acc: 0.5357




Epoch [6/20] | Train Acc: 0.5416 | Val Acc: 0.5526




Epoch [7/20] | Train Acc: 0.5569 | Val Acc: 0.5690




Epoch [8/20] | Train Acc: 0.5709 | Val Acc: 0.5745




Epoch [9/20] | Train Acc: 0.5831 | Val Acc: 0.5948




Epoch [10/20] | Train Acc: 0.5940 | Val Acc: 0.5990




Epoch [11/20] | Train Acc: 0.6045 | Val Acc: 0.6063




Epoch [12/20] | Train Acc: 0.6133 | Val Acc: 0.6141




Epoch [13/20] | Train Acc: 0.6232 | Val Acc: 0.6177




Epoch [14/20] | Train Acc: 0.6318 | Val Acc: 0.6180




Epoch [15/20] | Train Acc: 0.6380 | Val Acc: 0.6252




Epoch [16/20] | Train Acc: 0.6442 | Val Acc: 0.6321




Epoch [17/20] | Train Acc: 0.6484 | Val Acc: 0.6370




Epoch [18/20] | Train Acc: 0.6573 | Val Acc: 0.6475




Epoch [19/20] | Train Acc: 0.6604 | Val Acc: 0.6518


                                                                 

Epoch [20/20] | Train Acc: 0.6659 | Val Acc: 0.6469


