In [None]:
# %% [markdown]
# # üßÆ Math 1.5B (A100-40GB) ‚Äî SFT ‚Üí RLVR ‚Üí PRM ‚Üí Repair ‚Üí DPO (Long‚ÜíShort) ‚Üí Length-Aware RL ‚Üí Eval
# Auto-detects existing teacher data at /content/data/teacher_verified.jsonl and reuses it.
# If not found, synthesizes Program-of-Thought traces via OpenAI teacher, verifies locally, then proceeds.

# %%capture
!pip install -U "transformers>=4.43" "accelerate>=0.30" "trl>=0.9.6" peft datasets bitsandbytes \
  "flash-attn>=2.5.8" sympy "lighteval>=0.4.0" --no-build-isolation
!pip install -U openai

import torch, platform, sys, os, json, random, re, time, textwrap, traceback, glob
from pathlib import Path
import sympy as sp
from datasets import load_dataset, Dataset, DatasetDict

print("Torch:", torch.__version__, "| CUDA:", torch.cuda.is_available(), "| Dev:", torch.cuda.get_device_name(0) if torch.cuda.is_available() else "CPU")
print("Python:", sys.version, "| OS:", platform.platform())

# Paths
OUT  = Path("/content/output"); OUT.mkdir(parents=True, exist_ok=True)
DATA = Path("/content/data"); DATA.mkdir(parents=True, exist_ok=True)

# ==== Config (edit as needed) ====
MODEL_ID  = "Qwen/Qwen2.5-Math-1.5B"   # or "Qwen/Qwen2.5-Math-1.5B-Instruct"
USE_INSTRUCT_CHAT_TEMPLATE = False     # True if using the Instruct chat template

MAX_SEQ_LEN = 4096
SFT_EPOCHS  = 1
SFT_LR      = 2e-5

GRPO_STEPS  = 2000
GRPO_LR     = 1e-6
GRPO_GROUP  = 4
GRPO_KL     = 0.02

SYNTH_SAMPLES = 2000                   # how many PoT examples to synthesize if needed
TEACHER_MODEL = "gpt-5"                # replace with your org's model id
TEMPERATURE   = 0.2

EVAL_N = 200                           # quick eval size (increase for full test)
USE_MATH = False                       # optional (clones official MATH repo)

# OpenAI key prompt if not present
if "OPENAI_API_KEY" not in os.environ or not os.environ["OPENAI_API_KEY"]:
    import getpass
    os.environ["OPENAI_API_KEY"] = getpass.getpass("Enter your OPENAI_API_KEY (hidden): ")
print("API key present:", bool(os.environ.get("OPENAI_API_KEY")))


In [None]:
# %% [markdown]
# ## üìö Load eval/train datasets (GSM8K; optional MATH via GitHub)
gsm = load_dataset("openai/gsm8k", "main")
gsm_train = gsm["train"]
gsm_test  = gsm["test"]
print("GSM8K train/test:", len(gsm_train), len(gsm_test))

math_items = []
if USE_MATH:
    !rm -rf /content/math && git clone --depth 1 https://github.com/hendrycks/math /content/math
    for fp in glob.glob("/content/math/train/**/*.json", recursive=True):
        try:
            ex = json.load(open(fp))
            prob, sol = ex.get("problem",""), ex.get("solution","")
            m = re.search(r"\\boxed\{([^}]*)\}", sol) or re.search(r"Answer:\s*([^\n]+)", sol)
            ans = m.group(1).strip() if m else None
            if prob and ans:
                math_items.append({"question": prob, "answer": ans})
        except Exception:
            pass
    print("Parsed MATH problems:", len(math_items))


In [None]:
# %% [markdown]
# ## üßë‚Äçüè´ Teacher synthesis (PoT) ‚òÖ Auto-skip if /content/data/teacher_verified.jsonl exists
from openai import OpenAI
client = OpenAI()

