### Reference LLM Distillation notebook: https://github.com/simranjeet97/LLM_Distillation/blob/main/LLM_Distillation.ipynb

In [16]:
!pip install -U transformers 

huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)




In [3]:
import os
import pandas as pd
import torch
from datasets import Dataset
from dotenv import load_dotenv
from transformers import (
    AutoTokenizer,
    AutoModelForCausalLM,
    Trainer,
    TrainingArguments,
    BitsAndBytesConfig
)
from peft import get_peft_model, LoraConfig, TaskType, prepare_model_for_kbit_training

load_dotenv()
hf_token = os.getenv("HUGGINGFACE_API_KEY")

  from .autonotebook import tqdm as notebook_tqdm


In [4]:
# ====== Load dataset ======
def load_partition(path: str) -> Dataset:
    df = pd.read_csv(path).head(10)
    return Dataset.from_pandas(df)

dataset = load_partition("../Student_Training_Data/GPT.csv") ## should be GPT.csv
print(f"Loaded {len(dataset)} samples from dataset.") 

Loaded 10 samples from dataset.


In [26]:
# ====== Tokenizer & Model Setup ======
model_id = "google/gemma-3-1b-it"

tokenizer = AutoTokenizer.from_pretrained(model_id, token=hf_token, trust_remote_code=True)
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token

model = AutoModelForCausalLM.from_pretrained(
    model_id,
    token=hf_token,
    trust_remote_code=True,
    torch_dtype=torch.float16,
)

# model = prepare_model_for_kbit_training(model)
# lora_config = LoraConfig(
#     r=8,
#     lora_alpha=32,
#     target_modules=["q_proj", "v_proj"],
#     lora_dropout=0.05,
#     bias="none",
#     task_type=TaskType.CAUSAL_LM
# )
# model = get_peft_model(model, lora_config) # TODO Why getting PEFT model? Paper and Reference notebook did not use


In [50]:
# ====== Format data ======
def format_for_distillation(example):
    # parse the row with columns: sectionName,string,unique_id,model_classification,reasoning
    section_name, text, unique_id, classification, reasoning = example["sectionName"], example["string"], example["unique_id"], example["model_classification"], example["reasoning"]

    input_text = (f"Classify the following scientific text as one of [background, method, result].\n\n"
                f"Section Name: {section_name}, Text: {text}\n"
                f"Reply with the classification and nothing else.\n")
    rationale = reasoning
    # Convert classification to numerical index
    label_map = {"background": 0, "method": 1, "result": 2}
    label_idx = torch.tensor(label_map[classification])

    # Tokenize the input and rationale
    input_encoded = tokenizer(input_text, padding="max_length", truncation=True, max_length=512, return_tensors="pt") 
    reasoning_encoded = tokenizer(rationale, padding="max_length", truncation=True, max_length=512, return_tensors="pt")
    
    # print(f"Label: {label} | Unique ID: {unique_id} | Input: {input_text} | Rationale: {rationale}")

    print(label_idx)
    return {
        "input_ids": input_encoded["input_ids"],#[0], # indices of tokens in the tokenizer's vocabulary
        "attention_mask": input_encoded["attention_mask"],#[0],
        "labels": label_idx,
        "teacher_rationale": reasoning_encoded["input_ids"],#[0], 
        "rationale_mask": reasoning_encoded["attention_mask"],#[0]
    }

