In [1]:
import sys
from pathlib import Path
from typing import Any, Dict, Optional, Tuple
import os
import torch
from datasets import load_dataset
from peft import PeftModel
from tqdm import tqdm
from transformers import (
    AutoModelForCausalLM,
    AutoTokenizer,
    BitsAndBytesConfig,
)

sys.path.append(str(Path.cwd().resolve().parent))

In [2]:
from src.config import (
    GSM8K_PATH,
    TEACHER_SYSTEM_PROMPT,
    TEACHER_USER_PROMPT,
)
from src.dataset_generator.helpers.answers import (
    ParsingError,
    parse_gold_answer_number,
    parse_teacher_final_answer,
)

In [3]:
def build_prompt_cot(question: str) -> str:
    sys_txt = TEACHER_SYSTEM_PROMPT.strip()
    usr_txt = TEACHER_USER_PROMPT.strip().format(question=question.strip())
    return f"{sys_txt}\n\n{usr_txt}\n"


def build_prompt_label_only(question: str) -> str:
    shots = [
        (
            "A farm has 3 barns with 12 cows each. It sells 7 cows and buys 5 more. How many cows now?",
            "34",
        ),
        (
            "Pens cost $2 and notebooks $5. Alex buys 3 pens and 2 notebooks and pays with $20. How much change?",
            "4",
        ),
        (
            "A tank holds 250 liters. 35% is drained, then 40 liters are added. How many liters now?",
            "202.5",
        ),
    ]
    header = (
        "You are a concise math solver. Output only the final line as:\n"
        "Final Answer: <number>\n\n"
    )
    exemplars = [f"Question: {q}\nFinal Answer: {a}" for q, a in shots]
    exemplars_txt = "\n\n".join(exemplars)
    return f"{header}{exemplars_txt}\n\nQuestion: {question.strip()}\n"

In [4]:
def load_model_and_tokenizer(
    model_id: str,
    peft_or_merged_path: Optional[str] = None,
    use_4bit: bool = True,
    bf16: bool = True,
    device_map: str = "auto",
) -> Tuple[Any, Any]:
    """
    Loads either:
      - base model only (when peft_or_merged_path=None)
      - base+adapter (when peft_or_merged_path points to a PEFT dir with adapter_config.json)
      - merged model (when peft_or_merged_path points to a standard HF model dir)
    Returns (model, tokenizer)
    """
    quant_cfg = None
    if use_4bit:
        quant_cfg = BitsAndBytesConfig(
            load_in_4bit=True,
            bnb_4bit_compute_dtype=torch.bfloat16 if bf16 else torch.float16,
            bnb_4bit_use_double_quant=True,
            bnb_4bit_quant_type="nf4",
        )

    load_path = peft_or_merged_path

    if load_path is None:
        # Base model only
        model = AutoModelForCausalLM.from_pretrained(
            model_id,
            quantization_config=quant_cfg,
            device_map=device_map,
            trust_remote_code=True,
        )
        tok_src = model_id
    else:
        # Base + adapter
        model = AutoModelForCausalLM.from_pretrained(
            load_path,
            quantization_config=quant_cfg,
            device_map=device_map,
            trust_remote_code=True,
        )
        tok_src = load_path

    tokenizer = AutoTokenizer.from_pretrained(
        tok_src, use_fast=True, trust_remote_code=True
    )
    tokenizer.padding_side = "left"
    tokenizer.truncation_side = "left"
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token
        model.generation_config.pad_token_id = tokenizer.pad_token_id
    return model, tokenizer

