# Day 28: Knowledge Distillation Implementation - Part 1

In this notebook, we'll implement knowledge distillation for language models. We'll focus on response-based distillation, where a smaller student model learns to mimic the outputs of a larger teacher model.

## Overview

1. Setup and dependencies
2. Loading teacher and student models
3. Preparing a dataset for distillation
4. Implementing the distillation loss
5. Training the student model

## 1. Setup and Dependencies

First, let's install the necessary libraries:

In [None]:
!pip install -q transformers datasets torch evaluate accelerate

In [None]:
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
from datasets import load_dataset
from transformers import (
    AutoModelForSequenceClassification,
    AutoTokenizer,
    TrainingArguments,
    Trainer,
    DataCollatorWithPadding
)
import evaluate
import numpy as np

# Check if GPU is available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

## 2. Loading Teacher and Student Models

For this example, we'll use BERT-base as our teacher model and DistilBERT as our student model. We'll fine-tune them on a sentiment analysis task.

In [None]:
# Define model names
teacher_model_name = "bert-base-uncased"
student_model_name = "distilbert-base-uncased"

# Load tokenizers
teacher_tokenizer = AutoTokenizer.from_pretrained(teacher_model_name)
student_tokenizer = AutoTokenizer.from_pretrained(student_model_name)

# Load models for sequence classification (sentiment analysis)
teacher_model = AutoModelForSequenceClassification.from_pretrained(
    teacher_model_name,
    num_labels=2  # Binary classification for sentiment
)

student_model = AutoModelForSequenceClassification.from_pretrained(
    student_model_name,
    num_labels=2
)

# Move models to device
teacher_model = teacher_model.to(device)
student_model = student_model.to(device)

# Print model sizes
def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

print(f"Teacher model ({teacher_model_name}) has {count_parameters(teacher_model):,} parameters")
print(f"Student model ({student_model_name}) has {count_parameters(student_model):,} parameters")
print(f"Size reduction: {count_parameters(teacher_model) / count_parameters(student_model):.2f}x")

## 3. Preparing a Dataset for Distillation

We'll use the SST-2 (Stanford Sentiment Treebank) dataset for sentiment analysis.

In [None]:
# Load the SST-2 dataset
dataset = load_dataset("glue", "sst2")
print(dataset)

# Look at a few examples
for i in range(3):
    print(f"Example {i+1}:")
    print(f"Text: {dataset['train'][i]['sentence']}")
    print(f"Label: {dataset['train'][i]['label']}")
    print()

In [None]:
# Tokenize the dataset
def tokenize_function(examples):
    # We'll use the student's tokenizer for both models to ensure compatibility
    return student_tokenizer(examples["sentence"], padding="max_length", truncation=True, max_length=128)

tokenized_datasets = dataset.map(tokenize_function, batched=True)

# Prepare the datasets for training
train_dataset = tokenized_datasets["train"]
eval_dataset = tokenized_datasets["validation"]

# Create a data collator
data_collator = DataCollatorWithPadding(tokenizer=student_tokenizer)

## 4. Fine-tuning the Teacher Model

Before we can distill knowledge from the teacher, we need to fine-tune it on our task.

In [None]:
# Define metrics for evaluation
accuracy_metric = evaluate.load("accuracy")

def compute_metrics(eval_pred):
    predictions, labels = eval_pred
    predictions = np.argmax(predictions, axis=1)
    return accuracy_metric.compute(predictions=predictions, references=labels)

In [None]:
# Define training arguments for the teacher
teacher_training_args = TrainingArguments(
    output_dir="./results/teacher-sst2",
    learning_rate=2e-5,
    per_device_train_batch_size=16,
    per_device_eval_batch_size=32,
    num_train_epochs=3,
    weight_decay=0.01,
    evaluation_strategy="epoch",
    save_strategy="epoch",
    load_best_model_at_end=True,
    push_to_hub=False,
    report_to="none",  # Disable wandb, tensorboard, etc.
)

# Create the trainer for the teacher
teacher_trainer = Trainer(
    model=teacher_model,
    args=teacher_training_args,
    train_dataset=train_dataset,
    eval_dataset=eval_dataset,
    tokenizer=teacher_tokenizer,
    data_collator=data_collator,
    compute_metrics=compute_metrics,
)

