### Reference LLM Distillation notebook: https://github.com/simranjeet97/LLM_Distillation/blob/main/LLM_Distillation.ipynb

### Install Dependencies
Install the latest version of `transformers` library. This is to ensure that it is compatible with the latest Hugging Face models and tokenizer.


In [None]:
!pip install -U transformers 

### Import neccessary package

This section loads all essential Python packages for the training workflow. We also load the Hugging Face token stored in the `.env` file here.

In [None]:
import os
import pandas as pd
import torch
from datasets import Dataset
from dotenv import load_dotenv
from transformers import (
    AutoTokenizer,
    Trainer,
    TrainingArguments,
    AutoModelForSeq2SeqLM,
)

load_dotenv()
hf_token = os.getenv("HUGGINGFACE_API_KEY")

### Tokenizer & Model Setup
Set up the tokenizer and model for training

In [None]:
# Tokenizer & Model Setup
tokenizer = AutoTokenizer.from_pretrained("t5-base", token=hf_token, trust_remote_code=True)
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token

model = AutoModelForSeq2SeqLM.from_pretrained("t5-base", device_map="auto")

### Add Special Tokens
Add task-specific tokens (`[label]`, `[rationale]`) to the tokenizer and resize the model's embedding layer to accommodate them.


In [None]:
# Ensure tokenizer has special tokens:
tokenizer.add_special_tokens({
    'additional_special_tokens': ['[label]', '[rationale]']
})
model.resize_token_embeddings(len(tokenizer))

### Load Dataset
Load training dataset (train dataset + teachers generated classification + teacher generated rationale) from a CSV file and convert it into a Hugging Face `Dataset` object.

In [None]:
# Load dataset
def load_partition(path: str) -> Dataset:
    df = pd.read_csv(path)
    return Dataset.from_pandas(df)

dataset = load_partition("../Student_Training_Data/The_King.csv") 
print(f"Loaded {len(dataset)} samples from dataset.") 

### Helper Function: Add Special Tokens If Missing
Check whether `[label]` and `[rationale]` tokens are already in the tokenizer vocab. If the tokens are missing, then it will add them into the tokenizer vocab.

In [None]:
def add_special_tokens_if_missing(tokenizer):
    # Add task-specific tokens if not present
    special_tokens = []
    if "[label]" not in tokenizer.get_vocab():
        special_tokens.append("[label]")
    if "[rationale]" not in tokenizer.get_vocab():
        special_tokens.append("[rationale]")
    
    if special_tokens:
        tokenizer.add_special_tokens({"additional_special_tokens": special_tokens})
    return tokenizer

# Update tokenizer with special tokens
tokenizer = add_special_tokens_if_missing(tokenizer)


### Tokenize Dataset
Tokenizes the input data for both label classification and rationale generation tasks.

In [None]:
def tokenize_function(examples):
    # Create base text inputs
    base_texts = [
        f"Section Name: {sn}\nText: {txt}" 
        for sn, txt in zip(examples["sectionName"], examples["string"])
    ]

    # Create task-specific inputs
    label_inputs = [f"[label] {text} \nLabel (either background, method or result):" for text in base_texts]
    rationale_inputs = [f"[rationale] {text} \nRationale:" for text in base_texts]

    # Tokenize base inputs (for potential shared encoder)
    base_encoded = tokenizer(
        base_texts,
        padding="max_length",
        truncation=True,
        max_length=256,  # Reserve space for prefixes
        return_tensors="pt"
    )

    # Tokenize label task inputs and targets
    label_encoded = tokenizer(
        label_inputs,
        padding="max_length",
        truncation=True,
        max_length=512,
        return_tensors="pt"
    )
    
    # Tokenize label targets (text labels, not indices)
    label_targets = tokenizer(
        examples["model_classification"],
        padding="max_length",
        truncation=True, 
        max_length=32,  # Short length for class labels
        return_tensors="pt"
    )

    # Tokenize rationale task inputs
    rationale_encoded = tokenizer(
        rationale_inputs,
        padding="max_length",
        truncation=True,
        max_length=512,
        return_tensors="pt"
    )

    # Tokenize rationale targets
    rationale_targets = tokenizer(
        examples["reasoning"],
        padding="max_length",
        truncation=True,
        max_length=512,
        return_tensors="pt"
    )

    return {
        # Base inputs (shared between tasks)
        "base_input_ids": base_encoded.input_ids,
        "base_attention_mask": base_encoded.attention_mask,

        # Label prediction task
        "label_input_ids": label_encoded.input_ids,
        "label_attention_mask": label_encoded.attention_mask,
        "label_target_ids": label_targets.input_ids,

        # Rationale generation task
        "rationale_input_ids": rationale_encoded.input_ids,
        "rationale_attention_mask": rationale_encoded.attention_mask,
        "rationale_target_ids": rationale_targets.input_ids,
    }

