In [None]:
import random
import sys
from dataclasses import dataclass
from pathlib import Path
from typing import Dict, List

import torch
from accelerate.utils import set_seed
from datasets import load_dataset
from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training
from torch.nn.utils.rnn import pad_sequence
from transformers import (
    AutoModelForCausalLM,
    AutoTokenizer,
    BitsAndBytesConfig,
    Trainer,
    TrainingArguments,
)

sys.path.append(str(Path.cwd().resolve().parent))

from src.config import (
    MODELS_DIR,
    PROCESSED_DATA_DIR,
    TEACHER_SYSTEM_PROMPT,
    TEACHER_USER_PROMPT,
)

In [None]:
SEED = 42

set_seed(SEED)

random.seed(SEED)
torch.manual_seed(SEED)
if torch.cuda.is_available():
    torch.cuda.manual_seed(SEED)
    torch.cuda.manual_seed_all(SEED)

In [None]:
MODEL_ID = "Qwen/Qwen2.5-3B"
data_path = str(PROCESSED_DATA_DIR / "dataset.jsonl")
output_dir = str(MODELS_DIR / "qwen2.5_3b_sctod_lora")

MAX_SEQ_LENGTH = 2048
TRAIN_SPLIT = 0.95

lr = 2e-5
num_epochs = 2
per_device_train_bs = 2
per_device_eval_bs = 2
grad_accum_steps = 8
warmup_ratio = 0.03
logging_steps = 25

In [None]:
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
device = "cuda" if torch.cuda.is_available() else "cpu"
bf16 = False

In [None]:
dataset = load_dataset("json", data_files=data_path, split="train")

In [None]:
print(
    f"Loaded dataset with {len(dataset)} samples and {dataset.num_columns} columns: {dataset.column_names}"
)

In [None]:
qids = sorted(set(dataset["question_id"]))
random.shuffle(qids)

cut = int(len(qids) * TRAIN_SPLIT)

train_qids = set(qids[:cut])
eval_qids = set(qids[cut:])

train_ds = dataset.filter(lambda ex: ex["question_id"] in train_qids)
eval_ds = dataset.filter(lambda ex: ex["question_id"] in eval_qids)

print(f"Train examples: {len(train_ds):,}, Eval examples: {len(eval_ds):,}")

In [None]:
tokenizer = AutoTokenizer.from_pretrained(
    MODEL_ID, use_fast=True, trust_remote_code=True
)

In [None]:
bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_compute_dtype=torch.bfloat16 if bf16 else torch.float16,
    bnb_4bit_use_double_quant=True,
    bnb_4bit_quant_type="nf4",
)

model = AutoModelForCausalLM.from_pretrained(
    MODEL_ID,
    quantization_config=bnb_config,
    dtype=torch.bfloat16 if bf16 else torch.float16,
    device_map="auto",
    trust_remote_code=True,
)
model.config.use_cache = False  # important for gradient checkpointing
model.gradient_checkpointing_enable()

model = prepare_model_for_kbit_training(model)

In [None]:
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token
    model.config.pad_token_id = tokenizer.pad_token_id
    model.generation_config.pad_token_id = tokenizer.pad_token_id

In [None]:
lora_config = LoraConfig(
    r=16,
    lora_alpha=32,
    lora_dropout=0.05,
    bias="none",
    target_modules=[
        "q_proj",
        "k_proj",
        "v_proj",
        "o_proj",
        "gate_proj",
        "up_proj",
        "down_proj",
    ],
    task_type="CAUSAL_LM",
)
model = get_peft_model(model, lora_config)
model.print_trainable_parameters()

In [None]:
def build_prompt(question: str) -> str:
    sys_txt = TEACHER_SYSTEM_PROMPT.strip()
    usr_txt = TEACHER_USER_PROMPT.strip().format(question=question.strip())
    return f"{sys_txt}\n\n{usr_txt}\n"


def encode_example(prompt: str, answer: str) -> Dict[str, List[int]]:
    # Build concatenated sequence: [prompt][\n][answer][eos]
    prompt_ids = tokenizer(prompt, add_special_tokens=False)["input_ids"]
    # Prepend newline before answer to separate from prompt
    ans_text = "\n" + answer.strip() + tokenizer.eos_token
    answer_ids = tokenizer(ans_text, add_special_tokens=False)["input_ids"]

    # Truncate from left if too long (keep tail which includes full answer when possible)
    input_ids = (prompt_ids + answer_ids)[:MAX_SEQ_LENGTH]
    # Compute how many prompt tokens survived after truncation
    prompt_len = min(len(prompt_ids), len(input_ids))
    labels = [-100] * prompt_len + input_ids[prompt_len:]

    return {"input_ids": input_ids, "labels": labels}