def format_for_distillation_batch_true(examples):
    # Create lists to store the processed data
    input_ids_list = []
    attention_mask_list = []
    labels_list = []
    teacher_rationale_list = []
    rationale_mask_list = []
    label_map = {"background": 0, "method": 1, "result": 2}

    # Process each example in the batch
    for idx in range(len(examples["sectionName"])):
        section_name = examples["sectionName"][idx]
        text = examples["string"][idx]
        unique_id = examples["unique_id"][idx]
        classification = examples["model_classification"][idx]
        reasoning = examples["reasoning"][idx]

        input_text = (f"Classify the following scientific text as one of [background, method, result].\n\n"
                    f"Section Name: {section_name}, Text: {text}\n"
                    f"Reply with the classification and nothing else.\n")
        
        # Tokenize the input and rationale
        input_encoded = tokenizer(input_text, padding="max_length", truncation=True, max_length=512)
        reasoning_encoded = tokenizer(reasoning, padding="max_length", truncation=True, max_length=512)
        
        # print(f"Label: {classification} | Unique ID: {unique_id} | Input: {input_text[:100]}... | Rationale: {reasoning[:100]}...")
        
        # Append to lists
        input_ids_list.append(input_encoded["input_ids"])
        attention_mask_list.append(input_encoded["attention_mask"])
        labels_list.append(torch.tensor(label_map[classification]))
        teacher_rationale_list.append(reasoning_encoded["input_ids"])
        rationale_mask_list.append(reasoning_encoded["attention_mask"])
    
    
    print(labels_list)
    return {
        "input_ids": input_ids_list,
        "attention_mask": attention_mask_list,
        "labels": labels_list,
        "teacher_rationale": teacher_rationale_list,
        "rationale_mask": rationale_mask_list
    }

tokenized_dataset = dataset.map(format_for_distillation) #with batching is a bit faster.

Map: 100%|██████████| 10/10 [00:00<00:00, 1051.20 examples/s]

tensor(0)
tensor(0)
tensor(0)
tensor(0)
tensor(2)
tensor(2)
tensor(0)
tensor(2)
tensor(1)
tensor(1)





In [46]:
# ====== Training Args ======
training_args = TrainingArguments(
    output_dir="gemma3-phase1",
    num_train_epochs=3,
    per_device_train_batch_size=1,
    gradient_accumulation_steps=1,
    learning_rate=2e-5,
    max_steps=10,  
    logging_steps=1,
    save_strategy="no",
    remove_unused_columns=False,
    max_grad_norm=1.0,
    report_to="none"
)


In [55]:
class MultiTaskTrainer(Trainer):
    def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None, **kwargs):
        labels = inputs.pop("labels")
        rationale_ids = inputs.pop("teacher_rationale", None)
        print(f"Labels: {labels} | Rationale IDs: {rationale_ids}")
        
        outputs = model(input_ids=inputs["input_ids"], attention_mask=inputs["attention_mask"])

        # Reshape logits to [batch_size, num_classes]
        logits = outputs.logits[:, -1, :]  # Take last token's logits
        logits = logits[:, :3]  # Only take logits for the 3 classes
        loss_fn = torch.nn.CrossEntropyLoss()
        label_loss = loss_fn(logits, labels)
        
        if rationale_ids is not None:
            rationale_outputs = model(input_ids=rationale_ids, attention_mask=inputs["attention_mask"])
            rationale_loss = loss_fn(rationale_outputs.logits, rationale_ids)
            loss = label_loss + 0.5 * rationale_loss  # Weighted loss
        else:
            loss = label_loss
        
        return (loss, outputs) if return_outputs else loss

trainer = MultiTaskTrainer(
    model=model,
    args=training_args,
    train_dataset=tokenized_dataset,
    eval_dataset=tokenized_dataset
)

In [56]:
trainer.train()
trainer.save_model("gemma3-phase1-v2")

