In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader

In [2]:
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])

In [3]:
train_data = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
test_data = datasets.MNIST(root='./data', train=False, download=True, transform=transform)
train_loader = DataLoader(train_data, batch_size=64, shuffle=True)
test_loader = DataLoader(test_data, shuffle=False)

100%|██████████| 9.91M/9.91M [00:00<00:00, 38.4MB/s]
100%|██████████| 28.9k/28.9k [00:00<00:00, 1.07MB/s]
100%|██████████| 1.65M/1.65M [00:00<00:00, 9.82MB/s]
100%|██████████| 4.54k/4.54k [00:00<00:00, 11.9MB/s]


In [4]:
class TeacherMLP(nn.Module):
  def __init__(self, hidden1=512, hidden2=256):
    super().__init__()
    self.net = nn.Sequential(
        nn.Flatten(),
        nn.Linear(28*28, hidden1),
        nn.ReLU(),
        nn.Linear(hidden1,hidden2),
        nn.ReLU(),
        nn.Linear(hidden2,10)
    )

  def forward(self,x):
    return self.net(x)

In [5]:
teacher = TeacherMLP(hidden1=512, hidden2=256)

In [6]:
teacher

TeacherMLP(
  (net): Sequential(
    (0): Flatten(start_dim=1, end_dim=-1)
    (1): Linear(in_features=784, out_features=512, bias=True)
    (2): ReLU()
    (3): Linear(in_features=512, out_features=256, bias=True)
    (4): ReLU()
    (5): Linear(in_features=256, out_features=10, bias=True)
  )
)

In [7]:
def train_teacher(model, loader, epochs=1, lr=1e-3):

  opt = optim.Adam(model.parameters(), lr=lr)

  loss_fn = nn.CrossEntropyLoss()

  model.train()

  for ep in range(epochs):
    total_loss = 0
    for x,y in loader:
      opt.zero_grad()
      out = model(x)
      loss = loss_fn(out,y)
      loss.backward()
      opt.step()
      total_loss += loss.item()
  print(f"Teacher Epoch: {ep} loss {total_loss}")

In [8]:
train_teacher(teacher, train_loader)

Teacher Epoch: 0 loss 282.7695094048977


In [9]:
print(teacher.state_dict()) # these are the updated weights of the model

OrderedDict({'net.1.weight': tensor([[ 0.0075, -0.0234,  0.0022,  ...,  0.0235,  0.0270,  0.0070],
        [ 0.0235,  0.0189, -0.0247,  ..., -0.0149,  0.0038,  0.0261],
        [-0.0226,  0.0030, -0.0190,  ..., -0.0184,  0.0064, -0.0256],
        ...,
        [-0.0169, -0.0235,  0.0375,  ..., -0.0139,  0.0089,  0.0351],
        [ 0.0107,  0.0285,  0.0291,  ...,  0.0138,  0.0184,  0.0191],
        [-0.0241,  0.0211, -0.0102,  ..., -0.0309,  0.0133,  0.0228]]), 'net.1.bias': tensor([ 1.0554e-02, -3.8497e-02, -1.2759e-02, -8.4531e-03, -3.0692e-02,
         1.5226e-02,  2.4943e-02, -1.7480e-02,  3.0454e-02,  1.7203e-02,
        -3.8991e-02,  2.1866e-02, -8.2824e-04, -2.2781e-02, -3.7418e-02,
        -3.0662e-02,  1.2659e-02, -2.2236e-03,  8.2246e-03, -1.7084e-02,
         1.4090e-02, -2.5376e-02, -3.8082e-02,  1.1131e-02, -1.9621e-02,
        -1.8718e-02, -1.8419e-02, -3.1126e-02, -1.5475e-02, -1.4378e-02,
        -2.5952e-02, -2.9957e-02, -3.0219e-02, -2.4979e-02, -3.0865e-02,
        -3.

In [10]:
class StudentMLP(nn.Module):
  def __init__(self, hidden=128):
    super().__init__()
    self.net = nn.Sequential(
        nn.Flatten(),
        nn.Linear(28*28, hidden),
        nn.ReLU(),
        nn.Linear(hidden,10)
    )

  def forward(self,x):
    return self.net(x)

In [11]:
student = StudentMLP(hidden=128)

In [12]:
print(student)

StudentMLP(
  (net): Sequential(
    (0): Flatten(start_dim=1, end_dim=-1)
    (1): Linear(in_features=784, out_features=128, bias=True)
    (2): ReLU()
    (3): Linear(in_features=128, out_features=10, bias=True)
  )
)


In [13]:
def pretrain_student(student, loader, epochs=1, lr=1e-3):
  student.train()
  opt = optim.Adam(student.parameters(), lr=lr)
  ce_loss = nn.CrossEntropyLoss()
  for ep in range(epochs):
    for x,y in loader:
      opt.zero_grad()
      out = student(x)
      loss = ce_loss(out,y)
      loss.backward()
      opt.step()

In [14]:
pretrain_student(student, train_loader, epochs=1)

Now we have the trained student model, so what we will do is we will take the soft labels from the teacher model and then use the hard labels of the student model.

# Distil Training

In [15]:
temperature = 2.0
alpha = 0.7
ce_loss = nn.CrossEntropyLoss()
kl_loss = nn.KLDivLoss(reduction="batchmean")
optimizer = optim.Adam(student.parameters(), lr=1e-3)

In [16]:
def distil(student, teacher, loader, epochs=1):
  for ep in range(epochs):

    student.train() # This will not start the model training, this will just initialise the training.

    total_loss = 0 # We will calculate this later

    # Teacher model outputs
    for x,y in loader: # Here x is the 1d array and y is the ground truth
      with torch.no_grad():  # Torch.no_grad means that we are not going to update the teacher models gradients, this means we are not going to train the model.
        t_logits = teacher(x)
        t_probs = torch.softmax(t_logits / temperature, dim=1)

    # Student model outputs
      s_logits = student(x)
      s_log_probs = torch.log_softmax(s_logits / temperature, dim=1)

      # Calculating the Losses
      loss_soft = kl_loss(s_log_probs, t_probs) * (temperature ** 2)
      loss_hard = ce_loss(s_logits, y)
      loss = alpha * loss_soft + (1 - alpha) * loss_hard

      optimizer.zero_grad()

      loss.backward()

      optimizer.step()

      total_loss += loss.item()
    print(f" Student Epoch {ep} Loss {total_loss/len(loader):.4f}")



In [17]:
distil(student,teacher,train_loader)

 Student Epoch 0 Loss 0.1454


# Evaluating the performances of the teacher and student models

In [18]:
def evaluate(model, loader, name="Model"):
  model.eval()

  correct = 0
  total = 0

  with torch.no_grad():
    for x,y in loader:
      out = model(x)
      preds = out.argmax(dim=1)
      correct += (preds == y).sum().item()
      total += y.size(0)
  acc = correct / total * 100

  print(f"{name} Accuracy: {acc:.2f}%")
  return acc

In [19]:
evaluate(teacher, test_loader, "Teacher")

Teacher Accuracy: 96.08%


96.08

In [20]:
evaluate(student, test_loader, "Student")

Student Accuracy: 95.73%


95.73

Now we can see that if we use the pretrained student model, then the accuracy of the model is even higher and the model being smaller is very much usable in day to day tasks.