In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import AutoTokenizer, AutoModelForSequenceClassification, Trainer, TrainingArguments
from datasets import load_from_disk
import numpy as np
import evaluate

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

  from .autonotebook import tqdm as notebook_tqdm


Using device: cuda


In [3]:
print("\n" + "="*50)
print("STEP 1: Loading and Preparing Custom Medical Dataset")
print("="*50 + "\n")

med_dataset = load_from_disk("med_dr_notes")

label_list = ['Dermatology', 'Gastroenterology', 'Endocrinology', 'Oncology', 'Pulmonology']
num_labels = len(label_list)
label2id = {label: idx for idx, label in enumerate(label_list)}
id2label = {idx: label for idx, label in enumerate(label_list)}

def encode_label(example):
    example['label'] = label2id[example['label']]
    return example

med_dataset = med_dataset.map(encode_label)

train_dataset = med_dataset['train']
eval_dataset = med_dataset['test']
test_dataset = med_dataset['test']

print(f"Train: {len(train_dataset)}, Validation: {len(eval_dataset)}, Test: {len(test_dataset)}")


STEP 1: Loading and Preparing Custom Medical Dataset

Train: 2000, Validation: 500, Test: 500


In [4]:
print("\n" + "="*50)
print("STEP 2: Simplified Tokenization with Compatible Models")
print("="*50 + "\n")

teacher_id = "bert-base-uncased"
student_id = "boltuix/bert-micro"  

# Since both models share the same tokenizer, we only need to load it once.
tokenizer = AutoTokenizer.from_pretrained(teacher_id)
print("Single tokenizer loaded, as models are from the same family.")

def tokenize_function(examples):
    return tokenizer(examples["text"], padding="max_length", truncation=True, max_length=512)

tokenized_train = train_dataset.map(tokenize_function, batched=True)
tokenized_eval = eval_dataset.map(tokenize_function, batched=True)
tokenized_test = test_dataset.map(tokenize_function, batched=True)
print("\nData tokenized successfully using a single tokenizer.")


STEP 2: Simplified Tokenization with Compatible Models

Single tokenizer loaded, as models are from the same family.

Data tokenized successfully using a single tokenizer.


In [5]:
print("\n" + "="*50)
print(f"STEP 3: Fine-tuning Teacher Model ({teacher_id})")
print("="*50 + "\n")


teacher_model = AutoModelForSequenceClassification.from_pretrained(
    teacher_id, num_labels=num_labels, id2label=id2label, label2id=label2id
).to(device)

def compute_metrics(eval_pred):
    logits, labels = eval_pred
    predictions = np.argmax(logits, axis=-1)
    f1 = evaluate.load("f1").compute(predictions=predictions, references=labels, average="macro")["f1"]
    acc = evaluate.load("accuracy").compute(predictions=predictions, references=labels)["accuracy"]
    return {"accuracy": acc, "f1": f1}

teacher_training_args = TrainingArguments(
    output_dir="models/teacher_bert_med_notes",
    num_train_epochs=3, per_device_train_batch_size=16, per_device_eval_batch_size=16,
    logging_steps=50, eval_strategy="epoch", save_strategy="epoch",
    load_best_model_at_end=True, metric_for_best_model="f1", report_to="none"
)
teacher_trainer = Trainer(
    model=teacher_model, args=teacher_training_args, train_dataset=tokenized_train,
    eval_dataset=tokenized_eval, compute_metrics=compute_metrics
)
teacher_trainer.train()
teacher_eval_results = teacher_trainer.evaluate(tokenized_test)

print(f"\nTeacher Model Evaluation: {teacher_eval_results}")


STEP 3: Fine-tuning Teacher Model (bert-base-uncased)



Some weights of BertForSequenceClassification were not initialized from the model checkpoint at bert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Epoch,Training Loss,Validation Loss,Accuracy,F1
1,0.7409,0.443258,0.856,0.85091
2,0.3017,0.31185,0.89,0.89151
3,0.1642,0.312883,0.906,0.905635