PROMPT = """You are a math tutor. Solve the user's problem by:
1) writing a short Python function `solve()` that computes the exact answer
2) include 1-3 python assertions that verify the result (sympy ok: import sympy as sp)
3) FINALLY print a single line 'Answer: <final>' where <final> is a plain number or simplified expression.

Rules:
- Keep code minimal/deterministic; allowed: math, fractions, decimal, itertools, sympy as sp
- The printed 'Answer: <final>' must be exactly the final value (string)
Return ONLY a JSON object with keys: cot_program (string), tests (list of strings), final_answer (string).
"""

def ask_teacher(question, model=TEACHER_MODEL, temperature=TEMPERATURE, use_responses_api=True):
    if use_responses_api:
        resp = client.responses.create(
            model=model,
            instructions=PROMPT,
            input=[{"role":"user","content":question}],
            temperature=temperature,
        )
        txt = resp.output_text
    else:
        resp = client.chat.completions.create(
            model=model, temperature=temperature,
            messages=[{"role":"system","content":PROMPT},
                      {"role":"user","content":question}]
        )
        txt = resp.choices[0].message.content
    m = re.search(r"\{.*\}", txt, flags=re.S)
    if not m:
        raise ValueError("No JSON found in teacher response")
    return json.loads(m.group(0))

# Sandbox for PoT verification
ALLOWED_BUILTINS = {"abs":abs, "min":min, "max":max, "range":range, "len":len, "sum":sum, "print":print}
SAFE_GLOBALS = {"__builtins__": ALLOWED_BUILTINS, "math": __import__("math"),
                "fractions": __import__("fractions"), "decimal": __import__("decimal"),
                "itertools": __import__("itertools"), "sp": sp}

def verify_record(rec):
    code = rec["cot_program"]; tests = rec.get("tests", [])
    final = rec["final_answer"].strip()
    loc = {}
    try:
        exec(code, SAFE_GLOBALS, loc)
    except Exception as e:
        return False, f"exec error: {e}"
    for t in tests:
        try:
            exec(t, {**SAFE_GLOBALS, **loc}, {})
        except Exception as e:
            return False, f"test fail: {e}"
    if "solve" in loc and callable(loc["solve"]):
        try:
            got = loc["solve"]()
            if str(got).strip() != final:
                try:
                    if not sp.simplify(sp.nsimplify(got) - sp.nsimplify(final)) == 0:
                        return False, f"mismatch: solve()={got} vs final={final}"
                except Exception:
                    return False, f"mismatch: solve()={got} vs final={final}"
        except Exception as e:
            return False, f"solve() error: {e}"
    return True, "ok"

verified_path = DATA / "teacher_verified.jsonl"
raw_path      = DATA / "teacher_raw.jsonl"

verified = []
if verified_path.exists() and verified_path.stat().st_size > 0:
    print("üü¢ Found existing verified teacher data:", verified_path)
    verified = [json.loads(l) for l in open(verified_path)]
    print("Loaded verified records:", len(verified))
else:
    print("üü° No verified teacher data found. Synthesizing now‚Ä¶")
    pool = [{"id": f"gsm8k_{i}", "question": ex["question"], "gold": ex["answer"]} for i, ex in enumerate(gsm_train)]
    pool += [{"id": f"math_{i}", "question": ex["question"], "gold": ex["answer"]} for i, ex in enumerate(math_items)]
    random.seed(1337); random.shuffle(pool)
    sel = pool[:SYNTH_SAMPLES]

    raw_out = raw_path.open("w")
    ok_out  = verified_path.open("w")
    ok_count = 0
    for k, ex in enumerate(sel, 1):
        try:
            js = ask_teacher(ex["question"], use_responses_api=True)
            rec = {
                "id": ex["id"], "question": ex["question"],
                "cot_program": js["cot_program"], "tests": js.get("tests", []),
                "final_answer": str(js["final_answer"]).strip(), "tool_mode": "python"
            }
            print(json.dumps(rec), file=raw_out, flush=True)
            ok, msg = verify_record(rec)
            if ok:
                ok_count += 1
                verified.append(rec)
                print(json.dumps(rec), file=ok_out, flush=True)
            if k % 25 == 0:
                print(f"[{k}/{len(sel)}] verified_ok={ok_count}")
        except Exception:
            traceback.print_exc()
    raw_out.close(); ok_out.close()
    print("Synthesized & verified:", ok_count, "of", len(sel))


