# RL w/ Verifiable Rewards Experiments  

running locally so bear with me; sticking to smol models and toy-ish tasks to start with

#### imports

In [1]:
import os, re, random, math, json
import torch
from dataclasses import dataclass
from typing import List, Tuple
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
from tqdm.auto import tqdm

from typing import List, Dict, Tuple
from datasets import Dataset
from transformers import AutoTokenizer, BitsAndBytesConfig
from trl import GRPOConfig, GRPOTrainer
from peft import LoraConfig

### measure baseline accuracy (no RL)  

In [2]:


MODEL_ID = os.environ.get("MODEL_ID", "Qwen/Qwen3-0.6B")  # swap to ...-Instruct if you have it
DEVICE = "cuda"
DTYPE = torch.float32
LOAD_IN_4BIT = True  # flip to False if bitsandbytes is problematic under WSL2
N_EVAL = 100  # TODO: increase later

# Keep outputs short and verifiable. No chain-of-thought; just the number.
SYSTEM_PROMPT = (
    "You are a calculator. Solve the user's math problem and reply with an integer. "
    "ASCII characters only, no markdown. "
    "Your final answer should be on a new line at the end of your response. "
)

def build_messages(problem: str):
    return [
        {"role": "system", "content": SYSTEM_PROMPT},
        {"role": "user", "content": problem},
    ]

# Extract the last integer-like token (handles negatives)
NUM_RE = re.compile(r"[-+]?\d+")

def parse_answer(text: str) -> int | None:
    matches = NUM_RE.findall(text.strip())
    if not matches:
        return None
    return int(matches[-1])

def gen_synthetic_math(n: int = 500, seed: int = 0) -> List[Tuple[str, int]]:
    rng = random.Random(seed)
    items = []
    for _ in range(n):
        a = rng.randint(-99999, 99999)
        b = rng.randint(-99999, 99999)
        op = rng.choice(["+", "-", "*"])
        if op == "+":
            ans = a + b
            prob = f"{a} + {b} = ?"
        elif op == "-":
            ans = a - b
            prob = f"{a} - {b} = ?"
        elif op == "*":
            # keep magnitudes modest to avoid huge products
            a2 = rng.randint(-50, 50)
            b2 = rng.randint(-50, 50)
            ans = a2 * b2
            prob = f"{a2} * {b2} = ?"
        else:
            raise ValueError("Invalid mathematical operation.")
        items.append((prob, int(ans)))
    return items

def measure_baseline_accuracy():
    kwargs = dict(
        dtype=DTYPE,
        device_map="auto",
        trust_remote_code=True,
    )
    if LOAD_IN_4BIT:
        kwargs.update(dict(quantization_config=BitsAndBytesConfig(load_in_4bit=True, bnb_4bit_use_double_quant=True, bnb_4bit_quant_type="nf4", bnb_4bit_compute_dtype=torch.float32)))
    model = AutoModelForCausalLM.from_pretrained(MODEL_ID, **kwargs)
    tok = AutoTokenizer.from_pretrained(MODEL_ID, trust_remote_code=True)
    if tok.pad_token_id is None:
        tok.pad_token_id = tok.eos_token_id
    # Align model/generation config with tokenizer to avoid warnings
    model.config.pad_token_id = tok.pad_token_id
    model.config.eos_token_id = tok.eos_token_id
    model.config.bos_token_id = getattr(tok, 'bos_token_id', None)
    if hasattr(model, 'generation_config'):
        model.generation_config.pad_token_id = tok.pad_token_id
        model.generation_config.eos_token_id = tok.eos_token_id
        model.generation_config.bos_token_id = getattr(tok, 'bos_token_id', None)
    # Ensure dtype alignment of embeddings + head and retie (eval baseline)
    try:
        base = model
        # (eval baseline) optional dtype diagnostics
        base = model
        emb_in = base.get_input_embeddings()
        emb_out = base.get_output_embeddings() or getattr(base, 'lm_head', None)
        try:
            print('in_emb:', emb_in.weight.dtype if emb_in is not None else None, 'out_emb:', emb_out.weight.dtype if emb_out is not None else None)
        except Exception:
            pass
    except Exception:
        pass

    torch.backends.cuda.matmul.allow_tf32 = True
    model.eval()

    data = gen_synthetic_math(n=N_EVAL, seed=123)
    correct = 0
    samples = []

    for i, (problem, ans) in tqdm(enumerate(data)):
        messages = build_messages(problem)
        text = tok.apply_chat_template(
            messages, tokenize=False, add_generation_prompt=True, enable_thinking=False
        )
        inputs = tok(text, return_tensors="pt").to(DEVICE)
        with torch.no_grad():
            out = model.generate(
                **inputs,
                max_new_tokens=999,
                do_sample=False,
                temperature=0.0,
                eos_token_id=tok.eos_token_id,
                pad_token_id=tok.pad_token_id,
            )
        gen = tok.decode(out[0][inputs["input_ids"].shape[1]:], skip_special_tokens=True)
        pred = parse_answer(gen)
        is_ok = (pred == ans)
        correct += int(is_ok)

        if i < 10:  # keep a few examples
            samples.append({"problem": problem, "gold": ans, "raw": gen.strip(), "pred": pred, "ok": bool(is_ok)})

    acc = correct / len(data)
    print(json.dumps({"model": MODEL_ID, "n": len(data), "accuracy": acc, "samples": samples}, indent=2))



