# Knowledge Distillation for LLM Optimization

This notebook implements knowledge distillation to optimize DistilBERT for deployment on edge devices. Knowledge distillation is a technique where a smaller model (student) is trained to mimic the behavior of a larger, more powerful model (teacher).

In [None]:
import sys
import torch
import numpy as np
import matplotlib.pyplot as plt
from pathlib import Path

# Add the src directory to the path
sys.path.append('..')

In [None]:
# Import our modules
from src.models.distilbert import load_distilbert_model, get_device
from src.models.knowledge_distillation import create_student_model, DistillationTrainer
from src.data.dataset import load_and_prepare_data, prepare_batch_for_model
from src.utils.metrics import measure_performance, save_metrics, print_metrics

In [None]:
# Configuration
MODEL_NAME = "distilbert-base-uncased"
NUM_LABELS = 2  # Binary classification

# Dataset configuration
DATASET_NAME = "glue"
DATASET_CONFIG = "sst2"  # Stanford Sentiment Treebank
BATCH_SIZE = 16
MAX_LENGTH = 128

# Student model configuration
STUDENT_NUM_LAYERS = 2  # Reduced from 6 in the original DistilBERT

# Training configuration
LEARNING_RATE = 5e-5
EPOCHS = 3
TEMPERATURE = 2.0
ALPHA = 0.5  # Balance between hard and soft loss
EVAL_EVERY = 100  # Evaluate every N steps

# Output path
OUTPUT_DIR = Path("../outputs")
OUTPUT_DIR.mkdir(exist_ok=True, parents=True)

# Device
DEVICE = get_device()
print(f"Using device: {DEVICE}")

## 1. Load Teacher Model and Data

First, we'll load the teacher model (DistilBERT) and prepare the dataset.

In [None]:
# Load teacher model and tokenizer
print(f"Loading pre-trained {MODEL_NAME} model as teacher...")
teacher_model, tokenizer = load_distilbert_model(MODEL_NAME, NUM_LABELS)

# Load and prepare dataset
print("Loading dataset and preparing data loaders...")
tokenizer, train_dataloader, eval_dataloader = load_and_prepare_data(
    tokenizer, 
    dataset_name=DATASET_NAME, 
    dataset_config=DATASET_CONFIG,
    batch_size=BATCH_SIZE, 
    max_length=MAX_LENGTH
)

## 2. Create the Student Model

Now, we'll create a smaller student model with fewer transformer layers.

In [None]:
# Create student model with reduced number of layers
print(f"Creating student model with {STUDENT_NUM_LAYERS} layers (reduced from 6 in original)...")
student_model = create_student_model(
    teacher_model=teacher_model,
    num_layers=STUDENT_NUM_LAYERS,
    num_labels=NUM_LABELS
)

# Compare the number of parameters in teacher vs student
teacher_params = sum(p.numel() for p in teacher_model.parameters())
student_params = sum(p.numel() for p in student_model.parameters())

print(f"Teacher model parameters: {teacher_params:,}")
print(f"Student model parameters: {student_params:,}")
print(f"Reduction: {(1 - student_params / teacher_params) * 100:.2f}%")

## 3. Train the Student Model using Knowledge Distillation

We'll train the student to mimic the teacher's behavior using knowledge distillation.

In [None]:
# Define optimizer for the student model
optimizer = torch.optim.AdamW(student_model.parameters(), lr=LEARNING_RATE)

# Define a function to prepare batch (reusing from baseline)
def prepare_batch(batch, device):
    return prepare_batch_for_model(batch, device)

# Create distillation trainer
trainer = DistillationTrainer(
    teacher_model=teacher_model,
    student_model=student_model,
    temperature=TEMPERATURE
)

In [None]:
# Train the student model
print("Starting knowledge distillation training...")
training_stats = trainer.train(
    train_dataloader=train_dataloader,
    eval_dataloader=eval_dataloader,
    optimizer=optimizer,
    device=DEVICE,
    epochs=EPOCHS,
    alpha=ALPHA,
    eval_every=EVAL_EVERY,
    prepare_batch_fn=prepare_batch
)

## 4. Evaluate and Compare Teacher vs Student

Compare the performance of the original teacher model and our distilled student model.

In [None]:
# Plot training loss and evaluation accuracy
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(16, 5))

# Plot training loss
ax1.plot(training_stats['train_losses'])
ax1.set_title('Training Loss')
ax1.set_xlabel('Training Steps')
ax1.set_ylabel('Loss')
ax1.grid(True)

# Plot evaluation accuracy
eval_steps = [EVAL_EVERY * (i+1) for i in range(len(training_stats['eval_accuracies']))]
ax2.plot(eval_steps, training_stats['eval_accuracies'])
ax2.set_title('Evaluation Accuracy')
ax2.set_xlabel('Training Steps')
ax2.set_ylabel('Accuracy')
ax2.grid(True)

plt.tight_layout()
plt.show()

In [None]:
# Load baseline metrics to compare with
baseline_metrics = torch.load(OUTPUT_DIR / "baseline_metrics.pt")

# Measure student performance
print("Measuring student model performance...")
student_metrics = measure_performance(student_model, eval_dataloader, DEVICE, prepare_batch)

# Save student metrics
save_metrics(student_metrics, file_path=OUTPUT_DIR / "student_metrics.pt")

In [None]:
# Print and compare metrics
print("\n===== Teacher Model (Baseline) =====")
print_metrics(
    baseline_metrics, 
    model_name=MODEL_NAME, 
    dataset_info=f"Text Classification - {DATASET_NAME}/{DATASET_CONFIG}"
)

