### Reference LLM Distillation notebook: https://github.com/simranjeet97/LLM_Distillation/blob/main/LLM_Distillation.ipynb

In [1]:
!pip install -U transformers 



In [2]:
import os
import pandas as pd
import torch
from datasets import Dataset
from dotenv import load_dotenv
from transformers import (
    AutoTokenizer,
    AutoModelForCausalLM,
    Trainer,
    TrainingArguments,
    BitsAndBytesConfig,
    AutoModelForSequenceClassification,
    AutoModelForSeq2SeqLM,
)
from peft import get_peft_model, LoraConfig, TaskType, prepare_model_for_kbit_training

load_dotenv()
hf_token = os.getenv("HUGGINGFACE_API_KEY")

  from .autonotebook import tqdm as notebook_tqdm


In [71]:
# ====== Tokenizer & Model Setup ======
model_id = "google-bert/bert-base-uncased" #"google/gemma-3-1b-it"

# tokenizer = AutoTokenizer.from_pretrained(model_id, token=hf_token, trust_remote_code=True)
tokenizer = AutoTokenizer.from_pretrained("t5-base", token=hf_token, trust_remote_code=True)
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token

# model = AutoModelForSequenceClassification.from_pretrained(model_id, token=hf_token, trust_remote_code=True)

# model = AutoModelForCausalLM.from_pretrained(
#     model_id,
#     token=hf_token,
#     trust_remote_code=True,
#     torch_dtype=torch.float16,
# )

model = AutoModelForSeq2SeqLM.from_pretrained("t5-base", 
                                            #  load_in_8bit=True,
                                            #  device_map="auto",
                                            #  quantization_config=BitsAndBytesConfig(
                                            #      load_in_8bit=True,
                                            #      llm_int8_threshold=6.0,
                                            #      llm_int8_enable_fp32_cpu_offload=True,
                                            #  ),
                                            #  trust_remote_code=True,
                                             )


# model = prepare_model_for_kbit_training(model)
# lora_config = LoraConfig(
#     r=8,
#     lora_alpha=32,
#     target_modules=["q_proj", "v_proj"],
#     lora_dropout=0.05,
#     bias="none",
#     task_type=TaskType.CAUSAL_LM
# )
# model = get_peft_model(model, lora_config) # TODO Why getting PEFT model? Paper and Reference notebook did not use


In [72]:
#Ensure tokenizer has special tokens:
tokenizer.add_special_tokens({
    'additional_special_tokens': ['[label]', '[rationale]']
})
model.resize_token_embeddings(len(tokenizer))

Embedding(32102, 768)

In [73]:
# ====== Load dataset ======
def load_partition(path: str) -> Dataset:
    df = pd.read_csv(path).head(10) #.head(10)
    return Dataset.from_pandas(df)

dataset = load_partition("../Student_Training_Data/GPT.csv") ## should be GPT.csv
print(f"Loaded {len(dataset)} samples from dataset.") 

Loaded 10 samples from dataset.


In [74]:
## Ignore this part, just for understanding
# encoded_text = tokenizer.tokenize("Paris is the what of France?", return_tensors="pt").to(model.device)

# '''The model(encoded_text) call is more commonly used during training or when you want direct access 
# to the model's raw predictions, while generate() is used when you want the model to complete/continue 
# a sequence.
# '''
# print(encoded_text)

# outputs = model(encoded_text) # different from model.generate which produces logits and loss if labels are provided.
# print(outputs) # logits, loss (if label was given), hidden_states, attentions

# completion = model.generate(encoded_text, max_length=50)
# print(completion)

# decoded_text = tokenizer.decode(completion[0], skip_special_tokens=True)
# print(decoded_text)

In [75]:
def add_special_tokens_if_missing(tokenizer):
    # Add task-specific tokens if not present
    special_tokens = []
    if "[label]" not in tokenizer.get_vocab():
        special_tokens.append("[label]")
    if "[rationale]" not in tokenizer.get_vocab():
        special_tokens.append("[rationale]")
    
    if special_tokens:
        tokenizer.add_special_tokens({"additional_special_tokens": special_tokens})
    return tokenizer

# Update tokenizer with special tokens
tokenizer = add_special_tokens_if_missing(tokenizer)

