# Knowledge Distillation Basics

In [1]:
import lightning as L
import torch
import torch.nn.functional as F
import torchmetrics
from lightning.pytorch import seed_everything
from lightning.pytorch.callbacks import EarlyStopping, ModelCheckpoint, RichProgressBar
from lightning.pytorch.loggers import TensorBoardLogger
from torch import nn
from torchvision.models import (
    MobileNet_V2_Weights,
    SqueezeNet1_1_Weights,
    mobilenet_v2,
    squeezenet1_1,
)

from dataset import PetDataModule
from model import PetClassifier

In [2]:
%reload_ext watermark
%watermark --iversions

lightning   : 2.1.1
torchmetrics: 1.2.0
torch       : 2.1.0



In [3]:
SEED = 42
seed_everything(SEED, workers=True)

Seed set to 42


42

In [4]:
class PetClassifierDistilled(L.LightningModule):
    def __init__(self, num_classes: int, lr: float, wd: float, teacher_model=None, only_student: bool = False):
        super().__init__()
        # self.save_hyperparameters()
        self.lr = lr
        self.wd = wd

        self.student_model = self._create_student_model(num_classes)
        self.train_accuracy = torchmetrics.classification.Accuracy(task="multiclass", num_classes=num_classes)
        self.val_accuracy = torchmetrics.classification.Accuracy(task="multiclass", num_classes=num_classes)
        self.test_accuracy = torchmetrics.classification.Accuracy(task="multiclass", num_classes=num_classes)
        self.only_student = only_student
        self.temperature = 5
        self.lambda_param = 0.5
        self.teacher_model = None
        if not only_student:
            assert teacher_model, "Teacher model is none"
            self.teacher_model = teacher_model
            self.teacher_model.eval()
            self.teacher_model.to("cuda")

    def compute_loss_and_logits(self, imgs, labels):
        if self.only_student:
            student_logits = self.student_model(imgs)
            loss = F.cross_entropy(student_logits, labels)
        else:
            student_logits = self.student_model(imgs)
            teacher_logits = self.teacher_model(imgs)
            student_target_loss = F.cross_entropy(student_logits, labels)
            soft_teacher = F.softmax(teacher_logits / self.temperature, dim=-1)
            soft_student = F.log_softmax(student_logits / self.temperature, dim=-1)
            distillation_loss = F.kl_div(soft_student, soft_teacher, reduction="batchmean") * (self.temperature**2)
            loss = (1.0 - self.lambda_param) * student_target_loss + self.lambda_param * distillation_loss
        return loss, student_logits

    def training_step(self, batch):
        imgs, labels = batch

        loss, logits = self.compute_loss_and_logits(imgs, labels)

        prob = torch.softmax(logits, dim=1)
        preds = torch.argmax(prob, dim=1)
        self.train_accuracy(preds, labels)
        self.log("train_loss", loss, on_step=True, on_epoch=True, prog_bar=True, logger=True)
        self.log(
            "train_acc",
            self.train_accuracy,
            on_step=True,
            on_epoch=True,
            prog_bar=True,
            logger=True,
        )
        return loss

    def validation_step(self, batch):
        imgs, labels = batch

        loss, logits = self.compute_loss_and_logits(imgs, labels)

        prob = torch.softmax(logits, dim=1)
        preds = torch.argmax(prob, dim=1)
        self.val_accuracy(preds, labels)
        self.log("val_loss", loss, on_step=True, on_epoch=True, prog_bar=True, logger=True)
        self.log(
            "val_acc",
            self.val_accuracy,
            on_step=True,
            on_epoch=True,
            prog_bar=True,
            logger=True,
        )

    def test_step(self, batch):
        imgs, labels = batch
        logits = self.student_model(imgs)
        prob = torch.softmax(logits, dim=1)
        preds = torch.argmax(prob, dim=1)
        self.test_accuracy(preds, labels)
        self.log(
            "test_acc",
            self.test_accuracy,
            on_step=True,
            on_epoch=True,
            prog_bar=True,
            logger=True,
        )

    def forward(self, data):
        return self.student_model(data)

    def configure_optimizers(self):
        optimizer = torch.optim.AdamW(self.student_model.parameters(), lr=self.lr, weight_decay=self.wd)
        scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, self.trainer.max_epochs, 0)
        return [optimizer], [scheduler]

    def _create_student_model(self, num_classes):
        student_model = squeezenet1_1()  # weights=SqueezeNet1_1_Weights.IMAGENET1K_V1
        student_model.classifier._modules["1"] = nn.Conv2d(512, num_classes, kernel_size=(1, 1))
        student_model.num_classes = num_classes
        return student_model