# Apply tokenization
tokenized_dataset = dataset.map(
    tokenize_function,
    batched=True,
    batch_size=32,
    remove_columns=dataset.column_names  # Remove original columns
)

# Set format for PyTorch
tokenized_dataset.set_format(type="torch", columns=[
    "base_input_ids",
    "base_attention_mask",
    "label_input_ids",
    "label_attention_mask", 
    "label_target_ids",
    "rationale_input_ids",
    "rationale_attention_mask",
    "rationale_target_ids"
])

### TrainingArguments Configuration

- `output_dir`: Directory where model checkpoints and training outputs will be saved
- `logging_dir`: Directory where training logs will be stored
- `fp16`: 16-bit floating-point precision training (less memory-efficient but more precise)
- `bf16`: bfloat16 precision format, which provides better numerical stability than fp16 while still offering memory savings (requires compatible hardware)
- `per_device_train_batch_size`: Number of training examples processed per device (GPU/TPU) in each forward pass
- `per_device_eval_batch_size`: Number of evaluation examples processed per device in each forward pass
- `gradient_accumulation_steps`: Accumulates gradients over 2 batches before performing a parameter update
- `learning_rate`: Initial learning rate for the optimizer
- `num_train_epochs`: Number of complete passes through the training dataset
- `report_to`: Integration with tracking platforms
- `save_strategy`: Automatic model checkpointing during training
- `remove_unused_columns`: Preserves all columns in the dataset

In [None]:
training_args = TrainingArguments(
    output_dir="./results",
    fp16=False,  
    bf16=True,   # You can try enabling this if you have newer hardware
    per_device_train_batch_size=4,
    per_device_eval_batch_size=4,
    gradient_accumulation_steps=2,
    learning_rate=5e-5,
    num_train_epochs=3,
    logging_dir="./logs",
    report_to="none",
    save_strategy="no",
    remove_unused_columns=False
)

### Define trained model name
Sets the name used for saving or tracking the trained model.


In [None]:
new_trained_model_name = "distilled_t5"

### Custom Multi-Task Trainer: `MultiTaskTrainer`

This section defines a custom subclass of Hugging Face's `Trainer` to handle **multi-task learning** with both label classification and rationale generation.

---

#### MultiTaskTrainer Class:

- Inherits from `transformers.Trainer` and overrides the below functions.
  
#### 1. `compute_loss(...)`

This method overrides the base `Trainer`’s loss computation to enable multi-task training. This method calculates the combined loss for both the label prediction and rationale generation tasks.

##### `Process`
- `Configuration`: Sets alpha = 0.3 as the weighting hyperparameter (λ) to balance between label and rationale tasks
  Initializes the CrossEntropyLoss function, ignoring padding tokens (-100)

- `Label Task Processing`: Shifts the target ids right to create decoder inputs for the label task
  Passes label inputs through the model to generate label predictions
  Calculates cross-entropy loss between predicted and target label distributions


- `Rationale Task Processing`: Passes rationale inputs through the model
Retrieves the rationale loss directly from model outputs

- `Final Weighted Loss`: The two losses are combined using a task-balancing hyperparameter $ \alpha $.

    $$
    \text{total\_loss} = (1 - \alpha) \cdot \text{label\_loss} + \alpha \cdot \text{rationale\_loss}
    $$

#### `Return`

- If return_outputs is `True`: Returns a tuple containing the total loss and model outputs <br>
- If return_outputs is `False`: Returns only the total loss

---

#### 2. `prediction_step(...)`

This method handles model predictions during evaluation and testing phases.

##### `Process`

- `Input Preparation`: Separates inputs into label-specific and rationale-specific dictionaries


- `Individual Predictions`: Uses the parent Trainer.prediction_step method to get predictions for each task
Extracts loss, logits, and labels from both prediction outputs


