In [None]:
# -*- coding: utf-8 -*-
# ============================================================
# Training (A100-40GB Colab) following your Mermaid workflow,
# but reusing a *pretrained GRPO checkpoint* instead of retraining GRPO.
#
# Flow kept:
#   A0 Datasets ‚Üí A1 Teacher Synthesis ‚Üí A2 Verification ‚Üí A3 SFT ‚Üí
#   A4 (load your GRPO) ‚Üí A5 PRM ‚Üí A6 PRM-guided self-evolution ‚Üí
#   A7 Length-aware finishing ‚Üí Eval
# ============================================================

# ============ A. ENV CHECK (no forced restarts) ============
import importlib, sys, subprocess, platform, os, json, re, random, glob, shutil, time, traceback
from pathlib import Path

def ver(pkg):
    try:
        m = importlib.import_module(pkg); return getattr(m, "__version__", "unknown")
    except Exception:
        return "not-installed"

print("Python  :", sys.version.split()[0])
print("Platform:", platform.platform())
print("CUDA nvcc:", subprocess.getoutput("nvcc --version | tail -n1"))
print("torch   :", ver("torch"))
print("transformers:", ver("transformers"))
print("trl        :", ver("trl"))
print("accelerate  :", ver("accelerate"))
print("peft        :", ver("peft"))
print("datasets    :", ver("datasets"))

# Tip: only install if you actually hit an import error below.

In [None]:
# ============ B. IMPORTS & GLOBAL CONFIG ============
import torch, sympy as sp
from datasets import load_dataset, Dataset, DatasetDict
from transformers import (
    AutoTokenizer, AutoModelForCausalLM, AutoModelForSequenceClassification,
    TrainingArguments, Trainer, GenerationConfig
)
from trl import SFTTrainer, SFTConfig, DPOTrainer, DPOConfig
# GRPO is not used for training here; we only *load* your trained checkpoint for generation.

SEED = 1337
random.seed(SEED)
torch.manual_seed(SEED)

# ===== Paths (edit to your Drive if you want) =====
BASE_DIR = Path("/content")
OUT  = BASE_DIR / "outputs"                  # notebook-local outputs
DATA = BASE_DIR / "data"                     # notebook-local data
OUT.mkdir(parents=True, exist_ok=True)
DATA.mkdir(parents=True, exist_ok=True)

# If you saved teacher/verified files on Drive, you can mount and point to them:
# from google.colab import drive; drive.mount("/content/drive")
# DATA = Path("/content/drive/MyDrive/ncu_green_ai/data"); DATA.mkdir(parents=True, exist_ok=True)
# OUT  = Path("/content/drive/MyDrive/ncu_green_ai/output"); OUT.mkdir(parents=True, exist_ok=True)

# ===== Models =====
# Base SFT starting point (you kept Qwen math 1.5B earlier)
MODEL_ID  = os.getenv("MODEL_ID", "Qwen/Qwen2.5-Math-1.5B")
USE_INSTRUCT_CHAT_TEMPLATE = False  # True if using the Instruct chat template

# *** IMPORTANT ***: your GRPO checkpoint path
# Point this to the folder that contains config.json, pytorch_model.bin / safetensors, tokenizer files, etc.
GRPO_MODEL_PATH = os.getenv("GRPO_MODEL_PATH", "/content/outputs/default-GRPO/final")

# ===== Training knobs (SFT, then later DPO/length-aware RL) =====
MAX_SEQ_LEN = 4096
SFT_EPOCHS  = 2
SFT_LR      = 2e-5

EVAL_N      = 200  # quick eval

In [None]:
# ============ C. UTILITIES (parsers, verification) ============
ANS_LINE = re.compile(r"(?i)^Answer:\s*(.+)\s*$")
NUMERIC_TAIL = re.compile(r"[-+]?\d+(?:\.\d+)?(?:/[0-9]+)?")

