# Distill teacher into small CNN

In [None]:
!pip install transformers datasets accelerate tensorboard evaluate --upgrade

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from datasets import load_dataset, DatasetDict
from transformers import (
    ViTFeatureExtractor, ViTForImageClassification,
    TrainingArguments, Trainer, DefaultDataCollator
)
import evaluate
from sklearn.model_selection import train_test_split

In [None]:
dataset = load_dataset("cifar10")

train_subset_size = 5000
val_subset_size = 1000
test_subset_size = 2000

train_data = dataset["train"].shuffle(seed=42)
train_indices, val_indices = train_test_split(range(len(train_data)), test_size=val_subset_size, random_state=42)

dataset["train"] = train_data.select(range(train_subset_size))
dataset["validation"] = train_data.select(val_indices)
dataset["test"] = dataset["test"].shuffle(seed=42).select(range(test_subset_size))

# Load the feature extractor and teacher model
feature_extractor = ViTFeatureExtractor.from_pretrained("nateraw/vit-base-patch16-224-cifar10")
teacher_model = ViTForImageClassification.from_pretrained("nateraw/vit-base-patch16-224-cifar10")

# Preprocessing function
def process(examples):
    inputs = feature_extractor(examples["img"], return_tensors="np")
    examples["pixel_values"] = inputs["pixel_values"]
    return examples

processed_datasets = dataset.map(process, batched=True, remove_columns=["img"])

In [None]:
# Student model
class SmallCNN(nn.Module):
    def __init__(self, num_classes=10, input_size=224):
        super(SmallCNN, self).__init__()
        self.conv1 = nn.Conv2d(3, 32, kernel_size=3, padding=1)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1)
        self.conv3 = nn.Conv2d(64, 128, kernel_size=3, padding=1)
        self.pool = nn.MaxPool2d(2, 2)
        self.dropout = nn.Dropout(0.5)

        self._flattened_size = self._get_flattened_size(input_size)

        self.fc1 = nn.Linear(self._flattened_size, 256)
        self.fc2 = nn.Linear(256, num_classes)

    def _get_flattened_size(self, input_size):
        with torch.no_grad():
            x = torch.zeros(1, 3, input_size, input_size)
            x = self.pool(F.relu(self.conv1(x)))
            x = self.pool(F.relu(self.conv2(x)))
            x = self.pool(F.relu(self.conv3(x)))
        return x.numel()

    def forward(self, pixel_values, labels=None):
        x = self.pool(F.relu(self.conv1(pixel_values)))
        x = self.pool(F.relu(self.conv2(x)))
        x = self.pool(F.relu(self.conv3(x)))
        x = x.view(x.size(0), -1)
        x = self.dropout(F.relu(self.fc1(x)))
        logits = self.fc2(x)

        loss = None
        if labels is not None:
            loss = nn.CrossEntropyLoss()(logits, labels)

        return {"logits": logits, "loss": loss}

num_labels = len(processed_datasets["train"].features["label"].names)
small_cnn_model = SmallCNN(num_classes=num_labels, input_size=224)

In [None]:
# Define the Trainer for distillation
class DistillationTrainer(Trainer):
    def __init__(self, teacher_model=None, student_model=None, temperature=None, lambda_param=None, *args, **kwargs):
        super().__init__(model=student_model, *args, **kwargs)
        self.teacher = teacher_model
        self.student = student_model
        self.temperature = temperature
        self.lambda_param = lambda_param
        self.loss_function = nn.KLDivLoss(reduction="batchmean")
        device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        self.teacher.to(device)
        self.teacher.eval()

    def compute_loss(self, model, inputs, return_outputs=False, *args, **kwargs):
        labels = inputs.pop("labels")
        student_outputs = self.student(**inputs)
        student_logits = student_outputs["logits"]

        with torch.no_grad():
            teacher_outputs = self.teacher(**inputs)
            teacher_logits = teacher_outputs.logits

        # Distillation loss
        soft_teacher_probs = F.softmax(teacher_logits / self.temperature, dim=-1)
        soft_student_log_probs = F.log_softmax(student_logits / self.temperature, dim=-1)
        distillation_loss = self.loss_function(soft_student_log_probs, soft_teacher_probs) * (self.temperature ** 2)

        # Classification loss
        classification_loss = F.cross_entropy(student_logits, labels)

        # Combine the losses
        loss = (1. - self.lambda_param) * classification_loss + self.lambda_param * distillation_loss
        return (loss, student_outputs) if return_outputs else loss

# Define training arguments
distillation_training_args = TrainingArguments(
    output_dir="distilled-small-cnn",
    num_train_epochs=10,
    per_device_train_batch_size=32,
    per_device_eval_batch_size=64,
    learning_rate=1e-3,
    logging_dir=None,
    logging_strategy="epoch",
    eval_strategy="epoch",
    save_strategy="epoch",
    load_best_model_at_end=True,
    metric_for_best_model="accuracy",
    report_to="none",
    push_to_hub=False
)

In [None]:
accuracy = evaluate.load("accuracy")

def compute_metrics(eval_pred):
    predictions, labels = eval_pred
    preds = np.argmax(predictions, axis=1)
    acc = accuracy.compute(references=labels, predictions=preds)
    return {"accuracy": acc["accuracy"]}

In [None]:
data_collator = DefaultDataCollator()

# Initialize the Distillation Trainer
distillation_trainer = DistillationTrainer(
    teacher_model=teacher_model,
    student_model=small_cnn_model,
    args=distillation_training_args,
    train_dataset=processed_datasets["train"],
    eval_dataset=processed_datasets["validation"],
    data_collator=data_collator,
    compute_metrics=compute_metrics,
    temperature=5,
    lambda_param=0.3
)

In [None]:
# Train the distilled model
distillation_trainer.train()

In [None]:
# Evaluate on the test set
distillation_test_results = distillation_trainer.evaluate(processed_datasets["test"])
print("Distilled model test accuracy: ", distillation_test_results["eval_accuracy"])