Knowledge Distillation is a procedure for model compression, in which a small (student) model is trained to match a large pre-trained (teacher) model. Knowledge is transferred from the teacher model to the student by minimizing a loss function, aimed at matching softened teacher logits as well as ground-truth labels. The logits are softened by applying a "temperature" scaling function in the softmax, effectively smoothing out the probability distribution and revealing inter-class relationships learned by the teacher.


How to distil a fine-tuned ViT model(teacher model to the MobileNet Model(student model) suing Trainer API of transformer ?

### Step-0 :

In [2]:
!pip install transformers datasets accelerate tensorboard evaluate --upgrade --quiet

[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m41.7/41.7 kB[0m [31m3.5 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m11.2/11.2 MB[0m [31m129.6 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m5.5/5.5 MB[0m [31m117.2 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m84.1/84.1 kB[0m [31m8.6 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m558.8/558.8 kB[0m [31m47.1 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m363.4/363.4 MB[0m [31m4.5 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m13.8/13.8 MB[0m [31m125.6 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m24.6/24.6 MB[0m [31m93.6 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━

In [None]:
from datasets import load_dataset
from transformers import AutoImageProcessor

from transformers import TrainingArguments, Trainer
import torch
import torch.nn as nn
import torch.nn.functional as F
from accelerate.test_utils.testing import get_backend

from transformers import AutoModelForImageClassification, MobileNetV2Config, MobileNetV2ForImageClassification

import evaluate
import numpy as np

from transformers import DefaultDataCollator

In [None]:
# login the the huggingface hub using hf token for ush our model to the Hugging Face Hub through the Trainer
from huggingface_hub import notebook_login

notebook_login()

### Step-1

1. [Teacher model](https://huggingface.co/merve/beans-vit-224) is the fine-tuned form of the [original model](https://huggingface.co/google/vit-base-patch16-224-in21k) on the beans dataset from the datasets hf-library.

We will distill teacher model to randomly initialized MobileNetV2 model.

Note : both models are trained for the image classification task.

In [None]:
# load the beans dataset
dataset = load_dataset("beans")
# load the pretrained fineturned Teacher model
teacher_processor = AutoImageProcessor.from_pretrained("merve/beans-vit-224")
# image preprocessor from the model, return same output with same resolution
def process(examples):
    processed_inputs = teacher_processor(examples["image"])
    return processed_inputs
# we apply the preprocessing to the every split of dataset using map()
processed_datasets = dataset.map(process, batched=True)

Aim : student model able to mimic the teacher model.

1. get the logits output from the teacher and the student.
2. divide each of them by the parameter temperature which controls the importance of each soft target.
3. the Kullback-Leibler Divergence loss to compute the divergence between the student and teacher.


Note :
- A parameter called lambda weighs the importance of the distillation loss. Example - `temprature=5` and `lambda=0.5`.
- Given two data P and Q, KL Divergence explains how much extra information we need to represent P using Q. If two are identical, their KL divergence is zero, as there’s no other information needed to explain P from Q. Thus, in the context of knowledge distillation, KL divergence is useful.


In [None]:
class ImageDistilTrainer(Trainer):
    def __init__(self, teacher_model=None, student_model=None, temperature=None, lambda_param=None,  *args, **kwargs):
        super().__init__(model=student_model, *args, **kwargs)
        self.teacher = teacher_model
        self.student = student_model
        self.loss_function = nn.KLDivLoss(reduction="batchmean")
        device, _, _ = get_backend() # automatically detects the underlying device type (CUDA, CPU, XPU, MPS, etc.)
        self.teacher.to(device)
        self.teacher.eval()
        self.temperature = temperature
        self.lambda_param = lambda_param

    def compute_loss(self, student, inputs, return_outputs=False):
        student_output = self.student(**inputs)

        with torch.no_grad():
          teacher_output = self.teacher(**inputs)

        # Compute soft targets for teacher and student
        soft_teacher = F.softmax(teacher_output.logits / self.temperature, dim=-1)
        soft_student = F.log_softmax(student_output.logits / self.temperature, dim=-1)

        # Compute the loss
        distillation_loss = self.loss_function(soft_student, soft_teacher) * (self.temperature ** 2)

        # Compute the true label loss
        student_target_loss = student_output.loss

        # Calculate final loss
        loss = (1. - self.lambda_param) * student_target_loss + self.lambda_param * distillation_loss
        return (loss, student_output) if return_outputs else loss

In [None]:
training_args = TrainingArguments(
    output_dir="my-awesome-model",
    num_train_epochs=30,
    fp16=True,
    logging_dir=f"{repo_name}/logs",
    logging_strategy="epoch",
    eval_strategy="epoch",
    save_strategy="epoch",
    load_best_model_at_end=True,
    metric_for_best_model="accuracy",
    report_to="tensorboard",
    push_to_hub=True,
    hub_strategy="every_save",
    hub_model_id=repo_name,
    )

num_labels = len(processed_datasets["train"].features["labels"].names)

# initialize models
teacher_model = AutoModelForImageClassification.from_pretrained(
    "merve/beans-vit-224",
    num_labels=num_labels,
    ignore_mismatched_sizes=True
)

# training MobileNetV2 from scratch
student_config = MobileNetV2Config()
student_config.num_labels = num_labels
student_model = MobileNetV2ForImageClassification(student_config)

In [None]:
# compute_metrics function to evaluate our model on the test set.
# it is used during the training process to compute the accuracy & f1 of our model.

accuracy = evaluate.load("accuracy")

def compute_metrics(eval_pred):
    predictions, labels = eval_pred
    acc = accuracy.compute(references=labels, predictions=np.argmax(predictions, axis=1))
    return {"accuracy": acc["accuracy"]}

In [None]:
# Initialize the Trainer with the training arguments and data collator.
data_collator = DefaultDataCollator()
trainer = ImageDistilTrainer(
    student_model=student_model,
    teacher_model=teacher_model,
    training_args=training_args,
    train_dataset=processed_datasets["train"],
    eval_dataset=processed_datasets["validation"],
    data_collator=data_collator,
    processing_class=teacher_processor,
    compute_metrics=compute_metrics,
    temperature=5,
    lambda_param=0.5
)

In [None]:
# training our model
trainer.train()

In [None]:
# evaluation of our model on the test dataset
trainer.evaluate(processed_dataset["tests"])

### Step : References
1.[Knowledge Distillation for Computer vision](https://huggingface.co/docs/transformers/en/tasks/knowledge_distillation_for_image_classification)