In [1]:
import random
import sys
from dataclasses import dataclass
from pathlib import Path

import torch
from accelerate.utils import set_seed
from datasets import load_dataset
from peft import LoraConfig, PeftModel, get_peft_model, prepare_model_for_kbit_training
from torch.nn.utils.rnn import pad_sequence
from transformers import (
    AutoModelForCausalLM,
    AutoTokenizer,
    BitsAndBytesConfig,
    Trainer,
    TrainingArguments,
)

sys.path.append(str(Path.cwd().resolve().parent))

from src.config import (
    COT_SYSTEM_PROMPT,
    COT_USER_PROMPT,
    LABEL_ONLY_SYSTEM_PROMPT,
    LABEL_ONLY_USER_PROMPT,
    MODELS_DIR,
    PROCESSED_DATA_DIR,
)

In [2]:
%env TOKENIZERS_PARALLELISM=false

env: TOKENIZERS_PARALLELISM=false


In [3]:
SEED = 42
MODE = "cot"  # "label-only" or "cot"

set_seed(SEED)

random.seed(SEED)
torch.manual_seed(SEED)
if torch.cuda.is_available():
    torch.cuda.manual_seed(SEED)
    torch.cuda.manual_seed_all(SEED)

In [4]:
MODEL_ID = "Qwen/Qwen2.5-3B"
data_path = str(PROCESSED_DATA_DIR / "dataset.jsonl")
output_dir = str(
    MODELS_DIR / f"qwen2.5_3b_{'sctod' if MODE == 'cot' else 'labelonly'}_lora"
)

MAX_SEQ_LENGTH = 2048
TRAIN_SPLIT = 0.8

In [5]:
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
device = "cuda" if torch.cuda.is_available() else "cpu"
bf16 = True

In [6]:
dataset = load_dataset("json", data_files=data_path, split="train")

In [7]:
print(
    f"Loaded dataset with {len(dataset)} samples and {dataset.num_columns} columns: {dataset.column_names}"
)

Loaded dataset with 24952 samples and 7 columns: ['question_id', 'sample_id', 'question', 'gold_answer_text', 'gold_answer_number', 'teacher_answer_text', 'teacher_answer_number']


In [8]:
qids = sorted(set(dataset["question_id"]))
random.shuffle(qids)

cut = int(len(qids) * TRAIN_SPLIT)

train_qids = set(qids[:cut])
eval_qids = set(qids[cut:])

train_ds = dataset.filter(lambda ex: ex["question_id"] in train_qids)
eval_ds = dataset.filter(lambda ex: ex["question_id"] in eval_qids)

print(f"Train examples: {len(train_ds):,}, Eval examples: {len(eval_ds):,}")

Train examples: 20,028, Eval examples: 4,924


In [None]:
def dedupe_by_question(ds):
    ds_sorted = ds.sort(["question_id", "sample_id"])
    keep_indices: list[int] = []
    seen = set()
    for i, ex in enumerate(ds_sorted):
        qid = int(ex["question_id"])
        if qid not in seen:
            seen.add(qid)
            keep_indices.append(i)
    return ds_sorted.select(keep_indices)


if MODE == "label-only":
    train_ds = dedupe_by_question(train_ds)
    eval_ds = dedupe_by_question(eval_ds)

In [10]:
tokenizer = AutoTokenizer.from_pretrained(
    MODEL_ID, use_fast=True, trust_remote_code=True
)

In [11]:
tokenizer.padding_side

'right'

In [12]:
bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_compute_dtype=torch.bfloat16 if bf16 else torch.float16,
    bnb_4bit_use_double_quant=True,
    bnb_4bit_quant_type="nf4",
)

model = AutoModelForCausalLM.from_pretrained(
    MODEL_ID,
    quantization_config=bnb_config,
    dtype=torch.bfloat16 if bf16 else torch.float16,
    device_map="auto",
    trust_remote_code=True,
)
model.config.use_cache = False  # important for gradient checkpointing
model.gradient_checkpointing_enable()

model = prepare_model_for_kbit_training(model)

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

In [13]:
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token
    model.config.pad_token_id = tokenizer.pad_token_id
    model.generation_config.pad_token_id = tokenizer.pad_token_id

In [14]:
lora_config = LoraConfig(
    r=16,
    lora_alpha=32,
    lora_dropout=0.05,
    bias="none",
    target_modules=[
        "q_proj",
        "k_proj",
        "v_proj",
        "o_proj",
        "gate_proj",
        "up_proj",
        "down_proj",
    ],
    task_type="CAUSAL_LM",
)
model = get_peft_model(model, lora_config)
model.print_trainable_parameters()

