In [1]:
import os
import pandas as pd
import torch
from datasets import Dataset
from dotenv import load_dotenv
from transformers import (
    AutoTokenizer,
    AutoModelForCausalLM,
    Trainer,
    TrainingArguments,
    BitsAndBytesConfig
)
from peft import get_peft_model, LoraConfig, TaskType, prepare_model_for_kbit_training

load_dotenv()
hf_token = os.getenv("HUGGINGFACE_API_KEY")

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
# ====== Load dataset ======
def load_partition(path: str) -> Dataset:
    df = pd.read_csv(path)
    return Dataset.from_pandas(df)

dataset = load_partition("./merged_dataset.csv")

In [3]:
# ====== Tokenizer & Model Setup ======
model_id = "meta-llama/Llama-3.2-1B-Instruct"

tokenizer = AutoTokenizer.from_pretrained(model_id, token=hf_token, trust_remote_code=True)
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token

bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.float16,
    bnb_4bit_use_double_quant=True,
)

model = AutoModelForCausalLM.from_pretrained(
    model_id,
    token=hf_token,
    trust_remote_code=True,
    torch_dtype=torch.float16,
    quantization_config=bnb_config,
)

model = prepare_model_for_kbit_training(model)
lora_config = LoraConfig(
    r=8,
    lora_alpha=32,
    target_modules=["q_proj", "v_proj"],
    lora_dropout=0.05,
    bias="none",
    task_type=TaskType.CAUSAL_LM
)
model = get_peft_model(model, lora_config)


`low_cpu_mem_usage` was None, now default to True since model is quantized.


In [4]:
# ====== Format data ======
def format_for_distillation(examples):
    prompts, responses, reasonings = [], [], []
    for text, reasoning, classification in zip(examples["string"], examples["reasoning"], examples["model_classification"]):
        prompt = (f"<instruction>Classify the following scientific text as one of [background, method, result].\n\n"
                  f"Text: {text}\n"
                  f"Provide your classification and reasoning in JSON format.</instruction>")
        response = f'<response>{{"classification": "{classification}", "reasoning": "{reasoning}"}}'
        prompts.append(prompt + response)
        reasonings.append(reasoning)

    tokenized = tokenizer(prompts, padding="max_length", truncation=True, max_length=512, return_tensors="pt")
    labels = tokenized["input_ids"].clone()

    # Mask out instruction part
    for i, input_ids in enumerate(tokenized["input_ids"]):
        response_ids = tokenizer("<response>", add_special_tokens=False)["input_ids"]
        for j in range(len(input_ids) - len(response_ids)):
            if input_ids[j:j+len(response_ids)].tolist() == response_ids:
                labels[i, :j+len(response_ids)] = -100
                break

    return {
        "input_ids": tokenized["input_ids"],
        "attention_mask": tokenized["attention_mask"],
        "labels": labels,
        "student_reasoning": reasonings  # Keep for Phase 2
    }

tokenized_dataset = dataset.map(format_for_distillation, batched=True, remove_columns=["id"])

Map: 100%|██████████| 8194/8194 [01:08<00:00, 120.32 examples/s]


In [None]:
# ====== Training Args ======
training_args = TrainingArguments(
    output_dir="llama-student-phase1",
    num_train_epochs=2,
    per_device_train_batch_size=2,
    gradient_accumulation_steps=8,
    learning_rate=1e-4,
    bf16=True,
    logging_steps=10,
    save_strategy="epoch",
    remove_unused_columns=False,
    max_grad_norm=1.0,
    report_to="none"
)

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=tokenized_dataset
)

trainer.train()
model.save_pretrained("llama-student-phase1")
tokenizer.save_pretrained("llama-student-phase1")

No label_names provided for model class `PeftModelForCausalLM`. Since `PeftModel` hides base models input arguments, if label_names is not given, label_names can't be set automatically within `Trainer`. Note that empty label_names list will be used instead.
  return fn(*args, **kwargs)


Step,Training Loss


In [None]:
import torch.nn.functional as F

class ReasoningDistiller(Trainer):
    def __init__(self, *args, reasoning_weight=0.5, use_reasoning_loss=True, **kwargs):
        super().__init__(*args, **kwargs)
        self.reasoning_weight = reasoning_weight
        self.use_reasoning_loss = use_reasoning_loss

        self.reasoning_tokenizer = AutoTokenizer.from_pretrained("sentence-transformers/all-MiniLM-L6-v2")
        self.reasoning_model = AutoModel.from_pretrained("sentence-transformers/all-MiniLM-L6-v2")

    def compute_loss(self, model, inputs, return_outputs=False):
        outputs = model(
            input_ids=inputs["input_ids"],
            attention_mask=inputs["attention_mask"],
            labels=inputs["labels"]
        )
        ce_loss = outputs.loss

        if self.use_reasoning_loss and "student_reasoning" in inputs:
            try:
                generated = model.generate(inputs["input_ids"], max_length=512)
                decoded = self.tokenizer.batch_decode(generated, skip_special_tokens=True)
                student_reasonings = [self.extract_reasoning(txt) for txt in decoded]
                teacher_reasonings = inputs["student_reasoning"]

                student_embeds = self.get_embeddings(student_reasonings)
                teacher_embeds = self.get_embeddings(teacher_reasonings)
                cosine_loss = 1 - F.cosine_similarity(student_embeds, teacher_embeds).mean()
                total_loss = ce_loss + self.reasoning_weight * cosine_loss
            except Exception as e:
                print(f"Skipping cosine loss due to error: {e}")
                total_loss = ce_loss
        else:
            total_loss = ce_loss

        return (total_loss, outputs) if return_outputs else total_loss

    def extract_reasoning(self, text):
        match = re.search(r'"reasoning"\s*:\s*"(.+?)"\s*}', text)
        return match.group(1).strip() if match else ""

    def get_embeddings(self, texts):
        inputs = self.reasoning_tokenizer(texts, padding=True, truncation=True, return_tensors="pt")
        inputs = {k: v.to(self.model.device) for k, v in inputs.items()}
        with torch.no_grad():
            return self.reasoning_model(**inputs).last_hidden_state[:, 0, :]

In [None]:
from peft import PeftModel

model = AutoModelForCausalLM.from_pretrained("llama-student-phase1")
model = PeftModel.from_pretrained(model, "llama-student-phase1")

trainer = ReasoningDistiller(
    model=model,
    args=training_args,
    train_dataset=tokenized_dataset,
    reasoning_weight=0.5,
    use_reasoning_loss=True
)

trainer.train()
model.save_pretrained("llama-student-phase2")
tokenizer.save_pretrained("llama-student-phase2")