Labels: tensor([0], device='mps:0') | Rationale IDs: tensor([[[     0,      0,      0,      0,      0,      0,      0,      0,
               0,      0,      0,      0,      0,      0,      0,      0,
               0,      0,      0,      0,      0,      0,      0,      0,
               0,      0,      0,      0,      0,      0,      0,      0,
               0,      0,      0,      0,      0,      0,      0,      0,
               0,      0,      0,      0,      0,      0,      0,      0,
               0,      0,      0,      0,      0,      0,      0,      0,
               0,      0,      0,      0,      0,      0,      0,      0,
               0,      0,      0,      0,      0,      0,      0,      0,
               0,      0,      0,      0,      0,      0,      0,      0,
               0,      0,      0,      0,      0,      0,      0,      0,
               0,      0,      0,      0,      0,      0,      0,      0,
               0,      0,      0,      0,      0,      0,  

ValueError: too many values to unpack (expected 4)

In [None]:
trainer = Trainer( # TODO Need to specify the loss function for the trainer. 
    model=model,
    args=training_args,
    train_dataset=tokenized_dataset
) 

trainer.train() # TODO Why is training here when the loss function is defined below?? 
model.save_pretrained("gemma3-phase1")
tokenizer.save_pretrained("gemma3-phase1")

No label_names provided for model class `PeftModelForCausalLM`. Since `PeftModel` hides base models input arguments, if label_names is not given, label_names can't be set automatically within `Trainer`. Note that empty label_names list will be used instead.


ValueError: The model did not return a loss from the inputs, only the following keys: logits. For reference, the inputs it received are input_ids,attention_mask,teacher_rationale,rationale_mask.

In [None]:
import torch.nn.functional as F

class ReasoningDistiller(Trainer): 
    def __init__(self, *args, reasoning_weight=0.5, use_reasoning_loss=True, **kwargs):
        super().__init__(*args, **kwargs)
        self.reasoning_weight = reasoning_weight
        self.use_reasoning_loss = use_reasoning_loss

        self.reasoning_tokenizer = AutoTokenizer.from_pretrained("sentence-transformers/all-MiniLM-L6-v2")
        self.reasoning_model = AutoModel.from_pretrained("sentence-transformers/all-MiniLM-L6-v2")

    def compute_loss(self, model, inputs, return_outputs=False):
        outputs = model( # TODO Forward pass needs to be on prompt and citation without the teacher response and classification
            input_ids=inputs["input_ids"],
            attention_mask=inputs["attention_mask"],
            labels=inputs["labels"]
        )
        
        ce_loss = outputs.loss

        if self.use_reasoning_loss and "student_reasoning" in inputs:
            try:
                generated = model.generate(inputs["input_ids"], max_length=512)
                decoded = self.tokenizer.batch_decode(generated, skip_special_tokens=True)
                student_reasonings = [self.extract_reasoning(txt) for txt in decoded]
                teacher_reasonings = inputs["student_reasoning"]

                student_embeds = self.get_embeddings(student_reasonings)
                teacher_embeds = self.get_embeddings(teacher_reasonings)
                cosine_loss = 1 - F.cosine_similarity(student_embeds, teacher_embeds).mean()
                total_loss = ce_loss + self.reasoning_weight * cosine_loss
            except Exception as e:
                print(f"Skipping cosine loss due to error: {e}")
                total_loss = ce_loss
        else:
            total_loss = ce_loss

        return (total_loss, outputs) if return_outputs else total_loss

    def extract_reasoning(self, text):
        match = re.search(r'"reasoning"\s*:\s*"(.+?)"\s*}', text)
        return match.group(1).strip() if match else ""

    def get_embeddings(self, texts):
        inputs = self.reasoning_tokenizer(texts, padding=True, truncation=True, return_tensors="pt")
        inputs = {k: v.to(self.model.device) for k, v in inputs.items()}
        with torch.no_grad():
            return self.reasoning_model(**inputs).last_hidden_state[:, 0, :]

In [None]:
from peft import PeftModel

model = AutoModelForCausalLM.from_pretrained("llama-student-phase1")
model = PeftModel.from_pretrained(model, "llama-student-phase1")

trainer = ReasoningDistiller(
    model=model,
    args=training_args,
    train_dataset=tokenized_dataset,
    reasoning_weight=0.5,
    use_reasoning_loss=True
)

trainer.train()
model.save_pretrained("llama-student-phase2")
tokenizer.save_pretrained("llama-student-phase2")