Teacher Model Evaluation: {'eval_loss': 0.31288275122642517, 'eval_accuracy': 0.906, 'eval_f1': 0.905634712405052, 'eval_runtime': 6.9595, 'eval_samples_per_second': 71.844, 'eval_steps_per_second': 4.598, 'epoch': 3.0}


In [7]:

print("\n" + "="*50)
print(f"STEP 4: Distilling Knowledge into Student ({student_id})")
print("="*50 + "\n")

class DistillationTrainer(Trainer):
    def __init__(self, *args, teacher_model=None, alpha=0.5, temperature=2.0, **kwargs):
        super().__init__(*args, **kwargs)
        self.teacher_model = teacher_model
        self.alpha = alpha
        self.temperature = temperature
        if self.teacher_model:
            self.teacher_model.to(self.args.device)
            self.teacher_model.eval()

    def compute_loss(self, model, inputs, return_outputs=False,num_items_in_batch = None):
        student_outputs = model(**inputs)
        student_loss = student_outputs.loss
        with torch.no_grad():
            teacher_outputs = self.teacher_model(**inputs)
        distillation_loss = F.kl_div(
            F.log_softmax(student_outputs.logits / self.temperature, dim=-1),
            F.softmax(teacher_outputs.logits / self.temperature, dim=-1),
            reduction='batchmean'
        ) * (self.temperature ** 2)
        loss = self.alpha * student_loss + (1.0 - self.alpha) * distillation_loss
        return (loss, student_outputs) if return_outputs else loss

# Load student model
student_model = AutoModelForSequenceClassification.from_pretrained(
    student_id, num_labels=num_labels, id2label=id2label, label2id=label2id
).to(device)

student_training_args = TrainingArguments(
    output_dir="models/student_bert_micro_med_notes", # <--- CHANGE HERE: New output directory
    num_train_epochs=5, # A much smaller model might benefit from more training epochs
    per_device_train_batch_size=16,
    per_device_eval_batch_size=16,
    logging_steps=50,
    eval_strategy="epoch",
    save_strategy="epoch",
    load_best_model_at_end=True,
    metric_for_best_model="f1",
    report_to="none"
)

distillation_trainer = DistillationTrainer(
    model=student_model,
    args=student_training_args,
    teacher_model=teacher_model,
    train_dataset=tokenized_train,
    eval_dataset=tokenized_eval,
    compute_metrics=compute_metrics,
    alpha=0.2, # Giving even more weight to the teacher (70%) as the student is very small
    temperature=3.0
)

distillation_trainer.train()
student_eval_results = distillation_trainer.evaluate(tokenized_test)
print(f"\nDistilled Student Model Evaluation: {student_eval_results}")


STEP 4: Distilling Knowledge into Student (boltuix/bert-micro)