In [8]:
def batch_generate(
    model,
    tokenizer,
    questions: list[str],
    mode: str,
    max_new_tokens: int = 256,
    batch_size: int = 16,
    progress_desc: str = "",
) -> list[str]:
    build_prompt = build_prompt_cot if mode == "cot" else build_prompt_label_only
    outputs: List[str] = []
    model.eval()
    torch.backends.cuda.matmul.allow_tf32 = True
    torch.backends.cudnn.allow_tf32 = True
    with torch.inference_mode():
        for i in tqdm(range(0, len(questions), batch_size), desc=progress_desc):
            chunk = questions[i:i+batch_size]
            prompts = [build_prompt(q) for q in chunk]
            enc = tokenizer(
                prompts,
                return_tensors="pt",
                padding=True,
                truncation=True,
                max_length=2048,  # safe cap; adjust if needed
            )
            # Send to model device
            enc = {k: v.to(model.device) for k, v in enc.items()}
            gen = model.generate(
                **enc,
                do_sample=False,
                max_new_tokens=max_new_tokens,
                pad_token_id=tokenizer.pad_token_id,
                eos_token_id=tokenizer.eos_token_id,
                use_cache=True,
            )
            texts = tokenizer.batch_decode(gen, skip_special_tokens=True)
            # Slice off the prompt
            for prompt, full in zip(prompts, texts):
                outputs.append(full[len(prompt):].strip())
    return outputs

def evaluate_gsm8k_greedy_batched(
    model,
    tokenizer,
    mode: str,
    split: str = "test",
    limit: Optional[int] = None,
    batch_size: int = 16,
    max_new_tokens: int = 256,
) -> Dict[str, Any]:
    """
    mode: "cot" or "label-only"
    """
    ds = load_dataset(GSM8K_PATH, "main", split=split)

    if limit is not None:
        ds = ds.select(range(min(limit, len(ds))))

    questions = ds["question"]
    gold_texts = ds["answer"]
    gold_nums = [parse_gold_answer_number(t) for t in gold_texts]

    gens = batch_generate(
        model,
        tokenizer,
        questions,
        mode=mode,
        max_new_tokens=max_new_tokens,
        batch_size=batch_size,
        progress_desc=f"Evaluating ({mode}, greedy, bs={batch_size})",
    )

    n_total = len(gens)
    n_correct = 0
    for pred_text, gold_num in zip(gens, gold_nums):
        try:
            pred_num = parse_teacher_final_answer(pred_text)
        except ParsingError:
            pred_num = None
        except Exception:
            pred_num = None

        if pred_num is not None and gold_num is not None and pred_num == gold_num:
            n_correct += 1

    return {"accuracy": n_correct / n_total if n_total else 0.0, "n": n_total}

In [9]:
os.environ["TOKENIZERS_PARALLELISM"] = "false"
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True

In [None]:
MODEL_ID = "Qwen/Qwen2.5-3B"
# Point these to your checkpoints. If directory contains adapter_config.json -> treated as adapter (PEFT).
SCTOD_PATH = "../artifacts/models/qwen2.5_3b_sctod_lora/best_checkpoint"
LABELONLY_PATH = "../artifacts/models/qwen2.5_3b_labelonly_lora/best_checkpoint"

RUNS = [
    {"name": "student_sctod", "mode": "cot", "path": SCTOD_PATH},
    {"name": "student_label_only", "mode": "label-only", "path": LABELONLY_PATH},
    {"name": "base_cot_prompting", "mode": "cot", "path": None},
    {"name": "base_label_only", "mode": "label-only", "path": None},
]

limit = None   # e.g., 100 for a quick smoke test
batch_size = 16
max_new_tokens = 256

results = []
for run in RUNS:
    name = run["name"]
    mode = run["mode"]
    path = run["path"]
    print(f"\n=== Loading {name} ({'adapter/merged' if path else 'base'}) ===")
    model, tokenizer = load_model_and_tokenizer(
        model_id=MODEL_ID,
        peft_or_merged_path=path,
        use_4bit=True,
        bf16=True,
        device_map="auto",
    )
    # Ensure eval-time cache is on (may have been disabled in training config)
    if hasattr(model, "config"):
        model.config.use_cache = True

    metrics = evaluate_gsm8k_greedy_batched(
        model,
        tokenizer,
        mode=mode,
        split="test",
        limit=limit,
        batch_size=batch_size,
        max_new_tokens=max_new_tokens,
    )
    print(f"{name} -> accuracy: {metrics['accuracy']:.4f} (n={metrics['n']})")
    results.append((name, metrics))

print("\n=== Summary (greedy only) ===")
for name, m in results:
    print(f"{name:>24}: {m['accuracy']:.4f} (n={m['n']})")


=== Loading student_sctod (adapter/merged) ===


Evaluating (cot, greedy, bs=16):   5%|▍         | 4/83 [00:58<17:50, 13.54s/it]