def parse_final(text: str):
    """Robustly parse a final numeric answer from model output."""
    if not text: return None
    # Prefer \boxed{...}
    m = re.findall(r"\\boxed\{([^}]+)\}", text)
    if not m:
        m = re.findall(r"\\boxed\{([^}\n]+)", text)  # tolerate missing brace
    if m:
        cand = m[-1].strip().strip("`'\"")
        t = NUMERIC_TAIL.search(cand)
        return t.group(0) if t else None

    # Next, look for last "Answer: ...":
    matches = re.findall(r"(?i)Answer:\s*([^\n]+)", text)
    if matches:
        cand = matches[-1].strip().strip("`'\"")
        t = NUMERIC_TAIL.search(cand)
        return t.group(0) if t else None

    # Fallback: last non-empty line ‚Üí numeric tail
    lines = [ln.strip() for ln in text.splitlines() if ln.strip()]
    if not lines:
        return None
    t = NUMERIC_TAIL.search(lines[-1])
    return t.group(0) if t else None

def eq_correct(got, want):
    try:
        return sp.simplify(sp.nsimplify(got) - sp.nsimplify(want)) == 0
    except Exception:
        return str(got).strip() == str(want).strip()

def extract_gold_gsm(answer_text: str):
    m = re.search(r"####\s*([\-+]?\d+(?:\.\d+)?)", answer_text)
    return m.group(1).strip() if m else None

In [None]:
# ============ D. DATASETS (A0) ============
print("Loading GSM8K‚Ä¶")
gsm = load_dataset("openai/gsm8k", "main")
gsm_train = gsm["train"]
gsm_test  = gsm["test"]
print("GSM8K train/test:", len(gsm_train), len(gsm_test))

# (Optional) also parse some MATH set as before if you like.
math_items = []  # keep empty unless you clone and parse hendrycks/MATH

In [None]:
# ============ E. TEACHER SYNTHESIS (A1) + LOCAL VERIFICATION (A2) ============
# If you already have verified PoT, we *reuse* it; otherwise we can synthesize.
verified_path = DATA / "teacher_verified.jsonl"
raw_path      = DATA / "teacher_raw.jsonl"

def verify_record(rec):
    """Your safe Python verifier for PoT. Returns (ok, msg)."""
    ALLOWED_BUILTINS = {"abs":abs, "min":min, "max":max, "range":range, "len":len, "sum":sum, "print":print}
    SAFE_GLOBALS = {
        "__builtins__": ALLOWED_BUILTINS,
        "math": __import__("math"), "fractions": __import__("fractions"),
        "decimal": __import__("decimal"), "itertools": __import__("itertools"), "sp": sp
    }
    code = rec["cot_program"]; tests = rec.get("tests", [])
    final = str(rec.get("final_answer","")).strip()
    if NUMERIC_TAIL.search(final) is None:
        return False, "final not numeric-like"
    loc = {}
    try:
        exec(code, SAFE_GLOBALS, loc)
    except Exception as e:
        return False, f"exec error: {e}"
    for t in tests:
        try:
            exec(t, {**SAFE_GLOBALS, **loc}, {})
        except Exception as e:
            return False, f"test fail: {e}"
    if "solve" in loc and callable(loc["solve"]):
        try:
            got = str(loc["solve"]()).strip()
            try:
                if sp.simplify(sp.nsimplify(got) - sp.nsimplify(final)) != 0:
                    return False, f"mismatch: solve()={got} vs final={final}"
            except Exception:
                if got != final:
                    return False, f"mismatch: solve()={got} vs final={final}"
        except Exception as e:
            return False, f"solve() error: {e}"
    return True, "ok"

# If verified exists, just load a tiny preview and count:
if verified_path.exists() and verified_path.stat().st_size > 0:
    n_verified = sum(1 for _ in open(verified_path))
    print(f"‚úÖ Using existing verified PoT: {verified_path} (records={n_verified})")
else:
    print("‚ÑπÔ∏è No verified PoT found; you can paste your earlier teacher-synthesis cell here if needed.")
    # You can synthesize with your OpenAI Teacher the same way you did before.
    # Keeping the pipeline intact, but skipping here for brevity.