Some weights of BertForSequenceClassification were not initialized from the model checkpoint at boltuix/bert-micro and are newly initialized: ['bert.pooler.dense.bias', 'bert.pooler.dense.weight', 'classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Epoch,Training Loss,Validation Loss,Accuracy,F1
1,3.3819,2.976609,0.746,0.728221
2,2.6101,2.334731,0.78,0.756328
3,2.2097,1.953837,0.792,0.77168
4,1.9438,1.744463,0.808,0.791012
5,1.8204,1.676093,0.82,0.806528



Distilled Student Model Evaluation: {'eval_loss': 1.6760931015014648, 'eval_accuracy': 0.82, 'eval_f1': 0.8065280551652521, 'eval_runtime': 7.3942, 'eval_samples_per_second': 67.62, 'eval_steps_per_second': 4.328, 'epoch': 5.0}


In [8]:
print("\n" + "="*50)
print("STEP 5: Final Results and Comparison")
print("="*50 + "\n")

teacher_params = teacher_model.num_parameters() / 1_000_000
student_params = student_model.num_parameters() / 1_000_000

print(f"Teacher Model ('{teacher_id}') size: {teacher_params:.2f}M parameters")
print(f"Student Model ('{student_id}') size: {student_params:.2f}M parameters")
print(f"Size Reduction: {100 * (1 - student_params / teacher_params):.2f}%\n")

teacher_f1 = teacher_eval_results['eval_f1']
student_f1 = student_eval_results['eval_f1']
performance_retention = (student_f1 / teacher_f1) * 100

print("--- Performance on Medical Notes Test Set ---")
print(f"{'Model':<35} | {'Macro F1-Score':<15}")
print("-" * 55)
print(f"{'1. Fine-tuned Teacher (BERT-base)':<35} | {teacher_f1:<15.4f}")
print(f"{'2. Distilled Student (BERT-micro)':<35} | {student_f1:<15.4f}") 
print("-" * 55)
print(f"\nPerformance Retained: The distilled student retained {performance_retention:.2f}% of the teacher's F1-score.")


STEP 5: Final Results and Comparison

Teacher Model ('bert-base-uncased') size: 109.49M parameters
Student Model ('boltuix/bert-micro') size: 4.39M parameters
Size Reduction: 95.99%

--- Performance on Medical Notes Test Set ---
Model                               | Macro F1-Score 
-------------------------------------------------------
1. Fine-tuned Teacher (BERT-base)   | 0.9056         
2. Distilled Student (BERT-micro)   | 0.8065         
-------------------------------------------------------

Performance Retained: The distilled student retained 89.06% of the teacher's F1-score.


In [11]:
TEACHER_MODEL_PATH = "models/teacher_bert_med_notes/checkpoint-375"
STUDENT_MODEL_PATH = "models/student_bert_micro_med_notes/checkpoint-625"

In [12]:


# --- Load the best models saved during training ---
print(f"Loading best teacher model from: {TEACHER_MODEL_PATH}")
teacher_model_loaded = AutoModelForSequenceClassification.from_pretrained(TEACHER_MODEL_PATH).to(device)

print(f"Loading best student model from: {STUDENT_MODEL_PATH}")
student_model_loaded = AutoModelForSequenceClassification.from_pretrained(STUDENT_MODEL_PATH).to(device)


# --- Helper Function for Comprehensive Evaluation ---
def evaluate_model_performance(model_name, model, dataset, raw_text_list):
    """
    Evaluates a model for performance and speed, including RTF, on the full test set.
    Returns a dictionary with all calculated metrics.
    """
    print(f"\n--- Evaluating: {model_name} on {len(dataset)} examples ---")
    model.eval()
    # We need to re-add the 'label' column for the dataloader
    dataset.set_format('torch', columns=['input_ids', 'attention_mask', 'label'])
    data_loader = DataLoader(dataset, batch_size=EVAL_BATCH_SIZE)

    all_logits = []
    all_labels = []
    total_inference_time = 0.0

    for batch in tqdm(data_loader, desc=f"Inferencing with {model_name}"):
        labels = batch.pop('label').to(device) # .pop() is crucial here
        inputs = {k: v.to(device) for k, v in batch.items()}

        with torch.no_grad():
            start_time = time.perf_counter()
            outputs = model(**inputs)
            end_time = time.perf_counter()
            total_inference_time += (end_time - start_time)

        all_logits.append(outputs.logits.cpu().numpy())
        all_labels.append(labels.cpu().numpy())

    logits_array = np.concatenate(all_logits)
    labels_array = np.concatenate(all_labels)
    perf_metrics = compute_metrics((logits_array, labels_array))

    num_samples = len(dataset)
    avg_latency_ms = (total_inference_time / num_samples) * 1000
    throughput_sps = num_samples / total_inference_time

    # Calculate Real-Time Factor (RTF) assuming 2.5 words/sec human speech rate
    words_per_second_human = 2.5
    total_words = sum(len(text.split()) for text in raw_text_list)
    simulated_duration_sec = total_words / words_per_second_human
    rtf = total_inference_time / simulated_duration_sec

    results = {
        "model_name": model_name,
        "f1": perf_metrics["f1"],
        "accuracy": perf_metrics["accuracy"],
        "params_m": model.num_parameters() / 1_000_000,
        "total_time_s": total_inference_time,
        "latency_ms": avg_latency_ms,
        "throughput_sps": throughput_sps,
        "rtf": rtf
    }
    return results

# Use the original test_dataset to get the raw text for RTF calculation
raw_test_text = test_dataset['text']

# --- Run Evaluations on the full test set ---
teacher_results = evaluate_model_performance(
    f"Teacher ({teacher_id})", teacher_model_loaded, tokenized_test, raw_test_text
)
student_results = evaluate_model_performance(
    f"Student ({student_id})", student_model_loaded, tokenized_test, raw_test_text
)


# --- Display Final Comparison Table ---
print("\n" + "="*85)
print("Final Comparison: Performance, Size, and Speed on Medical Notes Test Set")
print("="*85)

header = f"{'Model':<30} | {'Macro F1':<10} | {'Params (M)':<12} | {'Time (s)':<10} | {'Latency (ms/ex)':<16} | {'RTF':<8}"
separator = "-" * len(header)

print(header)
print(separator)

print(f"{teacher_results['model_name']:<30} | {teacher_results['f1']:.4f}{'':<5} | "
      f"{teacher_results['params_m']:<12.2f} | {teacher_results['total_time_s']:<10.2f} | "
      f"{teacher_results['latency_ms']:.2f}{'':<10} | {teacher_results['rtf']:.4f}")

print(f"{student_results['model_name']:<30} | {student_results['f1']:.4f}{'':<5} | "
      f"{student_results['params_m']:<12.2f} | {student_results['total_time_s']:<10.2f} | "
      f"{student_results['latency_ms']:.2f}{'':<10} | {student_results['rtf']:.4f}")

print(separator)

# --- Summary & Interpretation ---
size_reduction = (1 - student_results['params_m'] / teacher_results['params_m']) * 100
performance_retention = (student_results['f1'] / teacher_results['f1']) * 100
speedup_factor = teacher_results['total_time_s'] / student_results['total_time_s']

print("\n--- Summary ---")
print(f"Size Reduction: The student model is {size_reduction:.2f}% smaller than the teacher.")
print(f"Performance Retained: The student retained {performance_retention:.2f}% of the teacher's F1-score.")
print(f"Inference Speedup: The student is {speedup_factor:.2f}x faster than the teacher.")
print(f"RTF Interpretation: An RTF < 1.0 means the model processes text faster than real-time speech.")
print("="*85)

Loading best teacher model from: models/teacher_bert_med_notes/checkpoint-375
Loading best student model from: models/student_bert_micro_med_notes/checkpoint-625

--- Evaluating: Teacher (bert-base-uncased) on 500 examples ---


Inferencing with Teacher (bert-base-uncased): 100%|█████████████████████████████████████| 16/16 [00:04<00:00,  3.51it/s]



--- Evaluating: Student (prajjwal1/bert-micro) on 500 examples ---


Inferencing with Student (prajjwal1/bert-micro): 100%|█████████████████████████████████| 16/16 [00:00<00:00, 269.90it/s]



Final Comparison: Performance, Size, and Speed on Medical Notes Test Set
Model                          | Macro F1   | Params (M)   | Time (s)   | Latency (ms/ex)  | RTF     
-----------------------------------------------------------------------------------------------------
Teacher (bert-base-uncased)    | 0.9056      | 109.49       | 0.72       | 1.45           | 0.0000
Student (prajjwal1/bert-micro) | 0.8065      | 4.39         | 0.02       | 0.03           | 0.0000
-----------------------------------------------------------------------------------------------------

--- Summary ---
Size Reduction: The student model is 95.99% smaller than the teacher.
Performance Retained: The student retained 89.06% of the teacher's F1-score.
Inference Speedup: The student is 46.38x faster than the teacher.
RTF Interpretation: An RTF < 1.0 means the model processes text faster than real-time speech.
