### Reference LLM Distillation notebook: https://github.com/simranjeet97/LLM_Distillation/blob/main/LLM_Distillation.ipynb

In [69]:
!pip install -U transformers 

huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)




In [70]:
import os
import pandas as pd
import torch
from datasets import Dataset
from dotenv import load_dotenv
from transformers import (
    AutoTokenizer,
    AutoModelForCausalLM,
    Trainer,
    TrainingArguments,
    BitsAndBytesConfig,
    AutoModelForSequenceClassification
)
from peft import get_peft_model, LoraConfig, TaskType, prepare_model_for_kbit_training

load_dotenv()
hf_token = os.getenv("HUGGINGFACE_API_KEY")

In [71]:
# ====== Load dataset ======
def load_partition(path: str) -> Dataset:
    df = pd.read_csv(path).head(10)
    return Dataset.from_pandas(df)

dataset = load_partition("../Student_Training_Data/GPT.csv") ## should be GPT.csv
print(f"Loaded {len(dataset)} samples from dataset.") 

Loaded 10 samples from dataset.


In [None]:
# ====== Tokenizer & Model Setup ======
model_id = "google-bert/bert-base-uncased" #"google/gemma-3-1b-it"

tokenizer = AutoTokenizer.from_pretrained(model_id, token=hf_token, trust_remote_code=True)
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token

# model = AutoModelForSequenceClassification.from_pretrained(model_id, token=hf_token, trust_remote_code=True)

model = AutoModelForCausalLM.from_pretrained(
    model_id,
    token=hf_token,
    trust_remote_code=True,
    torch_dtype=torch.float16,
)

# model = prepare_model_for_kbit_training(model)
# lora_config = LoraConfig(
#     r=8,
#     lora_alpha=32,
#     target_modules=["q_proj", "v_proj"],
#     lora_dropout=0.05,
#     bias="none",
#     task_type=TaskType.CAUSAL_LM
# )
# model = get_peft_model(model, lora_config) # TODO Why getting PEFT model? Paper and Reference notebook did not use


If you want to use `BertLMHeadModel` as a standalone, add `is_decoder=True.`


In [None]:
## Ignore this part, just for understanding
# encoded_text = tokenizer.tokenize("Paris is the what of France?", return_tensors="pt").to(model.device)

# '''The model(encoded_text) call is more commonly used during training or when you want direct access 
# to the model's raw predictions, while generate() is used when you want the model to complete/continue 
# a sequence.
# '''
# print(encoded_text)

# outputs = model(encoded_text) # different from model.generate which produces logits and loss if labels are provided.
# print(outputs) # logits, loss (if label was given), hidden_states, attentions

# completion = model.generate(encoded_text, max_length=50)
# print(completion)

# decoded_text = tokenizer.decode(completion[0], skip_special_tokens=True)
# print(decoded_text)

CausalLMOutputWithCrossAttentions(loss=None, logits=tensor([[[ -7.0781,  -7.0195,  -7.0234,  ...,  -6.1992,  -6.2891,  -4.2930],
         [-10.1797, -10.5234, -10.4297,  ...,  -9.0859,  -9.5234,  -5.8125],
         [-11.9531, -12.0312, -12.1016,  ...,  -9.8359,  -8.3203,  -9.9219],
         ...,
         [-11.2969, -12.0391, -11.9844,  ...,  -9.0234, -10.1484,  -7.1445],
         [-12.4375, -12.2266, -12.7891,  ..., -10.4844, -11.9922,  -4.1992],
         [-11.7969, -11.5000, -12.0156,  ..., -10.4453, -10.6094,  -8.8047]]],
       device='mps:0', dtype=torch.float16, grad_fn=<LinearBackward0>), past_key_values=None, hidden_states=None, attentions=None, cross_attentions=None)