In [None]:
# %% [markdown]
# ## üß± Build SFT dataset from verified teacher data
from transformers import AutoTokenizer
tok = AutoTokenizer.from_pretrained(MODEL_ID, use_fast=True)
if tok.pad_token is None: tok.pad_token = tok.eos_token

def to_chat_text(q, prog, ans):
    system = "You are a concise math solver. First write minimal Python to compute the answer, then output 'Answer: <value>'."
    if USE_INSTRUCT_CHAT_TEMPLATE and hasattr(tok, "apply_chat_template"):
        messages = [
            {"role":"system","content":system},
            {"role":"user","content":q},
            {"role":"assistant","content":f"# python\n{prog}\n\nAnswer: {ans}"}
        ]
        return tok.apply_chat_template(messages, tokenize=False)
    else:
        return f"<|system|>\n{system}\n<|user|>\n{q}\n<|assistant|>\n# python\n{prog}\n\nAnswer: {ans}"

sft_path = DATA / "sft_train.jsonl"
with sft_path.open("w") as f:
    for r in verified:
        print(json.dumps({"text": to_chat_text(r["question"], r["cot_program"], r["final_answer"])}), file=f)
print("SFT records:", sum(1 for _ in open(sft_path)))


In [None]:
# %% [markdown]
# ## üìò SFT (PoT) with TRL SFTTrainer
from trl import SFTTrainer, SFTConfig
from transformers import AutoModelForCausalLM

train_ds = load_dataset("json", data_files=str(sft_path))["train"]

model = AutoModelForCausalLM.from_pretrained(
    MODEL_ID, torch_dtype="auto",
    attn_implementation="flash_attention_2",
    device_map="auto"
)

sft_cfg = SFTConfig(
    output_dir=str(OUT / "sft-poT"),
    per_device_train_batch_size=1,
    gradient_accumulation_steps=16,
    learning_rate=SFT_LR,
    num_train_epochs=SFT_EPOCHS,
    bf16=True, logging_steps=10, save_steps=200,
    max_seq_length=MAX_SEQ_LEN,
    gradient_checkpointing=True
)

sft_trainer = SFTTrainer(
    model=model, tokenizer=tok, train_dataset=train_ds,
    args=sft_cfg, dataset_text_field="text", packing=False
)
sft_trainer.train()
sft_trainer.save_model(str(OUT / "sft-poT" / "final"))
print("‚úÖ SFT done ->", OUT / "sft-poT" / "final")


In [None]:
# %% [markdown]
# ## üß™ RL dataset (prompts + ground truth)
def extract_numeric(a):
    m = re.search(r"[-+]?[0-9]*\.?[0-9]+(?:/[0-9]+)?", a)
    return m.group(0) if m else a.strip()

rl_list = [{"prompt": ex["question"], "ground_truth": extract_numeric(ex["answer"])} for ex in gsm_train]
if len(math_items) > 0:
    for ex in math_items[:2000]:
        rl_list.append({"prompt": ex["question"], "ground_truth": str(ex["answer"])})

random.shuffle(rl_list)
rl_ds = Dataset.from_list(rl_list)
print("RL dataset size:", len(rl_ds))


In [None]:
# %% [markdown]
# ## üèÅ GRPO (verifiable reward): correctness + mild brevity shaping
from trl import GRPOTrainer, GRPOConfig
from transformers import AutoModelForCausalLM

