<a href="https://colab.research.google.com/github/SAIROHITHARETI/LLM-FineTuning/blob/main/Distillation.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using:", device)

# ------------------------
# 1️⃣ Load MNIST
# ------------------------
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))])
train_data = datasets.MNIST(root="./data", train=True, download=True, transform=transform)
test_data  = datasets.MNIST(root="./data", train=False, download=True, transform=transform)

train_loader = DataLoader(train_data, batch_size=128, shuffle=True)
test_loader  = DataLoader(test_data,  batch_size=256, shuffle=False)

# ------------------------
# 2️⃣ Define Teacher and Student MLPs
# ------------------------
class TeacherMLP(nn.Module):
    def __init__(self):
        super().__init__()
        self.net = nn.Sequential(
            nn.Flatten(),
            nn.Linear(28*28, 512),
            nn.ReLU(),
            nn.Linear(512, 256),
            nn.ReLU(),
            nn.Linear(256, 10)
        )
    def forward(self, x):
        return self.net(x)

class StudentMLP(nn.Module):
    def __init__(self):
        super().__init__()
        self.net = nn.Sequential(
            nn.Flatten(),
            nn.Linear(28*28, 128),
            nn.ReLU(),
            nn.Linear(128, 10)
        )
    def forward(self, x):
        return self.net(x)

# ------------------------
# 3️⃣ Helper: Evaluate accuracy
# ------------------------
def accuracy(model, loader):
    model.eval()
    correct, total = 0, 0
    with torch.no_grad():
        for x, y in loader:
            x, y = x.to(device), y.to(device)
            preds = model(x).argmax(1)
            total += y.size(0)
            correct += (preds == y).sum().item()
    return 100 * correct / total

# ------------------------
# 4️⃣ Train Teacher Normally
# ------------------------
def train_teacher(model, epochs=3, lr=1e-3):
    model.to(device)
    opt = optim.Adam(model.parameters(), lr=lr)
    loss_fn = nn.CrossEntropyLoss()
    for e in range(epochs):
        model.train()
        total_loss = 0
        for x, y in train_loader:
            x, y = x.to(device), y.to(device)
            opt.zero_grad()
            out = model(x)
            loss = loss_fn(out, y)
            loss.backward()
            opt.step()
            total_loss += loss.item()
        print(f"Teacher Epoch {e+1}: Loss={total_loss/len(train_loader):.4f}, Acc={accuracy(model, test_loader):.2f}%")

# ------------------------
# 5️⃣ Distillation Training for Student
# ------------------------
def distill_train(student, teacher, epochs=5, lr=1e-3, alpha=0.5, T=3.0):
    student.to(device)
    teacher.to(device).eval()  # freeze teacher

    opt = optim.Adam(student.parameters(), lr=lr)
    ce_loss = nn.CrossEntropyLoss()
    kl_loss = nn.KLDivLoss(reduction="batchmean")

    for e in range(epochs):
        student.train()
        total_loss = 0
        for x, y in train_loader:
            x, y = x.to(device), y.to(device)
            with torch.no_grad():
                t_logits = teacher(x)

            s_logits = student(x)

            # Hard (ground-truth) loss
            loss_hard = ce_loss(s_logits, y)

            # Soft (teacher-student) loss
            p_teacher = F.softmax(t_logits / T, dim=1)
            p_student = F.log_softmax(s_logits / T, dim=1)
            loss_soft = kl_loss(p_student, p_teacher) * (T * T)

            # Combine
            loss = alpha * loss_hard + (1 - alpha) * loss_soft

            opt.zero_grad()
            loss.backward()
            opt.step()
            total_loss += loss.item()
        print(f"Student Epoch {e+1}: Loss={total_loss/len(train_loader):.4f}, Acc={accuracy(student, test_loader):.2f}%")

# ------------------------
# 6️⃣ Run training
# ------------------------
teacher = TeacherMLP()
train_teacher(teacher, epochs=3)

student = StudentMLP()
distill_train(student, teacher, epochs=5)

print("✅ Final Student Accuracy:", accuracy(student, test_loader))


Using: cpu


100%|██████████| 9.91M/9.91M [00:00<00:00, 22.2MB/s]
100%|██████████| 28.9k/28.9k [00:00<00:00, 604kB/s]
100%|██████████| 1.65M/1.65M [00:00<00:00, 5.58MB/s]
100%|██████████| 4.54k/4.54k [00:00<00:00, 9.97MB/s]


Teacher Epoch 1: Loss=0.3326, Acc=94.59%
Teacher Epoch 2: Loss=0.1433, Acc=96.00%
Teacher Epoch 3: Loss=0.1025, Acc=96.43%
Student Epoch 1: Loss=1.3285, Acc=91.37%
Student Epoch 2: Loss=0.4559, Acc=94.34%
Student Epoch 3: Loss=0.2412, Acc=95.44%
Student Epoch 4: Loss=0.1663, Acc=95.88%
Student Epoch 5: Loss=0.1333, Acc=96.31%
✅ Final Student Accuracy: 96.31