tensor([[ 101, 3000, 2003, 1996, 2054, 1997, 2605, 1029,  102, 1012, 1012, 1012,
         1012, 1012, 1012, 1012, 1012, 1012, 1012, 1012, 1012, 1012, 1012, 1012,
         1012, 1012, 1012, 1012, 1012, 1012, 1012, 1012, 1012, 1012, 1012, 1012,
         1012, 1012, 1012, 1012, 1012, 1012, 1012, 1012, 1012, 1012, 1012

IndexError: index 1 is out of bounds for dimension 0 with size 1

In [177]:
# ====== Format data ======
def format_for_distillation(example):
    # parse the row with columns: sectionName,string,unique_id,model_classification,reasoning
    section_name, text, unique_id, classification, reasoning = example["sectionName"], example["string"], example["unique_id"], example["model_classification"], example["reasoning"]

    input_text = (f"Classify the following scientific text as one of [background, method, result].\n\n"
                f"Section Name: {section_name}, Text: {text}\n"
                f"Reply with the classification and nothing else.\n")
                
    # Convert classification to numerical index
    label_map = {"background": 0, "method": 1, "result": 2}
    label_idx = torch.tensor(label_map[classification])

    # Tokenize the input and rationale
    input_encoded = tokenizer(input_text, padding="max_length", truncation=True, max_length=512, return_tensors="pt") 
    reasoning_encoded = tokenizer(reasoning, padding="max_length", truncation=True, max_length=512, return_tensors="pt")
    
    # print(f"Label: {label} | Unique ID: {unique_id} | Input: {input_text} | Rationale: {rationale}")

    print(label_idx)
    return {
        "input_ids": input_encoded["input_ids"],#[0], # indices of tokens in the tokenizer's vocabulary
        "attention_mask": input_encoded["attention_mask"],#[0],
        "labels": label_idx,
        "rationale_ids": reasoning_encoded["input_ids"],#[0], 
        "rationale_mask": reasoning_encoded["attention_mask"],#[0]
    }

### Just in case we want to do it by batch (but actually produces the same result as the one above.)
def format_for_distillation_batch_true(examples):
    # Create lists to store the processed data
    input_ids_list = []
    attention_mask_list = []
    labels_list = []
    teacher_rationale_list = []
    rationale_mask_list = []
    label_map = {"background": 0, "method": 1, "result": 2}

    # Process each example in the batch
    for idx in range(len(examples["sectionName"])):
        section_name = examples["sectionName"][idx]
        text = examples["string"][idx]
        unique_id = examples["unique_id"][idx]
        classification = examples["model_classification"][idx]
        reasoning = examples["reasoning"][idx]

        input_text = (f"Classify the following scientific text as one of [background, method, result].\n\n"
                    f"Section Name: {section_name}, Text: {text}\n"
                    f"Reply with the classification and nothing else.\n")
        
        # Tokenize the input and rationale
        input_encoded = tokenizer(input_text, padding="max_length", truncation=True, max_length=512)
        reasoning_encoded = tokenizer(reasoning, padding="max_length", truncation=True, max_length=512)
        
        # print(f"Label: {classification} | Unique ID: {unique_id} | Input: {input_text[:100]}... | Rationale: {reasoning[:100]}...")
        
        # Append to lists
        input_ids_list.append(input_encoded["input_ids"])
        attention_mask_list.append(input_encoded["attention_mask"])
        labels_list.append(torch.tensor(label_map[classification]))
        teacher_rationale_list.append(reasoning_encoded["input_ids"])
        rationale_mask_list.append(reasoning_encoded["attention_mask"])
    
    
    print(labels_list)
    return {
        "input_ids": input_ids_list,
        "attention_mask": attention_mask_list,
        "labels": labels_list,
        "rationale_ids": teacher_rationale_list,
        "rationale_mask": rationale_mask_list
    }

tokenized_dataset = dataset.map(format_for_distillation_batch_true, batched=True)
# tokenized_dataset_unbatched = dataset.map(format_for_distillation, batched=False) #, remove_columns=['sectionName', 'string', 'unique_id', 'model_classification', 'reasoning']) #with batching is a bit faster.
                                # remove_columns=['sectionName', 'string', 'unique_id', 'model_classification', 'reasoning']) #with batching is a bit faster.

Map: 100%|██████████| 10/10 [00:00<00:00, 112.68 examples/s]

[tensor(0), tensor(0), tensor(0), tensor(0), tensor(2), tensor(2), tensor(0), tensor(2), tensor(1), tensor(1)]





In [178]:
tokenized_dataset[0] #for batched version, the tokenzied stuff have no additional empty dimension, like [101, 26268,
# tokenized_dataset_unbatched[0] #for UNbatched version, the tokenzied stuff have an additional empty dimension, like [[101, 26268,

{'sectionName': 'Introduction',
 'string': 'However, how frataxin interacts with the Fe-S cluster biosynthesis components remains unclear as direct one-to-one interactions with each component were reported (IscS [12,22], IscU/Isu1 [6,11,16] or ISD11/Isd11 [14,15]).',
 'unique_id': '1872080baa7d30ec8fb87be9a65358cd3a7fb649>894be9b4ea46a5c422e81ef3c241072d4c73fdc0_11',
 'model_classification': 'background',
 'reasoning': 'The citation provides context, references prior work, or summarizes existing knowledge.',
 'input_ids': [101,
  26268,
  1996,
  2206,
  4045,
  3793,
  2004,
  2028,
  1997,
  1031,
  4281,
  1010,
  4118,
  1010,
  2765,
  1033,
  1012,
  2930,
  2171,
  1024,
  4955,
  1010,
  3793,
  1024,
  2174,
  1010,
  2129,
  25312,
  2696,
  20303,
  11835,
  2015,
  2007,
  1996,
  10768,
  1011,
  1055,
  9324,
  16012,
  6508,
  3372,
  24124,
  6177,
  3464,
  10599,
  2004,
  3622,
  2028,
  1011,
  2000,
  1011,
  2028,
  10266,
  2007,
  2169,
  6922,
  2020,
  2988,
 

In [179]:
# ====== Training Args ======
training_args = TrainingArguments( ## Original Training Args
    output_dir="gemma3-phase1",
    eval_strategy="epoch",
    learning_rate=5e-5,
    per_device_train_batch_size=8,
    per_device_eval_batch_size=8,
    num_train_epochs=3,
    weight_decay=0.01,
    save_strategy="epoch",
    push_to_hub=False,
    remove_unused_columns=False
)

#### Training args in phase 1 distillation before edits
#     num_train_epochs=3,
#     per_device_train_batch_size=1,
#     gradient_accumulation_steps=1,
#     learning_rate=2e-5,
#     max_steps=10,  
#     logging_steps=1,
#     save_strategy="no",
#     remove_unused_columns=False,
#     max_grad_norm=1.0,
#     report_to="none"
# )


In [180]:
class MultiTaskTrainer(Trainer):
    def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None, **kwargs):
        print(inputs.keys())
        labels = inputs.pop("labels")
        rationale_ids = inputs.pop("rationale_ids", None)
        print(f"Labels: {labels} | Rationale IDs: {rationale_ids}")
        
        # outputs = model(input_ids=inputs["input_ids"], attention_mask=inputs["attention_mask"])
        outputs = model(
            **inputs
            # input_ids=inputs["input_ids"],
            # attention_mask=inputs["attention_mask"],
            # labels=inputs["labels"],
            # return_dict=True,
        )
        print(f"Outputs: {outputs}")
        print(f"logits: {outputs.logits}")

        # Reshape logits to [batch_size, num_classes]
        logits = outputs.logits[:, -1, :]  # Take last token's logits
        logits = logits[:, :3]  # Only take logits for the 3 classes
        loss_fn = torch.nn.CrossEntropyLoss()
        label_loss = loss_fn(logits, labels)
        
        if rationale_ids is not None:
            rationale_outputs = model(input_ids=rationale_ids, attention_mask=inputs["attention_mask"])
            rationale_loss = loss_fn(rationale_outputs.logits, rationale_ids)
            loss = label_loss + 0.5 * rationale_loss  # Weighted loss
        else:
            loss = label_loss
        
        return (loss, outputs) if return_outputs else loss

trainer = MultiTaskTrainer(
    model=model,
    args=training_args,
    train_dataset=tokenized_dataset,
    eval_dataset=tokenized_dataset
)

In [181]:
trainer.train()
trainer.save_model("gemma3-phase1-v2")

dict_keys(['input_ids', 'attention_mask', 'labels', 'rationale_ids', 'rationale_mask'])
Labels: tensor([0, 0, 0, 1, 2, 2, 0, 1], device='mps:0') | Rationale IDs: tensor([[  101,  1996, 11091,  ...,     0,     0,     0],
        [  101,  1996, 11091,  ...,     0,     0,     0],
        [  101,  1996, 11091,  ...,     0,     0,     0],
        ...,
        [  101,  1996, 11091,  ...,     0,     0,     0],
        [  101,  1996, 11091,  ...,     0,     0,     0],
        [  101,  1996, 11091,  ...,     0,     0,     0]], device='mps:0')
Outputs: CausalLMOutputWithCrossAttentions(loss=None, logits=tensor([[[nan, nan, nan,  ..., nan, nan, nan],
         [nan, nan, nan,  ..., nan, nan, nan],
         [nan, nan, nan,  ..., nan, nan, nan],
         ...,
         [nan, nan, nan,  ..., nan, nan, nan],
         [nan, nan, nan,  ..., nan, nan, nan],
         [nan, nan, nan,  ..., nan, nan, nan]],

        [[nan, nan, nan,  ..., nan, nan, nan],
         [nan, nan, nan,  ..., nan, nan, nan],
       

RuntimeError: Expected target size [8, 30522], got [8, 512]

In [None]:
trainer = Trainer( # TODO Need to specify the loss function for the trainer. 
    model=model,
    args=training_args,
    train_dataset=tokenized_dataset
) 

trainer.train() # TODO Why is training here when the loss function is defined below?? 
model.save_pretrained("gemma3-phase1")
tokenizer.save_pretrained("gemma3-phase1")

ValueError: You have set `args.eval_strategy` to IntervalStrategy.EPOCH but you didn't pass an `eval_dataset` to `Trainer`. Either set `args.eval_strategy` to `no` or pass an `eval_dataset`. 

In [None]:
import torch.nn.functional as F

class ReasoningDistiller(Trainer): 
    def __init__(self, *args, reasoning_weight=0.5, use_reasoning_loss=True, **kwargs):
        super().__init__(*args, **kwargs)
        self.reasoning_weight = reasoning_weight
        self.use_reasoning_loss = use_reasoning_loss

        self.reasoning_tokenizer = AutoTokenizer.from_pretrained("sentence-transformers/all-MiniLM-L6-v2")
        self.reasoning_model = AutoModel.from_pretrained("sentence-transformers/all-MiniLM-L6-v2")

    def compute_loss(self, model, inputs, return_outputs=False):
        outputs = model( # TODO Forward pass needs to be on prompt and citation without the teacher response and classification
            input_ids=inputs["input_ids"],
            attention_mask=inputs["attention_mask"],
            labels=inputs["labels"]
        )
        
        ce_loss = outputs.loss

        if self.use_reasoning_loss and "student_reasoning" in inputs:
            try:
                generated = model.generate(inputs["input_ids"], max_length=512)
                decoded = self.tokenizer.batch_decode(generated, skip_special_tokens=True)
                student_reasonings = [self.extract_reasoning(txt) for txt in decoded]
                teacher_reasonings = inputs["student_reasoning"]

                student_embeds = self.get_embeddings(student_reasonings)
                teacher_embeds = self.get_embeddings(teacher_reasonings)
                cosine_loss = 1 - F.cosine_similarity(student_embeds, teacher_embeds).mean()
                total_loss = ce_loss + self.reasoning_weight * cosine_loss
            except Exception as e:
                print(f"Skipping cosine loss due to error: {e}")
                total_loss = ce_loss
        else:
            total_loss = ce_loss

        return (total_loss, outputs) if return_outputs else total_loss

    def extract_reasoning(self, text):
        match = re.search(r'"reasoning"\s*:\s*"(.+?)"\s*}', text)
        return match.group(1).strip() if match else ""

    def get_embeddings(self, texts):
        inputs = self.reasoning_tokenizer(texts, padding=True, truncation=True, return_tensors="pt")
        inputs = {k: v.to(self.model.device) for k, v in inputs.items()}
        with torch.no_grad():
            return self.reasoning_model(**inputs).last_hidden_state[:, 0, :]

In [None]:
from peft import PeftModel

model = AutoModelForCausalLM.from_pretrained("llama-student-phase1")
model = PeftModel.from_pretrained(model, "llama-student-phase1")

trainer = ReasoningDistiller(
    model=model,
    args=training_args,
    train_dataset=tokenized_dataset,
    reasoning_weight=0.5,
    use_reasoning_loss=True
)

trainer.train()
model.save_pretrained("llama-student-phase2")
tokenizer.save_pretrained("llama-student-phase2")