trainable params: 29,933,568 || all params: 3,115,872,256 || trainable%: 0.9607


In [None]:
def build_prompt_cot(question: str) -> str:
    sys_txt = COT_SYSTEM_PROMPT.strip()
    usr_txt = COT_USER_PROMPT.strip().format(question=question.strip())
    return f"{sys_txt}\n\n{usr_txt}\n"


def build_prompt_label_only(question: str) -> str:
    sys_txt = LABEL_ONLY_SYSTEM_PROMPT.strip()
    usr_txt = LABEL_ONLY_USER_PROMPT.strip().format(question=question.strip())
    return f"{sys_txt}\n\n{usr_txt}\n"


def format_final_answer(num) -> str:
    return f"Final Answer: {num}"


def encode_example(prompt: str, answer: str) -> dict[str, list[int]]:
    prompt_ids = tokenizer(prompt, add_special_tokens=False)["input_ids"]
    ans_text = answer.strip() + tokenizer.eos_token
    answer_ids = tokenizer(ans_text, add_special_tokens=False)["input_ids"]

    allowed_prompt = MAX_SEQ_LENGTH - len(answer_ids)
    prompt_ids = prompt_ids[-max(0, allowed_prompt) :]
    input_ids = prompt_ids + answer_ids
    labels = [-100] * len(prompt_ids) + answer_ids
    return {"input_ids": input_ids, "labels": labels}


def preprocess_batch(batch):
    inputs, labels = [], []
    for q, teacher_answer_text, gold_answer_number in zip(
        batch["question"], batch["teacher_answer_text"], batch["gold_answer_number"]
    ):
        build_prompt = build_prompt_cot if MODE == "cot" else build_prompt_label_only
        prompt = build_prompt(q)
        ans = (
            teacher_answer_text
            if MODE == "cot"
            else format_final_answer(gold_answer_number)
        )
        rec = encode_example(prompt, ans)
        inputs.append(rec["input_ids"])
        labels.append(rec["labels"])
    return {"input_ids": inputs, "labels": labels}


train_ds = train_ds.map(
    preprocess_batch,
    batched=True,
    remove_columns=train_ds.column_names,
    load_from_cache_file=False,
)
eval_ds = eval_ds.map(
    preprocess_batch,
    batched=True,
    remove_columns=eval_ds.column_names,
    load_from_cache_file=False,
)

Map: 100%|##########| 20028/20028 [00:00<?, ? examples/s]

Map: 100%|##########| 4924/4924 [00:00<?, ? examples/s]

In [16]:
@dataclass
class DataCollator:
    tokenizer: AutoTokenizer
    pad_to_multiple_of: int = 8  # for Tensor Cores

    def __call__(self, features):
        input_ids = [torch.tensor(f["input_ids"], dtype=torch.long) for f in features]
        labels = [torch.tensor(f["labels"], dtype=torch.long) for f in features]
        pad_id = self.tokenizer.pad_token_id
        input_ids = pad_sequence(input_ids, batch_first=True, padding_value=pad_id)
        labels = pad_sequence(labels, batch_first=True, padding_value=-100)

        if self.pad_to_multiple_of is not None:

            def _pad_to_mult(t, pad_value):
                m = self.pad_to_multiple_of
                if t.size(1) % m != 0:
                    pad_len = m - (t.size(1) % m)
                    pad_tensor = torch.full(
                        (t.size(0), pad_len), pad_value, dtype=t.dtype
                    )
                    t = torch.cat([t, pad_tensor], dim=1)
                return t

            input_ids = _pad_to_mult(input_ids, pad_id)
            labels = _pad_to_mult(labels, -100)

        attention_mask = (input_ids != pad_id).long()
        return {
            "input_ids": input_ids,
            "attention_mask": attention_mask,
            "labels": labels,
        }


collator = DataCollator(tokenizer)

In [None]:
total_train_tokens = sum(len(x) for x in train_ds["input_ids"])
print(f"Approx train tokens (pre-padding): {total_train_tokens:,}")

if MODE == "cot":
    training_args = {
        "learning_rate": 2e-5,
        "num_train_epochs": 10,
        "per_device_train_batch_size": 8,
        "per_device_eval_batch_size": 8,
        "gradient_accumulation_steps": 2,
        "warmup_ratio": 0.03,
        "weight_decay": 0,
        "save_steps": 200,
        "eval_steps": 200,
    }
elif MODE == "label-only":
    training_args = {
        "learning_rate": 1e-5,
        "num_train_epochs": 10,
        "per_device_train_batch_size": 16,
        "per_device_eval_batch_size": 16,
        "gradient_accumulation_steps": 1,
        "warmup_ratio": 0.1,
        "weight_decay": 0.1,
        "max_grad_norm": 1.0,
        "save_steps": 50,
        "eval_steps": 50,
    }