In [5]:
BATCH_SIZE = 32
NUM_WORKERS = 8
NUM_CLASSES = 37
MAX_EPOCHS = 100

lr = 1e-4
weight_decay = 1e-6

dm = PetDataModule(BATCH_SIZE, NUM_WORKERS)
student_model = PetClassifierDistilled(NUM_CLASSES, lr, weight_decay, only_student=True)

In [6]:
logger = TensorBoardLogger("logs", name="student_model")
checkpoint_callback = ModelCheckpoint(dirpath="st_model_checkpoints", monitor="val_acc", mode="max")
bar = RichProgressBar()
early_stopping = EarlyStopping("val_loss", patience=5)
trainer = L.Trainer(
    max_epochs=MAX_EPOCHS,
    accelerator="gpu",
    logger=logger,
    callbacks=[bar, checkpoint_callback, early_stopping],
    deterministic=True,
    check_val_every_n_epoch=3,
)
trainer.fit(model=student_model, datamodule=dm)

Trainer will use only 1 of 3 GPUs because it is running inside an interactive / notebook environment. You may try to set `Trainer(devices=3)` but please note that multi-GPU inside interactive / notebook environments is considered experimental and unstable. Your mileage may vary.
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
/home/ababu/mlw_2023/.venv1/lib/python3.10/site-packages/lightning/pytorch/callbacks/model_checkpoint.py:634: Checkpoint directory st_model_checkpoints exists and is not empty.
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1,2]


Output()

In [7]:
trainer.test(model=student_model, datamodule=dm)

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1,2]


Output()

[{'test_acc_epoch': 0.2319078892469406}]

In [8]:
CHECKPOINT = "/home/ababu/mlw_2023/checkpoints/epoch=8-step=720.ckpt"
teacher_model = PetClassifier.load_from_checkpoint(CHECKPOINT)

kd_model = PetClassifierDistilled(NUM_CLASSES, lr, weight_decay, teacher_model=teacher_model, only_student=False)

logger = TensorBoardLogger("logs", name="kd_model")
checkpoint_callback = ModelCheckpoint(dirpath="kd_model_checkpoints", monitor="val_acc", mode="max")
bar = RichProgressBar()
early_stopping = EarlyStopping("val_loss", patience=5)
trainer = L.Trainer(
    max_epochs=MAX_EPOCHS,
    accelerator="gpu",
    logger=logger,
    callbacks=[bar, checkpoint_callback, early_stopping],
    deterministic=True,
    check_val_every_n_epoch=3,
)
trainer.fit(model=kd_model, datamodule=dm)

Trainer will use only 1 of 3 GPUs because it is running inside an interactive / notebook environment. You may try to set `Trainer(devices=3)` but please note that multi-GPU inside interactive / notebook environments is considered experimental and unstable. Your mileage may vary.
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
/home/ababu/mlw_2023/.venv1/lib/python3.10/site-packages/lightning/pytorch/callbacks/model_checkpoint.py:634: Checkpoint directory kd_model_checkpoints exists and is not empty.
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1,2]


Output()

In [9]:
trainer.test(model=kd_model, datamodule=dm)

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1,2]


Output()

[{'test_acc_epoch': 0.30427631735801697}]