- `Combined Loss Calculation`: Applies the weighting formula

    $$
    \text{total\_loss} = \alpha \cdot \text{label\_loss} + (1 - \alpha) \cdot \text{rationale\_loss}
    $$
  Note: The weighting formula here appears to be reversed from the one in compute_loss

#### `Return`

- If prediction_loss_only is `True`: Returns only the combined loss <br>
- If prediction_loss_only is `False`: Returns a tuple containing: The combined loss, a list containing logits from both tasks, a list containing labels from both tasks

In [None]:
class MultiTaskTrainer(Trainer):
    def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None, **kwargs):
        alpha = 0.3  # λ hyperparameter from the paper
        ce_loss = torch.nn.CrossEntropyLoss(ignore_index=-100)

        # Label Task ----------------------------------------------------------
        # Create decoder inputs by shifting labels
        label_decoder_input_ids = model._shift_right(inputs["label_target_ids"])
        
        # Process Label Task --------------------------------------------------
        label_outputs = model(
            input_ids=inputs["label_input_ids"],
            attention_mask=inputs["label_attention_mask"],
            decoder_input_ids=label_decoder_input_ids,
            return_dict=True
        )

        # Calculate loss for label prediction
        label_loss = ce_loss(
            label_outputs.logits.view(-1, model.config.vocab_size),
            inputs["label_target_ids"].view(-1)
        )

        # Process Rationale Task ----------------------------------------------
        rationale_outputs = model(
            input_ids=inputs["rationale_input_ids"],
            attention_mask=inputs["rationale_attention_mask"],
            labels=inputs["rationale_target_ids"]
        )
        rationale_loss = rationale_outputs.loss

        # Combine Losses ------------------------------------------------------
        total_loss = (1 - alpha) * label_loss + alpha * rationale_loss

        return (total_loss, (label_outputs, rationale_outputs)) if return_outputs else total_loss

    def prediction_step(self, model, inputs, prediction_loss_only, ignore_keys=None):
        alpha = 0.3

        label_inputs = {
            "input_ids": inputs["label_input_ids"],
            "attention_mask": inputs["label_attention_mask"],
            "labels": inputs["label_target_ids"]
        }

        rationale_inputs = {
            "input_ids": inputs["rationale_input_ids"],
            "attention_mask": inputs["rationale_attention_mask"],
            "labels": inputs["rationale_target_ids"]
        }

        # super calls the parent class Trainer's prediction_step method 
        label_outputs = super().prediction_step(model, label_inputs, prediction_loss_only=False, ignore_keys=ignore_keys)
        rationale_outputs = super().prediction_step(model, rationale_inputs, prediction_loss_only=False, ignore_keys=ignore_keys)
        # this will reutrn (loss, logits, labels) which we can unpack
        label_loss, label_logits, label_labels = label_outputs
        rationale_loss, rationale_logits, rationale_labels = rationale_outputs
        # combine the losses now
        loss = (alpha * label_loss) + (1 - alpha) * rationale_loss

        if prediction_loss_only:
            return (loss, None, None)

        return (
            loss,
            [label_logits, rationale_logits],
            [label_labels, rationale_labels]
        )

### Trainer Initialization:

This code initializes and executes the training process for a multi-task citation intent classification model using the custom MultiTaskTrainer class.

#### Parameters

- `model`: The pre-trained transformer model being fine-tuned
- `args`: Training arguments defining hyperparameters and settings
- `train_dataset`: The tokenized dataset used for training
- `eval_dataset`: The tokenized dataset used for evaluation (same as training dataset in this case)
- `data_collator`: A function that collates individual data points into batches. It extracts and stacks tensors for both tasks (label and rationale).

In [None]:
# Initialize Trainer
trainer = MultiTaskTrainer(
    model=model,
    args=training_args,
    train_dataset=tokenized_dataset,
    eval_dataset=tokenized_dataset,
    data_collator=lambda data: {
        "label_input_ids": torch.stack([d["label_input_ids"] for d in data]),
        "label_attention_mask": torch.stack([d["label_attention_mask"] for d in data]),
        "label_target_ids": torch.stack([d["label_target_ids"] for d in data]),
        "rationale_input_ids": torch.stack([d["rationale_input_ids"] for d in data]),
        "rationale_attention_mask": torch.stack([d["rationale_attention_mask"] for d in data]),
        "rationale_target_ids": torch.stack([d["rationale_target_ids"] for d in data])
    }
)