In [None]:
# ============ F. BUILD SFT DATASET (A3 input) ============
def to_chat_text(q, prog, ans, tok, use_template=False):
    system = "You are a concise math solver. First write minimal Python to compute the answer, then output 'Answer: <value>'."
    if use_template and hasattr(tok, "apply_chat_template"):
        messages = [
            {"role":"system","content":system},
            {"role":"user","content":q},
            {"role":"assistant","content":f"# python\n{prog}\n\nAnswer: {ans}"}
        ]
        return tok.apply_chat_template(messages, tokenize=False)
    else:
        return f"<|system|>\n{system}\n<|user|>\n{q}\n<|assistant|>\n# python\n{prog}\n\nAnswer: {ans}"

verified = []
if verified_path.exists():
    for line in open(verified_path):
        try:
            js = json.loads(line)
            if all(k in js for k in ["question","cot_program","final_answer"]):
                verified.append(js)
        except Exception:
            pass

sft_path = DATA / "sft_train.jsonl"
if verified:
    tok_tmp = AutoTokenizer.from_pretrained(MODEL_ID, use_fast=True)
    if tok_tmp.pad_token is None: tok_tmp.pad_token = tok_tmp.eos_token
    with sft_path.open("w") as f:
        for r in verified:
            text = to_chat_text(r["question"], r["cot_program"], r["final_answer"], tok_tmp, USE_INSTRUCT_CHAT_TEMPLATE)
            print(json.dumps({"text": text, "question": r["question"], "final_answer": r["final_answer"]}), file=f)
    print("SFT records:", sum(1 for _ in open(sft_path)))
else:
    print("‚è≠Ô∏è No verified data ‚Üí SFT will be skipped unless you add teacher_verified.jsonl.")

In [None]:
# ============ G. SFT TRAINING (A3) ============
if sft_path.exists():
    train_ds = load_dataset("json", data_files=str(sft_path))["train"]
    sft_model = AutoModelForCausalLM.from_pretrained(
        MODEL_ID,
        torch_dtype="auto",
        attn_implementation="flash_attention_2",
        device_map="auto"
    )
    tok_sft = AutoTokenizer.from_pretrained(MODEL_ID, use_fast=True)
    if tok_sft.pad_token is None: tok_sft.pad_token = tok_sft.eos_token

    sft_cfg = SFTConfig(
        output_dir=str(OUT / "sft-poT"),
        per_device_train_batch_size=1,
        gradient_accumulation_steps=16,
        learning_rate=SFT_LR,
        num_train_epochs=SFT_EPOCHS,
        bf16=True,
        logging_steps=10,
        save_steps=200,
        gradient_checkpointing=True
    )

    def formatting_func(example):
        q = example.get("question", "")
        ans = example.get("final_answer", "")
        return (
            "Solve the following math problem.\n\n"
            f"{q}\n\n"
            "Use minimal Python to compute, then end strictly with:\n"
            f"Answer: {ans}\n"
        )

    sft_trainer = SFTTrainer(
        model=sft_model,
        args=sft_cfg,
        train_dataset=train_ds,
        formatting_func=formatting_func,
        tokenizer=tok_sft
    )
    sft_trainer.train()
    sft_save = OUT / "sft-poT" / "final"
    sft_trainer.save_model(str(sft_save))
    tok_sft.save_pretrained(str(sft_save))
    print("‚úÖ SFT saved ->", sft_save)
else:
    print("‚è≠Ô∏è Skipping SFT (no sft_train.jsonl).")


In [None]:
# ============ H. LOAD YOUR TRAINED GRPO (A4 replacement) ============
assert Path(GRPO_MODEL_PATH).exists(), (
    f"‚ùå GRPO_MODEL_PATH not found: {GRPO_MODEL_PATH}\n"
    "Set GRPO_MODEL_PATH to the directory of your trained GRPO checkpoint (contains config.json, tokenizer files, model weights)."
)

