In [1]:
import sys
from pathlib import Path
from typing import Any, Dict, Optional, Tuple

import torch
from datasets import load_dataset
from peft import PeftModel
from tqdm import tqdm
from transformers import (
    AutoModelForCausalLM,
    AutoTokenizer,
    BitsAndBytesConfig,
)

sys.path.append(str(Path.cwd().resolve().parent))

In [2]:
from src.config import (
    GSM8K_PATH,
    TEACHER_SYSTEM_PROMPT,
    TEACHER_USER_PROMPT,
)
from src.dataset_generator.helpers.answers import (
    ParsingError,
    parse_gold_answer_number,
    parse_teacher_final_answer,
)

In [3]:
def build_prompt_cot(question: str) -> str:
    sys_txt = TEACHER_SYSTEM_PROMPT.strip()
    usr_txt = TEACHER_USER_PROMPT.strip().format(question=question.strip())
    return f"{sys_txt}\n\n{usr_txt}\n"


def build_prompt_label_only(question: str) -> str:
    shots = [
        (
            "A farm has 3 barns with 12 cows each. It sells 7 cows and buys 5 more. How many cows now?",
            "34",
        ),
        (
            "Pens cost $2 and notebooks $5. Alex buys 3 pens and 2 notebooks and pays with $20. How much change?",
            "4",
        ),
        (
            "A tank holds 250 liters. 35% is drained, then 40 liters are added. How many liters now?",
            "202.5",
        ),
    ]
    header = (
        "You are a concise math solver. Output only the final line as:\n"
        "Final Answer: <number>\n\n"
    )
    exemplars = [f"Question: {q}\nFinal Answer: {a}" for q, a in shots]
    exemplars_txt = "\n\n".join(exemplars)
    return f"{header}{exemplars_txt}\n\nQuestion: {question.strip()}\n"

In [8]:
def load_model_and_tokenizer(
    model_id: str,
    peft_or_merged_path: Optional[str] = None,
    use_4bit: bool = True,
    bf16: bool = True,
    device_map: str = "auto",
) -> Tuple[Any, Any]:
    """
    Loads either:
      - base model only (when peft_or_merged_path=None)
      - base+adapter (when peft_or_merged_path points to a PEFT dir with adapter_config.json)
      - merged model (when peft_or_merged_path points to a standard HF model dir)
    Returns (model, tokenizer)
    """
    quant_cfg = None
    if use_4bit:
        quant_cfg = BitsAndBytesConfig(
            load_in_4bit=True,
            bnb_4bit_compute_dtype=torch.bfloat16 if bf16 else torch.float16,
            bnb_4bit_use_double_quant=True,
            bnb_4bit_quant_type="nf4",
        )

    load_path = peft_or_merged_path

    if load_path is None:
        # Base model only
        model = AutoModelForCausalLM.from_pretrained(
            model_id,
            quantization_config=quant_cfg,
            device_map=device_map,
            trust_remote_code=True,
        )
        tok_src = model_id
    else:
        # Base + adapter
        model = AutoModelForCausalLM.from_pretrained(
            load_path,
            quantization_config=quant_cfg,
            device_map=device_map,
            trust_remote_code=True,
        )
        tok_src = load_path

    tokenizer = AutoTokenizer.from_pretrained(
        tok_src, use_fast=True, trust_remote_code=True
    )
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token
    return model, tokenizer


def generate_answer(
    model, tokenizer, mode: str, question: str, max_new_tokens: int = 256
) -> str:
    model.eval()
    with torch.no_grad():
        build_prompt = build_prompt_cot if mode == "cot" else build_prompt_label_only
        prompt = build_prompt(question)
        inputs = tokenizer(prompt, return_tensors="pt")
        inputs = {k: v.to(model.device) for k, v in inputs.items()}
        out = model.generate(
            **inputs,
            do_sample=False,
            temperature=0.0,
            max_new_tokens=max_new_tokens,
            eos_token_id=tokenizer.eos_token_id,
        )
        text = tokenizer.decode(out[0], skip_special_tokens=True)
        return text[len(prompt) :].strip()


