In [None]:
# intialize the path directories for the student and teacher model and also the name
from transformers import AutoTokenizer, AutoModelForSequenceClassification, Trainer, TrainingArguments, DataCollatorWithPadding
import torch
import torch.nn.functional as F
from torch import nn

TEACHER_PATH = "./drive/MyDrive/bert_sst2_baseline/best_model"
STUDENT_NAME = "bert_sst2_student"
OUTPUT_DIR = "./drive/MyDrive/bert_sst2_student/best_model"
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

# training hyperparameters
num_epochs = 4
per_device_train_batch_size = 16
per_device_eval_batch_size = 64
learning_rate = 3e-5
weight_decay = 0.01
warmup_ratio = 0.06

temperature = 2.0
alpha = 0.5

In [None]:
# load the dataset
from datasets import load_dataset
from evaluate import load
tasks = "sst2"
dataset = load_dataset("glue", tasks)
metric = load("accuracy")

# same as before, load the tokenizer and tokenize the dataset
tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")

def tokenize(examples):
    return tokenizer(examples["sentence"], truncation=True)

tokenized_datasets = dataset.map(tokenize, batched=True)
tokenized_datasets = tokenized_datasets.remove_columns(["sentence", "idx"])
tokenized_datasets.set_format(type="torch")
# split into three
train_datasets = tokenized_datasets["train"]
val_datasets = tokenized_datasets["validation"]
test_datasets = tokenized_datasets["test"]

data_collator = DataCollatorWithPadding(tokenizer=tokenizer)

In [None]:
teacher = AutoModelForSequenceClassification.from_pretrained(TEACHER_PATH, num_labels=2, local_files_only=True)
teacher.to(DEVICE)
teacher.eval()
for p in teacher.parameters():
    p.requires_grad = False
# pretrained-student
PRETRAINED_STUDENT = "distilbert-base-uncased"
student = AutoModelForSequenceClassification.from_pretrained(PRETRAINED_STUDENT, num_labels=2)

In [None]:
class DistillationTrainer(Trainer):
    def __init__(self, teacher_model=None, temperature=1.0, alpha=0.5, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.teacher = teacher_model
        self.temperature = temperature
        self.alpha = alpha
        self.kl_loss = nn.KLDivLoss(reduction='batchmean')

    def compute_loss(self, model, inputs, return_outputs=False, **kwargs):
        """
        Compute combined CE + KL distillation loss
        Accept **kwargs to avoid unexpected keyword errors from Trainer.
        """
        labels = inputs.get("labels")
        inputs = {k: v for k, v in inputs.items()}
        inputs.pop("token_type_ids", None)

        # forward for student (exclude labels when passing to model)
        student_inputs = {k: v for k, v in inputs.items() if k != "labels"}
        outputs_student = model(**student_inputs)
        logits_student = outputs_student.logits

        # CE loss (hard labels)
        ce_loss = F.cross_entropy(logits_student, labels)

        # Teacher logits
        with torch.no_grad():
            teacher_inputs = {k: v.to(DEVICE) for k, v in student_inputs.items()}
            teacher_outputs = self.teacher(**teacher_inputs)
            logits_teacher = teacher_outputs.logits

        # Soft targets
        T = self.temperature
        student_log_probs = F.log_softmax(logits_student/T, dim=-1)
        teacher_probs = F.softmax(logits_teacher/T, dim=-1)
        kl = self.kl_loss(student_log_probs, teacher_probs)*(T*T)

        loss = self.alpha * ce_loss + (1.0 - self.alpha) * kl

        return (loss, outputs_student) if return_outputs else loss


In [None]:
training_args = TrainingArguments(
    output_dir=OUTPUT_DIR,
    num_train_epochs=num_epochs,
    per_device_train_batch_size=per_device_train_batch_size,
    per_device_eval_batch_size=per_device_eval_batch_size,
    eval_strategy="epoch",
    save_strategy="epoch",
    learning_rate=learning_rate,
    weight_decay=weight_decay,
    logging_steps=50,
    load_best_model_at_end=True,
    metric_for_best_model="accuracy",
    fp16=True,
    push_to_hub=False,
    remove_unused_columns=False,
)

def compute_metrics(eval_pred):
    logits, labels = eval_pred
    preds = logits.argmax(-1)
    return {"accuracy": metric.compute(predictions=preds, references=labels)["accuracy"]}

trainer = DistillationTrainer(
    teacher_model=teacher,
    temperature=temperature,
    alpha=alpha,
    model=student,
    args=training_args,
    train_dataset=train_datasets,
    eval_dataset=val_datasets,
    tokenizer=tokenizer,
    data_collator=data_collator,
    compute_metrics=compute_metrics,
)

In [None]:
trainer.train()

In [None]:
import os
# Save the best model
trainer.save_model("./drive/MyDrive/bert_sst2_student/best_model")
tokenizer.save_pretrained("./drive/MyDrive/bert_sst2_student/tokenizer")

In [None]:
max_len = 128
import time
# parameter counts
def params_size_mb(model):
    params = sum(p.numel() for p in model.parameters())
    size_mb = params * 4 / (1024**2)
    return params, size_mb

s_params, s_size = params_size_mb(student)
t_params, t_size = params_size_mb(teacher)
print(f"Student params: {s_params} ≈ {s_size:.1f} MB")
print(f"Teacher params: {t_params} ≈ {t_size:.1f} MB")

# folder size of saved student
def folder_size_mb(path):
    total = 0
    for root, dirs, files in os.walk(path):
        for f in files:
            total += os.path.getsize(os.path.join(root, f))
    return total / (1024**2)

print("Saved student folder size (MB):", folder_size_mb(os.path.join(OUTPUT_DIR, "best_student")))

# Latency measurement (batch=1)
device = DEVICE
student.to(device)
student.eval()
sample = tokenizer("This is a sample sentence to measure latency.", return_tensors="pt", max_length=max_len, truncation=True, padding="max_length")
input_ids = sample['input_ids'].to(device)
attention_mask = sample['attention_mask'].to(device)

# warmup
with torch.no_grad():
    for _ in range(10):
        _ = student(input_ids=input_ids, attention_mask=attention_mask)

N = 200
torch.cuda.synchronize() if device=="cuda" else None
t0 = time.time()
with torch.no_grad():
    for _ in range(N):
        _ = student(input_ids=input_ids, attention_mask=attention_mask)
if device=="cuda":
    torch.cuda.synchronize()
t1 = time.time()
latency_ms = (t1 - t0) / N * 1000
print(f"Student batch=1 latency (avg over {N} runs): {latency_ms:.2f} ms")

In [None]:
model_path = "./drive/MyDrive/bert_sst2_student/best_model"
tokenizer_path = "./drive/MyDrive/bert_sst2_baseline/tokenizer"

tokenizer = AutoTokenizer.from_pretrained(tokenizer_path, local_files_only=True)
model = AutoModelForSequenceClassification.from_pretrained(model_path, local_files_only=True)

In [None]:
device = "cuda" if torch.cuda.is_available() else "cpu"
model.to(device)
model.eval()

sentence = "bad"
inputs = tokenizer(sentence, return_tensors="pt", padding=True, truncation=True, return_token_type_ids=False).to(device)

with torch.no_grad():
    outputs = model(**inputs)
    probs = torch.nn.functional.softmax(outputs.logits, dim=-1)
    pred = probs.argmax(-1).item()

print("Sentence:", sentence)
print("Prediction:", pred, "(probabilities:", probs.cpu().numpy(), ")")