In [1]:
import numpy as np
import torch
import torch.nn.functional as F
from datasets import load_dataset
from transformers import (
    AutoTokenizer,
    AutoModelForSequenceClassification,
    DataCollatorWithPadding,
    TrainingArguments,
    Trainer,
)
import evaluate

# -----------------------------
# Config
# -----------------------------
STUDENT_MODEL_NAME = "distilbert-base-uncased"
TEACHER_MODEL_NAME = "bert-base-uncased"   # teacher (will be fine-tuned first)
NUM_LABELS = 2
MAX_LENGTH = 256
SEED = 12

# KD hyperparams
TEMPERATURE = 2.0   # T
ALPHA = 0.5         # weight for CE vs KD; higher = more label supervision

# -----------------------------
# Data
# -----------------------------
dataset = load_dataset("imdb")

# Rename split for consistency
dataset = {
    "train": dataset["train"],
    "validation": dataset["test"],  # paper reports test accuracy
}

tokenizer = AutoTokenizer.from_pretrained(STUDENT_MODEL_NAME)

def tokenize_fn(batch):
    return tokenizer(
        batch["text"],
        truncation=True,
        max_length=MAX_LENGTH,
    )

tokenized = {
    "train": dataset["train"].map(tokenize_fn, batched=True),
    "validation": dataset["validation"].map(tokenize_fn, batched=True),
}

data_collator = DataCollatorWithPadding(tokenizer=tokenizer)

# Remove unused columns and rename label -> labels
keep_cols = ["input_ids", "attention_mask", "label"]
for split in ["train", "validation"]:
    tokenized[split] = tokenized[split].remove_columns(
        [c for c in tokenized[split].column_names if c not in keep_cols]
    )
    tokenized[split] = tokenized[split].rename_column("label", "labels")

# -----------------------------
# Metrics
# -----------------------------
accuracy = evaluate.load("accuracy")

def compute_metrics(eval_pred):
    logits, labels = eval_pred
    preds = np.argmax(logits, axis=-1)
    return {"accuracy": accuracy.compute(predictions=preds, references=labels)["accuracy"]}

# -----------------------------
# 1) Train Teacher (task-specific)
# -----------------------------
teacher = AutoModelForSequenceClassification.from_pretrained(
    TEACHER_MODEL_NAME,
    num_labels=NUM_LABELS,
)

teacher_args = TrainingArguments(
    output_dir="teacher-bert-imdb",
    learning_rate=2e-5,
    per_device_train_batch_size=16,
    per_device_eval_batch_size=32,
    num_train_epochs=3,
    weight_decay=0.01,
    logging_steps=200,
    seed=SEED,
    report_to="none",
    save_strategy="no",   # keep simple; change if you want checkpoints
)

teacher_trainer = Trainer(
    model=teacher,
    args=teacher_args,
    train_dataset=tokenized["train"],
    eval_dataset=tokenized["validation"],
    tokenizer=tokenizer,
    data_collator=data_collator,
    compute_metrics=compute_metrics,
)

teacher_trainer.train()
print("Teacher eval:", teacher_trainer.evaluate())

# Freeze teacher (no grads) for KD stage
teacher_trainer.model.eval()
for p in teacher_trainer.model.parameters():
    p.requires_grad = False

# -----------------------------
# 2) Train Student with Task-Specific KD
# -----------------------------
student = AutoModelForSequenceClassification.from_pretrained(
    STUDENT_MODEL_NAME,
    num_labels=NUM_LABELS,
)

class KDTrainer(Trainer):
    """
    Task-specific distillation during fine-tuning:
      L = alpha * CE(labels, student_logits)
        + (1-alpha) * T^2 * KL(softmax(teacher/T) || softmax(student/T))
    """
    def __init__(self, teacher_model, temperature=2.0, alpha=0.5, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.teacher_model = teacher_model
        self.temperature = temperature
        self.alpha = alpha

    def compute_loss(self, model, inputs, return_outputs=False, **kwargs):
        labels = inputs.get("labels")
        # Student forward
        student_outputs = model(**inputs)
        student_logits = student_outputs.logits

        # Standard supervised loss (CE) is available as student_outputs.loss
        ce_loss = student_outputs.loss

        # Teacher forward (no grad)
        with torch.no_grad():
            teacher_outputs = self.teacher_model(**inputs)
            teacher_logits = teacher_outputs.logits

        T = self.temperature

        # KL divergence between softened distributions
        # KL(teacher || student) = sum p_t * (log p_t - log p_s)
        # Use log_softmax for student, softmax for teacher
        student_log_probs = F.log_softmax(student_logits / T, dim=-1)
        teacher_probs = F.softmax(teacher_logits / T, dim=-1)

        kd_loss = F.kl_div(
            student_log_probs,
            teacher_probs,
            reduction="batchmean",
        ) * (T * T)

        loss = self.alpha * ce_loss + (1.0 - self.alpha) * kd_loss

        return (loss, student_outputs) if return_outputs else loss

student_args = TrainingArguments(
    output_dir="student-distilbert-kd-imdb",
    learning_rate=2e-5,
    per_device_train_batch_size=16,
    per_device_eval_batch_size=32,
    num_train_epochs=3,
    weight_decay=0.01,
    logging_steps=200,
    seed=SEED,
    report_to="none",
    save_strategy="no",
)

kd_trainer = KDTrainer(
    teacher_model=teacher_trainer.model,  # task-specific teacher
    temperature=TEMPERATURE,
    alpha=ALPHA,
    model=student,
    args=student_args,
    train_dataset=tokenized["train"],
    eval_dataset=tokenized["validation"],
    tokenizer=tokenizer,
    data_collator=data_collator,
    compute_metrics=compute_metrics,
)

kd_trainer.train()
print("Student KD eval:", kd_trainer.evaluate())


Some weights of BertForSequenceClassification were not initialized from the model checkpoint at bert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
  teacher_trainer = Trainer(


KeyboardInterrupt: 