print(f"üîπ Using trained GRPO model from: {GRPO_MODEL_PATH}")
policy = AutoModelForCausalLM.from_pretrained(
    GRPO_MODEL_PATH,
    torch_dtype="auto",
    attn_implementation="flash_attention_2",
    device_map="auto"
)
tok_policy = AutoTokenizer.from_pretrained(GRPO_MODEL_PATH, use_fast=True)
if tok_policy.pad_token is None: tok_policy.pad_token = tok_policy.eos_token
tok_policy.padding_side = "left"

# Greedy-ish generation config for eval
gen_cfg = GenerationConfig.from_model_config(policy.config)
gen_cfg.do_sample = False
gen_cfg.temperature = None
gen_cfg.top_p = None
gen_cfg.max_new_tokens = 192
policy.generation_config = gen_cfg

def make_prompt(question: str):
    return (
        "Solve the problem briefly. Output ONLY one line:\n"
        "Answer: <number>\n\n"
        f"Problem:\n{question}\n\n"
        "Answer: "
    )

@torch.no_grad()
def generate_once(model, tok, prompt, max_new_tokens=192):
    inputs = tok([prompt], return_tensors="pt").to(model.device)
    newline_id = tok.encode("\n", add_special_tokens=False)[-1]
    eos_ids = [tok.eos_token_id, newline_id] if newline_id is not None else tok.eos_token_id
    out = model.generate(
        **inputs, do_sample=False, max_new_tokens=max_new_tokens,
        pad_token_id=tok.pad_token_id, eos_token_id=eos_ids
    )[0]
    text = tok.decode(out[inputs.input_ids.shape[1]:], skip_special_tokens=True)
    return text.splitlines()[0].strip() if text else ""

In [None]:
# ============ I. PRM (A5) ============

PRM_ID = "Qwen/Qwen2.5-Math-PRM-7B"
prm_dir = OUT / "prm" / "final"

try:
    # Try pretrained PRM first
    print(f"üîπ Attempting to load pretrained PRM: {PRM_ID}")
    prm_tok = AutoTokenizer.from_pretrained(PRM_ID, use_fast=True)
    prm_model = AutoModelForSequenceClassification.from_pretrained(
        PRM_ID, torch_dtype="auto", device_map="auto"
    ).eval()
    print("‚úÖ Loaded pretrained PRM:", PRM_ID)

except Exception as e:
    print("‚ö†Ô∏è Could not load pretrained PRM:", e)
    if prm_dir.exists():
        # fallback to local fine-tuned PRM
        print("‚úÖ Reusing existing PRM at:", prm_dir)
        prm_model = AutoModelForSequenceClassification.from_pretrained(
            str(prm_dir), device_map="auto", torch_dtype="auto"
        )
        prm_tok   = AutoTokenizer.from_pretrained(str(prm_dir), use_fast=True)
    else:
        # fallback to PRM800K training
        raw = _try_load_prm800k()
        if raw is not None:
            std = []
            for ex in raw:
                pair = _standardize_prm(ex)
                if pair is not None:
                    std.append(pair)
            if std:
                prm_train_ds = Dataset.from_list(std)
                prm_tok = AutoTokenizer.from_pretrained(MODEL_ID, use_fast=True)
                if prm_tok.pad_token is None: prm_tok.pad_token = prm_tok.eos_token

                def _tok_fn(ex):
                    out = prm_tok(ex["text"], truncation=True, max_length=1024)
                    out["labels"] = int(ex["label"])
                    return out

                prm_ds_tok = prm_train_ds.map(_tok_fn, remove_columns=prm_train_ds.column_names)
                prm_model = AutoModelForSequenceClassification.from_pretrained(
                    MODEL_ID, num_labels=2, torch_dtype="auto", device_map="auto"
                )
                prm_args = TrainingArguments(
                    output_dir=str(OUT / "prm"),
                    per_device_train_batch_size=4,
                    gradient_accumulation_steps=2,
                    learning_rate=1e-5, num_train_epochs=1,
                    bf16=True, logging_steps=20, save_steps=200
                )
                prm_trainer = Trainer(model=prm_model, args=prm_args,
                                      train_dataset=prm_ds_tok, tokenizer=prm_tok)
                prm_trainer.train()
                prm_trainer.save_model(str(prm_dir))
                prm_tok.save_pretrained(str(prm_dir))
                print("‚úÖ PRM trained & saved ->", prm_dir)
            else:
                print("‚è≠Ô∏è PRM800K had no usable (text,label). Skipping PRM.")
                prm_model = None; prm_tok = None
        else:
            print("‚è≠Ô∏è PRM800K not found. Skipping PRM.")
            prm_model = None; prm_tok = None


