# Model (Knowledge) Distillation

In the following notebook, the knowledge distillation approach will be discussed. It allows training a smaller model (student) with the same accuracy as much larger model (teacher).

I will use the following works: [link 1](https://josehoras.github.io/knowledge-distillation/) and [link 2](https://pytorch.org/tutorials/beginner/blitz/cifar10_tutorial.html)

In [1]:
import torch
from torch import nn
from torch.utils.data import DataLoader
import torchvision.transforms as transforms
from torchvision.datasets import CIFAR10
from torchvision.models import resnet50
from torch.nn import functional as F
from torchmetrics import Accuracy
import pytorch_lightning as pl

## Data preparation
Define transforms, batch size and device

In [2]:
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

batch_size = 32

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

Download CIFAR 10 train and test datasets, create loaders

In [3]:
# Train
train_set = CIFAR10(
    root='./data',
    train=True,
    download=True,
    transform=transform
)

train_loader = DataLoader(train_set, batch_size=batch_size, shuffle=True)

# Validation
val_set = CIFAR10(
    root='./data',
    train=False,
    download=True,
    transform=transform
)

val_loader = DataLoader(val_set, batch_size=batch_size, shuffle=False)

Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to ./data/cifar-10-python.tar.gz


  0%|          | 0/170498071 [00:00<?, ?it/s]

Extracting ./data/cifar-10-python.tar.gz to ./data
Files already downloaded and verified


## Creation of teacher model
Teacher model will be pretrained ResNet. Since the model was pretrained on ImageNet, we have to train it on CIFAR.

First, define teacher model:

In [4]:
class Teacher(pl.LightningModule):
    def __init__(self):
        super().__init__()

        # Load pretrained model
        self.model = resnet50(pretrained=True)
        self.model.fc = nn.Linear(2048, 10)
        
        # Define loss and accuracy
        self.accuracy = Accuracy(task="multiclass", num_classes=10, top_k=1)
        self.loss = F.cross_entropy

    def forward(self, x):
        return self.model.forward(x)

    def step(self, batch):
        x, y = batch
        logits = self.forward(x)
        loss = self.loss(logits, y)        
        acc = self.accuracy(logits, y)
        return {'loss': loss, 'acc': acc}

    def training_step(self, batch, batch_idx):
        return self.step(batch)

    def validation_step(self, batch, batch_idx, dataloader_idx=0):
        return self.step(batch)

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=1e-4)
        return optimizer

    def epoch_end(self, outputs, mode):
        avg_loss = torch.stack([x['loss'] for x in outputs]).mean()
        avg_acc = torch.stack([x['acc'] for x in outputs]).mean()
        print(f"\n{mode}: loss = {avg_loss:.2f}, accuracy = {avg_acc:.2f}")

    def training_epoch_end(self, outputs):
        self.epoch_end(outputs, "Train")

    def validation_epoch_end(self, outputs):
        self.epoch_end(outputs, "Validation")

Training of teacher

In [5]:
# Create model
teacher = Teacher()

# Create trainer
teacher_trainer = pl.Trainer(
    devices=1,
    max_epochs=3,
    accelerator=device.type
)

# Train
teacher_trainer.fit(teacher, train_loader, val_loader)

Downloading: "https://download.pytorch.org/models/resnet50-0676ba61.pth" to /root/.cache/torch/hub/checkpoints/resnet50-0676ba61.pth


  0%|          | 0.00/97.8M [00:00<?, ?B/s]

Sanity Checking: 0it [00:00, ?it/s]


Validation: loss = 2.48, accuracy = 0.12


Training: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]


Validation: loss = 0.57, accuracy = 0.81

Train: loss = 0.89, accuracy = 0.70


Validation: 0it [00:00, ?it/s]


Validation: loss = 0.47, accuracy = 0.84

Train: loss = 0.50, accuracy = 0.83


Validation: 0it [00:00, ?it/s]


Validation: loss = 0.48, accuracy = 0.84

Train: loss = 0.34, accuracy = 0.88


We can see that just in 3 epochs teacher accuracy is very high.

## Create student model
Student model is a simple CNN.

Let's first define abstract class, which implements necessary fields and methods. Then we will train and check accuracy for StudentWithoutTeacher (training without teacher) and for StudentWithTeacher (training with teacher). Using them, we can compare if accuracy really becomes higher with knowledge distillation technique.

Abstract student class:

In [6]:
class StudentAbstractClass(pl.LightningModule):
    def __init__(self):
        super().__init__()

        # Define model layers
        self.conv1 = nn.Conv2d(3, 16, 5)
        self.pool1 = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(16, 64, 5)
        self.pool2 = nn.MaxPool2d(2, 2)
        self.fc1 = nn.Linear(64 * 5 * 5, 1024)
        self.fc2 = nn.Linear(1024, 128)
        self.fc3 = nn.Linear(128, 10)
        
        # Define accuracy
        self.accuracy = Accuracy(task="multiclass", num_classes=10, top_k=1)

    def forward(self, x):
        x = self.pool1(F.relu(self.conv1(x)))
        x = self.pool2(F.relu(self.conv2(x)))
        x = torch.flatten(x, 1)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x
    
    def loss(self, x, logits, y):
        raise NotImplementedError("loss function is not implemented in abstract class")

    def step(self, batch):
        x, y = batch
        logits = self.forward(x)
        loss = self.loss(x, logits, y)
        acc = self.accuracy(logits, y)
        return {'loss': loss, 'acc': acc}

    def training_step(self, batch, batch_idx):
        return self.step(batch)

    def validation_step(self, batch, batch_idx, dataloader_idx=0):
        return self.step(batch)

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=1e-4)
        return optimizer

    def epoch_end(self, outputs, mode):
        avg_loss = torch.stack([x['loss'] for x in outputs]).mean()
        avg_acc = torch.stack([x['acc'] for x in outputs]).mean()
        print(f"\n{mode}: loss = {avg_loss:.2f}, accuracy = {avg_acc:.2f}")

    def training_epoch_end(self, outputs):
        self.epoch_end(outputs, "Train")

    def validation_epoch_end(self, outputs):
        self.epoch_end(outputs, "Validation")