In [None]:
# Train the teacher model
teacher_trainer.train()

In [None]:
# Evaluate the teacher model
teacher_eval_results = teacher_trainer.evaluate()
print(f"Teacher model evaluation results: {teacher_eval_results}")

# Save the fine-tuned teacher model
teacher_model_path = "./teacher-sst2"
teacher_model.save_pretrained(teacher_model_path)
teacher_tokenizer.save_pretrained(teacher_model_path)

## 5. Generating Soft Labels from the Teacher

Now, we'll use the fine-tuned teacher model to generate soft labels (logits) for our training data.

In [None]:
# Function to generate soft labels (logits) from the teacher
def generate_soft_labels(model, dataset, batch_size=32):
    model.eval()  # Set the model to evaluation mode
    all_logits = []
    
    # Create a dataloader
    from torch.utils.data import DataLoader
    dataloader = DataLoader(
        dataset,
        batch_size=batch_size,
        collate_fn=data_collator
    )
    
    # Generate logits batch by batch
    with torch.no_grad():  # Disable gradient calculation
        for batch in dataloader:
            # Move batch to device
            batch = {k: v.to(device) for k, v in batch.items() if k != 'labels'}
            
            # Forward pass
            outputs = model(**batch)
            logits = outputs.logits
            
            # Store logits
            all_logits.append(logits.cpu())
    
    # Concatenate all logits
    all_logits = torch.cat(all_logits, dim=0)
    
    return all_logits

In [None]:
# Generate soft labels for the training dataset
print("Generating soft labels from the teacher model...")
teacher_logits = generate_soft_labels(teacher_model, train_dataset)
print(f"Generated logits shape: {teacher_logits.shape}")

# Look at a few examples of soft labels
for i in range(3):
    print(f"Example {i+1}:")
    print(f"Text: {train_dataset[i]['sentence']}")
    print(f"Hard label: {train_dataset[i]['label']}")
    print(f"Soft logits: {teacher_logits[i]}")
    print(f"Soft probabilities: {F.softmax(teacher_logits[i], dim=0)}")
    print()

## 6. Implementing the Distillation Loss

Now, let's implement a custom trainer that combines the standard cross-entropy loss with the distillation loss.

In [None]:
# Custom trainer for knowledge distillation
class DistillationTrainer(Trainer):
    def __init__(self, *args, teacher_logits=None, alpha=0.5, temperature=2.0, **kwargs):
        super().__init__(*args, **kwargs)
        self.teacher_logits = teacher_logits
        self.alpha = alpha  # Weight for the distillation loss
        self.temperature = temperature  # Temperature for softening the distributions
    
    def compute_loss(self, model, inputs, return_outputs=False):
        # Get the index in the dataset
        if hasattr(inputs, "idx"):
            idx = inputs.pop("idx")
        else:
            # If idx is not provided, try to infer it from the input_ids
            # This is a simplification and might not work in all cases
            idx = torch.arange(inputs["input_ids"].shape[0])
        
        # Standard forward pass
        outputs = model(**inputs)
        student_logits = outputs.logits
        
        # Get the teacher's logits for this batch
        teacher_logits_batch = self.teacher_logits[idx].to(device)
        
        # Standard cross-entropy loss
        hard_loss = outputs.loss
        
        # Distillation loss (KL divergence)
        soft_targets = F.softmax(teacher_logits_batch / self.temperature, dim=-1)
        soft_predictions = F.log_softmax(student_logits / self.temperature, dim=-1)
        distillation_loss = F.kl_div(soft_predictions, soft_targets, reduction="batchmean") * (self.temperature ** 2)
        
        # Combined loss
        loss = self.alpha * hard_loss + (1 - self.alpha) * distillation_loss
        
        return (loss, outputs) if return_outputs else loss

## 7. Training the Student Model with Distillation

Now, let's train the student model using our custom distillation trainer.