In [18]:
measure_baseline_accuracy()

The `load_in_4bit` and `load_in_8bit` arguments are deprecated and will be removed in the future versions. Please, pass a `BitsAndBytesConfig` object in `quantization_config` argument instead.


0it [00:00, ?it/s]

{
  "model": "Qwen/Qwen3-0.6B",
  "n": 100,
  "accuracy": 0.83,
  "samples": [
    {
      "problem": "-86273 + -29830 = ?",
      "gold": -116103,
      "raw": "-86273 + -29830 = -86273 - 29830 = -116103.",
      "pred": -116103,
      "ok": true
    },
    {
      "problem": "6756 + -30124 = ?",
      "gold": -23368,
      "raw": "6756 + (-30124) = 6756 - 30124 = -23368.",
      "pred": -23368,
      "ok": true
    },
    {
      "problem": "21 * -8 = ?",
      "gold": -168,
      "raw": "21 * -8 = -168.",
      "pred": -168,
      "ok": true
    },
    {
      "problem": "-10661 + -86394 = ?",
      "gold": -97055,
      "raw": "-10661 + -86394 = -97055",
      "pred": -97055,
      "ok": true
    },
    {
      "problem": "-8 * 39 = ?",
      "gold": -312,
      "raw": "-8 * 39 = -312",
      "pred": -312,
      "ok": true
    },
    {
      "problem": "-35731 + -57066 = ?",
      "gold": -92797,
      "raw": "-35731 + -57066 = -35731 - 57066 = -92797.",
      "pred": -92797,
     

### RL w/ Verifiable Rewards (GRPO)  

In [None]:
# train GRPO


# ----- Config -----
MODEL_ID = os.environ.get("MODEL_ID", "Qwen/Qwen3-0.6B")  # use base + strict system prompt
DTYPE = torch.bfloat16
LOAD_IN_4BIT = False          # was being flaky with this enabled
SEED = 42

# Keep completions tiny; GRPO will sample multiple per prompt (num_generations)
MAX_PROMPT_TOK = 128
MAX_COMPLETION_TOK = 128
NUM_GENERATIONS = 4          # group size G; 4–8 is typical for small models

# ----- Prompt rendering with chat template -----
def _pairs_to_rows(pairs: List[Tuple[str, int]]) -> List[Dict[str, int]]:
    return [{"problem": p, "gold": int(g)} for p, g in pairs]

def render_prompts(rows: List[Dict[str, int]]) -> Dataset:
    tok = AutoTokenizer.from_pretrained(MODEL_ID, trust_remote_code=True)
    if tok.pad_token_id is None:
        tok.pad_token_id = tok.eos_token_id
    tok.padding_side = "left"

    def _render(row):
        problem = row.get("problem", row.get("0"))
        gold = row.get("gold", row.get("1"))
        messages = [
            {"role": "system", "content": SYSTEM_PROMPT},
            {"role": "user", "content": problem},
        ]
        # Render to a single prompt string that already includes the assistant tag.
        prompt = tok.apply_chat_template(
            messages, tokenize=False, add_generation_prompt=True, enable_thinking=False
        )
        return {"prompt": prompt, "gold": gold}

    ds = Dataset.from_list(rows)
    return ds.map(_render, remove_columns=list(ds.column_names))

# Verifiable reward: use same logic as checking base SFT'ed model
def reward_correct_integer(completions: List[str], gold: List[int], **kwargs) -> List[float]:
    rewards = []
    for out, gt in zip(completions, gold):
        pred = parse_answer(out)
        rewards.append(1.0 if pred == gt else 0.0)
    return rewards

def train_grpo_integer_math():
    random.seed(SEED)

    train_rows = _pairs_to_rows(gen_synthetic_math(n=4000, seed=SEED))
    eval_rows  = _pairs_to_rows(gen_synthetic_math(n=400,  seed=SEED + 1))
    train_ds = render_prompts(train_rows)
    eval_ds  = render_prompts(eval_rows)

    # LoRA for small, fast updates (QLoRA if LOAD_IN_4BIT=True)
    lora = LoraConfig(
        r=16, lora_alpha=32, lora_dropout=0.05, bias="none", task_type="CAUSAL_LM",
        target_modules=["q_proj","k_proj","v_proj","o_proj","gate_proj","up_proj","down_proj"]
    )

    # GRPO configuration (grouped sampling, KL via DAPO loss)
    args = GRPOConfig(
        output_dir="qwen3-06b-grpo-math",
        seed=SEED,
        tf32=True,
        bf16=(DTYPE == torch.bfloat16),
        per_device_train_batch_size=16,     # effective batch = this * (world size)
        gradient_accumulation_steps=1,
        learning_rate=1e-5,
        logging_steps=10,
        save_steps=200,
        max_prompt_length=MAX_PROMPT_TOK,
        max_completion_length=MAX_COMPLETION_TOK,
        # steps_per_generation=None,
        num_generations=NUM_GENERATIONS,    # group size G; used for group-relative baseline
        temperature=1.0,                    # allow exploration
        top_p=0.9,
        # GRPO-specific knobs:
        beta=0.05,                          # KL strength
        epsilon=0.2,                        # clipping (GRPO’s PPO-style clip)
        scale_rewards="group",              # subtract group mean (relative baseline)
        loss_type="dapo",                   # TRL default; KL in loss
        # Optional speed/memory extras:
        # use_liger_loss=True,              # enable if you pip-install TRL[liger] from source
        # use_vllm=True,                    # if you want an external vLLM server for gen
        # vllm_mode="colocate",
        # Model init (QLoRA):
        model_init_kwargs=dict(
            dtype=DTYPE,
            trust_remote_code=True,
            device_map="auto",
            **({
                "quantization_config": BitsAndBytesConfig(load_in_4bit=True, bnb_4bit_use_double_quant=True, bnb_4bit_quant_type="nf4", bnb_4bit_compute_dtype=torch.float32)
            } if LOAD_IN_4BIT else {})
        ),
    )
    args.generation_batch_size = None

    tok = AutoTokenizer.from_pretrained(MODEL_ID, trust_remote_code=True)
    if tok.pad_token_id is None:
        tok.pad_token_id = tok.eos_token_id
    tok.padding_side = "left"

    trainer = GRPOTrainer(
        model=MODEL_ID,
        reward_funcs=reward_correct_integer,     # custom verifiable reward
        args=args,
        train_dataset=train_ds,
        eval_dataset=eval_ds,
        processing_class=tok,                    # tokenizer for policy/ref processing
        peft_config=lora
    )
    # Align model/generation config with tokenizer to avoid warnings
    model = trainer.model
    model.config.pad_token_id = tok.pad_token_id
    model.config.eos_token_id = tok.eos_token_id
    model.config.bos_token_id = getattr(tok, 'bos_token_id', None)
    if hasattr(model, 'generation_config'):
        model.generation_config.pad_token_id = tok.pad_token_id
        model.generation_config.eos_token_id = tok.eos_token_id
        model.generation_config.bos_token_id = getattr(tok, 'bos_token_id', None)
    # After TRL constructs the trainer (and internally re-created args),
    # restore a concrete generation_batch_size if it is None so downstream
    # dataloaders and samplers can derive batch/replication correctly.
    if trainer.args.generation_batch_size is None:
        trainer.args.generation_batch_size = (
            trainer.args.per_device_train_batch_size * trainer.args.world_size * trainer.args.steps_per_generation
        )

    trainer.train()
    trainer.save_model()  # PEFT adapter if LoRA; full if no PEFT



In [None]:
train_grpo_integer_math()

Map:   0%|          | 0/4000 [00:00<?, ? examples/s]

Map:   0%|          | 0/400 [00:00<?, ? examples/s]



Step,Training Loss
10,-0.0121
20,-0.0022
