# Maximizing Model Performance with Knowledge Distillation in PyTorch

Tutorial followed by a "Medium" post.
[Link](!medium.com/artificialis/maximizing-model-performance-with-knowledge-distillation-in-pytorch-12b3960a486a)


Knowledge Distillation is the task of compress a large ANN model into a smaller one. This involves training a smaller model to mimic the behavior of the larger model.

The [Chest X-Ray Dataset (Phneumonia)](!https://www.kaggle.com/datasets/paultimothymooney/chest-xray-pneumonia) will be used.

In [1]:
# Imports
import torch
from torch import nn
from torch.utils.data import DataLoader
from torchvision.utils import  make_grid
from torchvision import transforms
import matplotlib.pyplot as plt
import numpy as np
import os
from tqdm.notebook import tqdm_notebook as tqdm
from torchvision.datasets import ImageFolder
import torch.nn.functional as F


In [2]:
# Preparing CUDA if available
if torch.cuda.is_available():
    device = torch.device('cuda:0')
    print(f'Running on the GPU {torch.cuda.get_device_name(0)}')
    torch.cuda.empty_cache()
    torch.cuda.ipc_collect()
else:
    device = torch.device('cpu')
    print('Running on the CPU')

Running on the GPU NVIDIA GeForce RTX 3070


---------------------
## Preparation of the dataset

In [3]:
# Prepare the dataset transforms
transforms_train = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406],
                         [0.229, 0.224, 0.225])
    ])

transforms_test = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406],
                         [0.229, 0.224, 0.225])
])

# Load data
train_dir = r"D:/Programacion/TorchProjects/datasets/chest_xray/train/"
test_dir = r"D:/Programacion/TorchProjects/datasets/chest_xray/test/"

train_data = ImageFolder(root=train_dir, transform=transforms_train)
test_data = ImageFolder(root=test_dir, transform=transforms_test)


----------------

# Teacher

For the knowledge distillation, we need a teacher that teach the simpler network how to do the same job as the big model.

In this case, we will use a Resnet-18 finetuned for this dataset

In [20]:
from torchvision import models

class TeacherNet(nn.Module):
    def __init__(self):
        super().__init__()
        self.model = models.resnet18(pretrained=True)

        for params in self.model.parameters():
            params.requires_grad_ = False

        n_filters = self.model.fc.in_features
        self.model.fc = nn.Linear(n_filters, 2)

    def forward(self, x):
        x = self.model(x)
        return x

In [21]:
# Training
def train(model, train_loader, test_loader, optimizer, criterion, num_epochs,  device):
    dataloaders = {"train": train_loader,
                   "val": test_loader}

    for epoch in tqdm(range(num_epochs)):
        print(f"Epoch {epoch}/{num_epochs-1}")
        print("".center(40, "-"))

        for phase in ["train", "val"]:
            if phase == "train":
                model.train()
            else:
                model.eval()

            running_loss = 0.0
            running_corrects = 0

            for inputs, labels in tqdm(dataloaders[phase]):
                inputs = inputs.to(device)
                labels = labels

                optimizer.zero_grad()

                with torch.set_grad_enabled(phase=="train"):
                    outputs = model(inputs)
                    loss = criterion(outputs, labels)

                    _, preds = torch.max(outputs, 1)

                    if phase == "train":
                        loss.backward()
                        optimizer.step()

                running_loss += loss.item() * inputs.size(0)
                running_corrects += torch.sum(preds==labels.data)

            epoch_loss = running_loss / len(dataloaders[phase].dataset)
            epoch_acc = running_corrects.double() / len(dataloaders[phase].dataset)

            print(f"{phase} Loss: {epoch_loss:.4f} Acc: {epoch_acc:.4f}")



In [22]:
# Training the teacher
teacher = TeacherNet().to(device)
optimizer = torch.optim.Adam(teacher.parameters(), lr=0.001)
criterion = nn.BCELoss()
num_epochs = 30
train(teacher, train_data, test_data, optimizer, criterion, num_epochs,  device)

  0%|          | 0/30 [00:00<?, ?it/s]

Epoch 0/29
----------------------------------------


  0%|          | 0/5216 [00:00<?, ?it/s]

ValueError: expected 4D input (got 3D input)

----------------
## Student

The student is a shallow CNN that will consist in few CNN layers and about 100k parameters.

If we train this model with the same training function than the above one (the Teacher) we won't use the knowledge learned by the teacher to improve better the student.

In [29]:
class StudentNet(nn.Module):
    def __init__(self):
        super().__init__()

        self.layer1 = nn.Sequential(
            nn.Conv2d(3, 4, kernel_size=3, padding=1),
            nn.BatchNorm2d(4),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2)
        )
        self.fc = nn.Linear(4 * 112 * 112, 2)

    def forward(self, x):
        out = self.layer1(x)
        out = out.view(out.size(0), -1)
        out = self.fc(out)
        return out


In [30]:
# Train the student without lteacher = TeacherNet().to(device)
student = StudentNet().to(device)
optimizer = torch.optim.Adam(teacher.parameters(), lr=0.001)
criterion = nn.BCELoss()
num_epochs = 30
train(student, train_data, test_data, optimizer, criterion, num_epochs,  device)

  0%|          | 0/30 [00:00<?, ?it/s]

Epoch 0/29
----------------------------------------


  0%|          | 0/5216 [00:00<?, ?it/s]

ValueError: expected 4D input (got 3D input)

**Prepare training Distillation**

The training loop will be basically the same, but the difference will be on the training loss computation. This will be done with the teacher's logits, the student's actual loss and the distillation loss.

The loss is a weighted sum of the classification loss (*student_target_loss*) and the cross entropy loss between the student logits and the teacher logits (*distillation_loss*)

In [8]:
class DistillationLoss:
    def __init__(self):
        self.student_loss = nn.CrossEntropyLoss()
        self.fistillation_loss = nn.KLDivLoss()
        self.temperature = 1
        self.alpha = 0.25


    def __call__(self, student_logits, student_target_loss, teacher_logits):
        distillation_loss = self.distillation_loss(F.log_softmax(student_logits / self.temperature, dim=1),
                                                   F.softmax(teacher_logits / self.temperature, dim=1))
        loss = (1 - self.alpha) * student_target_loss + self.alpha * distillation_loss
        return loss


We want to teach the student how the teacher "thinks", which also refers to its uncertainty. That means that if the teacher's final outputs probabilities are [0.53, 0.47], then the student will be equally uncertain and the difference between those two predictions is the destillation loss.
We can control the loss with the parameters **alpha** and **temperature**.
**alpha**: Weight of the distillation loss. 0 means we'd only consider the distillation loss, and vice versa.
**temperature**: This scales the uncertainty of the teacher's predictions.