def parse_final(text:str):
    m = re.search(r"(?i)Answer:\s*([^\n]+)", text)
    if not m: m = re.search(r"\\boxed\{([^}]+)\}", text)
    return m.group(1).strip() if m else text.strip().splitlines()[-1]

def eq_correct(got, want):
    try:
        return sp.simplify(sp.nsimplify(got) - sp.nsimplify(want)) == 0
    except Exception:
        return str(got).strip() == str(want).strip()

# In the math_reward_func, add PRM scoring:
def math_reward_func(prompts, completions, ground_truth, **kwargs):
    rewards = []
    for prompt, comp, gt in zip(prompts, completions, ground_truth):
        content = comp if isinstance(comp, str) else comp[0]["content"]
        try:
            pred = parse_final(content)
            correct = eq_correct(pred, gt)
            
            # ADD THESE 2 LINES: Include PRM score in reward
            lines = [ln.strip() for ln in content.splitlines() if ln.strip()]
            prm_bonus = sum(get_prm_score(prompt, ln, i+1) for i, ln in enumerate(lines[:5])) / max(1, len(lines[:5])) * 0.1
            
            r = (1.0 if correct else 0.0) - 0.0002 * len(content) + prm_bonus  # MODIFY THIS LINE
            rewards.append(float(r))
        except Exception:
            rewards.append(-0.1)
    return rewards

policy = AutoModelForCausalLM.from_pretrained(
    str(OUT / "sft-poT" / "final"),
    torch_dtype="auto",
    attn_implementation="flash_attention_2",
    device_map="auto"
)

grpo_cfg = GRPOConfig(
    output_dir=str(OUT / "grpo"),
    learning_rate=GRPO_LR,
    per_device_train_batch_size=1,
    gradient_accumulation_steps=8,
    bf16=True, logging_steps=10, save_steps=100,
    max_prompt_length=1024, max_completion_length=512,
    num_generations=GRPO_GROUP, kl_coeff=GRPO_KL
)

grpo_trainer = GRPOTrainer(
    model=policy,
    reward_funcs=math_reward_func,
    train_dataset=rl_ds.select(range(min(len(rl_ds), 6000))),
    processing_class=tok, args=grpo_cfg
)
grpo_trainer.train(max_steps=GRPO_STEPS)
grpo_trainer.save_model(str(OUT / "grpo" / "final"))
print("‚úÖ GRPO done ->", OUT / "grpo" / "final")


In [None]:
# %% [markdown]
# ## üß≠ PRM (Process Reward Model) ‚Äî simple heuristic step labels from verified traces
from transformers import AutoModelForSequenceClassification, TrainingArguments, Trainer

def build_prm_dataset(verified_records):
    samples = []
    for r in verified_records:
        code = r["cot_program"]
        lines = [ln for ln in code.splitlines() if ln.strip()]
        for i, ln in enumerate(lines):
            # Heuristic: verified solution -> mark earlier lines as correct (1)
            label = 1
            txt = f"{r['question']}\n\n# step {i+1}\n{ln}"
            samples.append({"text": txt, "label": label})
    return Dataset.from_list(samples)

# Reload verified from disk (robust to fresh runtime)
verified = [json.loads(l) for l in open(DATA/"teacher_verified.jsonl")]
prm_train_ds = build_prm_dataset(verified)
print("PRM samples:", len(prm_train_ds))

from transformers import AutoTokenizer
tok2 = AutoTokenizer.from_pretrained(MODEL_ID, use_fast=True)

def tokenize_prm(ex):
    return tok2(ex["text"], truncation=True, max_length=1024)
prm_ds = prm_train_ds.map(tokenize_prm)

prm_model = AutoModelForSequenceClassification.from_pretrained(
    MODEL_ID, num_labels=2, torch_dtype="auto", device_map="auto"
)

