# Day 28: Knowledge Distillation Implementation - Part 2

In this notebook, we'll explore advanced knowledge distillation techniques for language models, focusing on feature-based distillation and generating synthetic data for distillation.

## Overview

1. Setup and dependencies
2. Feature-based distillation (hidden states)
3. Generating synthetic data for distillation
4. Evaluating the distilled model

## 1. Setup and Dependencies

In [None]:
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
from datasets import load_dataset
from transformers import (
    AutoModelForCausalLM,
    AutoTokenizer,
    TrainingArguments,
    Trainer,
    DataCollatorForLanguageModeling
)
import numpy as np
from tqdm import tqdm

# Check if GPU is available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

## 2. Feature-Based Distillation

In feature-based distillation, the student learns to mimic the internal representations (hidden states) of the teacher model, not just the final outputs.

In [None]:
# Load teacher and student models for causal language modeling
teacher_model_name = "gpt2-medium"  # 345M parameters
student_model_name = "gpt2"         # 124M parameters

# Load tokenizers
tokenizer = AutoTokenizer.from_pretrained(student_model_name)
tokenizer.pad_token = tokenizer.eos_token

# Load models with output_hidden_states=True to access internal representations
teacher_model = AutoModelForCausalLM.from_pretrained(
    teacher_model_name,
    output_hidden_states=True
)

student_model = AutoModelForCausalLM.from_pretrained(
    student_model_name,
    output_hidden_states=True
)

# Move models to device
teacher_model = teacher_model.to(device)
student_model = student_model.to(device)

# Print model sizes
def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

print(f"Teacher model ({teacher_model_name}) has {count_parameters(teacher_model):,} parameters")
print(f"Student model ({student_model_name}) has {count_parameters(student_model):,} parameters")
print(f"Size reduction: {count_parameters(teacher_model) / count_parameters(student_model):.2f}x")

### 2.1 Implementing Feature-Based Distillation Loss

We'll create a custom loss function that combines:
1. Language modeling loss (next token prediction)
2. Hidden state distillation loss (MSE between teacher and student hidden states)

In [None]:
# Custom model for feature-based distillation
class FeatureDistillationModel(nn.Module):
    def __init__(self, student_model):
        super().__init__()
        self.student = student_model
    
    def forward(self, input_ids, attention_mask=None, labels=None, teacher_hidden_states=None):
        # Forward pass through student model
        outputs = self.student(
            input_ids=input_ids,
            attention_mask=attention_mask,
            labels=labels,
            output_hidden_states=True
        )
        
        # Standard language modeling loss
        lm_loss = outputs.loss
        
        # Feature distillation loss
        feature_loss = 0
        if teacher_hidden_states is not None:
            student_hidden_states = outputs.hidden_states
            
            # We'll use the last hidden state for distillation
            # For more comprehensive distillation, you could use multiple layers
            teacher_last_hidden = teacher_hidden_states[-1]
            student_last_hidden = student_hidden_states[-1]
            
            # MSE loss between hidden states
            feature_loss = F.mse_loss(student_last_hidden, teacher_last_hidden)
        
        # Combined loss
        loss = lm_loss + feature_loss
        
        return {
            "loss": loss,
            "lm_loss": lm_loss,
            "feature_loss": feature_loss,
            "logits": outputs.logits,
            "hidden_states": outputs.hidden_states
        }

## 3. Generating Synthetic Data for Distillation

One powerful approach to distillation is to generate synthetic data using the teacher model. This allows us to create a large, diverse dataset tailored to our needs.

