# Knowledge Distillation

This notebook performs knowledge distillation from a large teacher model (**Bio_ClinicalBERT**) to a smaller, domain-adapted student model.


## 1. Imports

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from datasets import load_dataset
from transformers import (
    AutoTokenizer,
    AutoModelForMaskedLM,
    DataCollatorForLanguageModeling,
    Trainer,
    TrainingArguments,
)

## 2. Configuration

**Crucial Step:** Ensure `STUDENT_MODEL_PATH` points to the directory created by the `fix_model_state.ipynb` script.

In [None]:
class Config:
    # Paths
    STUDENT_MODEL_PATH = "./medtok/model"
    STUDENT_TOKENIZER_PATH = "./medtok/tokenizer"

    TEACHER_MODEL_NAME = "emilyalsentzer/Bio_ClinicalBERT"
    TRAIN_FILE = "./data/normalize/pubmed.jsonl"
    OUTPUT_DIR = "./artifacts/distilled_model"

    #Training Parameters
    MAX_SEQ_LENGTH = 256
    NUM_TRAIN_EPOCHS = 20
    PER_DEVICE_TRAIN_BATCH_SIZE = 16
    LEARNING_RATE = 3e-5

    #Distillation Parameters
    ALPHA = 0.5

config = Config()

## 3. Custom Distillation Trainer

In [None]:
class DistillTrainer(Trainer):
    def __init__(self, teacher_model=None, teacher_tokenizer=None, alpha=0.5, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.teacher = teacher_model
        self.teacher_tokenizer = teacher_tokenizer
        self.alpha = alpha
        
        if self.teacher is not None:
            self.teacher.eval()
            self.teacher.to(self.args.device)

    def compute_loss(self, model, inputs, return_outputs=False, **kwargs):
        # 1. Get Student Outputs
        outputs_s = model(**inputs, output_hidden_states=True)
        mlm_loss = outputs_s.loss
        hidden_s = outputs_s.hidden_states[-1]

        # 2. Prepare Inputs for the Teacher (Handles vocab mismatch)
        teacher_inputs = inputs.copy()
        teacher_vocab_size = self.teacher.config.vocab_size
        unk_token_id = self.teacher_tokenizer.unk_token_id
        input_ids_for_teacher = teacher_inputs["input_ids"].clone()
        out_of_bounds_mask = input_ids_for_teacher >= teacher_vocab_size
        input_ids_for_teacher[out_of_bounds_mask] = unk_token_id
        teacher_inputs["input_ids"] = input_ids_for_teacher

        # 3. Get Teacher Outputs
        with torch.no_grad():
            outputs_t = self.teacher(**teacher_inputs, output_hidden_states=True)
            hidden_t = outputs_t.hidden_states[-1]

        # 4. Compute Distillation Loss (MSE on Hidden States)
        attention_mask = inputs.get("attention_mask")
        if attention_mask is not None:
            expanded_mask = attention_mask.unsqueeze(-1).expand_as(hidden_s)
            bool_mask = expanded_mask.to(torch.bool)
            hidden_s_masked = torch.masked_select(hidden_s, bool_mask)
            hidden_t_masked = torch.masked_select(hidden_t, bool_mask)
            loss_kd = F.mse_loss(hidden_s_masked, hidden_t_masked)
        else:
            loss_kd = F.mse_loss(hidden_s, hidden_t)

        # 5. Combine the two losses
        loss = self.alpha * mlm_loss + (1 - self.alpha) * loss_kd

        return (loss, outputs_s) if return_outputs else loss

## 4. Load Models and Tokenizers

In [None]:
device = "cuda" if torch.cuda.is_available() else "cpu"

# 1. Load Student Model and Tokenizer
print(f"Loading student model and tokenizer from: {config.STUDENT_MODEL_PATH}")
student_tok = AutoTokenizer.from_pretrained(config.STUDENT_TOKENIZER_PATH)
student = AutoModelForMaskedLM.from_pretrained(config.STUDENT_MODEL_PATH)
student.to(device)

# 2. Load Teacher Model AND Tokenizer
print(f"Loading teacher model and tokenizer: {config.TEACHER_MODEL_NAME}")
teacher = AutoModelForMaskedLM.from_pretrained(config.TEACHER_MODEL_NAME)
teacher_tokenizer = AutoTokenizer.from_pretrained(config.TEACHER_MODEL_NAME)
teacher.to(device)

## 5. Diagnostic Check

This cell verifies that the student model is structurally sound before starting the training. If this cell runs without error, the training will work.

In [None]:
print("--- Verifying Model and Tokenizer Consistency ---")

tokenizer_vocab_size = len(student_tok)
model_config_vocab_size = student.config.vocab_size
model_output_layer_size = student.vocab_projector.out_features

print(f"Tokenizer Vocabulary Size: {tokenizer_vocab_size}")
print(f"Model Config Vocab Size: {model_config_vocab_size}")
print(f"Model Output Layer Size (n_classes): {model_output_layer_size}")

assert tokenizer_vocab_size == model_config_vocab_size, "FATAL ERROR: Mismatch between tokenizer and model config!"
assert tokenizer_vocab_size == model_output_layer_size, "FATAL ERROR: Mismatch between tokenizer and model's output layer!"

print("\n All sizes match. The student model is correctly configured and ready for training.")

## 6. Load and Process Dataset

In [None]:
print(f"Loading and tokenizing dataset from: {config.TRAIN_FILE}")
dataset = load_dataset("json", data_files={"train": config.TRAIN_FILE}, split="train")
text_column_name = "text"

def tokenize_fn(examples):
    return student_tok(
        examples[text_column_name],
        truncation=True,
        padding="max_length",
        max_length=config.MAX_SEQ_LENGTH,
    )

tokenized_dataset = dataset.map(tokenize_fn, batched=True, remove_columns=[text_column_name])
data_collator = DataCollatorForLanguageModeling(tokenizer=student_tok, mlm=True, mlm_probability=0.15)

print("Dataset is ready for training.")

## 7. Initialize Trainer and Start Training

In [None]:
training_args = TrainingArguments(
    output_dir=config.OUTPUT_DIR,
    overwrite_output_dir=True,
    num_train_epochs=config.NUM_TRAIN_EPOCHS,
    per_device_train_batch_size=config.PER_DEVICE_TRAIN_BATCH_SIZE,
    learning_rate=config.LEARNING_RATE,
    save_steps=10_000,
    save_total_limit=2,
    logging_steps=500,
    fp16=torch.cuda.is_available(),
    report_to="none",
)

trainer = DistillTrainer(
    model=student,
    teacher_model=teacher,
    teacher_tokenizer=teacher_tokenizer,
    alpha=config.ALPHA,
    args=training_args,
    train_dataset=tokenized_dataset,
    data_collator=data_collator,
)

print("Starting corrected knowledge distillation training... ✨")
trainer.train()
print("Training complete!")

## 8. Save the Final Distilled Model

In [None]:
print(f"Saving model to {config.OUTPUT_DIR}")
trainer.save_model(config.OUTPUT_DIR)
student_tok.save_pretrained(config.OUTPUT_DIR)
print("Script finished successfully.")