def tokenize_function(examples):
    # Create base text inputs
    base_texts = [
        f"Section Name: {sn}\nText: {txt}" 
        for sn, txt in zip(examples["sectionName"], examples["string"])
    ]

    # print(f"Base texts: {base_texts}")

    # Create task-specific inputs
    label_inputs = [f"[label] {text}" for text in base_texts]
    rationale_inputs = [f"[rationale] {text}" for text in base_texts]

    # Tokenize base inputs (for potential shared encoder)
    base_encoded = tokenizer(
        base_texts,
        padding="max_length",
        truncation=True,
        max_length=256,  # Reserve space for prefixes
        return_tensors="pt"
    )

    # Tokenize label task inputs and targets
    label_encoded = tokenizer(
        label_inputs,
        padding="max_length",
        truncation=True,
        max_length=512,
        return_tensors="pt"
    )
    
    # Tokenize label targets (text labels, not indices)
    label_targets = tokenizer(
        examples["model_classification"],
        padding="max_length",
        truncation=True, 
        max_length=32,  # Short length for class labels
        return_tensors="pt"
    )

    # Tokenize rationale task inputs
    rationale_encoded = tokenizer(
        rationale_inputs,
        padding="max_length",
        truncation=True,
        max_length=512,
        return_tensors="pt"
    )

    # Tokenize rationale targets
    rationale_targets = tokenizer(
        examples["reasoning"],
        padding="max_length",
        truncation=True,
        max_length=512,
        return_tensors="pt"
    )

    return {
        # Base inputs (shared between tasks)
        "base_input_ids": base_encoded.input_ids,
        "base_attention_mask": base_encoded.attention_mask,

        # Label prediction task
        "label_input_ids": label_encoded.input_ids,
        "label_attention_mask": label_encoded.attention_mask,
        "label_target_ids": label_targets.input_ids,

        # Rationale generation task
        "rationale_input_ids": rationale_encoded.input_ids,
        "rationale_attention_mask": rationale_encoded.attention_mask,
        "rationale_target_ids": rationale_targets.input_ids,
    }

# Apply tokenization
tokenized_dataset = dataset.map(
    tokenize_function,
    batched=True,
    batch_size=32,
    remove_columns=dataset.column_names  # Remove original columns
)

# Set format for PyTorch
tokenized_dataset.set_format(type="torch", columns=[
    "base_input_ids",
    "base_attention_mask",
    "label_input_ids",
    "label_attention_mask", 
    "label_target_ids",
    "rationale_input_ids",
    "rationale_attention_mask",
    "rationale_target_ids"
])

Map: 100%|██████████| 10/10 [00:00<00:00, 341.03 examples/s]


In [76]:
tokenized_dataset

Dataset({
    features: ['base_input_ids', 'base_attention_mask', 'label_input_ids', 'label_attention_mask', 'label_target_ids', 'rationale_input_ids', 'rationale_attention_mask', 'rationale_target_ids'],
    num_rows: 10
})

In [77]:
tokenized_dataset[0]

