In [1]:
import re
from statistics import mean
from typing import Optional, Tuple, Dict, Any, List

import torch
from datasets import load_dataset

In [14]:
NUMBER_PATTERN = re.compile(r"[-+]?\d*\.?\d+(?:[eE][-+]?\d+)?")

def parse_number(text: str) -> int | float | None:
    match = NUMBER_PATTERN.findall(text.replace(",", ""))
    if not match:
        raise Exception(f"No number found in text: {text}")
    val = match[-1]
    if "e" in val.lower() or "." in val:
        return float(val)
    return int(val)


def parse_gold_answer_number(answer: str) -> Optional[int | float]:
    part = answer.split("####")[-1].strip()
    return parse_number(part)


def parse_teacher_final_answer(answer: str) -> Optional[int | float]:
    for line in answer.splitlines():
        if line.lower().startswith("final answer:"):
            return parse_number(line.split(":", 1)[-1].strip())
    raise Exception(f"No 'Final Answer:' line found in teacher answer: {answer}")


In [15]:
def generate_one(question: str, max_new_tokens: int = 256, temperature: float = 0.0, top_p: float = 1.0, do_sample: bool = False, model) -> str:
    model.eval()
    with torch.no_grad():
        prompt = build_prompt(question)
        inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
        out = model.generate(
            **inputs,
            max_new_tokens=max_new_tokens,
            do_sample=do_sample,
            temperature=temperature,
            top_p=top_p,
            eos_token_id=tokenizer.eos_token_id,
        )
        text = tokenizer.decode(out[0], skip_special_tokens=True)
        return text[len(prompt):].strip()

In [18]:
def evaluate_split(dataset, greedy: bool = True, sc_k: int = 0, sc_temp: float = 0.7, sc_top_p: float = 0.95, max_new_tokens: int = 256, limit: Optional[int] = None) -> Dict[str, Any]:
    """
    Evaluate on a HF dataset split with columns: 'question', 'answer'
    (where 'answer' contains the gold GSM8K solution text that includes the final numeric answer).

    - greedy=True: greedy decoding
    - sc_k>0: self-consistency with k samples, majority vote on parsed numbers
    """
    n = len(dataset) if limit is None else min(limit, len(dataset))
    correct = 0
    total = 0
    missing = 0

    for i in range(n):
        q = dataset[i]["question"]
        gold_text = dataset[i]["answer"]
        gold_num = parse_gold_answer_number(gold_text)

        if sc_k and sc_k > 0:
            preds = []
            for _ in range(sc_k):
                gen = generate_one(q, max_new_tokens=max_new_tokens, temperature=sc_temp, top_p=sc_top_p, do_sample=True)
                num = parse_teacher_final_answer(gen)
            if preds:
                # majority vote
                from collections import Counter
                vote = Counter(preds).most_common(1)[0][0]
                pred_num = vote
            else:
                pred_num = None
        else:
            gen = generate_one(q, max_new_tokens=max_new_tokens, temperature=0.0, top_p=1.0, do_sample=False)
            pred_num = parse_teacher_final_answer(gen)

        if gold_num is None or pred_num is None:
            missing += 1
        else:
            correct += int(pred_num == gold_num)
            total += 1

        if (i + 1) % 50 == 0:
            print(f"Processed {i+1}/{n}: acc_so_far={(correct/max(1,total)):.4f}, missing={missing}")

    accuracy = correct / max(1, total)
    coverage = (n - missing) / n
    return {"n": n, "correct": correct, "total": total, "accuracy": accuracy, "missing": missing, "coverage": coverage}


gsm8k_test = load_dataset("openai/gsm8k", "main", split="test")

In [19]:
print("Running greedy decoding on full test set…")
res_greedy = evaluate_split(gsm8k_test, greedy=True, sc_k=0, max_new_tokens=256, limit=None)
print("Greedy:", res_greedy)

SC_K = 20
print(f"Running self-consistency with k={SC_K}…")
res_sc = evaluate_split(gsm8k_test, greedy=False, sc_k=SC_K, sc_temp=0.7, sc_top_p=0.95, max_new_tokens=256, limit=None)
print("Self-Consistency:", res_sc)

# Compact summary
def pct(x):
    return f"{100.0*x:.2f}%"

summary = {
    "greedy_accuracy": pct(res_greedy["accuracy"]),
    "greedy_coverage": pct(res_greedy["coverage"]),
    "sc_k": SC_K,
    "sc_accuracy": pct(res_sc["accuracy"]),
    "sc_coverage": pct(res_sc["coverage"]),
}
print("Summary:", summary)

Running greedy decoding on full test set…


NameError: name 'model' is not defined