In [None]:
# Define training arguments for the student
student_training_args = TrainingArguments(
    output_dir="./results/student-sst2-distilled",
    learning_rate=5e-5,  # Slightly higher learning rate for the student
    per_device_train_batch_size=16,
    per_device_eval_batch_size=32,
    num_train_epochs=5,  # More epochs for the student
    weight_decay=0.01,
    evaluation_strategy="epoch",
    save_strategy="epoch",
    load_best_model_at_end=True,
    push_to_hub=False,
    report_to="none",
)

# Create the distillation trainer
distillation_trainer = DistillationTrainer(
    model=student_model,
    args=student_training_args,
    train_dataset=train_dataset,
    eval_dataset=eval_dataset,
    tokenizer=student_tokenizer,
    data_collator=data_collator,
    compute_metrics=compute_metrics,
    teacher_logits=teacher_logits,
    alpha=0.5,  # Equal weight to hard and soft targets
    temperature=2.0,  # Temperature for softening
)

In [None]:
# Train the student model with distillation
distillation_trainer.train()

In [None]:
# Evaluate the distilled student model
student_eval_results = distillation_trainer.evaluate()
print(f"Distilled student model evaluation results: {student_eval_results}")

# Save the distilled student model
student_model_path = "./student-sst2-distilled"
student_model.save_pretrained(student_model_path)
student_tokenizer.save_pretrained(student_model_path)

## 8. Comparing Teacher and Student Performance

Let's compare the performance of the teacher and student models.

In [None]:
# Compare the results
print(f"Teacher model accuracy: {teacher_eval_results['eval_accuracy']:.4f}")
print(f"Distilled student model accuracy: {student_eval_results['eval_accuracy']:.4f}")
print(f"Performance retention: {student_eval_results['eval_accuracy'] / teacher_eval_results['eval_accuracy'] * 100:.2f}%")
print(f"Size reduction: {count_parameters(teacher_model) / count_parameters(student_model):.2f}x")

## 9. Testing on Individual Examples

Let's test both models on some individual examples to see how their predictions compare.

In [None]:
# Function to get predictions from a model
def get_prediction(model, tokenizer, text):
    inputs = tokenizer(text, return_tensors="pt", padding=True, truncation=True).to(device)
    with torch.no_grad():
        outputs = model(**inputs)
    
    logits = outputs.logits
    probabilities = F.softmax(logits, dim=-1)
    prediction = torch.argmax(probabilities, dim=-1).item()
    confidence = probabilities[0][prediction].item()
    
    return prediction, confidence, probabilities[0].cpu().numpy()

In [None]:
# Test examples
test_examples = [
    "This movie was fantastic! I really enjoyed it.",
    "What a terrible waste of time. I hated every minute.",
    "The film was neither good nor bad, just mediocre.",
    "While it had some flaws, overall I'd recommend seeing it."
]

# Compare predictions
for text in test_examples:
    print(f"Text: {text}")
    
    # Teacher prediction
    teacher_pred, teacher_conf, teacher_probs = get_prediction(teacher_model, teacher_tokenizer, text)
    teacher_sentiment = "positive" if teacher_pred == 1 else "negative"
    print(f"Teacher: {teacher_sentiment} (confidence: {teacher_conf:.4f}, probs: {teacher_probs})")
    
    # Student prediction
    student_pred, student_conf, student_probs = get_prediction(student_model, student_tokenizer, text)
    student_sentiment = "positive" if student_pred == 1 else "negative"
    print(f"Student: {student_sentiment} (confidence: {student_conf:.4f}, probs: {student_probs})")
    
    # Check if they agree
    agreement = "✓" if teacher_pred == student_pred else "✗"
    print(f"Agreement: {agreement}")
    print()

## Conclusion

In this notebook, we've implemented knowledge distillation for language models. We've seen how to:

1. Fine-tune a teacher model on a specific task
2. Generate soft labels from the teacher model
3. Implement a custom distillation loss that combines hard and soft targets
4. Train a smaller student model using knowledge distillation
5. Compare the performance of the teacher and student models

The distilled student model achieves comparable performance to the teacher model while being significantly smaller. This demonstrates the power of knowledge distillation for creating efficient models that can be deployed in resource-constrained environments.

In Part 2, we'll explore more advanced distillation techniques, including feature-based distillation and multi-task distillation.