In [None]:
import re
from typing import Optional, Dict, Any
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from datasets import load_dataset

In [None]:
BASE_MODEL_ID = "Qwen/Qwen2.5-3B"
LABEL_ONLY_DIR = "<PATH_TO_LABEL_ONLY_MODEL_DIR>"  # TODO: set this
SCTOD_DIR = "<PATH_TO_SCTOD_MODEL_DIR>"  # e.g., output_dir you saved best checkpoint to

bf16 = True
MAX_NEW_TOKENS = 256
GEN_KW = dict(do_sample=False, temperature=0.0, top_p=1.0)

In [None]:
def build_prompt_cot(question: str) -> str:
    # Reuse your training prompt (system + user)
    sys_txt = TEACHER_SYSTEM_PROMPT.strip()
    usr_txt = TEACHER_USER_PROMPT.strip().format(question=question.strip())
    return f"{sys_txt}\n\n{usr_txt}\n"

def build_prompt_answer_only(question: str) -> str:
    # Minimal, answer-only prompt (for label-only SFT baseline)
    return (
        "You are a helpful math assistant. Solve the problem and provide only the final numeric answer "
        "in the format '#### <number>'.\n\n"
        f"Problem: {question.strip()}\n"
    )

In [None]:
def extract_final_number(text: str) -> Optional[str]:
    m = re.search(r"####\s*(-?\d[\d,]*\.?\d*)\b", text)
    if m:
        return m.group(1).replace(",", "")
    nums = re.findall(r"-?\d[\d,]*\.?\d*", text)
    if nums:
        return nums[-1].replace(",", "")
    return None

def normalize_number(x: str) -> str:
    x = x.replace(",", "").strip()
    try:
        if re.fullmatch(r"-?\d+", x):
            return str(int(x))
        v = float(x)
        if v.is_integer():
            return str(int(v))
        return ("%f" % v).rstrip("0").rstrip(".")
    except Exception:
        return x

In [None]:
def load_model_and_tokenizer(path_or_id: str):
    tok = AutoTokenizer.from_pretrained(path_or_id, use_fast=True, trust_remote_code=True)
    if tok.pad_token is None:
        tok.pad_token = tok.eos_token
    model = AutoModelForCausalLM.from_pretrained(
        path_or_id,
        device_map="auto",
        torch_dtype=torch.bfloat16 if bf16 else torch.float16,
        trust_remote_code=True,
    )
    model.eval()
    return model, tok

In [None]:
def generate_greedy(model, tokenizer, prompt: str, max_new_tokens: int = MAX_NEW_TOKENS) -> str:
    with torch.no_grad():
        inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
        out = model.generate(
            **inputs,
            max_new_tokens=max_new_tokens,
            eos_token_id=tokenizer.eos_token_id,
            **GEN_KW,
        )
        text = tokenizer.decode(out[0], skip_special_tokens=True)
        return text[len(prompt):].strip()

In [None]:
def evaluate_split_greedy(model, tokenizer, dataset, prompt_builder, limit: Optional[int] = None) -> Dict[str, Any]:
    n = len(dataset) if limit is None else min(limit, len(dataset))
    correct, total, missing = 0, 0, 0
    for i in range(n):
        q = dataset[i]["question"]
        gold_text = dataset[i]["answer"]
        gold_num = extract_final_number(gold_text)
        gold_num = normalize_number(gold_num) if gold_num is not None else None

        prompt = prompt_builder(q)
        gen = generate_greedy(model, tokenizer, prompt, max_new_tokens=MAX_NEW_TOKENS)
        pred_raw = extract_final_number(gen)
        pred_num = normalize_number(pred_raw) if pred_raw is not None else None

        if gold_num is None or pred_num is None:
            missing += 1
        else:
            correct += int(pred_num == gold_num)
            total += 1

        if (i + 1) % 50 == 0:
            print(f"Processed {i+1}/{n}: acc_so_far={(correct/max(1,total)):.4f}, missing={missing}")

    accuracy = correct / max(1, total)
    coverage = (n - missing) / n
    return {"n": n, "correct": correct, "total": total, "accuracy": accuracy, "missing": missing, "coverage": coverage}

In [None]:
gsm8k_val = load_dataset("openai/gsm8k", "main", split="train[-1000:]")
gsm8k_test = load_dataset("openai/gsm8k", "main", split="test")

In [None]:
print("Loading Base model...")
base_model, base_tok = load_model_and_tokenizer(BASE_MODEL_ID)

print("Loading Label-only SFT model...")
label_model, label_tok = load_model_and_tokenizer(LABEL_ONLY_DIR)

print("Loading SCoTD student model...")
sctod_model, sctod_tok = load_model_and_tokenizer(SCTOD_DIR)

In [None]:
LIMIT = None
results = []

print("\nEvaluating Base (prompt-only, CoT prompt)...")
res = evaluate_split_greedy(base_model, base_tok, gsm8k_val, build_prompt_cot, limit=LIMIT)
results.append(("Base Qwen2.5-3B", "val", res))
res = evaluate_split_greedy(base_model, base_tok, gsm8k_test, build_prompt_cot, limit=LIMIT)
results.append(("Base Qwen2.5-3B", "test", res))

print("\nEvaluating Label-only SFT (answer-only prompt)...")
res = evaluate_split_greedy(label_model, label_tok, gsm8k_val, build_prompt_answer_only, limit=LIMIT)
results.append(("Label-only SFT", "val", res))
res = evaluate_split_greedy(label_model, label_tok, gsm8k_test, build_prompt_answer_only, limit=LIMIT)
results.append(("Label-only SFT", "test", res))

print("\nEvaluating SCoTD student (CoT prompt)...")
res = evaluate_split_greedy(sctod_model, sctod_tok, gsm8k_val, build_prompt_cot, limit=LIMIT)
results.append(("SCoTD student", "val", res))
res = evaluate_split_greedy(sctod_model, sctod_tok, gsm8k_test, build_prompt_cot, limit=LIMIT)
results.append(("SCoTD student", "test", res))

In [None]:
print("\nResults (greedy decoding only)")
print(f"{'Model':<18} {'Split':<6} {'Acc':>7} {'Coverage':>10} {'n':>6} {'correct':>8} {'total':>6} {'missing':>8}")
for name, split, r in results:
    print(f"{name:<18} {split:<6} {r['accuracy']:.4f} {r['coverage']:.4f} {r['n']:6d} {r['correct']:8d} {r['total']:6d} {r['missing']:8d}")