prm_args = TrainingArguments(
    output_dir=str(OUT / "prm"),
    per_device_train_batch_size=4,
    learning_rate=1e-5,
    num_train_epochs=1,
    bf16=True, logging_steps=20, save_steps=200
)
prm_trainer = Trainer(
    model=prm_model, args=prm_args,
    train_dataset=prm_ds, tokenizer=tok2
)
prm_trainer.train()
prm_trainer.save_model(str(OUT / "prm" / "final"))
print("‚úÖ PRM done ->", OUT / "prm" / "final")


In [None]:
# %% [markdown]
# ## üîÅ StepCo-style Verify-Then-Revise helper
from transformers import AutoModelForCausalLM, AutoTokenizer

final_model_path = str(OUT / "grpo" / "final")
eval_model = AutoModelForCausalLM.from_pretrained(
    final_model_path, torch_dtype="auto",
    attn_implementation="flash_attention_2", device_map="auto"
)
eval_tok = AutoTokenizer.from_pretrained(MODEL_ID, use_fast=True)
if eval_tok.pad_token is None: eval_tok.pad_token = eval_tok.eos_token

def generate_reply(model, tok, prompt, temperature=0.0, max_new_tokens=512):
    if USE_INSTRUCT_CHAT_TEMPLATE and hasattr(tok, "apply_chat_template"):
        messages = [
            {"role":"system","content":"You are a concise math solver. Write minimal Python, then 'Answer: <value>'."},
            {"role":"user","content":prompt}
        ]
        full = tok.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
    else:
        full = f"<|system|>\nYou are a concise math solver. Write minimal Python, then 'Answer: <value>'.\n<|user|>\n{prompt}\n<|assistant|>\n"
    inputs = tok([full], return_tensors="pt").to(model.device)
    out = model.generate(**inputs, do_sample=temperature>0, temperature=temperature, max_new_tokens=max_new_tokens)[0]
    text = tok.decode(out[inputs.input_ids.shape[1]:], skip_special_tokens=True)
    return text

def solve_with_pot(question, model=eval_model, tok=eval_tok):
    text = generate_reply(model, tok, question, temperature=0.0)
    rec = {"cot_program": text, "tests": [], "final_answer": parse_final(text)}
    ok, _ = verify_record(rec)
    return text, rec["final_answer"], ok

def stepco_repair(question, max_rounds=2):
    text, pred, ok = solve_with_pot(question)
    if ok: return text, pred, True
    for _ in range(max_rounds):
        repair_prompt = f"The following solution seems incorrect. Fix only the wrong steps and keep it concise.\n\nQuestion:\n{question}\n\nSolution:\n{text}"
        text = generate_reply(eval_model, eval_tok, repair_prompt, temperature=0.0)
        rec = {"cot_program": text, "tests": [], "final_answer": parse_final(text)}
        ok, _ = verify_record(rec)
        if ok:
            return text, rec["final_answer"], True
    return text, parse_final(text), False

print("Repair helper ready.")


In [None]:
# %% [markdown]
# ## üîÑ PRM-Guided Self-Evolution (A6): Best-of-N + shallow MCTS for better training data

from transformers import AutoModelForSequenceClassification, AutoTokenizer
import torch.nn.functional as F

# Load trained PRM
prm_model = AutoModelForSequenceClassification.from_pretrained(
    str(OUT / "prm" / "final"), torch_dtype="auto", device_map="auto"
)
prm_tok = AutoTokenizer.from_pretrained(MODEL_ID, use_fast=True)

def get_prm_score(question, code_line, step_num):
    """Get PRM confidence score for a code step"""
    text = f"{question}\n\n# step {step_num}\n{code_line}"
    inputs = prm_tok(text, return_tensors="pt", truncation=True, max_length=1024).to(prm_model.device)
    with torch.no_grad():
        logits = prm_model(**inputs).logits
        probs = F.softmax(logits, dim=-1)
        return float(probs[0][1])  # probability of "correct" label