args = TrainingArguments(
    **training_args,
    output_dir=output_dir,
    logging_steps=10,
    save_strategy="steps",
    eval_strategy="steps",
    lr_scheduler_type="linear",
    fp16=not bf16 and torch.cuda.is_available(),
    bf16=bf16,
    optim="paged_adamw_8bit",
    gradient_checkpointing=True,
    group_by_length=True,
    report_to="tensorboard",
    dataloader_num_workers=4,
    dataloader_pin_memory=True,
    load_best_model_at_end=True,
    metric_for_best_model="eval_loss",
    greater_is_better=False,
    seed=SEED,
    data_seed=SEED,
)

trainer = Trainer(
    model=model,
    args=args,
    train_dataset=train_ds,
    eval_dataset=eval_ds,
    data_collator=collator,
    processing_class=tokenizer,
)

try:
    trainer.train()
except KeyboardInterrupt:
    print("Training interrupted keyboard input")

The tokenizer has new PAD/BOS/EOS tokens that differ from the model config and generation config. The model config and generation config were aligned accordingly, being updated with the tokenizer's values. Updated tokens: {'bos_token_id': None, 'pad_token_id': 151643}.


Approx train tokens (pre-padding): 11,975,867


Step,Training Loss,Validation Loss
200,0.2385,0.234745
400,0.1904,0.225248
600,0.1711,0.219601
800,0.1741,0.218803
1000,0.1811,0.215693
1200,0.1902,0.214166
1400,0.1597,0.214895
1600,0.1689,0.21491
1800,0.146,0.214741
2000,0.1762,0.213399


In [18]:
trainer.state.best_model_checkpoint

'/workspace/chain-of-thought-distillation/code/artifacts/models/qwen2.5_3b_sctod_lora/checkpoint-2000'

In [None]:
bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_compute_dtype=torch.bfloat16,  # you trained with bf16=True
    bnb_4bit_use_double_quant=True,
    bnb_4bit_quant_type="nf4",
)

base_model = AutoModelForCausalLM.from_pretrained(
    MODEL_ID,
    quantization_config=bnb_config,
    device_map="auto",
    trust_remote_code=True,
)

model = PeftModel.from_pretrained(
    base_model, trainer.state.best_model_checkpoint, is_trainable=False
)

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]



('/workspace/chain-of-thought-distillation/code/artifacts/models/qwen2.5_3b_sctod_lora/best_checkpoint/tokenizer_config.json',
 '/workspace/chain-of-thought-distillation/code/artifacts/models/qwen2.5_3b_sctod_lora/best_checkpoint/special_tokens_map.json',
 '/workspace/chain-of-thought-distillation/code/artifacts/models/qwen2.5_3b_sctod_lora/best_checkpoint/chat_template.jinja',
 '/workspace/chain-of-thought-distillation/code/artifacts/models/qwen2.5_3b_sctod_lora/best_checkpoint/vocab.json',
 '/workspace/chain-of-thought-distillation/code/artifacts/models/qwen2.5_3b_sctod_lora/best_checkpoint/merges.txt',
 '/workspace/chain-of-thought-distillation/code/artifacts/models/qwen2.5_3b_sctod_lora/best_checkpoint/added_tokens.json',
 '/workspace/chain-of-thought-distillation/code/artifacts/models/qwen2.5_3b_sctod_lora/best_checkpoint/tokenizer.json')

In [None]:
def generate_answer(question: str) -> str:
    model.eval()
    with torch.no_grad():
        build_prompt = build_prompt_cot if MODE == "cot" else build_prompt_label_only
        prompt = build_prompt(question)
        enc = tokenizer(prompt, return_tensors="pt").to(model.device)

        out = model.generate(
            **enc,
            do_sample=False,
            eos_token_id=tokenizer.eos_token_id,
            pad_token_id=tokenizer.pad_token_id,
            use_cache=True,
        )

        new_tokens = out[:, enc["input_ids"].size(1) :]
        text = tokenizer.batch_decode(new_tokens, skip_special_tokens=True)[0].strip()
        return text


print(
    generate_answer(
        "A farm has 3 barns with 12 cows each. It sells 7 cows and buys 5 more. How many cows now?"
    )
)

The following generation flags are not valid and may be ignored: ['temperature']. Set `TRANSFORMERS_VERBOSITY=info` for more details.


Reasoning:
- Initially: 3×12 = 36
- After selling: 36 − 7 = 29
- After buying: 29 + 5 = 34
Final Answer: 34
