<a href="https://www.kaggle.com/code/aisuko/knowledge-distillation?scriptVersionId=164846818" target="_blank"><img align="left" alt="Kaggle" title="Open in Kaggle" src="https://kaggle.com/static/images/open-in-kaggle.svg"></a>

# Overview

Knowledge distillation is a technique used to transfer knowledge from a larger more complex model(teacher) to a smaller,simpler model(student). To distill knowledge from one model to another, we take a pre-trained teacher model trained on a certain task(image classification for this case) and randomly initialize a student model to be trained on image classification. Next, we train the student model to minimize the difference between it's outputs and the teacher's outputs, thus making it mimic the behavior. This was first introduced in [Distilling the Knowledge in a Neural Network](https://arxiv.org/abs/1503.02531) Here, let's distill a teacher model to a student model on the image classification tasks.

In [1]:
%%capture
!pip install transformers==4.35.2
!pip install datasets==2.15.0
!pip install evaluate==0.4.1

In [2]:
import os
from huggingface_hub import login
from kaggle_secrets import UserSecretsClient

user_secrets = UserSecretsClient()

login(token=user_secrets.get_secret("HUGGINGFACE_TOKEN"))

os.environ["WANDB_API_KEY"]=user_secrets.get_secret("WANDB_API_KEY")
os.environ["WANDB_PROJECT"] = "Fine-tuning merve-beans-vit-224-on-beands"
os.environ["WANDB_NOTES"] = "Fine tune model distilbert base uncased"
os.environ["WANDB_NAME"] = "distill-beans-vit-224-to-mobile-net-v2"

Token will not been saved to git credential helper. Pass `add_to_git_credential=True` if you want to set the git credential as well.
Token is valid (permission: write).
Your token has been saved to /root/.cache/huggingface/token
Login successful


Here, we are using the `merve/beans-vit-224` model as teacher model. It is based on `google/vit-base-patch16-224-in21k` fine-tuned on beands dataset. We will distill this model to a randomly initialized MobileNetV2.

In [3]:
from datasets import load_dataset

dataset=load_dataset("beans", split="train[:500]")
dataset

Downloading readme:   0%|          | 0.00/4.95k [00:00<?, ?B/s]



Downloading data files:   0%|          | 0/3 [00:00<?, ?it/s]

Downloading data:   0%|          | 0.00/144M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/18.5M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/17.7M [00:00<?, ?B/s]

Extracting data files:   0%|          | 0/3 [00:00<?, ?it/s]

Generating train split:   0%|          | 0/1034 [00:00<?, ? examples/s]

Generating validation split:   0%|          | 0/133 [00:00<?, ? examples/s]

Generating test split:   0%|          | 0/128 [00:00<?, ? examples/s]

Dataset({
    features: ['image_file_path', 'image', 'labels'],
    num_rows: 500
})

In [4]:
dataset=dataset.train_test_split(test_size=0.2)
dataset

DatasetDict({
    train: Dataset({
        features: ['image_file_path', 'image', 'labels'],
        num_rows: 400
    })
    test: Dataset({
        features: ['image_file_path', 'image', 'labels'],
        num_rows: 100
    })
})

# Preprocessing

We can use an image processor from either of the models. Let's make sure the dataset return the same output with same resolution. Here, we will use the `map()` method of dataset to apply the preprocessing to every split of the dataset.

In [5]:
from transformers import AutoImageProcessor

teacher_processor=AutoImageProcessor.from_pretrained("merve/beans-vit-224")
print(teacher_processor)

preprocessor_config.json:   0%|          | 0.00/325 [00:00<?, ?B/s]

ViTImageProcessor {
  "do_normalize": true,
  "do_rescale": true,
  "do_resize": true,
  "image_mean": [
    0.5,
    0.5,
    0.5
  ],
  "image_processor_type": "ViTImageProcessor",
  "image_std": [
    0.5,
    0.5,
    0.5
  ],
  "resample": 2,
  "rescale_factor": 0.00392156862745098,
  "size": {
    "height": 224,
    "width": 224
  }
}



In [6]:
def process(examples):
    processed_inputs=teacher_processor(examples["image"])
    return processed_inputs

processed_datasets=dataset.map(process, batched=True)

Map:   0%|          | 0/400 [00:00<?, ? examples/s]

Map:   0%|          | 0/100 [00:00<?, ? examples/s]

Essentially, we want the student model(a randomly initialized MobileNet) to mimic the teacher model (fine-tuned vision-transformer). To achieve this, we first get the logits output from the teacher and the student. Then, we divide each of them by the parameter temperature which controls the importance of each soft target. A parameter called lambda weights the importance of the distillation loss. 

Here, we will use `temperature=5` and `lambda=0.5`. We will use the Kullback-Leibler Divergence loss to compute the divergence between the student and teacher. Given two data P and Q, KL Divergence explains how much extra information we need to represent P using Q. If two are identical, their KL Divergence is zero, as there's no other information needed to explain P from Q. Thus, in the context of knowledge distillation, KL divergence is useful.

In [7]:
from transformers import TrainingArguments, Trainer
import torch
import torch.nn as nn
import torch.nn.functional as F

class ImageDistilTrainer(Trainer):
    def __init__(self, teacher_model=None, student_model=None, temperature=5, lambda_param=0.5, *args, **kwargs):
        super().__init__(model=student_model,*args, **kwargs)
        self.teacher=teacher_model
        self.student=student_model
        self.loss_function=nn.KLDivLoss(reduction="batchmean")
        device="cuda"
        self.teacher.to(device)
        self.teacher.eval()
        self.temperature=temperature
        self.lambda_param=lambda_param
    
    def compute_loss(self, student, inputs, return_outputs=False):
        student_output=self.student(**inputs)
        
        with torch.no_grad():
            teacher_output=self.teacher(**inputs)
            
        # Compute soft targets for teacher and student
        soft_teacher=F.softmax(teacher_output.logits/self.temperature, dim=-1)
        soft_student=F.log_softmax(student_output.logits/self.temperature, dim=-1)
        
        # Compute the loss
        distillation_loss=self.loss_function(soft_student, soft_teacher)*(self.temperature **2)
        
        # Compute the true label loss
        student_target_loss=student_output.loss
        
        # Calculate final loss
        loss=(1.- self.lambda_param)*student_target_loss+self.lambda_param*distillation_loss
        return (loss, student_output) if return_outputs else loss

In [8]:
from transformers import AutoModelForImageClassification, MobileNetV2Config, MobileNetV2ForImageClassification

training_args=TrainingArguments(
    output_dir=os.getenv("WANDB_NAME"),
    num_train_epochs=5,
    fp16=True,
    logging_dir=f"{os.getenv('WANDB_NAME')}/logs",
    logging_strategy="epoch",
    evaluation_strategy="epoch",
    save_strategy="epoch",
    load_best_model_at_end=True,
    metric_for_best_model="accuracy",
    push_to_hub=False,
    hub_strategy="every_save",
    hub_model_id=os.getenv("WANDB_NAME"),
    report_to="wandb", # or report_to="tensorboard"
    run_name=os.getenv("WANDB_NAME"),
)

num_labels=len(processed_datasets["train"].features["labels"].names)

# initialize models
teacher_model=AutoModelForImageClassification.from_pretrained(
    "merve/beans-vit-224",
    num_labels=num_labels,
    ignore_mismatched_sizes=True
)

# training MobileNetV2 from scratch
student_config=MobileNetV2Config() # initiate randomly
student_config.num_labels=num_labels
student_model=MobileNetV2ForImageClassification(student_config)

config.json:   0%|          | 0.00/799 [00:00<?, ?B/s]

pytorch_model.bin:   0%|          | 0.00/343M [00:00<?, ?B/s]

We can use `compute_metrics` function to evaluate our model on the test set. This function will be used during the training process to compute the accuracy& f1 of our model.

In [9]:
import evaluate
import numpy as np

accuracy=evaluate.load("accuracy")

def compute_metrics(eval_pred):
    predictions, labels=eval_pred
    acc=accuracy.compute(references=labels, predictions=np.argmax(predictions, axis=1))
    return {"accuracy":acc["accuracy"]}

Downloading builder script:   0%|          | 0.00/4.20k [00:00<?, ?B/s]

In [10]:
from transformers import DefaultDataCollator

data_collator=DefaultDataCollator()

trainer=ImageDistilTrainer(
    student_model=student_model,
    teacher_model=teacher_model,
    temperature=5,
    lambda_param=0.5,
    args=training_args,
    train_dataset=processed_datasets["train"],
    eval_dataset=processed_datasets["test"],
    data_collator=data_collator,
    tokenizer=teacher_processor,
    compute_metrics=compute_metrics,
)

trainer.train()

[34m[1mwandb[0m: Currently logged in as: [33murakiny[0m ([33mcausal_language_trainer[0m). Use [1m`wandb login --relogin`[0m to force relogin
[34m[1mwandb[0m: wandb version 0.16.3 is available!  To upgrade, please run:
[34m[1mwandb[0m:  $ pip install wandb --upgrade
[34m[1mwandb[0m: Tracking run with wandb version 0.16.1
[34m[1mwandb[0m: Run data is saved locally in [35m[1m/kaggle/working/wandb/run-20240229_114002-loj7qu1m[0m
[34m[1mwandb[0m: Run [1m`wandb offline`[0m to turn off syncing.
[34m[1mwandb[0m: Syncing run [33mdistill-beans-vit-224-to-mobile-net-v2[0m
[34m[1mwandb[0m: ⭐️ View project at [34m[4mhttps://wandb.ai/causal_language_trainer/Fine-tuning%20merve-beans-vit-224-on-beands[0m
[34m[1mwandb[0m: 🚀 View run at [34m[4mhttps://wandb.ai/causal_language_trainer/Fine-tuning%20merve-beans-vit-224-on-beands/runs/loj7qu1m[0m


Epoch,Training Loss,Validation Loss,Accuracy
1,0.7502,0.509075,0.71
2,0.5833,0.452112,0.71
3,0.5376,0.548939,0.71
4,0.5283,0.406016,0.72
5,0.5164,0.441271,0.7


TrainOutput(global_step=250, training_loss=0.5831593399047852, metrics={'train_runtime': 416.2823, 'train_samples_per_second': 4.804, 'train_steps_per_second': 0.601, 'total_flos': 4024001802240000.0, 'train_loss': 0.5831593399047852, 'epoch': 5.0})

In [11]:
trainer.evaluate(processed_datasets["test"])

{'eval_loss': 0.4060158431529999,
 'eval_accuracy': 0.72,
 'eval_runtime': 14.7784,
 'eval_samples_per_second': 6.767,
 'eval_steps_per_second': 0.88,
 'epoch': 5.0}

In [12]:
kwargs={
    'model_name': f'{os.getenv("WANDB_NAME")}',
#     'finetuned_from': model_name,
    'tasks': 'Image Classification',
#     'dataset_tags':'',
    'dataset':'beans'
}


teacher_processor.push_to_hub(os.getenv("WANDB_NAME"))
trainer.push_to_hub(**kwargs)

Upload 2 LFS files:   0%|          | 0/2 [00:00<?, ?it/s]

training_args.bin:   0%|          | 0.00/4.22k [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/9.09M [00:00<?, ?B/s]

'https://huggingface.co/aisuko/distill-beans-vit-224-to-mobile-net-v2/tree/main/'