In [23]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
from tqdm import tqdm

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)


Using device: cuda


In [24]:
transform = transforms.Compose([
    transforms.ToTensor()  # no normalization needed for MNIST MLP
])


In [25]:
trainset = datasets.MNIST(root="./data", train=True, download=True, transform=transform)
testset  = datasets.MNIST(root="./data", train=False, download=True, transform=transform)

trainloader = DataLoader(trainset, batch_size=128, shuffle=True)
testloader  = DataLoader(testset, batch_size=256, shuffle=False)



In [26]:
class Teacher(nn.Module):
    def __init__(self):
        super().__init__()
        self.fc1 = nn.Linear(28*28, 512)
        self.fc2 = nn.Linear(512, 512)
        self.fc3 = nn.Linear(512, 10)

    def forward(self, x):
        x = x.view(x.size(0), -1)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        return self.fc3(x)

class SmallStudent(nn.Module):
    def __init__(self):
        super().__init__()
        # single linear layer → guaranteed weak student
        self.fc = nn.Linear(28*28, 10)

    def forward(self, x):
        x = x.view(x.size(0), -1)
        return self.fc(x)


In [27]:
teacher = Teacher().to(device)
student = SmallStudent().to(device)


In [28]:
def evaluate(model):
    model.eval()   # << IMPORTANT
    correct = 0
    total = 0
    with torch.no_grad():
        for x, y in testloader:
            x, y = x.to(device), y.to(device)
            out = model(x)
            pred = out.argmax(1)
            correct += (pred == y).sum().item()
            total += y.size(0)
    return 100 * correct / total


In [29]:
opt_t = optim.Adam(teacher.parameters(), lr=1e-3)
criterion = nn.CrossEntropyLoss()

print("Training Teacher...")
for epoch in range(5):
    teacher.train()
    for x, y in tqdm(trainloader):
        x, y = x.to(device), y.to(device)
        out = teacher(x)
        loss = criterion(out, y)

        opt_t.zero_grad()
        loss.backward()
        opt_t.step()

    acc = evaluate(teacher)
    print(f"Teacher Epoch {epoch}: Test Accuracy = {acc:.2f}%")


Training Teacher...


100%|██████████| 469/469 [00:06<00:00, 69.90it/s]


Teacher Epoch 0: Test Accuracy = 96.61%


100%|██████████| 469/469 [00:07<00:00, 60.30it/s]


Teacher Epoch 1: Test Accuracy = 97.20%


100%|██████████| 469/469 [00:07<00:00, 64.57it/s]


Teacher Epoch 2: Test Accuracy = 98.01%


100%|██████████| 469/469 [00:07<00:00, 62.00it/s]


Teacher Epoch 3: Test Accuracy = 97.82%


100%|██████████| 469/469 [00:06<00:00, 71.36it/s]


Teacher Epoch 4: Test Accuracy = 97.65%


In [30]:
def kd_loss(student_logits, teacher_logits, labels, T=4, alpha=0.9):
    soft_teacher = F.softmax(teacher_logits / T, dim=1)
    soft_student = F.log_softmax(student_logits / T, dim=1)

    KD = F.kl_div(soft_student, soft_teacher, reduction="batchmean") * (T*T)
    CE = F.cross_entropy(student_logits, labels)

    return alpha * CE + (1 - alpha) * KD


In [33]:
opt_s = optim.Adam(student.parameters(), lr=5e-4)

print("Training Student with KD...")
for epoch in range(4):
    student.train()
    for x, y in tqdm(trainloader):
        x, y = x.to(device), y.to(device)

        with torch.no_grad():
            t_logits = teacher(x)

        s_logits = student(x)
        loss = kd_loss(s_logits, t_logits, y, T=4, alpha=0.9)

        opt_s.zero_grad()
        loss.backward()
        opt_s.step()

    acc = evaluate(student)
    print(f"Student Epoch {epoch}: Test Accuracy = {acc:.2f}%")


Training Student with KD...


100%|██████████| 469/469 [00:07<00:00, 63.07it/s]


Student Epoch 0: Test Accuracy = 91.17%


100%|██████████| 469/469 [00:06<00:00, 71.10it/s]


Student Epoch 1: Test Accuracy = 91.35%


100%|██████████| 469/469 [00:07<00:00, 64.89it/s]


Student Epoch 2: Test Accuracy = 91.45%


100%|██████████| 469/469 [00:07<00:00, 64.72it/s]


Student Epoch 3: Test Accuracy = 91.55%


In [34]:
print("======================================")
print("Final Teacher Accuracy :", evaluate(teacher))
print("Final Student Accuracy :", evaluate(student))
print("======================================")


Final Teacher Accuracy : 97.65
Final Student Accuracy : 91.55