### Training Execution
This part of code initiates the training process using the configured trainer. After training the model, it saves the tokenizer, trained model weights and configuration to the specified directory

In [None]:
trainer.train()
trainer.save_model(f"./{new_trained_model_name}")
tokenizer.save_pretrained(f"./{new_trained_model_name}")

## Model Evaluation
This part of the code loads the trained model for evaluation.

In [None]:
import json
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
from sklearn.metrics import accuracy_score, classification_report

# Configuration
model_path = f"./{new_trained_model_name}" ## TODO: Note that this model is trained only on a 1000 samples! Because the paper says 25% of full training ata was alr good enough, so i wanted to just test with a smaller number of samples first.
test_data_path = "../data/test.jsonl"
device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")

# Load the distilled model
# tokenizer = AutoTokenizer.from_pretrained(model_path)
# After loading tokenizer
model = AutoModelForSeq2SeqLM.from_pretrained(model_path, trust_remote_code=True, torch_dtype=torch.float32).to(device)
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)

def load_test_data(file_path):
    """Load and parse test data"""
    test_data = []
    with open(file_path, 'r') as f:
        for line in f:
            entry = json.loads(line)
            test_data.append({
                "section": entry["sectionName"],
                "text": entry["string"],
                "true_label": entry["label"]
            })
    return test_data

def preprocess_input(section, text):
    """Format input with task prefix"""
    input_text = f"[label] Section: {section}\nText: {text} \nLabel (either background, method or result):" ## TODO: NOTE THAT THIS IS KEYyyyy
    return tokenizer(
        input_text,
        padding="max_length",
        truncation=True,
        max_length=512,
        return_tensors="pt"
    ).to(device)

def predict_label(model, inputs):
    """Generate label prediction"""
    with torch.no_grad():
        outputs = model.generate(
            input_ids=inputs["input_ids"],
            attention_mask=inputs["attention_mask"],
            max_new_tokens=10,
            # For deterministic results (default):
            do_sample=False,  # Disables sampling
            num_beams=3,     # Beam search works better for Seq2Seq
            early_stopping=False,
            # Remove temperature parameter when do_sample=False
            decoder_start_token_id=tokenizer.pad_token_id, #critical for T5
            pad_token_id=tokenizer.pad_token_id,
            # forced_bos_token_id=tokenizer.convert_tokens_to_ids("method"),
            # eos_token_id=tokenizer.eos_token_id,
        )

    # Debug raw outputs
    print("Raw output IDs:", outputs[0])
    # print("Decoded output:", tokenizer.decode(outputs[0]))
    return tokenizer.decode(outputs[0], skip_special_tokens=True)

def clean_prediction(raw_prediction):
    """Extract label from model output"""
    # Split on "Label:" and take the first word after it
    print(f"Raw: {raw_prediction}")
    # parts = raw_prediction.split("Label:")
    if len(raw_prediction) > 1:
        prediction = raw_prediction.strip().split()[0].lower()
        # Map to valid labels
        valid_labels = {"background", "method", "result"}
        print(f"Prediction: {prediction}")
        return prediction if prediction in valid_labels else "unknown"
    return "unknown"

# Load test data
test_data = load_test_data(test_data_path)

# Run predictions
true_labels = []
pred_labels = []

for example in test_data:
    # Preprocess input
    inputs = preprocess_input(example["section"], example["text"])
    
    # Get prediction
    raw_pred = predict_label(model, inputs)
    cleaned_label = clean_prediction(raw_pred)
    
    # Store results
    true_labels.append(example["true_label"])
    pred_labels.append(cleaned_label)
    
    # Print example (optional)
    print(f"Section: {example['section']}")
    print(f"Text: {example['text'][:100]}...")
    print(f"True: {example['true_label']} | Pred: {cleaned_label}")
    print("-" * 80)

# Calculate accuracy
accuracy = accuracy_score(true_labels, pred_labels)
print(f"\nTest Accuracy: {accuracy:.4f}")

# Save results
with open("predictions_t5_trained.csv", "w") as f:
    f.write("true_label,predicted_label\n")
    for true, pred in zip(true_labels, pred_labels):
        f.write(f"{true},{pred}\n")

