In [1]:
import random
import sys
from dataclasses import dataclass
from pathlib import Path
from typing import Dict, List

import torch
from accelerate.utils import set_seed
from datasets import load_dataset
from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training, PeftModel
from torch.nn.utils.rnn import pad_sequence
from transformers import (
    AutoModelForCausalLM,
    AutoTokenizer,
    BitsAndBytesConfig,
    Trainer,
    TrainingArguments,
)

sys.path.append(str(Path.cwd().resolve().parent))

from src.config import (
    MODELS_DIR,
    PROCESSED_DATA_DIR,
    TEACHER_SYSTEM_PROMPT,
    TEACHER_USER_PROMPT,
)

In [2]:
%env TOKENIZERS_PARALLELISM=false

env: TOKENIZERS_PARALLELISM=false


In [3]:
SEED = 42
MODE = "label-only" # or "cot"

set_seed(SEED)

random.seed(SEED)
torch.manual_seed(SEED)
if torch.cuda.is_available():
    torch.cuda.manual_seed(SEED)
    torch.cuda.manual_seed_all(SEED)

In [4]:
MODEL_ID = "Qwen/Qwen2.5-3B"
data_path = str(PROCESSED_DATA_DIR / "dataset.jsonl")
output_dir = str(MODELS_DIR / f"qwen2.5_3b_{'sctod' if MODE == 'cot' else 'labelonly'}_lora")

MAX_SEQ_LENGTH = 2048 if MODE == "cot" else 2048
TRAIN_SPLIT = 0.95

In [5]:
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
device = "cuda" if torch.cuda.is_available() else "cpu"
bf16 = True

In [6]:
dataset = load_dataset("json", data_files=data_path, split="train")

In [7]:
print(
    f"Loaded dataset with {len(dataset)} samples and {dataset.num_columns} columns: {dataset.column_names}"
)

Loaded dataset with 24952 samples and 7 columns: ['question_id', 'sample_id', 'question', 'gold_answer_text', 'gold_answer_number', 'teacher_answer_text', 'teacher_answer_number']


In [8]:
qids = sorted(set(dataset["question_id"]))
random.shuffle(qids)

cut = int(len(qids) * TRAIN_SPLIT)

train_qids = set(qids[:cut])
eval_qids = set(qids[cut:])

train_ds = dataset.filter(lambda ex: ex["question_id"] in train_qids)
eval_ds = dataset.filter(lambda ex: ex["question_id"] in eval_qids)

print(f"Train examples: {len(train_ds):,}, Eval examples: {len(eval_ds):,}")

Train examples: 23,688, Eval examples: 1,264


In [9]:
tokenizer = AutoTokenizer.from_pretrained(
    MODEL_ID, use_fast=True, trust_remote_code=True
)

In [10]:
tokenizer.padding_side

'right'

In [11]:
bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_compute_dtype=torch.bfloat16 if bf16 else torch.float16,
    bnb_4bit_use_double_quant=True,
    bnb_4bit_quant_type="nf4",
)

model = AutoModelForCausalLM.from_pretrained(
    MODEL_ID,
    quantization_config=bnb_config,
    dtype=torch.bfloat16 if bf16 else torch.float16,
    device_map="auto",
    trust_remote_code=True,
)
model.config.use_cache = False  # important for gradient checkpointing
model.gradient_checkpointing_enable()

model = prepare_model_for_kbit_training(model)

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

In [12]:
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token
    model.config.pad_token_id = tokenizer.pad_token_id
    model.generation_config.pad_token_id = tokenizer.pad_token_id

In [13]:
lora_config = LoraConfig(
    r=16,
    lora_alpha=32,
    lora_dropout=0.05,
    bias="none",
    target_modules=[
        "q_proj",
        "k_proj",
        "v_proj",
        "o_proj",
        "gate_proj",
        "up_proj",
        "down_proj",
    ],
    task_type="CAUSAL_LM",
)
model = get_peft_model(model, lora_config)
model.print_trainable_parameters()

trainable params: 29,933,568 || all params: 3,115,872,256 || trainable%: 0.9607


In [14]:
def build_prompt_cot(question: str) -> str:
    sys_txt = TEACHER_SYSTEM_PROMPT.strip()
    usr_txt = TEACHER_USER_PROMPT.strip().format(question=question.strip())
    return f"{sys_txt}\n\n{usr_txt}\n"

def build_prompt_label_only(question: str) -> str:
    return (
        "You are a concise math solver. Output only the final line as:\n"
        "Final Answer: <number>\n\n"
        f"Question: {question.strip()}\n"
    )

def format_final_answer(num: int) -> str:
    return f"Final Answer: {num}"

def encode_example(prompt: str, answer: str) -> Dict[str, List[int]]:
    # Build concatenated sequence: [prompt][\n][answer][eos]
    prompt_ids = tokenizer(prompt, add_special_tokens=False)["input_ids"]
    # Prepend newline before answer to separate from prompt
    ans_text = "\n" + answer.strip() + tokenizer.eos_token
    answer_ids = tokenizer(ans_text, add_special_tokens=False)["input_ids"]

    input_ids = (prompt_ids + answer_ids)[:MAX_SEQ_LENGTH]
    prompt_len = min(len(prompt_ids), len(input_ids))
    labels = [-100] * prompt_len + input_ids[prompt_len:]

    return {"input_ids": input_ids, "labels": labels}


