# Distillation using the state-of-the-art model
## NOTE IMPORTANTE:
Dans ce notebook, on utilise un teacher model ViT fine tuné sur cifar10 qu'on load depuis HuggingFace pour faciliter l'utilisation avec Colab. Cependant, ce fine-tuning du modèle a aussi été fait "manuellement" dans un autre notebook => le remplacer pour utiliser le notre.

In [None]:
!pip install transformers datasets accelerate tensorboard evaluate --upgrade

Collecting transformers
  Downloading transformers-4.46.3-py3-none-any.whl.metadata (44 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m44.1/44.1 kB[0m [31m160.0 kB/s[0m eta [36m0:00:00[0m
[?25hCollecting datasets
  Downloading datasets-3.1.0-py3-none-any.whl.metadata (20 kB)
Collecting tensorboard
  Downloading tensorboard-2.18.0-py3-none-any.whl.metadata (1.6 kB)
Collecting evaluate
  Downloading evaluate-0.4.3-py3-none-any.whl.metadata (9.2 kB)
Collecting dill<0.3.9,>=0.3.0 (from datasets)
  Downloading dill-0.3.8-py3-none-any.whl.metadata (10 kB)
Collecting xxhash (from datasets)
  Downloading xxhash-3.5.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (12 kB)
Collecting multiprocess<0.70.17 (from datasets)
  Downloading multiprocess-0.70.16-py310-none-any.whl.metadata (7.2 kB)
Collecting fsspec<=2024.9.0,>=2023.1.0 (from fsspec[http]<=2024.9.0,>=2023.1.0->datasets)
  Downloading fsspec-2024.9.0-py3-none-any.whl.metadata (11 kB)
Downlo

In [None]:
from datasets import load_dataset
from transformers import ViTFeatureExtractor, ViTForImageClassification
from datasets import DatasetDict
from PIL import Image
from transformers import TrainingArguments, Trainer
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import AutoModelForImageClassification, MobileNetV2Config, MobileNetV2ForImageClassification
import evaluate
import numpy as np

In [None]:
dataset = load_dataset("cifar10")

README.md:   0%|          | 0.00/5.16k [00:00<?, ?B/s]

train-00000-of-00001.parquet:   0%|          | 0.00/120M [00:00<?, ?B/s]

test-00000-of-00001.parquet:   0%|          | 0.00/23.9M [00:00<?, ?B/s]

Generating train split:   0%|          | 0/50000 [00:00<?, ? examples/s]

Generating test split:   0%|          | 0/10000 [00:00<?, ? examples/s]

In [None]:
print(dataset)
print(dataset.column_names)

DatasetDict({
    train: Dataset({
        features: ['img', 'label'],
        num_rows: 50000
    })
    test: Dataset({
        features: ['img', 'label'],
        num_rows: 10000
    })
})
{'train': ['img', 'label'], 'test': ['img', 'label']}


In [None]:
feature_extractor = ViTFeatureExtractor.from_pretrained("nateraw/vit-base-patch16-224-cifar10") # from HuggingFace repository -> change to our teacher model directory
model = ViTForImageClassification.from_pretrained("nateraw/vit-base-patch16-224-cifar10") # from HuggingFace repository -> change to our teacher model directory

split_dataset = dataset["train"].train_test_split(test_size=0.1, seed=42)  # 90% train, 10% validation

dataset = DatasetDict({
    "train": split_dataset["train"],
    "validation": split_dataset["test"],
    "test": dataset["test"]  # Retain the original test set
})


def process(examples):
    processed_inputs = feature_extractor(images=examples["img"], return_tensors="pt")
    return {
        "pixel_values": processed_inputs["pixel_values"],
        "label": examples["label"]  # Ensure labels are preserved
    }


processed_datasets = dataset.map(process, batched=True, remove_columns=["img"])

# Export the processed dataset
processed_datasets.save_to_disk("./processed_cifar10")




Map:   0%|          | 0/45000 [00:00<?, ? examples/s]

In [None]:
class ImageDistilTrainer(Trainer):
    def __init__(self, teacher_model=None, student_model=None, temperature=None, lambda_param=None,  *args, **kwargs):
        super().__init__(model=student_model, *args, **kwargs)
        self.teacher = teacher_model
        self.student = student_model
        self.loss_function = nn.KLDivLoss(reduction="batchmean")
        device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        self.teacher.to(device)
        self.teacher.eval()
        self.temperature = temperature
        self.lambda_param = lambda_param

    def compute_loss(self, student, inputs, return_outputs=False, **kwargs):
        student_output = self.student(**inputs)

        with torch.no_grad():
          teacher_output = self.teacher(**inputs)

        # Compute soft targets for teacher and student
        soft_teacher = F.softmax(teacher_output.logits / self.temperature, dim=-1)
        soft_student = F.log_softmax(student_output.logits / self.temperature, dim=-1)

        # Compute the loss
        distillation_loss = self.loss_function(soft_student, soft_teacher) * (self.temperature ** 2)

        # Compute the true label loss
        student_target_loss = student_output.loss

        # Calculate final loss
        loss = (1. - self.lambda_param) * student_target_loss + self.lambda_param * distillation_loss
        return (loss, student_output) if return_outputs else loss

In [None]:
training_args = TrainingArguments(
    output_dir="distilled-model",
    num_train_epochs=30,
    fp16=True,
    logging_dir=None,
    logging_strategy="epoch",
    eval_strategy="epoch",
    save_strategy="epoch",
    load_best_model_at_end=True,
    metric_for_best_model="accuracy",
    report_to="none",
    push_to_hub=False,
    )

num_labels = len(processed_datasets["train"].features["label"].names)

# initialize models
teacher_model = AutoModelForImageClassification.from_pretrained(
    "nateraw/vit-base-patch16-224-cifar10", # from HuggingFace repository -> change to our teacher model directory
    num_labels=num_labels,
    ignore_mismatched_sizes=True
)

# training MobileNetV2 from scratch (student)
student_config = MobileNetV2Config()
student_config.num_labels = num_labels
student_model = MobileNetV2ForImageClassification(student_config)

In [None]:
accuracy = evaluate.load("accuracy")

def compute_metrics(eval_pred):
    predictions, labels = eval_pred
    acc = accuracy.compute(references=labels, predictions=np.argmax(predictions, axis=1))
    return {"accuracy": acc["accuracy"]}

In [None]:
from transformers import DefaultDataCollator

data_collator = DefaultDataCollator()
trainer = ImageDistilTrainer(
    student_model=student_model,
    teacher_model=teacher_model,
    args=training_args,
    train_dataset=processed_datasets["train"],
    eval_dataset=processed_datasets["validation"],
    data_collator=data_collator,
    processing_class=feature_extractor,
    compute_metrics=compute_metrics,
    temperature=5,
    lambda_param=0.5
)

In [None]:
trainer.train()

In [None]:
trainer.evaluate(processed_datasets["test"])