In [None]:
# ============ J. PRM-GUIDED SELF-EVOLUTION (A6) ============
import torch.nn.functional as F

# Log which PRM source is being used
if prm_model is None:
    print("‚ö†Ô∏è PRM not available ‚Üí using uniform fallback score (0.5).")
elif "Qwen2.5-Math-PRM-7B" in getattr(prm_model.config, "_name_or_path", ""):
    print("‚úÖ Using pretrained PRM: Qwen/Qwen2.5-Math-PRM-7B")
elif str(prm_dir) in getattr(prm_model.config, "_name_or_path", ""):
    print("‚úÖ Using locally fine-tuned PRM (loaded from /prm/final)")
else:
    print("‚úÖ Using custom-trained PRM (fallback from PRM800K or verified traces)")

def get_prm_score(question, code_line, step_num):
    if prm_model is None:
        return 0.5  # neutral fallback
    txt = f"{question}\n\n# step {step_num}\n{code_line}"
    inputs = prm_tok(txt, return_tensors="pt", truncation=True, max_length=1024).to(prm_model.device)
    with torch.no_grad():
        logits = prm_model(**inputs).logits
        probs = F.softmax(logits, dim=-1)
        return float(probs[0][1])  # P(correct)


def best_of_n_with_prm(question, n=4, model=policy, tok=tok_policy):
    candidates = []
    for _ in range(n):
        prompt = (
            "Write minimal Python to compute the result, then output only one line:\n"
            "Answer: <number>\n\n"
            f"Problem:\n{question}\n\n"
            "Answer: "
        )
        text = generate_once(model, tok, prompt, max_new_tokens=192)
        lines = [ln.strip() for ln in text.splitlines() if ln.strip()]
        if not lines:
            candidates.append((text, 0.0)); continue
        scores = []
        for i, line in enumerate(lines[:10]):
            try: scores.append(get_prm_score(question, line, i+1))
            except: scores.append(0.5)
        avg_score = sum(scores)/len(scores) if scores else 0.0
        candidates.append((text, avg_score))
    return max(candidates, key=lambda x: x[1])

print("üîÑ Generating PRM-guided enhanced data‚Ä¶")
enhanced_records = []
evolution_samples = gsm_train.select(range(min(800, len(gsm_train))))
for i, ex in enumerate(evolution_samples):
    try:
        text, s = best_of_n_with_prm(ex["question"], n=4)
        pred = parse_final(text)
        gold = extract_gold_gsm(ex["answer"])
        if pred and gold and eq_correct(pred, gold) and s > 0.6:
            enhanced_records.append({
                "question": ex["question"],
                "cot_program": text,
                "final_answer": pred,
                "prm_score": s
            })
        if (i+1) % 50 == 0:
            print(f"[{i+1}/{len(evolution_samples)}] enhanced={len(enhanced_records)}")
    except Exception as e:
        if i < 3: print("Example error:", e)

enhanced_path = DATA / "prm_enhanced.jsonl"
with enhanced_path.open("w") as f:
    for rec in enhanced_records:
        print(json.dumps(rec), file=f)
print("‚úÖ Saved enhanced data:", enhanced_path, "count:", len(enhanced_records))