def preprocess_batch(batch):
    inputs, labels = [], []
    for q, teacher_answer_text, gold_answer_number in zip(batch["question"], batch["teacher_answer_text"], batch["gold_answer_number"]):
        build_prompt = build_prompt_cot if MODE == "cot" else build_prompt_label_only
        prompt = build_prompt(q)
        ans = teacher_answer_cot if MODE == "cot" else format_final_answer(int(gold_answer_number))
        rec = encode_example(prompt, ans)
        inputs.append(rec["input_ids"])
        labels.append(rec["labels"])
        
    return {"input_ids": inputs, "labels": labels}


train_ds = train_ds.map(
    preprocess_batch, batched=True, remove_columns=train_ds.column_names,load_from_cache_file=False
)
eval_ds = eval_ds.map(
    preprocess_batch, batched=True, remove_columns=eval_ds.column_names,load_from_cache_file=False
)

Map:   0%|          | 0/23688 [00:00<?, ? examples/s]

Map:   0%|          | 0/1264 [00:00<?, ? examples/s]

In [15]:
@dataclass
class DataCollator:
    tokenizer: AutoTokenizer
    pad_to_multiple_of: int = 8  # for Tensor Cores

    def __call__(self, features):
        input_ids = [torch.tensor(f["input_ids"], dtype=torch.long) for f in features]
        labels = [torch.tensor(f["labels"], dtype=torch.long) for f in features]
        pad_id = self.tokenizer.pad_token_id
        input_ids = pad_sequence(input_ids, batch_first=True, padding_value=pad_id)
        labels = pad_sequence(labels, batch_first=True, padding_value=-100)

        if self.pad_to_multiple_of is not None:

            def _pad_to_mult(t, pad_value):
                m = self.pad_to_multiple_of
                if t.size(1) % m != 0:
                    pad_len = m - (t.size(1) % m)
                    pad_tensor = torch.full(
                        (t.size(0), pad_len), pad_value, dtype=t.dtype
                    )
                    t = torch.cat([t, pad_tensor], dim=1)
                return t

            input_ids = _pad_to_mult(input_ids, pad_id)
            labels = _pad_to_mult(labels, -100)

        attention_mask = (input_ids != pad_id).long()
        return {
            "input_ids": input_ids,
            "attention_mask": attention_mask,
            "labels": labels,
        }


collator = DataCollator(tokenizer)

In [16]:
total_train_tokens = sum(len(x) for x in train_ds["input_ids"])
print(f"Approx train tokens (pre-padding): {total_train_tokens:,}")

args = TrainingArguments(
    output_dir=output_dir,
    learning_rate=2e-5,
    num_train_epochs=2,
    per_device_train_batch_size=8,
    per_device_eval_batch_size=8,
    gradient_accumulation_steps=2,
    warmup_ratio=0.03,
    logging_steps=10,
    save_strategy="steps",
    eval_strategy="steps",
    save_steps=100,
    eval_steps=100,
    lr_scheduler_type="linear",
    weight_decay=0.0,
    fp16=not bf16 and torch.cuda.is_available(),
    bf16=bf16,
    optim="paged_adamw_8bit",
    gradient_checkpointing=True,
    group_by_length=True,
    report_to="tensorboard",
    dataloader_num_workers=4,
    dataloader_pin_memory=True,
    load_best_model_at_end=True,
    metric_for_best_model="eval_loss",
    greater_is_better=False,
    run_name="Label-only Fine-Tuning"
)

trainer = Trainer(
    model=model,
    args=args,
    train_dataset=train_ds,
    eval_dataset=eval_ds,
    data_collator=collator,
    processing_class=tokenizer,
)

trainer.train()

The tokenizer has new PAD/BOS/EOS tokens that differ from the model config and generation config. The model config and generation config were aligned accordingly, being updated with the tokenizer's values. Updated tokens: {'bos_token_id': None, 'pad_token_id': 151643}.


Approx train tokens (pre-padding): 2,119,559


Step,Training Loss,Validation Loss
100,0.2666,0.404177
200,0.2069,0.42657
300,0.1124,0.490882
400,0.0454,0.659585
500,0.0602,0.725561


KeyboardInterrupt: 

In [None]:
best_ckpt = '/workspace/chain-of-thought-distillation/code/artifacts/models/qwen2.5_3b_sctod_lora/checkpoint-1900'

In [None]:
bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_compute_dtype=torch.bfloat16,  # you trained with bf16=True
    bnb_4bit_use_double_quant=True,
    bnb_4bit_quant_type="nf4",
)

base_model = AutoModelForCausalLM.from_pretrained(
    MODEL_ID,
    quantization_config=bnb_config,
    device_map="auto",
    trust_remote_code=True,
)

merged = PeftModel.from_pretrained(base_model, best_ckpt, is_trainable=False)
merged = merged.merge_and_unload()
merged.save_pretrained(f"{output_dir}/best_checkpoint")
tokenizer.save_pretrained(f"{output_dir}/best_checkpoint")

In [None]:
def generate_answer(question: str, max_new_tokens: int = 256) -> str:
    model.eval()
    with torch.no_grad():
        prompt = build_prompt(question)
        input_ids = tokenizer(prompt, return_tensors="pt").to(model.device)
        out = model.generate(
            **input_ids,
            max_new_tokens=max_new_tokens,
            do_sample=False,
            temperature=0.0,
            eos_token_id=tokenizer.eos_token_id,
        )
        text = tokenizer.decode(out[0], skip_special_tokens=True)
        return text[len(prompt) :].strip()


print(
    generate_answer(
        "A farm has 3 barns with 12 cows each. It sells 7 cows and buys 5 more. How many cows now?"
    )
)