print("\n===== Student Model (Distilled) =====")
print_metrics(
    student_metrics, 
    model_name=f"Distilled {MODEL_NAME} ({STUDENT_NUM_LAYERS} layers)", 
    dataset_info=f"Text Classification - {DATASET_NAME}/{DATASET_CONFIG}"
)

In [None]:
# Calculate key metrics for comparison
size_reduction = 1 - (student_metrics['model_size_mb'] / baseline_metrics['model_size_mb'])
speed_improvement = baseline_metrics['avg_latency_seconds'] / student_metrics['avg_latency_seconds'] 
accuracy_retention = student_metrics['accuracy'] / baseline_metrics['accuracy']

print("\n===== Performance Comparison =====")
print(f"Size reduction: {size_reduction * 100:.2f}%")
print(f"Speed improvement: {speed_improvement:.2f}x")
print(f"Accuracy retention: {accuracy_retention * 100:.2f}%")

# Check against optimization targets
print("\n===== Optimization Targets =====")
print(f"Size target (50% reduction): {'✅ Achieved' if size_reduction >= 0.5 else '❌ Not achieved'}")
print(f"Accuracy target (90% retention): {'✅ Achieved' if accuracy_retention >= 0.9 else '❌ Not achieved'}")

In [None]:
# Plot size, speed, and accuracy comparison
labels = ['Teacher', 'Student']

# Prepare data for comparison
sizes = [baseline_metrics['model_size_mb'], student_metrics['model_size_mb']]
latencies = [baseline_metrics['avg_latency_ms'], student_metrics['avg_latency_ms']]
accuracies = [baseline_metrics['accuracy'] * 100, student_metrics['accuracy'] * 100]
params = [baseline_metrics['num_parameters'] / 1e6, student_metrics['num_parameters'] / 1e6]

# Create bar chart comparisons
fig, axs = plt.subplots(2, 2, figsize=(15, 10))

# Size comparison
axs[0, 0].bar(labels, sizes)
axs[0, 0].set_title('Model Size (MB)')
axs[0, 0].set_ylabel('Size (MB)')
for i, v in enumerate(sizes):
    axs[0, 0].text(i, v + 2, f"{v:.2f}", ha='center')

# Latency comparison
axs[0, 1].bar(labels, latencies)
axs[0, 1].set_title('Inference Latency (ms)')
axs[0, 1].set_ylabel('Latency (ms)')
for i, v in enumerate(latencies):
    axs[0, 1].text(i, v + 2, f"{v:.2f}", ha='center')

# Accuracy comparison
axs[1, 0].bar(labels, accuracies)
axs[1, 0].set_title('Accuracy (%)')
axs[1, 0].set_ylabel('Accuracy (%)')
for i, v in enumerate(accuracies):
    axs[1, 0].text(i, v + 1, f"{v:.2f}%", ha='center')

# Parameter count comparison
axs[1, 1].bar(labels, params)
axs[1, 1].set_title('Number of Parameters (Millions)')
axs[1, 1].set_ylabel('Parameters (M)')
for i, v in enumerate(params):
    axs[1, 1].text(i, v + 1, f"{v:.2f}M", ha='center')

plt.tight_layout()
plt.savefig(OUTPUT_DIR / 'knowledge_distillation_comparison.png')
plt.show()

## 5. Save the Optimized Model

Save the student model for later use or deployment.

In [None]:
# Save the student model
student_save_path = OUTPUT_DIR / "distilled_model"
student_save_path.mkdir(exist_ok=True, parents=True)

student_model.save_pretrained(student_save_path)
tokenizer.save_pretrained(student_save_path)

print(f"Student model saved to {student_save_path}")

## 6. Summary and Conclusion

Summarize the results of knowledge distillation for our LLM optimization task.

In [None]:
# Summary of knowledge distillation results
print("===== Knowledge Distillation Summary =====")
print(f"\nTeacher model: {baseline_metrics['num_parameters']:,} parameters, {baseline_metrics['model_size_mb']:.2f} MB")
print(f"Student model: {student_metrics['num_parameters']:,} parameters, {student_metrics['model_size_mb']:.2f} MB")
print(f"Size reduction: {size_reduction * 100:.2f}%")
print(f"\nTeacher accuracy: {baseline_metrics['accuracy'] * 100:.2f}%")
print(f"Student accuracy: {student_metrics['accuracy'] * 100:.2f}%")
print(f"Accuracy retention: {accuracy_retention * 100:.2f}%")
print(f"\nTeacher inference time: {baseline_metrics['avg_latency_ms']:.2f} ms")
print(f"Student inference time: {student_metrics['avg_latency_ms']:.2f} ms")
print(f"Speed improvement: {speed_improvement:.2f}x")

print("\nConclusion:")
if size_reduction >= 0.5 and accuracy_retention >= 0.9:
    print("✅ Knowledge distillation successfully achieved our optimization targets.")
elif size_reduction >= 0.5:
    print("⚠️ Size reduction target was met, but accuracy retention fell below the 90% target.")
elif accuracy_retention >= 0.9:
    print("⚠️ Accuracy retention target was met, but size reduction fell below the 50% target.")
else:
    print("❌ Neither size reduction nor accuracy retention targets were met.")

print("\nNext steps could include:")
print("- Tuning knowledge distillation hyperparameters (temperature, alpha)")
print("- Combining distillation with other optimization approaches")
print("- Further reducing the student model size")