In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader

In [2]:
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])

In [3]:
trainingData = datasets.MNIST(root = './data', train = True, transform=transform, download=True)
testingData = datasets.MNIST(root = './data', train = False, transform=transform, download=True)

trainingLoader = DataLoader(trainingData, batch_size=64, shuffle=True)
testingLoader =  DataLoader(testingData, batch_size = 1000)

100%|██████████| 9.91M/9.91M [00:01<00:00, 5.11MB/s]
100%|██████████| 28.9k/28.9k [00:00<00:00, 133kB/s]
100%|██████████| 1.65M/1.65M [00:01<00:00, 1.28MB/s]
100%|██████████| 4.54k/4.54k [00:00<00:00, 8.21MB/s]


In [45]:
class TeacherModel(nn.Module):
  def __init__(self, hidden1 = 512, hidden2 = 256):
    super().__init__()
    self.net = nn.Sequential(
      nn.Flatten(),
      nn.Linear(28*28, hidden1),
      nn.ReLU(),
      nn.Linear(hidden1, hidden2),
      nn.ReLU(),
      nn.Linear(hidden2, 10)
    )

  def forward(self, x):
    return self.net(x)

In [46]:
teacher = TeacherModel()

In [47]:
teacher

TeacherModel(
  (net): Sequential(
    (0): Flatten(start_dim=1, end_dim=-1)
    (1): Linear(in_features=784, out_features=512, bias=True)
    (2): ReLU()
    (3): Linear(in_features=512, out_features=256, bias=True)
    (4): ReLU()
    (5): Linear(in_features=256, out_features=10, bias=True)
  )
)

In [39]:
def TeacherTraining(model, loader, epochs = 5, lr = 1e-3):
  optimizer = optim.Adam(model.parameters(), lr = lr)
  lossFunction = nn.CrossEntropyLoss()

  model.train()

  for epoch in range(epochs):  # Training loop
    totalLoss = 0;
    for x, y in loader:        # x--> input , y--> target
      optimizer.zero_grad()
      output = model(x)
      loss = lossFunction(output, y)
      loss.backward()
      optimizer.step()
      totalLoss += loss.item()
    print(f"Teacher Epoch {epoch + 1}: Loss = {totalLoss/len(loader):.4f}")


In [48]:
TeacherTraining(teacher, trainingLoader)

Teacher Epoch 1: Loss = 0.3024
Teacher Epoch 2: Loss = 0.1348
Teacher Epoch 3: Loss = 0.0999
Teacher Epoch 4: Loss = 0.0819
Teacher Epoch 5: Loss = 0.0696


In [41]:
class StudentModel(nn.Module):
  def __init__(self, hidden = 128):
    super().__init__()
    self.net = nn.Sequential(
      nn.Flatten(),
      nn.Linear(28*28, hidden),
      nn.ReLU(),
      nn.Linear(hidden, 10)
    )
  def forward(self, x):
    return self.net(x)

In [42]:
student = StudentModel()

* S1) Pretraining Student Model On Hard Labels

In [43]:
def pretrainStudent(model, loader, epochs = 3, lr = 1e-3):
  optimizer2 = optim.Adam(model.parameters(), lr = lr)
  lossFunction = nn.CrossEntropyLoss()

  model.train()

  for epoch in range(epochs):  # Training loop
    totalLoss2 = 0;
    for x, y in loader:        # x--> input , y--> target
      optimizer2.zero_grad()
      output = model(x)
      loss = lossFunction(output, y)
      loss.backward()
      optimizer2.step()
      totalLoss2 += loss.item()
    print(f"Student Epoch {epoch + 1}: Loss = {totalLoss2/len(loader):.4f}")

In [44]:
pretrainStudent(student, trainingLoader)

Student Epoch 1: Loss = 0.3830
Student Epoch 2: Loss = 0.2005
Student Epoch 3: Loss = 0.1437


* S2) Distillation

In [49]:
temperature = 2.0
alpha = 0.7
crossEntropyLoss = nn.CrossEntropyLoss()
klDivergernceLoss = nn.KLDivLoss(reduction = 'batchmean')
optimizer = optim.Adam(student.parameters(), lr = 1e-3)

In [52]:
def distillation(student,teacher,loader,epochs=5,lr=1e-3):
    teacher.eval()

    studentOptimizer = optim.Adam(student.parameters(), lr=lr)

    crossEntropyLoss = nn.CrossEntropyLoss()
    KLDivergenceLoss = nn.KLDivLoss(reduction="batchmean")

    for epoch in range(epochs):
        student.train()
        totalLoss = 0
        for x, y in loader:
            with torch.no_grad():
                teacherLogits = teacher(x)
                teacherProbs = torch.softmax(teacherLogits / temperature, dim=1)

            studentLogits = student(x)
            studentLogProbs = torch.log_softmax(studentLogits / temperature, dim=1)

            softLoss = KLDivergenceLoss(studentLogProbs,teacherProbs) * (temperature ** 2)
            hardLoss = crossEntropyLoss(studentLogits, y)


            loss = alpha * softLoss + (1 - alpha) * hardLoss

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            totalLoss += loss.item()

        print(f"Student Epoch {epoch+1}: Loss = {totalLoss/len(loader):.4f}")

In [53]:
distillation(student, teacher, trainingLoader)

Student Epoch 1: Loss = 0.1922
Student Epoch 2: Loss = 0.1323
Student Epoch 3: Loss = 0.1090
Student Epoch 4: Loss = 0.0972
Student Epoch 5: Loss = 0.0907


In [54]:
def evaluation(model, loader, name = "Model"):
  model.eval()
  correct, total = 0, 0

  with torch.no_grad():
    for x, y in loader:
      output = model(x)
      prediction = output.argmax(dim = 1)
      correct += (prediction == y).sum().item()
      total += y.size(0)

  accuracy = correct / total * 100
  print(f"{name} Accuracy: {accuracy:.2f}%")
  return accuracy

In [55]:
evaluation(teacher, testingLoader, "Teacher")

Teacher Accuracy: 97.05%


97.05

In [56]:
evaluation(student, testingLoader, "Student")

Student Accuracy: 97.12%


97.11999999999999

In [57]:
def prediction(model, x):
  model.eval()
  with torch.no_grad():
    output = model(x)
    return output.argmax(dim = 1)

In [58]:
sampleBatch, sampleLabels = next(iter(testingLoader))
predictions = prediction(student, sampleBatch)

In [62]:
print(f"Predicted Labels: {predictions[:10]}")
print(f"Actual Labels: {sampleLabels[:10]}")

Predicted Labels: tensor([7, 2, 1, 0, 4, 1, 4, 9, 5, 9])
Actual Labels: tensor([7, 2, 1, 0, 4, 1, 4, 9, 5, 9])


In [63]:
torch.save(student.state_dict(), "distilledDL.pth")
print("Saved")

Saved