def preprocess_batch(batch):
    inputs, labels = [], []
    for q, ans in zip(batch["question"], batch["teacher_answer_text"]):
        prompt = build_prompt(q)
        rec = encode_example(prompt, ans)
        inputs.append(rec["input_ids"])
        labels.append(rec["labels"])
    return {"input_ids": inputs, "labels": labels}


train_ds = train_ds.map(
    preprocess_batch, batched=True, remove_columns=train_ds.column_names
)
eval_ds = eval_ds.map(
    preprocess_batch, batched=True, remove_columns=eval_ds.column_names
)

In [None]:
@dataclass
class DataCollator:
    tokenizer: AutoTokenizer
    pad_to_multiple_of: int = 8  # for Tensor Cores

    def __call__(self, features):
        input_ids = [torch.tensor(f["input_ids"], dtype=torch.long) for f in features]
        labels = [torch.tensor(f["labels"], dtype=torch.long) for f in features]
        pad_id = self.tokenizer.pad_token_id
        input_ids = pad_sequence(input_ids, batch_first=True, padding_value=pad_id)
        labels = pad_sequence(labels, batch_first=True, padding_value=-100)

        # Optional: pad length to multiple of 8
        if self.pad_to_multiple_of is not None:

            def _pad_to_mult(t, pad_value):
                m = self.pad_to_multiple_of
                if t.size(1) % m != 0:
                    pad_len = m - (t.size(1) % m)
                    pad_tensor = torch.full(
                        (t.size(0), pad_len), pad_value, dtype=t.dtype
                    )
                    t = torch.cat([t, pad_tensor], dim=1)
                return t

            input_ids = _pad_to_mult(input_ids, pad_id)
            labels = _pad_to_mult(labels, -100)

        attention_mask = (input_ids != pad_id).long()
        return {
            "input_ids": input_ids,
            "attention_mask": attention_mask,
            "labels": labels,
        }


collator = DataCollator(tokenizer)

In [None]:
total_train_tokens = sum(len(x) for x in train_ds["input_ids"])
print(f"Approx train tokens (pre-padding): {total_train_tokens:,}")

args = TrainingArguments(
    output_dir=output_dir,
    num_train_epochs=num_epochs,
    per_device_train_batch_size=per_device_train_bs,
    per_device_eval_batch_size=per_device_eval_bs,
    gradient_accumulation_steps=grad_accum_steps,
    learning_rate=lr,
    warmup_ratio=warmup_ratio,
    logging_steps=logging_steps,
    save_strategy="epoch",
    eval_strategy="steps",
    eval_steps=500,
    save_total_limit=2,
    lr_scheduler_type="linear",
    weight_decay=0.0,
    fp16=not bf16 and torch.cuda.is_available(),
    bf16=bf16,
    optim="paged_adamw_8bit",
    gradient_checkpointing=True,
    report_to="tensorboard",
)

trainer = Trainer(
    model=model,
    args=args,
    train_dataset=train_ds,
    eval_dataset=eval_ds,
    data_collator=collator,
    processing_class=tokenizer,
)

trainer.train()

In [None]:
model = model.merge_and_unload()
trainer.model.save_pretrained(output_dir)  # saves PEFT adapter weights
tokenizer.save_pretrained(output_dir)
model.save_pretrained(output_dir)

In [None]:
def generate_answer(question: str, max_new_tokens: int = 256) -> str:
    model.eval()
    with torch.no_grad():
        prompt = build_prompt(question)
        input_ids = tokenizer(prompt, return_tensors="pt").to(model.device)
        out = model.generate(
            **input_ids,
            max_new_tokens=max_new_tokens,
            do_sample=False,
            temperature=0.0,
            eos_token_id=tokenizer.eos_token_id,
        )
        text = tokenizer.decode(out[0], skip_special_tokens=True)
        return text[len(prompt) :].strip()


print(
    generate_answer(
        "A farm has 3 barns with 12 cows each. It sells 7 cows and buys 5 more. How many cows now?"
    )
)