In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader

In [2]:
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])

In [3]:
train_data = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
test_data = datasets.MNIST(root='./data', train=False, download=True, transform=transform)
train_loader = DataLoader(train_data, batch_size=64, shuffle=True)
test_loader = DataLoader(test_data, shuffle=False)

100%|██████████| 9.91M/9.91M [00:01<00:00, 4.96MB/s]
100%|██████████| 28.9k/28.9k [00:00<00:00, 129kB/s]
100%|██████████| 1.65M/1.65M [00:01<00:00, 1.23MB/s]
100%|██████████| 4.54k/4.54k [00:00<00:00, 9.87MB/s]


In [5]:
class TeacherMLP(nn.Module):
  def __init__(self, hidden1=512, hidden2=256):
    super().__init__()
    self.net = nn.Sequential(
        nn.Flatten(),
        nn.Linear(28*28, hidden1),
        nn.ReLU(),
        nn.Linear(hidden1,hidden2),
        nn.ReLU(),
        nn.Linear(hidden2,10)
    )

  def forward(self,x):
    return self.net(x)

In [6]:
teacher = TeacherMLP(hidden1=512, hidden2=256)

In [7]:
teacher

TeacherMLP(
  (net): Sequential(
    (0): Flatten(start_dim=1, end_dim=-1)
    (1): Linear(in_features=784, out_features=512, bias=True)
    (2): ReLU()
    (3): Linear(in_features=512, out_features=256, bias=True)
    (4): ReLU()
    (5): Linear(in_features=256, out_features=10, bias=True)
  )
)

In [10]:
def train_teacher(model, loader, epochs=1, lr=1e-3):

  opt = optim.Adam(model.parameters(), lr=lr)

  loss_fn = nn.CrossEntropyLoss()

  model.train()

  for ep in range(epochs):
    total_loss = 0
    for x,y in loader:
      opt.zero_grad()
      out = model(x)
      loss = loss_fn(out,y)
      loss.backward()
      opt.step()
      total_loss += loss.item()
  print(f"Teacher Epoch: {ep} loss {total_loss}")

In [11]:
train_teacher(teacher, train_loader)

Teacher Epoch: 0 loss 280.74651252664626


In [14]:
print(teacher.state_dict()) # these are the updated weights of the model

OrderedDict({'net.1.weight': tensor([[ 0.0031, -0.0300, -0.0247,  ..., -0.0118, -0.0216, -0.0018],
        [-0.0303,  0.0119,  0.0244,  ..., -0.0267,  0.0087,  0.0063],
        [ 0.0218, -0.0121,  0.0245,  ...,  0.0248,  0.0027,  0.0077],
        ...,
        [ 0.0228,  0.0148,  0.0160,  ..., -0.0133, -0.0179, -0.0127],
        [-0.0027, -0.0122, -0.0084,  ...,  0.0172,  0.0155,  0.0356],
        [ 0.0256,  0.0028, -0.0232,  ...,  0.0194, -0.0161,  0.0202]]), 'net.1.bias': tensor([ 1.7872e-02,  3.7583e-03,  2.2602e-02,  9.3915e-03,  1.0925e-02,
         2.4752e-02, -2.9046e-02, -3.2058e-02, -3.6937e-02, -1.6323e-02,
         1.9877e-02, -6.8324e-03, -3.8924e-02,  5.7817e-03, -3.7199e-02,
         1.2491e-02, -3.0926e-02, -3.4160e-02, -3.9047e-02,  2.5146e-02,
         4.2276e-03,  1.3058e-02,  5.8560e-03, -1.7277e-02, -3.2338e-02,
         2.4362e-02,  1.6240e-02,  1.6024e-02, -3.4222e-02, -1.9115e-02,
        -2.6796e-02,  1.5300e-02,  2.1581e-02, -2.0665e-02, -3.3326e-02,
         3.

In [16]:
class StudentMLP(nn.Module):
  def __init__(self, hidden=128):
    super().__init__()
    self.net = nn.Sequential(
        nn.Flatten(),
        nn.Linear(28*28, hidden),
        nn.ReLU(),
        nn.Linear(hidden,10)
    )

  def forward(self,x):
    return self.net(x)

In [17]:
student = StudentMLP(hidden=128)

In [18]:
print(student)

StudentMLP(
  (net): Sequential(
    (0): Flatten(start_dim=1, end_dim=-1)
    (1): Linear(in_features=784, out_features=128, bias=True)
    (2): ReLU()
    (3): Linear(in_features=128, out_features=10, bias=True)
  )
)


Now we have the untrained student model, so what we will do is we will take the soft labels from the teacher model and then use the hard labels of the student model.

# Distil Training

In [19]:
temperature = 2.0
alpha = 0.7
ce_loss = nn.CrossEntropyLoss()
kl_loss = nn.KLDivLoss(reduction="batchmean")
optimizer = optim.Adam(student.parameters(), lr=1e-3)

In [22]:
def distil(student, teacher, loader, epochs=1):
  for ep in range(epochs):

    student.train() # This will not start the model training, this will just initialise the training.

    total_loss = 0 # We will calculate this later

    # Teacher model outputs
    for x,y in loader: # Here x is the 1d array and y is the ground truth
      with torch.no_grad():  # Torch.no_grad means that we are not going to update the teacher models gradients, this means we are not going to train the model.
        t_logits = teacher(x)
        t_probs = torch.softmax(t_logits / temperature, dim=1)

    # Student model outputs
      s_logits = student(x)
      s_log_probs = torch.log_softmax(s_logits / temperature, dim=1)

      # Calculating the Losses
      loss_soft = kl_loss(s_log_probs, t_probs) * (temperature ** 2)
      loss_hard = ce_loss(s_logits, y)
      loss = alpha * loss_soft + (1 - alpha) * loss_hard

      optimizer.zero_grad()

      loss.backward()

      optimizer.step()

      total_loss += loss.item()
    print(f" Student Epoch {ep} Loss {total_loss/len(loader):.4f}")



In [23]:
distil(student,teacher,train_loader)

 Student Epoch 0 Loss 0.5280


# Evaluating the performances of the teacher and student models

In [24]:
def evaluate(model, loader, name="Model"):
  model.eval()

  correct = 0
  total = 0

  with torch.no_grad():
    for x,y in loader:
      out = model(x)
      preds = out.argmax(dim=1)
      correct += (preds == y).sum().item()
      total += y.size(0)
  acc = correct / total * 100

  print(f"{name} Accuracy: {acc:.2f}%")
  return acc

In [25]:
evaluate(teacher, test_loader, "Teacher")

Teacher Accuracy: 95.25%


95.25

In [26]:
evaluate(student, test_loader, "Student")

Student Accuracy: 92.54%


92.54

From here we can see that the student model's performance is near to the teacher model and in real world cases where performance matters more than the accuracy of the result the student model outshines the teacher model as it is a smaller model than the teacher model and thus is much faster.

# Sample Predictions

In [27]:
def predict(model, x):
  model.eval()
  with torch.no_grad():
    out = model(x)
    return out.argmax(dim=1)


In [28]:
sample_batch, sample_labels = next(iter(test_loader))

In [29]:
preds = predict(student, sample_batch)

In [30]:
print("Student Predictions: ", preds[:20])

Student Predictions:  tensor([7])


In [31]:
print("True Labels: ", sample_labels[:20])

True Labels:  tensor([7])