Disable grads for teacher and send model to device

In [7]:
teacher.eval()
teacher = teacher.to(device)

Define student model without teacher:

In [8]:
class StudentWithoutTeacher(StudentAbstractClass):
    def __init__(self):
        super().__init__()
    
    def loss(self, x, logits, y):
        return F.cross_entropy(logits, y)

Train student without teacher

In [9]:
# Create model
student_without_teacher = StudentWithoutTeacher()

# Create trainer
trainer_without_teacher = pl.Trainer(
    devices=1,
    max_epochs=10,
    accelerator=device.type
)

# Train
trainer_without_teacher.fit(student_without_teacher, train_loader, val_loader)

Sanity Checking: 0it [00:00, ?it/s]


Validation: loss = 2.31, accuracy = 0.11


Training: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]


Validation: loss = 1.42, accuracy = 0.48

Train: loss = 1.63, accuracy = 0.41


Validation: 0it [00:00, ?it/s]


Validation: loss = 1.29, accuracy = 0.54

Train: loss = 1.35, accuracy = 0.51


Validation: 0it [00:00, ?it/s]


Validation: loss = 1.19, accuracy = 0.57

Train: loss = 1.21, accuracy = 0.57


Validation: 0it [00:00, ?it/s]


Validation: loss = 1.12, accuracy = 0.60

Train: loss = 1.11, accuracy = 0.60


Validation: 0it [00:00, ?it/s]


Validation: loss = 1.07, accuracy = 0.62

Train: loss = 1.03, accuracy = 0.64


Validation: 0it [00:00, ?it/s]


Validation: loss = 1.03, accuracy = 0.64

Train: loss = 0.96, accuracy = 0.66


Validation: 0it [00:00, ?it/s]


Validation: loss = 0.96, accuracy = 0.66

Train: loss = 0.90, accuracy = 0.68


Validation: 0it [00:00, ?it/s]


Validation: loss = 0.96, accuracy = 0.66

Train: loss = 0.84, accuracy = 0.70


Validation: 0it [00:00, ?it/s]


Validation: loss = 0.92, accuracy = 0.68

Train: loss = 0.79, accuracy = 0.73


Validation: 0it [00:00, ?it/s]


Validation: loss = 0.90, accuracy = 0.68

Train: loss = 0.73, accuracy = 0.75


Define student model with teacher:

In [10]:
class StudentWithTeacher(StudentAbstractClass):
    def __init__(self):
        super().__init__()
    
    def loss(self, x, logits, y, T=3):
        teacher_loss = F.mse_loss(
            F.softmax(logits / T, dim=-1),
            F.softmax(teacher.forward(x) / T, dim=-1)
        )
        lables_loss = F.cross_entropy(logits, y)
        return teacher_loss + lables_loss

Train student with teacher

In [11]:
# Create model
student_with_teacher = StudentWithTeacher()

# Create trainer
trainer_with_teacher = pl.Trainer(
    devices=1,
    max_epochs=10,
    accelerator=device.type
)

# Train
trainer_with_teacher.fit(student_with_teacher, train_loader, val_loader)

Sanity Checking: 0it [00:00, ?it/s]


Validation: loss = 2.33, accuracy = 0.19


Training: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]


Validation: loss = 1.45, accuracy = 0.48

Train: loss = 1.66, accuracy = 0.40


Validation: 0it [00:00, ?it/s]


Validation: loss = 1.30, accuracy = 0.54

Train: loss = 1.38, accuracy = 0.51


Validation: 0it [00:00, ?it/s]


Validation: loss = 1.19, accuracy = 0.58

Train: loss = 1.24, accuracy = 0.56


Validation: 0it [00:00, ?it/s]


Validation: loss = 1.11, accuracy = 0.61

Train: loss = 1.13, accuracy = 0.60


Validation: 0it [00:00, ?it/s]


Validation: loss = 1.07, accuracy = 0.62

Train: loss = 1.05, accuracy = 0.63


Validation: 0it [00:00, ?it/s]


Validation: loss = 1.03, accuracy = 0.64

Train: loss = 0.98, accuracy = 0.66


Validation: 0it [00:00, ?it/s]


Validation: loss = 0.99, accuracy = 0.65

Train: loss = 0.92, accuracy = 0.68


Validation: 0it [00:00, ?it/s]


Validation: loss = 0.96, accuracy = 0.67

Train: loss = 0.87, accuracy = 0.70


Validation: 0it [00:00, ?it/s]


Validation: loss = 0.93, accuracy = 0.68

Train: loss = 0.81, accuracy = 0.72


Validation: 0it [00:00, ?it/s]


Validation: loss = 0.92, accuracy = 0.68

Train: loss = 0.77, accuracy = 0.74


We can see from results above that accuracy with model distillation is a bit higher on some epochs, and also there is less overfitting on training data.