In [None]:
pip install transformers datasets evaluate rouge_score pandas tqdm

Collecting datasets
  Downloading datasets-3.5.0-py3-none-any.whl.metadata (19 kB)
Collecting evaluate
  Downloading evaluate-0.4.3-py3-none-any.whl.metadata (9.2 kB)
Collecting rouge_score
  Downloading rouge_score-0.1.2.tar.gz (17 kB)
  Preparing metadata (setup.py) ... [?25l[?25hdone
Collecting dill<0.3.9,>=0.3.0 (from datasets)
  Downloading dill-0.3.8-py3-none-any.whl.metadata (10 kB)
Collecting xxhash (from datasets)
  Downloading xxhash-3.5.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (12 kB)
Collecting multiprocess<0.70.17 (from datasets)
  Downloading multiprocess-0.70.16-py311-none-any.whl.metadata (7.2 kB)
Collecting fsspec<=2024.12.0,>=2023.1.0 (from fsspec[http]<=2024.12.0,>=2023.1.0->datasets)
  Downloading fsspec-2024.12.0-py3-none-any.whl.metadata (11 kB)
Downloading datasets-3.5.0-py3-none-any.whl (491 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m491.2/491.2 kB[0m [31m18.3 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading

In [None]:
import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset, random_split
from transformers import (
    AutoModelForCausalLM,
    AutoTokenizer,
    get_scheduler
)
from torch.optim import AdamW
import evaluate
from tqdm import tqdm
import pandas as pd
import os

# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# Models
TEACHER_MODEL = "gpt2-large"  #
STUDENT_MODEL = "openai-community/gpt2"

# Hyperparameters
MAX_LENGTH = 256
BATCH_SIZE = 1  # To avoid OOM errors
LEARNING_RATE = 5e-5
NUM_EPOCHS = 3
TEMPERATURE = 2.0
ALPHA = 0.5

# Load SQuAD subset
def load_squad_sample(sample_size=100):
    from datasets import load_dataset
    return load_dataset("squad", split=f"train[:{sample_size}]")

# Dataset class
class KnowledgeDistillationDataset(Dataset):
    def __init__(self, dataset, teacher_tokenizer, student_tokenizer, max_length):
        self.dataset = dataset
        self.teacher_tokenizer = teacher_tokenizer
        self.student_tokenizer = student_tokenizer
        self.max_length = max_length

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):
        item = self.dataset[idx]
        question = item["question"]
        context = item["context"]
        answer = item["answers"]["text"][0] if item["answers"]["text"] else ""
        prompt = f"Context: {context}\nQuestion: {question}\nAnswer:"

        teacher_inputs = self.teacher_tokenizer(prompt, max_length=self.max_length, padding="max_length", truncation=True, return_tensors="pt")
        student_inputs = self.student_tokenizer(prompt, max_length=self.max_length, padding="max_length", truncation=True, return_tensors="pt")
        teacher_labels = self.teacher_tokenizer(answer, max_length=self.max_length, padding="max_length", truncation=True, return_tensors="pt")
        student_labels = self.student_tokenizer(answer, max_length=self.max_length, padding="max_length", truncation=True, return_tensors="pt")

        return {
            "teacher_input_ids": teacher_inputs.input_ids.squeeze(),
            "teacher_attention_mask": teacher_inputs.attention_mask.squeeze(),
            "teacher_labels": teacher_labels.input_ids.squeeze(),
            "student_input_ids": student_inputs.input_ids.squeeze(),
            "student_attention_mask": student_inputs.attention_mask.squeeze(),
            "student_labels": student_labels.input_ids.squeeze(),
            "prompt": prompt,
            "answer": answer
        }

# Distillation Loss
def distillation_loss(student_logits, teacher_logits, labels, temperature, alpha, ignore_index):
    vocab_size = student_logits.size(-1)
    if teacher_logits.size(-1) > vocab_size:
        teacher_logits = teacher_logits[..., :vocab_size]

    soft_targets = F.softmax(teacher_logits / temperature, dim=-1)
    soft_prob = F.log_softmax(student_logits / temperature, dim=-1)
    kd_loss = F.kl_div(soft_prob, soft_targets, reduction='batchmean') * (temperature ** 2)
    ce_loss = F.cross_entropy(student_logits.view(-1, vocab_size), labels.view(-1), ignore_index=ignore_index)
    return alpha * ce_loss + (1 - alpha) * kd_loss, ce_loss, kd_loss

# Training loop
def train_model(teacher_model, student_model, train_loader, optimizer, scheduler, ignore_index):
    student_model.train()
    teacher_model.eval()
    total_loss = total_ce_loss = total_kd_loss = 0
    progress_bar = tqdm(train_loader, desc="Training")

    for batch in progress_bar:
        student_input_ids = batch["student_input_ids"].to(device)
        student_attention_mask = batch["student_attention_mask"].to(device)
        student_labels = batch["student_labels"].to(device)
        teacher_input_ids = batch["teacher_input_ids"].to(device)
        teacher_attention_mask = batch["teacher_attention_mask"].to(device)

        with torch.no_grad():
            teacher_logits = teacher_model(teacher_input_ids, attention_mask=teacher_attention_mask).logits

        student_outputs = student_model(student_input_ids, attention_mask=student_attention_mask)
        student_logits = student_outputs.logits

        loss, ce_loss, kd_loss = distillation_loss(student_logits, teacher_logits, student_labels, TEMPERATURE, ALPHA, ignore_index)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        scheduler.step()

        total_loss += loss.item()
        total_ce_loss += ce_loss.item()
        total_kd_loss += kd_loss.item()
        progress_bar.set_postfix({
            "Loss": total_loss / (progress_bar.n + 1),
            "CE": total_ce_loss / (progress_bar.n + 1),
            "KD": total_kd_loss / (progress_bar.n + 1)
        })

# Evaluation
def evaluate_model(model, tokenizer, eval_dataset, name="Model"):
    model.eval()
    squad_metric = evaluate.load("squad")
    rouge_metric = evaluate.load("rouge")
    bleu_metric = evaluate.load("bleu")

    results = []
    loader = DataLoader(eval_dataset, batch_size=1)
    all_preds, all_refs = [], []

    for idx, batch in enumerate(tqdm(loader, desc=f"Evaluating {name}")):
        input_ids = batch[f"{name.lower()}_input_ids"].to(device)
        attention_mask = batch[f"{name.lower()}_attention_mask"].to(device)
        prompt = batch["prompt"][0]
        true_answer = batch["answer"][0]

        with torch.no_grad():
            outputs = model.generate(input_ids=input_ids, attention_mask=attention_mask, max_new_tokens=32)
            prediction = tokenizer.decode(outputs[0], skip_special_tokens=True)

        generated_answer = prediction.replace(prompt, "").strip()
        reference = true_answer.strip()

        results.append({
            "prompt": prompt,
            "reference": reference,
            "prediction": generated_answer
        })

        all_preds.append(generated_answer)
        all_refs.append(reference)

        squad_metric.add(
            prediction={"id": str(idx), "prediction_text": generated_answer},
            reference={"id": str(idx), "answers": {"text": [reference], "answer_start": [0]}}
        )

    squad_scores = squad_metric.compute()
    rouge_scores = rouge_metric.compute(predictions=all_preds, references=all_refs)
    bleu_scores = bleu_metric.compute(predictions=all_preds, references=[[ref] for ref in all_refs])

    combined = {
        "squad": squad_scores,
        "rouge": rouge_scores,
        "bleu": bleu_scores
    }

    return results, combined

# Print metrics
def print_metrics(name, scores):
    print(f"\n{name} Evaluation:")
    print(f"  SQuAD - EM: {scores['squad']['exact_match']:.2f}, F1: {scores['squad']['f1']:.2f}")
    print(f"  ROUGE:")
    for k, v in scores["rouge"].items():
        print(f"    {k}: {v:.2f}")
    print(f"  BLEU: {scores['bleu']['bleu']:.2f}")

# Main function
def main():
    print("📚 Loading dataset...")
    dataset = load_squad_sample(100)

    print("🧠 Loading teacher model...")
    teacher_tokenizer = AutoTokenizer.from_pretrained(TEACHER_MODEL)
    teacher_tokenizer.pad_token = teacher_tokenizer.eos_token
    teacher_model = AutoModelForCausalLM.from_pretrained(TEACHER_MODEL).to(device)

    print("🎓 Loading student model...")
    student_tokenizer = AutoTokenizer.from_pretrained(STUDENT_MODEL)
    student_tokenizer.pad_token = student_tokenizer.eos_token
    student_pad_token_id = student_tokenizer.pad_token_id
    student_model = AutoModelForCausalLM.from_pretrained(STUDENT_MODEL).to(device)

    print("🛠️ Preparing dataset...")
    kd_dataset = KnowledgeDistillationDataset(dataset, teacher_tokenizer, student_tokenizer, MAX_LENGTH)
    train_size = int(0.8 * len(kd_dataset))
    train_dataset, eval_dataset = random_split(kd_dataset, [train_size, len(kd_dataset) - train_size])
    train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)

    optimizer = AdamW(student_model.parameters(), lr=LEARNING_RATE)
    scheduler = get_scheduler("linear", optimizer=optimizer, num_warmup_steps=0, num_training_steps=len(train_loader)*NUM_EPOCHS)

    for epoch in range(NUM_EPOCHS):
        print(f"\n🚀 Epoch {epoch+1}/{NUM_EPOCHS}")
        train_model(teacher_model, student_model, train_loader, optimizer, scheduler, student_pad_token_id)

        out_dir = f"student_checkpoint_epoch_{epoch+1}"
        os.makedirs(out_dir, exist_ok=True)
        student_model.save_pretrained(out_dir)
        student_tokenizer.save_pretrained(out_dir)

    print("\n📏 Evaluating models...")
    student_results, student_scores = evaluate_model(student_model, student_tokenizer, eval_dataset, name="Student")
    teacher_results, teacher_scores = evaluate_model(teacher_model, teacher_tokenizer, eval_dataset, name="Teacher")

    print_metrics("Student", student_scores)
    print_metrics("Teacher", teacher_scores)

    pd.DataFrame({
        "Prompt": [r["prompt"] for r in student_results],
        "True Answer": [r["reference"] for r in student_results],
        "Student Prediction": [r["prediction"] for r in student_results],
        "Teacher Prediction": [r["prediction"] for r in teacher_results]
    }).to_csv("distillation_eval_results.csv", index=False)

    print("\n✅ Results saved to distillation_eval_results.csv")

if __name__ == "__main__":
    main()


Using device: cuda
📚 Loading dataset...


The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


README.md:   0%|          | 0.00/7.62k [00:00<?, ?B/s]

train-00000-of-00001.parquet:   0%|          | 0.00/14.5M [00:00<?, ?B/s]

validation-00000-of-00001.parquet:   0%|          | 0.00/1.82M [00:00<?, ?B/s]

Generating train split:   0%|          | 0/87599 [00:00<?, ? examples/s]

Generating validation split:   0%|          | 0/10570 [00:00<?, ? examples/s]

🧠 Loading teacher model...


tokenizer_config.json:   0%|          | 0.00/26.0 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/666 [00:00<?, ?B/s]

vocab.json:   0%|          | 0.00/1.04M [00:00<?, ?B/s]

merges.txt:   0%|          | 0.00/456k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/1.36M [00:00<?, ?B/s]

Xet Storage is enabled for this repo, but the 'hf_xet' package is not installed. Falling back to regular HTTP download. For better performance, install the package with: `pip install huggingface_hub[hf_xet]` or `pip install hf_xet`


model.safetensors:   0%|          | 0.00/3.25G [00:00<?, ?B/s]

generation_config.json:   0%|          | 0.00/124 [00:00<?, ?B/s]

🎓 Loading student model...


tokenizer_config.json:   0%|          | 0.00/26.0 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/665 [00:00<?, ?B/s]

vocab.json:   0%|          | 0.00/1.04M [00:00<?, ?B/s]

merges.txt:   0%|          | 0.00/456k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/1.36M [00:00<?, ?B/s]

Xet Storage is enabled for this repo, but the 'hf_xet' package is not installed. Falling back to regular HTTP download. For better performance, install the package with: `pip install huggingface_hub[hf_xet]` or `pip install hf_xet`


model.safetensors:   0%|          | 0.00/548M [00:00<?, ?B/s]

generation_config.json:   0%|          | 0.00/124 [00:00<?, ?B/s]

🛠️ Preparing dataset...

🚀 Epoch 1/3


Training: 100%|██████████| 80/80 [00:18<00:00,  4.22it/s, Loss=106, CE=10.3, KD=202]



🚀 Epoch 2/3


Training: 100%|██████████| 80/80 [00:18<00:00,  4.27it/s, Loss=85.5, CE=9.78, KD=161]



🚀 Epoch 3/3


Training: 100%|██████████| 80/80 [00:18<00:00,  4.35it/s, Loss=80.1, CE=9.41, KD=151]



📏 Evaluating models...


Downloading builder script:   0%|          | 0.00/4.53k [00:00<?, ?B/s]

Downloading extra modules:   0%|          | 0.00/3.32k [00:00<?, ?B/s]

Downloading builder script:   0%|          | 0.00/6.27k [00:00<?, ?B/s]

Downloading builder script:   0%|          | 0.00/5.94k [00:00<?, ?B/s]

Downloading extra modules:   0%|          | 0.00/1.55k [00:00<?, ?B/s]

Downloading extra modules:   0%|          | 0.00/3.34k [00:00<?, ?B/s]

Evaluating Student:   0%|          | 0/20 [00:00<?, ?it/s]Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.
Evaluating Student:   5%|▌         | 1/20 [00:00<00:11,  1.63it/s]Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.
Evaluating Student:  10%|█         | 2/20 [00:00<00:08,  2.19it/s]Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.
Evaluating Student:  15%|█▌        | 3/20 [00:01<00:07,  2.39it/s]Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.
Evaluating Student:  20%|██        | 4/20 [00:01<00:06,  2.56it/s]Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.
Evaluating Student:  25%|██▌       | 5/20 [00:02<00:05,  2.62it/s]Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.
Evaluating Student:  30%|███       | 6/20 [00:02<00:05,  2.67it/s]Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.
Evaluating Student:  35%|███▌      | 7/20 


Student Evaluation:
  SQuAD - EM: 0.00, F1: 5.33
  ROUGE:
    rouge1: 0.06
    rouge2: 0.02
    rougeL: 0.06
    rougeLsum: 0.06
  BLEU: 0.01

Teacher Evaluation:
  SQuAD - EM: 0.00, F1: 6.87
  ROUGE:
    rouge1: 0.07
    rouge2: 0.03
    rougeL: 0.07
    rougeLsum: 0.07
  BLEU: 0.01

✅ Results saved to distillation_eval_results.csv