In [None]:
# ============ K. DPO (A7 part 1: long‚Üíshort preferences) ============
# Build short-vs-long from verified (if present)
pref_recs = []
if verified:
    def make_short(code: str, final_ans: str):
        kept = []
        for ln in code.splitlines():
            s = ln.strip()
            if s.startswith("#"):      continue
            if "print(" in s and "Answer:" not in s: continue
            kept.append(ln)
        return f"{'\n'.join(kept)}\n\nAnswer: {final_ans}"

    for r in verified[: min(1500, len(verified))]:
        q = r["question"]
        long = f"{r['cot_program']}\n\nAnswer: {r['final_answer']}"
        short = make_short(r["cot_program"], r["final_answer"])
        pref_recs.append({"prompt": q, "chosen": short, "rejected": long})

pref_path = DATA / "short_vs_long.jsonl"
if pref_recs:
    with pref_path.open("w") as f:
        for ex in pref_recs:
            print(json.dumps(ex), file=f)
    pref_ds = load_dataset("json", data_files=str(pref_path))["train"]

    dpo_model = AutoModelForCausalLM.from_pretrained(
        GRPO_MODEL_PATH,  # start DPO from your GRPO policy
        torch_dtype="auto", attn_implementation="flash_attention_2", device_map="auto"
    )
    dpo_cfg = DPOConfig(
        output_dir=str(OUT / "dpo"),
        per_device_train_batch_size=2,
        gradient_accumulation_steps=8,
        num_train_epochs=1,
        learning_rate=5e-6,
        bf16=True, logging_steps=20, save_steps=200,
        max_length=1024
    )
    dpo_trainer = DPOTrainer(
        model=dpo_model, ref_model=None, tokenizer=tok_policy,
        args=dpo_cfg, train_dataset=pref_ds
    )
    dpo_trainer.train()
    dpo_save = OUT / "dpo" / "final"
    dpo_trainer.save_model(str(dpo_save))
    tok_policy.save_pretrained(str(dpo_save))
    print("‚úÖ DPO saved ->", dpo_save)
else:
    print("‚è≠Ô∏è No verified data ‚Üí skipping DPO stage.")
    dpo_save = None

In [None]:
# ============ L. Length-aware RL ‚Äúfinishing‚Äù (A7 part 2; optional) ============
# For a simple ‚Äúlength-aware finishing‚Äù, we‚Äôll *not* run GRPO here to keep things light on Colab.
# Instead we prepare a small shaping function and (optionally) do a tiny extra pass via DPO or SFT.
# If you want true RL finishing, you can plug GRPO back in later starting from dpo_save or GRPO_MODEL_PATH.

final_base_for_eval = str(dpo_save) if dpo_save is not None and Path(dpo_save).exists() else GRPO_MODEL_PATH
print("üîπ Final base for eval:", final_base_for_eval)

In [None]:
# ============ M. Evaluation (GSM8K quick pass@1) ============
eval_model = AutoModelForCausalLM.from_pretrained(
    final_base_for_eval, torch_dtype="auto",
    attn_implementation="flash_attention_2", device_map="auto"
)
eval_tok = AutoTokenizer.from_pretrained(final_base_for_eval, use_fast=True)
if eval_tok.pad_token is None: eval_tok.pad_token = eval_tok.eos_token

@torch.no_grad()
def evaluate_gsm8k(n=EVAL_N):
    subset = gsm_test.select(range(min(n, len(gsm_test))))
    correct = 0
    for i, ex in enumerate(subset):
        prompt = make_prompt(ex["question"])
        text = generate_once(eval_model, eval_tok, prompt, max_new_tokens=192)
        pred = parse_final(text)
        gold = extract_gold_gsm(ex["answer"])
        good = (pred is not None and gold is not None and eq_correct(pred, gold))
        correct += int(good)
        if (i+1) % 20 == 0:
            print(f"[{i+1}/{len(subset)}] acc={correct/(i+1):.3f}")
    print(f"Final GSM8K@{len(subset)}: acc={correct/len(subset):.3f}")

print("\n‚ñ∂Ô∏è Running evaluation‚Ä¶")
evaluate_gsm8k(EVAL_N)