{'base_input_ids': tensor([ 5568,  5570,    10, 18921,  5027,    10,   611,     6,   149,  8072,
         14727,    77,  6815,     7,    28,     8,  4163,    18,   134,  9068,
          2392, 17282,  3379,  3048, 19363,    38,  1223,    80,    18,   235,
            18,   782,  9944,    28,   284,  3876,   130,  2196,    41,   196,
             7,    75,   134,   784,  2122,     6,  2884, 13679,    27,     7,
            75,  1265,    87,   196,     7,    76,   536,   784, 11071,  2596,
             6,  2938,   908,    42,    27,  7331,  2596,    87,   196,     7,
            26,  2596,   784,  2534,     6,  1808,   908,   137,     1,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     

In [78]:
# ====== Training Args ======
# training_args = TrainingArguments( ## Original Training Args
#     output_dir="gemma3-phase1",
#     eval_strategy="epoch",
#     learning_rate=5e-5,
#     per_device_train_batch_size=8,
#     per_device_eval_batch_size=8,
#     num_train_epochs=3,
#     weight_decay=0.01,
#     save_strategy="epoch",
#     push_to_hub=False,
#     remove_unused_columns=False
# )

## New Training Args
training_args = TrainingArguments(
    output_dir="./results",
    # Disable fp16 for MPS devices
    fp16=False,  # ← THIS IS CRUCIAL
    bf16=True,   # You can try enabling this if you have newer hardware
    use_mps_device=True,  # Explicitly enable MPS
    per_device_train_batch_size=4,
    per_device_eval_batch_size=4,
    gradient_accumulation_steps=2,
    learning_rate=5e-5,
    num_train_epochs=3,
    logging_dir="./logs",
    report_to="none",
    save_strategy="no",
    remove_unused_columns=False
)






In [79]:
tokenized_dataset

Dataset({
    features: ['base_input_ids', 'base_attention_mask', 'label_input_ids', 'label_attention_mask', 'label_target_ids', 'rationale_input_ids', 'rationale_attention_mask', 'rationale_target_ids'],
    num_rows: 10
})

In [None]:
class MultiTaskTrainer(Trainer):
    def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None, **kwargs):
        alpha = 0.3  # λ hyperparameter from the paper
        
        # Process Label Task --------------------------------------------------
        label_outputs = model(
            input_ids=inputs["label_input_ids"],
            attention_mask=inputs["label_attention_mask"],
            labels=inputs["label_target_ids"]
        )
        label_loss = label_outputs.loss

        # Process Rationale Task ----------------------------------------------
        rationale_outputs = model(
            input_ids=inputs["rationale_input_ids"],
            attention_mask=inputs["rationale_attention_mask"],
            labels=inputs["rationale_target_ids"]
        )
        rationale_loss = rationale_outputs.loss

        # Combine Losses ------------------------------------------------------
        total_loss = (1 - alpha) * label_loss + alpha * rationale_loss

        return (total_loss, (label_outputs, rationale_outputs)) if return_outputs else total_loss

    def prediction_step(self, model, inputs, prediction_loss_only, ignore_keys=None):
        # Handle both tasks during evaluation
        with torch.no_grad():
            loss = self.compute_loss(model, inputs)
        
        return (loss, None, None)  # (loss, predictions, labels)

# Initialize Trainer
trainer = MultiTaskTrainer(
    model=model,
    args=training_args,
    train_dataset=tokenized_dataset,
    eval_dataset=tokenized_dataset,
    data_collator=lambda data: {
        "label_input_ids": torch.stack([d["label_input_ids"] for d in data]),
        "label_attention_mask": torch.stack([d["label_attention_mask"] for d in data]),
        "label_target_ids": torch.stack([d["label_target_ids"] for d in data]),
        "rationale_input_ids": torch.stack([d["rationale_input_ids"] for d in data]),
        "rationale_attention_mask": torch.stack([d["rationale_attention_mask"] for d in data]),
        "rationale_target_ids": torch.stack([d["rationale_target_ids"] for d in data])
    }
)


In [83]:
trainer.train()
trainer.save_model("./new_trained_model")

Could not estimate the number of tokens of the input, floating-point operations will not be computed


Step,Training Loss


In [84]:
trainer.save_model("./distilled_t5_on_10_samples")

In [88]:
import json
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
from sklearn.metrics import accuracy_score, classification_report

# Configuration
model_path = "./distilled_t5_on_10_samples" ## TODO: Note that this model is trained only on a 100 samples! Because the paper says 25% of full training ata was alr good enough, so i wanted to just test with a smaller number of samples first.
test_data_path = "../data/test.jsonl"
device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")

# Load the distilled model
# tokenizer = AutoTokenizer.from_pretrained(model_path)
# After loading tokenizer
model = AutoModelForSeq2SeqLM.from_pretrained(model_path, trust_remote_code=True, torch_dtype=torch.float32).to(device)

def load_test_data(file_path):
    """Load and parse test data"""
    test_data = []
    with open(file_path, 'r') as f:
        for line in f:
            entry = json.loads(line)
            test_data.append({
                "section": entry["sectionName"],
                "text": entry["string"],
                "true_label": entry["label"]
            })
    return test_data

def preprocess_input(section, text):
    """Format input with task prefix"""
    input_text = f"[label] Section: {section}\nText: {text}\nLabel:" ## TODO: NOTE THAT THIS IS KEYyyyy
    return tokenizer(
        input_text,
        padding="max_length",
        truncation=True,
        max_length=512,
        return_tensors="pt"
    ).to(device)

def predict_label(model, inputs):
    """Generate label prediction"""
    with torch.no_grad():
        outputs = model.generate(
            input_ids=inputs["input_ids"],
            attention_mask=inputs["attention_mask"],
            max_new_tokens=10,
            # For deterministic results (default):
            do_sample=False,  # Disables sampling
            num_beams=3,     # Beam search works better for Seq2Seq
            early_stopping=False,
            # Remove temperature parameter when do_sample=False
            decoder_start_token_id=tokenizer.pad_token_id, #critical for T5
            pad_token_id=tokenizer.pad_token_id,
            # forced_bos_token_id=tokenizer.convert_tokens_to_ids("method"),
            # eos_token_id=tokenizer.eos_token_id,
        )

    # Debug raw outputs
    print("Raw output IDs:", outputs[0])
    print("Decoded output:", tokenizer.decode(outputs[0]))
    return tokenizer.decode(outputs[0], skip_special_tokens=True)

def clean_prediction(raw_prediction):
    """Extract label from model output"""
    # Split on "Label:" and take the first word after it
    print(f"raw: {raw_prediction}")
    # parts = raw_prediction.split("Label:")
    if len(raw_prediction) > 1:
        prediction = raw_prediction.strip().split()[0].lower()
        # Map to valid labels
        valid_labels = {"background", "method", "result"}
        print(f"Prediction: {prediction}")
        return prediction if prediction in valid_labels else "unknown"
    return "unknown"

# Load test data
test_data = load_test_data(test_data_path)

# Run predictions
true_labels = []
pred_labels = []

for example in test_data:
    # Preprocess input
    inputs = preprocess_input(example["section"], example["text"])
    
    # Get prediction
    raw_pred = predict_label(model, inputs)
    cleaned_label = clean_prediction(raw_pred)
    
    # Store results
    true_labels.append(example["true_label"])
    pred_labels.append(cleaned_label)
    
    # Print example (optional)
    print(f"Section: {example['section']}")
    print(f"Text: {example['text'][:100]}...")
    print(f"True: {example['true_label']} | Pred: {cleaned_label}")
    print("-" * 80)

# Calculate accuracy
accuracy = accuracy_score(true_labels, pred_labels)
print(f"\nTest Accuracy: {accuracy:.4f}")

# Save results
with open("predictions_t5_trained.csv", "w") as f:
    f.write("true_label,predicted_label\n")
    for true, pred in zip(true_labels, pred_labels):
        f.write(f"{true},{pred}\n")

Raw output IDs: tensor([0, 0], device='mps:0')
Decoded output: <pad><pad>
raw: 
Section: 
Text: Chapel, as well as X10 [2], UPC [3] , CoArray Fortran [6], and Titanium [5], rely on the Partitioned...
True: background | Pred: unknown
--------------------------------------------------------------------------------
Raw output IDs: tensor([0, 0], device='mps:0')
Decoded output: <pad><pad>
raw: 
Section: Discussion
Text: In addition, the result of the present study supports previous studies, which did not find increased...
True: result | Pred: unknown
--------------------------------------------------------------------------------
Raw output IDs: tensor([0, 0], device='mps:0')
Decoded output: <pad><pad>
raw: 
Section: Discussion
Text: Several instruments that more specifically address patient-reported outcomes following gastrectomy a...
True: background | Pred: unknown
--------------------------------------------------------------------------------
Raw output IDs: tensor([0, 0], device='mps

KeyboardInterrupt: 

### Model Performance Comparison (Base vs Distilled)

In [89]:
#%%
import pandas as pd
from tqdm import tqdm

def evaluate_model(model, tokenizer, test_data, model_name="Model"):
    """Evaluate model performance on test data"""
    true_labels = []
    pred_labels = []
    
    for example in tqdm(test_data, desc=f"Evaluating {model_name}"):
        inputs = preprocess_input(example["section"], example["text"])
        raw_pred = predict_label(model, inputs)
        cleaned_label = clean_prediction(raw_pred)
        
        true_labels.append(example["true_label"])
        pred_labels.append(cleaned_label)
    
    accuracy = accuracy_score(true_labels, pred_labels)
    class_report = classification_report(true_labels, pred_labels, output_dict=True)
    
    return {
        "model": model_name,
        "accuracy": accuracy,
        "precision_background": class_report["background"]["precision"],
        "recall_background": class_report["background"]["recall"],
        "precision_method": class_report["method"]["precision"],
        "recall_method": class_report["method"]["recall"],
        "precision_result": class_report["result"]["precision"],
        "recall_result": class_report["result"]["recall"],
    }

#%% [markdown]
#### 1. Load Base Model (Pre-trained)
#%%
base_model = AutoModelForSeq2SeqLM.from_pretrained("t5-base").to(device)
base_tokenizer = tokenizer

# # Add special tokens if missing
# special_tokens = ["[label]", "[rationale]"]
# base_tokenizer.add_special_tokens({"additional_special_tokens": special_tokens})
# base_model.resize_token_embeddings(len(base_tokenizer))

#%% [markdown]
#### 2. Load Distilled Model (Fine-tuned)
#%%
distilled_model = AutoModelForSeq2SeqLM.from_pretrained("./distilled_t5_on_10_samples").to(device)
distilled_tokenizer = tokenizer

#%% [markdown]
#### 3. Evaluate Both Models
#%%
test_data = load_test_data(test_data_path)[:5]  # Use subset for faster evaluation

base_results = evaluate_model(base_model, base_tokenizer, test_data, "Base Model")
distilled_results = evaluate_model(distilled_model, distilled_tokenizer, test_data, "Distilled Model")

#%% [markdown]
#### 4. Display Comparison
#%%
results_df = pd.DataFrame([base_results, distilled_results])
print("\nPerformance Comparison:")
display(results_df.style
       .format("{:.2%}", subset=["accuracy", "precision_background", "recall_background", 
                                "precision_method", "recall_method", 
                                "precision_result", "recall_result"])
       .background_gradient(cmap="Blues", subset=["accuracy"]))

#%% [markdown]
#### 5. Sample Predictions Comparison
#%%
print("\nSample Prediction Comparison:")
sample_data = test_data[:3]  # First 3 examples

for example in sample_data:
    # Base model prediction
    inputs = preprocess_input(example["section"], example["text"])
    base_pred = clean_prediction(predict_label(base_model, base_tokenizer, inputs))
    
    # Distilled model prediction
    inputs = preprocess_input(example["section"], example["text"])
    distilled_pred = clean_prediction(predict_label(distilled_model, distilled_tokenizer, inputs))
    
    print(f"\nSection: {example['section']}")
    print(f"Text: {example['text'][:100]}...")
    print(f"True Label: {example['true_label']}")
    print(f"Base Model: {base_pred} | Distilled Model: {distilled_pred}")
    print("-" * 80)

Evaluating Base Model:  20%|██        | 1/5 [00:04<00:16,  4.01s/it]

Raw output IDs: tensor([    0,  5481,   254,    18, 26296,  4652,   188,   357,   516,     1],
       device='mps:0')
Decoded output: <pad> HPC-EUROPA2 project</s>
raw: HPC-EUROPA2 project
Prediction: hpc-europa2


Evaluating Base Model:  40%|████      | 2/5 [00:06<00:08,  2.89s/it]

Raw output IDs: tensor([    0, 32099, 16330,  5027,    10,    86,   811,     6,     8,   915,
          810], device='mps:0')
Decoded output: <pad><extra_id_0> Discussion Text: In addition, the present study
raw: Discussion Text: In addition, the present study
Prediction: discussion


Evaluating Base Model:  60%|██████    | 3/5 [00:07<00:04,  2.32s/it]

Raw output IDs: tensor([    0, 32099,     5, 32098,  5568,    10, 16330,  5027,    10,     3,
         8656], device='mps:0')
Decoded output: <pad><extra_id_0>.<extra_id_1> Section: Discussion Text: Several
raw: . Section: Discussion Text: Several
Prediction: .


Evaluating Base Model:  80%|████████  | 4/5 [00:08<00:01,  1.87s/it]

Raw output IDs: tensor([    0, 10747,     7,    15,     1], device='mps:0')
Decoded output: <pad> False</s>
raw: False
Prediction: false


Evaluating Base Model: 100%|██████████| 5/5 [00:10<00:00,  2.02s/it]
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


Raw output IDs: tensor([    0, 10747,     7,    15,     1], device='mps:0')
Decoded output: <pad> False</s>
raw: False
Prediction: false


Evaluating Distilled Model:  40%|████      | 2/5 [00:00<00:00,  6.21it/s]

Raw output IDs: tensor([0, 0], device='mps:0')
Decoded output: <pad><pad>
raw: 
Raw output IDs: tensor([0, 0], device='mps:0')
Decoded output: <pad><pad>
raw: 


Evaluating Distilled Model:  80%|████████  | 4/5 [00:00<00:00,  7.51it/s]

Raw output IDs: tensor([0, 0], device='mps:0')
Decoded output: <pad><pad>
raw: 
Raw output IDs: tensor([0, 0], device='mps:0')
Decoded output: <pad><pad>
raw: 


Evaluating Distilled Model: 100%|██████████| 5/5 [00:00<00:00,  7.18it/s]

Raw output IDs: tensor([0, 0], device='mps:0')
Decoded output: <pad><pad>
raw: 

Performance Comparison:



  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


Unnamed: 0,model,accuracy,precision_background,recall_background,precision_method,recall_method,precision_result,recall_result
0,Base Model,0.00%,0.00%,0.00%,0.00%,0.00%,0.00%,0.00%
1,Distilled Model,0.00%,0.00%,0.00%,0.00%,0.00%,0.00%,0.00%



Sample Prediction Comparison:


TypeError: predict_label() takes 2 positional arguments but 3 were given