# 🧠 Knowledge Distillation: GPT-2 → TinyTransformer
This notebook demonstrates how to train a small transformer (student) to mimic GPT-2 (teacher) on a next-token prediction task using knowledge distillation.

**Task**: Language Modeling (Next Token Prediction)  
**Teacher**: GPT-2  
**Student**: Custom 4-layer Transformer  
**Dataset**: WikiText-2 (subset)

In [None]:
!pip install transformers datasets

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from transformers import GPT2Tokenizer, GPT2LMHeadModel
from torch.utils.data import Dataset, DataLoader
from datasets import load_dataset
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np
from tqdm import tqdm
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

## 🔗 Load GPT-2 Teacher

In [None]:
tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
tokenizer.pad_token = tokenizer.eos_token
teacher = GPT2LMHeadModel.from_pretrained('gpt2').to(device)
teacher.eval()

## 📚 Load and Preprocess WikiText-2

In [None]:
raw_dataset = load_dataset("wikitext", "wikitext-2-raw-v1")
train_texts = [t for t in raw_dataset['train']['text'] if len(t.strip()) > 30][:5000]

class WikiDataset(Dataset):
    def __init__(self, texts, tokenizer, seq_len=32):
        self.samples = []
        for t in texts:
            ids = tokenizer(t, return_tensors="pt", truncation=True,
                            max_length=seq_len+1, padding="max_length")['input_ids'].squeeze(0)
            self.samples.append((ids[:-1], ids[1:]))
    def __len__(self): return len(self.samples)
    def __getitem__(self, idx): return self.samples[idx]

train_ds = WikiDataset(train_texts, tokenizer)
train_dl = DataLoader(train_ds, batch_size=8, shuffle=True)

## 🧑‍🎓 Define TinyTransformer Student

In [None]:
class TinyTransformer(nn.Module):
    def __init__(self, vocab_size, d_model=256, nhead=4, num_layers=4, ff_dim=512):
        super().__init__()
        self.embed = nn.Embedding(vocab_size, d_model)
        encoder = nn.TransformerEncoderLayer(d_model, nhead, ff_dim)
        self.transformer = nn.TransformerEncoder(encoder, num_layers)
        self.lm_head = nn.Linear(d_model, vocab_size)
    def forward(self, x):
        x = self.embed(x).permute(1, 0, 2)
        x = self.transformer(x).permute(1, 0, 2)
        return self.lm_head(x)

student = TinyTransformer(len(tokenizer)).to(device)
optimizer = optim.Adam(student.parameters(), lr=5e-4)

## 🔥 Train Student with Knowledge Distillation

In [None]:
T = 2.0
alpha = 0.7

def distillation_loss(student_logits, teacher_logits, target):
    soft = F.kl_div(F.log_softmax(student_logits / T, dim=-1),
                    F.softmax(teacher_logits / T, dim=-1), reduction='batchmean') * T * T
    hard = F.cross_entropy(student_logits.view(-1, student_logits.size(-1)), target.view(-1))
    return alpha * soft + (1 - alpha) * hard

In [None]:
losses, kls, entropies, accs = [], [], [], []
for epoch in range(10):
    total_loss = total_kl = total_ent = total_acc = 0
    student.train()
    for x, y in tqdm(train_dl, desc=f"Epoch {epoch+1}"):
        x, y = x.to(device), y.to(device)
        with torch.no_grad(): t_logits = teacher(x).logits
        s_logits = student(x)
        loss = distillation_loss(s_logits, t_logits, y)
        optimizer.zero_grad(); loss.backward(); optimizer.step()
        kl = F.kl_div(F.log_softmax(s_logits / T, dim=-1), F.softmax(t_logits / T, dim=-1), reduction='batchmean').item()
        ent = -(F.softmax(s_logits, dim=-1) * F.log_softmax(s_logits, dim=-1)).sum(-1).mean().item()
        acc = (s_logits.argmax(dim=-1) == y).float().mean().item()
        total_loss += loss.item(); total_kl += kl; total_ent += ent; total_acc += acc
    losses.append(total_loss/len(train_dl))
    kls.append(total_kl/len(train_dl))
    entropies.append(total_ent/len(train_dl))
    accs.append(total_acc/len(train_dl))

## 📊 Visualize Training Metrics

In [None]:
plt.figure(figsize=(16,4))
for i, (data, title) in enumerate(zip([losses, kls, entropies, accs],
   ["Loss", "KL", "Entropy", "Accuracy"])):
    plt.subplot(1,4,i+1)
    plt.plot(data)
    plt.title(title)
    plt.grid(True)
plt.tight_layout(); plt.show()

## 🔍 Compare Predictions for a Token

In [None]:
def visualize_prediction(sentence, token_index):
    inputs = tokenizer(sentence, return_tensors="pt").input_ids.to(device)
    with torch.no_grad():
        t_logits = teacher(inputs).logits
        s_logits = student(inputs)
    t_probs = F.softmax(t_logits[0, token_index], dim=-1).cpu().numpy()
    s_probs = F.softmax(s_logits[0, token_index], dim=-1).cpu().numpy()
    topk = np.argsort(t_probs)[-10:]
    labels = [tokenizer.decode([i]) for i in topk]
    x = np.arange(len(labels))
    plt.bar(x - 0.2, t_probs[topk], 0.4, label='Teacher')
    plt.bar(x + 0.2, s_probs[topk], 0.4, label='Student')
    plt.xticks(x, labels, rotation=45)
    plt.legend()
    plt.title(f"Token {token_index} prediction: '{tokenizer.decode([inputs[0, token_index].item()])}'")
    plt.tight_layout(); plt.show()

# visualize_prediction("The quick brown fox jumps over the lazy", 5)