def best_of_n_with_prm(question, n=4, model=eval_model, tok=eval_tok):
    """Generate N solutions and pick best by PRM scores"""
    candidates = []
    for _ in range(n):
        text = generate_reply(model, tok, question, temperature=0.3, max_new_tokens=512)
        
        # Calculate average PRM score across code steps
        lines = [ln.strip() for ln in text.splitlines() if ln.strip() and not ln.strip().startswith('#')]
        if not lines:
            candidates.append((text, 0.0))
            continue
            
        scores = []
        for i, line in enumerate(lines[:10]):  # limit to first 10 steps
            try:
                score = get_prm_score(question, line, i+1)
                scores.append(score)
            except:
                scores.append(0.5)  # neutral score on error
        
        avg_score = sum(scores) / len(scores) if scores else 0.0
        candidates.append((text, avg_score))
    
    # Return best candidate
    best_text, best_score = max(candidates, key=lambda x: x[1])
    return best_text, best_score

# Generate enhanced training data using PRM guidance
print("üîÑ Generating PRM-guided enhanced training data...")
enhanced_records = []

# Use subset of original GSM8K for self-evolution
evolution_samples = gsm_train.select(range(min(800, len(gsm_train))))

for i, ex in enumerate(evolution_samples):
    try:
        # Generate best-of-N solution guided by PRM
        enhanced_text, prm_score = best_of_n_with_prm(ex["question"])
        
        # Verify the enhanced solution
        rec = {
            "cot_program": enhanced_text, 
            "tests": [], 
            "final_answer": parse_final(enhanced_text)
        }
        is_correct, _ = verify_record(rec)
        
        # Only keep if both PRM likes it AND it's actually correct
        if is_correct and prm_score > 0.6:  # threshold for quality
            enhanced_records.append({
                "question": ex["question"],
                "cot_program": enhanced_text,
                "final_answer": rec["final_answer"],
                "prm_score": prm_score
            })
            
        if (i+1) % 50 == 0:
            print(f"[{i+1}/{len(evolution_samples)}] enhanced_collected={len(enhanced_records)}")
            
    except Exception as e:
        if i < 5:  # only print first few errors
            print(f"Error on sample {i}: {e}")
        continue

print(f"‚úÖ Generated {len(enhanced_records)} PRM-enhanced training examples")

# Save enhanced data for potential SFT fine-tuning
enhanced_path = DATA / "prm_enhanced.jsonl"
with enhanced_path.open("w") as f:
    for rec in enhanced_records:
        print(json.dumps(rec), file=f)

print(f"Saved enhanced data to: {enhanced_path}")

In [None]:
# %% [markdown]
# ## ‚úÇÔ∏è Long‚ÜíShort preference data & DPO training
from trl import DPOTrainer, DPOConfig
from transformers import AutoModelForCausalLM

def make_short_from_program(code:str, final_ans:str):
    kept = []
    for ln in code.splitlines():
        s = ln.strip()
        if s.startswith("#"): 
            continue
        if "print(" in s and "Answer:" not in s:
            continue
        kept.append(ln)
    short = "\n".join(kept)
    return f"{short}\n\nAnswer: {final_ans}"

pref_recs = []
for r in verified[: min(1500, len(verified))]:
    q = r["question"]
    long = f"{r['cot_program']}\n\nAnswer: {r['final_answer']}"
    short = make_short_from_program(r["cot_program"], r["final_answer"])
    pref_recs.append({"prompt": q, "chosen": short, "rejected": long})

pref_path = DATA / "short_vs_long.jsonl"
with pref_path.open("w") as f:
    for ex in pref_recs:
        print(json.dumps(ex), file=f)

pref_ds = load_dataset("json", data_files=str(pref_path))["train"]

dpo_model = AutoModelForCausalLM.from_pretrained(
    str(OUT / "grpo" / "final"),
    torch_dtype="auto", attn_implementation="flash_attention_2", device_map="auto"
)

