# ðŸŽ“ Knowledge Distillation: From Teacher to Student

[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/adiel2012/model-quantization/blob/main/distillation_demo.ipynb)

Knowledge Distillation (KD) is a model compression technique where a smaller model (**Student**) is trained to reproduce the behavior of a larger, pre-trained model (**Teacher**). 

Instead of just training on hard labels (correct/incorrect), the student learns from the teacher's "soft targets"â€”the probability distributions over all classes. These soft targets contain rich information about the teacher's internal logic.

In [None]:
!pip install transformers datasets torch -q

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import GPT2LMHeadModel, GPT2Config, GPT2Tokenizer

# 1. Load Teacher (GPT-2)
teacher_model = GPT2LMHeadModel.from_pretrained("gpt2")
tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
teacher_model.eval() # Freeze teacher

# 2. Initialize a much smaller Student
# GPT-2 has 12 layers. Let's make a 2-layer student.
student_config = GPT2Config(
    n_layer=2, 
    n_head=8, 
    n_embd=512, # smaller embedding
    vocab_size=tokenizer.vocab_size
)
student_model = GPT2LMHeadModel(student_config)

print(f"Teacher Parameters: {sum(p.numel() for p in teacher_model.parameters())/1e6:.1f}M")
print(f"Student Parameters: {sum(p.numel() for p in student_model.parameters())/1e6:.1f}M")

## ðŸ§  The Distillation Loss
We use **KL-Divergence** to measure how well the student's probability distribution matches the teacher's. We also use a **Temperature** (T) parameter to smooth the distributions.

In [None]:
def distillation_loss(student_logits, teacher_logits, labels, T=2.0, alpha=0.5):
    # 1. Distillation Loss (KL Divergence)
    # Filter out labels to match teacher_logits if needed, but for LM we usually match the whole sequence
    soft_loss = nn.KLDivLoss(reduction="batchmean")(
        F.log_softmax(student_logits / T, dim=-1),
        F.softmax(teacher_logits / T, dim=-1)
    ) * (T * T)
    
    # 2. Hard Label Loss (Standard Cross Entropy)
    hard_loss = F.cross_entropy(student_logits.view(-1, student_logits.size(-1)), labels.view(-1))
    
    return alpha * soft_loss + (1 - alpha) * hard_loss

print("Distillation loss function defined.")

## ðŸ”„ Training Loop Mockup
In a real scenario, you would run this over a large dataset (like WikiText or OpenWebText).

In [None]:
optimizer = torch.optim.AdamW(student_model.parameters(), lr=5e-5)

dummy_input = tokenizer("Knowledge distillation is a powerful technique", return_tensors="pt")
labels = dummy_input["input_ids"]

# Forward pass
with torch.no_grad():
    teacher_outputs = teacher_model(**dummy_input)
    teacher_logits = teacher_outputs.logits

student_outputs = student_model(**dummy_input)
student_logits = student_outputs.logits

loss = distillation_loss(student_logits, teacher_logits, labels)
loss.backward()
optimizer.step()

print(f"Distillation Step Complete. Loss: {loss.item():.4f}")