# Knowledge Distillation with hf_distiller
This notebook demonstrates:
1. Loading a teacher model from Hugging Face Hub
2. Creating a smaller student model
3. Preparing a toy dataset
4. Training the student using knowledge distillation
5. Visualizing training loss and logits comparison

You can replace the demo dataset with your own dataset for real training.

In [None]:
# Step 0 — Install requirements (run only once)
# !pip install --no-deps git+https://github.com/Dhiraj309/transformers_distillation.git

## Step 1 — Imports and Setup

In [None]:
import sys
import os
from transformers import AutoTokenizer, TrainingArguments
from datasets import Dataset
from transformers_distillation.models import load_teacher, load_student
from transformers_distillation import DistillTrainer
import torch

## Step 2 — Load Teacher Model

In [None]:
MODEL_NAME = 'google-bert/bert-base-uncased'

# Load teacher and tokenizer
teacher = load_teacher(model_name_or_path=MODEL_NAME)
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, use_fast=True)
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token

print("Teacher model loaded:", teacher.__class__.__name__)
print("Tokenizer vocab size:", len(tokenizer))

## Step 3 — Create Student Model
A smaller architecture for faster inference and lower memory usage.

In [None]:
student = load_student(
    model_name_or_path=MODEL_NAME,
    from_scratch=True,
    n_layers=4,
    n_heads=4,
    n_embd=256,
    is_pretrained=False
)
print("Student model created:", student.__class__.__name__)

## Step 4 — Prepare Dataset
Small in-memory dataset for demonstration. Replace with your own data for real training.

In [None]:
texts = [
    "Hello world!",
    "The quick brown fox jumps over the lazy dog.",
    "Artificial intelligence is transforming industries.",
    "Once upon a time, there was a curious developer.",
    "PyTorch makes deep learning both fun and powerful."
]
dataset = Dataset.from_dict({"text": texts})

def tokenize(batch):
    return tokenizer(batch['text'], max_length=128, padding=True, truncation=True)

tokenized_dataset = dataset.map(tokenize, remove_columns=['text'])
eval_dataset = tokenized_dataset.select(range(1))
print("Tokenized example:", tokenized_dataset[0])

## Step 5 — Define Training Arguments

In [None]:
training_args = TrainingArguments(
    output_dir='./student-llm',
    per_device_train_batch_size=1,
    num_train_epochs=3,
    learning_rate=2e-4,
    logging_steps=1,
    save_steps=100,
    save_total_limit=5,
    report_to='none',
    lr_scheduler_type='cosine',
    warmup_steps=10,
)

## Step 6 — Initialize Distillation Trainer

In [None]:
trainer = DistillTrainer(
    teacher_model=teacher,
    student_model=student,
    train_dataset=tokenized_dataset,
    tokenizer=tokenizer,
    training_args=training_args,
    kd_alpha=0.5,
    temperature=2.0,
    is_pretrained=False
)

## Step 7 — Train Student Model
The student learns from both teacher outputs and ground truth labels.

In [None]:
# Keep track of loss for visualization
trainer_state = trainer.train()
losses = trainer_state.training_loss if hasattr(trainer_state, 'training_loss') else []

## Step 8 — Evaluate Student Model

In [None]:
results = trainer.evaluate(eval_dataset = eval_dataset)
print('Evaluation results:', results)