dpo_cfg = DPOConfig(
    output_dir=str(OUT / "dpo"),
    per_device_train_batch_size=2,
    gradient_accumulation_steps=8,
    num_train_epochs=1,
    learning_rate=5e-6,
    bf16=True, logging_steps=20, save_steps=200,
    max_length=1024
)

dpo_trainer = DPOTrainer(
    model=dpo_model, ref_model=None, tokenizer=eval_tok,
    args=dpo_cfg, train_dataset=pref_ds
)
dpo_trainer.train()
dpo_trainer.save_model(str(OUT / "dpo" / "final"))
print("‚úÖ DPO done ->", OUT / "dpo" / "final")


In [None]:
# %% [markdown]
# ## üß† Length-aware RL (optional but included): stronger brevity penalty after DPO
from trl import GRPOTrainer, GRPOConfig

def lapo_like_reward(prompts, completions, ground_truth, **kwargs):
    rewards = []
    for comp, gt in zip(completions, ground_truth):
        content = comp if isinstance(comp, str) else comp[0]["content"]
        try:
            pred = parse_final(content)
            correct = eq_correct(pred, gt)
            r = (1.0 if correct else 0.0) - 0.0005 * len(content)  # stronger after DPO
            rewards.append(float(r))
        except Exception:
            rewards.append(-0.1)
    return rewards

lapo_model = AutoModelForCausalLM.from_pretrained(
    str(OUT / "dpo" / "final"),
    torch_dtype="auto", attn_implementation="flash_attention_2", device_map="auto"
)

lapo_cfg = GRPOConfig(
    output_dir=str(OUT / "grpo-lapo"),
    learning_rate=1e-6, per_device_train_batch_size=1,
    gradient_accumulation_steps=8, bf16=True,
    logging_steps=10, save_steps=100,
    max_prompt_length=1024, max_completion_length=512,
    num_generations=GRPO_GROUP, kl_coeff=GRPO_KL
)

lapo_trainer = GRPOTrainer(
    model=lapo_model, reward_funcs=lapo_like_reward,
    train_dataset=rl_ds.select(range(min(len(rl_ds), 4000))),
    processing_class=eval_tok, args=lapo_cfg
)
lapo_trainer.train(max_steps=1000)
lapo_trainer.save_model(str(OUT / "grpo-lapo" / "final"))
print("‚úÖ Length-aware RL done ->", OUT / "grpo-lapo" / "final")


In [None]:
# %% [markdown]
# ## ‚úÖ Evaluation (GSM8K quick pass@1 with PoT execution)
from transformers import AutoModelForCausalLM

final_eval_path = str(OUT / "grpo-lapo" / "final") if (OUT / "grpo-lapo" / "final").exists() else str(OUT / "dpo" / "final")
eval_model = AutoModelForCausalLM.from_pretrained(
    final_eval_path, torch_dtype="auto",
    attn_implementation="flash_attention_2", device_map="auto"
)

def evaluate_gsm8k(n=EVAL_N):
    subset = gsm_test.select(range(min(n, len(gsm_test))))
    correct = 0
    for i, ex in enumerate(subset):
        text = generate_reply(eval_model, eval_tok, ex["question"], temperature=0.0, max_new_tokens=512)
        pred = parse_final(text)
        gold = extract_numeric(ex["answer"])
        good = eq_correct(pred, gold)
        if not good:
            # try one repair round for stubborn cases (cheap)
            _, pred2, ok2 = stepco_repair(ex["question"], max_rounds=1)
            good = ok2 and eq_correct(pred2, gold)
        correct += int(good)
        if (i+1) % 25 == 0:
            print(f"[{i+1}/{len(subset)}] acc={correct/(i+1):.3f}")
    print(f"Final GSM8K@{len(subset)}: acc={correct/len(subset):.3f}")

evaluate_gsm8k(EVAL_N)