In [None]:
#%%
import pandas as pd
from tqdm import tqdm

def evaluate_and_save(model, tokenizer, test_data, model_name="Model", save_path=None):
    """Evaluate model and save predictions to CSV"""
    true_labels = []
    pred_labels = []
    raw_preds = []
    sections = []
    texts = []
    
    for example in tqdm(test_data, desc=f"Evaluating {model_name}"):
        inputs = preprocess_input(example["section"], example["text"])
        raw_pred = predict_label(model, inputs)
        cleaned_label = clean_prediction(raw_pred)
        
        # Collect data for CSV
        sections.append(example["section"])
        texts.append(example["text"])
        true_labels.append(example["true_label"])
        pred_labels.append(cleaned_label)
        raw_preds.append(raw_pred)
    
    # Create DataFrame
    results_df = pd.DataFrame({
        "section": sections,
        "text": texts,
        "true_label": true_labels,
        "predicted_label": pred_labels,
        "raw_prediction": raw_preds
    })
    
    # Save to CSV if path specified
    if save_path:
        results_df.to_csv(save_path, index=False)
        print(f"Saved predictions to {save_path}")
    
    # Calculate metrics
    accuracy = accuracy_score(true_labels, pred_labels)
    class_report = classification_report(true_labels, pred_labels, output_dict=True)
    
    return {
        "model": model_name,
        "accuracy": accuracy,
        "precision_background": class_report["background"]["precision"],
        "recall_background": class_report["background"]["recall"],
        "precision_method": class_report["method"]["precision"],
        "recall_method": class_report["method"]["recall"],
        "precision_result": class_report["result"]["precision"],
        "recall_result": class_report["result"]["recall"],
    }
#%% [markdown]
#### 1. Load Base Model (Pre-trained)
#%%
base_model = AutoModelForSeq2SeqLM.from_pretrained("t5-base").to(device)
base_tokenizer = AutoTokenizer.from_pretrained("t5-base") 

# Add special tokens if missing
special_tokens = ["[label]", "[rationale]"]
base_tokenizer.add_special_tokens({"additional_special_tokens": special_tokens})
base_model.resize_token_embeddings(len(base_tokenizer))

#%% [markdown]
#### 2. Load Distilled Model (Fine-tuned)
#%%
distilled_model = AutoModelForSeq2SeqLM.from_pretrained(f"./{new_trained_model_name}").to(device)
distilled_tokenizer = AutoTokenizer.from_pretrained(f"./{new_trained_model_name}")

#%% [markdown]
#### 3. Evaluate Both Models
#%%
test_data = load_test_data(test_data_path)[:5]  # Use subset for faster evaluation

# Evaluate base model and save
base_results = evaluate_and_save(
    base_model, 
    base_tokenizer,
    test_data,
    model_name="Base Model",
    save_path="base_model_predictions.csv"
)

# Evaluate distilled model and save
distilled_results = evaluate_and_save(
    distilled_model,
    distilled_tokenizer,
    test_data,
    model_name="Distilled Model",
    save_path="distilled_model_predictions.csv"
)

#%% [markdown]
#### 4. Display Comparison
#%%
results_df = pd.DataFrame([base_results, distilled_results])
print("\nPerformance Comparison:")
display(results_df.style
       .format("{:.2%}", subset=["accuracy", "precision_background", "recall_background", 
                                "precision_method", "recall_method", 
                                "precision_result", "recall_result"])
       .background_gradient(cmap="Blues", subset=["accuracy"]))

#%% [markdown]
#### 5. Sample Predictions Comparison
#%%
print("\nSample Prediction Comparison:")
sample_data = test_data[:3]  # First 3 examples

for example in sample_data:
    # Base model prediction
    inputs = preprocess_input(example["section"], example["text"])
    base_pred = clean_prediction(predict_label(base_model, inputs))
    
    # Distilled model prediction
    inputs = preprocess_input(example["section"], example["text"])
    distilled_pred = clean_prediction(predict_label(distilled_model, inputs))
    
    print(f"\nSection: {example['section']}")
    print(f"Text: {example['text'][:100]}...")
    print(f"True Label: {example['true_label']}")
    print(f"Base Model: {base_pred} | Distilled Model: {distilled_pred}")
    print("-" * 80)