In [1]:
# 📓 Distillation Notebook: Train a Student LLM from a Teacher

from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, TrainingArguments, DataCollatorForSeq2Seq, Trainer
from datasets import load_dataset
import torch
import torch.nn.functional as F
import pandas as pd
from peft import PeftModel

In [3]:
class DistillationTrainer(Trainer):
    def __init__(self, teacher_model, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.teacher = teacher_model.eval()

    def compute_loss(self, model, inputs, return_outputs=False, **kwargs):
        with torch.no_grad():
            teacher_outputs = self.teacher(**inputs)
        student_outputs = model(**inputs)

        loss = F.kl_div(
            input=F.log_softmax(student_outputs.logits / 1.0, dim=-1),
            target=F.softmax(teacher_outputs.logits / 1.0, dim=-1),
            reduction="batchmean"
        )
        return (loss, student_outputs) if return_outputs else loss



In [4]:
# Define models
teacher_model = "google/flan-t5-base"
student_model = "google/flan-t5-small"
tokenizer = AutoTokenizer.from_pretrained(student_model)
teacher = AutoModelForSeq2SeqLM.from_pretrained(teacher_model)
student = AutoModelForSeq2SeqLM.from_pretrained(student_model)

# 📁 Load and preprocess data
file_path = "../data/sample_train.jsonl"
dataset = load_dataset("json", data_files=file_path, split="train")

def preprocess(example):
    x = tokenizer(example["text"], padding="max_length", truncation=True, max_length=128)
    x["labels"] = x["input_ids"].copy()
    return x

dataset = dataset.map(preprocess, batched=True)

# 🛠️ Training configuration
args = TrainingArguments(
    output_dir="../distilled-student",
    per_device_train_batch_size=2,
    max_steps=100,
    logging_steps=10,
    save_steps=100,
    report_to="none",
    fp16=False,
)

trainer = DistillationTrainer(
    model=student,
    teacher_model=teacher,
    args=args,
    train_dataset=dataset,
    tokenizer=tokenizer,
    data_collator=DataCollatorForSeq2Seq(tokenizer, model=student)
)

# Train the student model
trainer.train()

# Save outputs
trainer.model.save_pretrained("../distilled-student")
tokenizer.save_pretrained("../distilled-student")


Generating train split: 0 examples [00:00, ? examples/s]

Map:   0%|          | 0/3 [00:00<?, ? examples/s]

  super().__init__(*args, **kwargs)


RuntimeError: Placeholder storage has not been allocated on MPS device!