In [None]:
# Function to generate synthetic data using the teacher model
def generate_synthetic_data(model, tokenizer, prompts, num_samples=100, max_length=128):
    model.eval()
    generated_texts = []
    
    for prompt in tqdm(prompts * (num_samples // len(prompts) + 1)):
        if len(generated_texts) >= num_samples:
            break
            
        # Tokenize the prompt
        inputs = tokenizer(prompt, return_tensors="pt").to(device)
        
        # Generate text
        with torch.no_grad():
            outputs = model.generate(
                inputs.input_ids,
                max_length=max_length,
                do_sample=True,
                top_p=0.9,
                temperature=0.7,
                num_return_sequences=1
            )
        
        # Decode the generated text
        generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
        generated_texts.append(generated_text)
    
    return generated_texts[:num_samples]

In [None]:
# Define some diverse prompts for generation
prompts = [
    "The best way to learn a new language is",
    "In the future, artificial intelligence will",
    "Climate change is affecting our planet by",
    "The most important scientific discovery of the last century was",
    "When it comes to healthy eating habits,"
]

# Generate synthetic data (small sample for demonstration)
print("Generating synthetic data...")
synthetic_texts = generate_synthetic_data(teacher_model, tokenizer, prompts, num_samples=10)

# Display some examples
for i, text in enumerate(synthetic_texts[:3]):
    print(f"Example {i+1}:\n{text}\n")

In [None]:
# Create a dataset from the synthetic texts
from datasets import Dataset

synthetic_dataset = Dataset.from_dict({"text": synthetic_texts})
print(synthetic_dataset)

### 3.1 Preparing the Dataset for Feature-Based Distillation

Now, we'll tokenize the synthetic data and extract the teacher's hidden states.

In [None]:
# Tokenize the synthetic dataset
def tokenize_function(examples):
    return tokenizer(examples["text"], padding="max_length", truncation=True, max_length=128)

tokenized_dataset = synthetic_dataset.map(tokenize_function, batched=True)

# Create a data collator
data_collator = DataCollatorForLanguageModeling(
    tokenizer=tokenizer,
    mlm=False  # We're doing causal language modeling, not masked language modeling
)

In [None]:
# Function to extract teacher hidden states
def extract_teacher_hidden_states(model, dataset, batch_size=2):
    model.eval()
    all_hidden_states = []
    
    # Create a dataloader
    from torch.utils.data import DataLoader
    dataloader = DataLoader(
        dataset,
        batch_size=batch_size,
        collate_fn=data_collator
    )
    
    # Extract hidden states batch by batch
    with torch.no_grad():
        for batch in tqdm(dataloader):
            # Move batch to device
            batch = {k: v.to(device) for k, v in batch.items()}
            
            # Forward pass
            outputs = model(**batch, output_hidden_states=True)
            
            # Store hidden states
            hidden_states = outputs.hidden_states
            all_hidden_states.append([h.cpu() for h in hidden_states])
    
    return all_hidden_states

# Extract teacher hidden states (this would be done for the full dataset in practice)
print("Extracting teacher hidden states...")
teacher_hidden_states = extract_teacher_hidden_states(teacher_model, tokenized_dataset)

## 4. Custom Trainer for Feature-Based Distillation

Now, let's create a custom trainer that incorporates the teacher's hidden states.

In [None]:
# Custom trainer for feature-based distillation
class FeatureDistillationTrainer(Trainer):
    def __init__(self, *args, teacher_hidden_states=None, **kwargs):
        super().__init__(*args, **kwargs)
        self.teacher_hidden_states = teacher_hidden_states
    
    def compute_loss(self, model, inputs, return_outputs=False):
        # Get the batch index
        if hasattr(self, "_current_batch_idx"):
            batch_idx = self._current_batch_idx
        else:
            batch_idx = 0
            self._current_batch_idx = 0
        
        # Get teacher hidden states for this batch
        if self.teacher_hidden_states and batch_idx < len(self.teacher_hidden_states):
            teacher_hidden = self.teacher_hidden_states[batch_idx]
            teacher_hidden = [h.to(device) for h in teacher_hidden]
        else:
            teacher_hidden = None
        
        # Increment batch index for next call
        self._current_batch_idx = (batch_idx + 1) % len(self.teacher_hidden_states) if self.teacher_hidden_states else 0
        
        # Forward pass with feature distillation
        outputs = model(
            input_ids=inputs["input_ids"],
            attention_mask=inputs.get("attention_mask"),
            labels=inputs.get("labels"),
            teacher_hidden_states=teacher_hidden
        )
        
        loss = outputs["loss"]
        
        return (loss, outputs) if return_outputs else loss

## 5. Training with Feature-Based Distillation

Now, let's train the student model using feature-based distillation.

In [None]:
# Wrap the student model in our feature distillation model
feature_distillation_model = FeatureDistillationModel(student_model)

# Define training arguments
training_args = TrainingArguments(
    output_dir="./results/gpt2-feature-distilled",
    learning_rate=5e-5,
    per_device_train_batch_size=2,  # Small batch size for demonstration
    num_train_epochs=3,
    weight_decay=0.01,
    logging_steps=10,
    save_strategy="epoch",
    report_to="none",
)

# Create the feature distillation trainer
feature_trainer = FeatureDistillationTrainer(
    model=feature_distillation_model,
    args=training_args,
    train_dataset=tokenized_dataset,
    data_collator=data_collator,
    tokenizer=tokenizer,
    teacher_hidden_states=teacher_hidden_states
)

In [None]:
# Train the model with feature distillation
# Note: In practice, you would use a larger dataset and more epochs
feature_trainer.train()

## 6. Evaluating the Distilled Model

Let's compare the teacher and student models on text generation.

In [None]:
# Function to generate text with a model
def generate_text(model, tokenizer, prompt, max_length=100):
    inputs = tokenizer(prompt, return_tensors="pt").to(device)
    
    with torch.no_grad():
        outputs = model.generate(
            inputs.input_ids,
            max_length=max_length,
            do_sample=True,
            top_p=0.9,
            temperature=0.7,
            num_return_sequences=1
        )
    
    return tokenizer.decode(outputs[0], skip_special_tokens=True)

In [None]:
# Test prompts
test_prompts = [
    "The future of artificial intelligence is",
    "The most effective way to combat climate change is",
    "When considering the ethics of technology, it's important to"
]

# Compare teacher and student outputs
for prompt in test_prompts:
    print(f"Prompt: {prompt}")
    
    # Teacher output
    teacher_output = generate_text(teacher_model, tokenizer, prompt)
    print(f"\nTeacher:\n{teacher_output}")
    
    # Student output
    student_output = generate_text(student_model, tokenizer, prompt)
    print(f"\nStudent:\n{student_output}")
    
    print("\n" + "-"*50 + "\n")

## 7. Measuring Efficiency Gains

Let's measure the inference speed and memory usage of both models.

In [None]:
import time

# Function to measure inference time
def measure_inference_time(model, tokenizer, prompt, num_runs=10):
    inputs = tokenizer(prompt, return_tensors="pt").to(device)
    
    # Warm-up run
    with torch.no_grad():
        _ = model.generate(inputs.input_ids, max_length=50)
    
    # Timed runs
    start_time = time.time()
    with torch.no_grad():
        for _ in range(num_runs):
            _ = model.generate(inputs.input_ids, max_length=50)
    end_time = time.time()
    
    avg_time = (end_time - start_time) / num_runs
    return avg_time

# Measure inference time for both models
prompt = "The future of artificial intelligence is"
teacher_time = measure_inference_time(teacher_model, tokenizer, prompt)
student_time = measure_inference_time(student_model, tokenizer, prompt)

print(f"Teacher inference time: {teacher_time:.4f} seconds")
print(f"Student inference time: {student_time:.4f} seconds")
print(f"Speedup: {teacher_time / student_time:.2f}x")

In [None]:
# Measure memory usage
def measure_memory_usage(model, tokenizer, prompt):
    if torch.cuda.is_available():
        torch.cuda.reset_peak_memory_stats()
        torch.cuda.empty_cache()
        
        inputs = tokenizer(prompt, return_tensors="pt").to(device)
        with torch.no_grad():
            _ = model.generate(inputs.input_ids, max_length=50)
        
        memory_usage = torch.cuda.max_memory_allocated() / 1024**2  # MB
        return memory_usage
    else:
        return "N/A (CUDA not available)"

# Measure memory usage for both models
if torch.cuda.is_available():
    teacher_memory = measure_memory_usage(teacher_model, tokenizer, prompt)
    student_memory = measure_memory_usage(student_model, tokenizer, prompt)
    
    print(f"Teacher memory usage: {teacher_memory:.2f} MB")
    print(f"Student memory usage: {student_memory:.2f} MB")
    print(f"Memory reduction: {teacher_memory / student_memory:.2f}x")
else:
    print("Memory usage measurement requires CUDA.")

## Conclusion

In this notebook, we've explored advanced knowledge distillation techniques for language models:

1. **Feature-Based Distillation**: We implemented distillation based on the internal representations (hidden states) of the teacher model, not just the final outputs.

2. **Synthetic Data Generation**: We used the teacher model to generate synthetic data for distillation, which can be particularly useful when labeled data is scarce.

3. **Efficiency Evaluation**: We measured the inference speed and memory usage improvements achieved through distillation.

The distilled student model is significantly smaller and faster than the teacher model while maintaining much of its generation capability. This demonstrates the power of knowledge distillation for creating efficient models that can be deployed in resource-constrained environments.

In practice, you would want to:
- Use a larger synthetic dataset
- Train for more epochs
- Experiment with different distillation objectives and hyperparameters
- Evaluate on standardized benchmarks

Knowledge distillation is an active area of research, and new techniques are constantly being developed to improve the efficiency and performance of language models.