def evaluate_gsm8k_greedy(
    model,
    tokenizer,
    mode: str,
    split: str = "test",
    limit: Optional[int] = None,
) -> Dict[str, Any]:
    """
    mode: "cot" or "label-only"
    Returns metrics dict with 'accuracy' and 'n'
    """
    ds = load_dataset(GSM8K_PATH, "main", split=split)

    n_total = 0
    n_correct = 0
    it = ds if limit is None else ds.select(range(min(limit, len(ds))))

    for ex in tqdm(it, desc=f"Evaluating ({mode}, greedy)"):
        q = ex["question"]
        gold_text = ex["answer"]
        gold_num = parse_gold_answer_number(gold_text)

        try:
            gen = generate_answer(model, tokenizer, mode, q)
            pred_num = parse_teacher_final_answer(gen)
        except ParsingError:
            pred_num = None
        except Exception:
            pred_num = None

        if pred_num is not None and gold_num is not None and pred_num == gold_num:
            n_correct += 1
        n_total += 1

    return {"accuracy": n_correct / n_total if n_total else 0.0, "n": n_total}

In [9]:
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
device = "cuda" if torch.cuda.is_available() else "cpu"
bf16 = True

In [None]:
MODEL_ID = "Qwen/Qwen2.5-3B"

SCTOD_PATH = "../artifacts/models/qwen2.5_3b_sctod_lora/best_checkpoint"
LABELONLY_PATH = "../artifacts/models/qwen2.5_3b_labelonly_lora/best_checkpoint"

RUNS = [
    {"name": "student_sctod", "mode": "cot", "path": SCTOD_PATH},
    {"name": "student_label_only", "mode": "label-only", "path": LABELONLY_PATH},
    {"name": "base_cot_prompting", "mode": "cot", "path": None},
    {"name": "base_label_only", "mode": "label-only", "path": None},
]


def main(limit: Optional[int] = None):
    results = []
    for run in RUNS:
        name = run["name"]
        mode = run["mode"]
        path = run["path"]

        print(f"\n=== Loading {name} ({'adapter/merged' if path else 'base'}) ===")
        model, tokenizer = load_model_and_tokenizer(
            model_id=MODEL_ID,
            peft_or_merged_path=path,
            use_4bit=True,
            bf16=True,
            device_map="auto",
        )
        metrics = evaluate_gsm8k_greedy(
            model, tokenizer, mode=mode, split="test", limit=limit
        )
        print(f"{name} -> accuracy: {metrics['accuracy']:.4f} (n={metrics['n']})")
        results.append((name, metrics))

    print("\n=== Summary (greedy only) ===")
    for name, m in results:
        print(f"{name:>24}: {m['accuracy']:.4f} (n={m['n']})")


if __name__ == "__main__":
    main(limit=100)


=== Loading student_sctod (adapter/merged) ===




README.md: 0.00B [00:00, ?B/s]

main/train-00000-of-00001.parquet:   0%|          | 0.00/2.31M [00:00<?, ?B/s]

main/test-00000-of-00001.parquet:   0%|          | 0.00/419k [00:00<?, ?B/s]

Generating train split:   0%|          | 0/7473 [00:00<?, ? examples/s]

Generating test split:   0%|          | 0/1319 [00:00<?, ? examples/s]

Evaluating (cot, greedy):   0%|          | 0/100 [00:00<?, ?it/s]The following generation flags are not valid and may be ignored: ['temperature']. Set `TRANSFORMERS_VERBOSITY=info` for more details.
Setting `pad_token_id` to `eos_token_id`:151643 for open-end generation.
Evaluating (cot, greedy):   1%|          | 1/100 [00:05<09:46,  5.93s/it]Setting `pad_token_id` to `eos_token_id`:151643 for open-end generation.
Evaluating (cot, greedy):   2%|▏         | 2/100 [00:08<06:35,  4.03s/it]Setting `pad_token_id` to `eos_token_id`:151643 for open-end generation.
Evaluating (cot, greedy):   3%|▎         | 3/100 [00:17<10:26,  6.46s/it]Setting `pad_token_id` to `eos_token_id`:151643 for open-end generation.
Evaluating (cot, greedy):   4%|▍         | 4/100 [00:21<08:47,  5.50s/it]Setting `pad_token_id` to `eos_token_id`:151643 for open-end generation.
Evaluating (cot, greedy):   5%|▌         | 5/100 [00:26<08:07,  5.13s/it]Setting `pad_token_id` to `eos_token_id`:151643 for open-end generation