## Testing Optimizer - trial 1 Run 


In [6]:
import os, math, random, re, time
from dataclasses import dataclass
from typing import List, Dict, Tuple, Optional
import random
import numpy as np
import torch
import torch.nn.functional as F
from torch.optim import AdamW
from argparse import Namespace
import matplotlib.pyplot as plt

from transformers import (
    AutoTokenizer,
    AutoModelForCausalLM,
    GenerationConfig,
    set_seed,
)

# LORA for simple stable run
try:
    from peft import LoraConfig, get_peft_model, TaskType
    PEFT_AVAILABLE = True
except Exception:
    PEFT_AVAILABLE = False

## Configuration Setup 

In [5]:

config = Namespace(
   
    model_name=os.environ.get("MODEL_NAME", "gpt2"),
    use_lora=True,
    lora_r=8,
    lora_alpha=16,
    lora_dropout=0.05,


    device=("cuda" if torch.cuda.is_available() else "cpu" ) or ("mps" if torch.mps.is_available() else "cpu"),
    dtype=torch.float16 if torch.cuda.is_available() else torch.float32,

   # rl metrics
    G=4,
    prompts_per_batch=8,
    updates=80,
    gen_max_new_tokens=128,
    temperature=0.9,
    top_p=0.95,

    #decoupled clipping
    eps_low=0.2,
    eps_high=0.28,

    #advantage sampling 
    adv_eps=1e-8,

    #length sampling 
    target_tokens=90,
    len_penalty_alpha=0.015,
    overlong_mask=True,

    #dynamic_Sampling
    dynamic_sampling=True,
    max_prompt_tries=200,

    # optimization
    lr=2e-5,
    weight_decay=0.0,
    grad_clip=1.0,
    microbatch=4,
    # reproducibility
    seed=42,
)

## Synthetic Dataset 

In [7]:
def _gcd(a: int, b: int) -> int:
    while b:
        a, b = b, a % b
    return abs(a)

def _simplify(n: int, d: int) -> tuple[int, int]:
    g = _gcd(n, d)
    return n // g, d // g

def _add_frac(a: int, b: int, c: int, d: int) -> tuple[int, int]:
    # a/b + c/d = (ad + cb) / (bd)
    n = a * d + c * b
    den = b * d
    return _simplify(n, den)

def _build_prompt(a: int, b: int, c: int, d: int) -> str:
    # common novice error: add numerators & denominators directly
    wrong_n = a + c
    wrong_d = b + d

    return (
        "You are an educational assistant.\n"
        f"Student claims: {a}/{b} + {c}/{d} = {wrong_n}/{wrong_d}.\n"
        "Correct the mistake and provide the right answer.\n"
        "Use this format:\n"
        "Mistake: ...\n"
        "Steps:\n"
        "1) ...\n"
        "2) ...\n"
        "Answer: <fraction>\n"
    )

def generate_synthetic_dataset(n: int = 800, seed: int = 0) -> list[dict]:
    """
    Returns a list of examples:
      {
        "prompt": str,
        "answer": "n/d" (simplified),
        "rubric": [str, ...]  # keywords/signals for shaping rewards
      }
    """
    rng = random.Random(seed)
    examples: list[dict] = []

    for _ in range(n):
        b = rng.randint(2, 9)
        d = rng.randint(2, 9)
        a = rng.randint(1, b - 1)
        c = rng.randint(1, d - 1)

        n_gt, d_gt = _add_frac(a, b, c, d)
        prompt = _build_prompt(a, b, c, d)

        examples.append(
            {
                "prompt": prompt,
                "answer": f"{n_gt}/{d_gt}",
                "rubric": ["common denominator", "add numerators", "simplify"],
            }
        )

    return examples

data = generate_synthetic_dataset(n=800, seed=config.seed)  # use your Namespace config
random.shuffle(data)

train = data[:650]
val = data[650:]

print(f"train={len(train)}  val={len(val)}")
print(train[0]["prompt"][:220], "\nGT:", train[0]["answer"])

train=650  val=150
You are an educational assistant.
Student claims: 2/4 + 3/6 = 5/10.
Correct the mistake and provide the right answer.
Use this format:
Mistake: ...
Steps:
1) ...
2) ...
Answer: <fraction>
 
GT: 1/1


In [8]:

ANSWER_RE = re.compile(r"Answer\s*:\s*([0-9]+)\s*/\s*([0-9]+)", re.IGNORECASE)

def parse_answer_frac(text: str) -> Optional[Tuple[int,int]]:
    m = ANSWER_RE.search(text)
    if not m:
        return None
    n = int(m.group(1))
    d = int(m.group(2))
    if d == 0:
        return None
    n,d = simplify(n,d)
    return n,d

def frac_equal(ans_str: str, pred: Tuple[int,int]) -> bool:
    n_gt, d_gt = map(int, ans_str.split("/"))
    n_gt, d_gt = simplify(n_gt, d_gt)
    return (n_gt, d_gt) == pred

def rubric_score(text: str, rubric: List[str]) -> float:
    t = text.lower()
    s = 0.0
    if "denominator" in t or "common" in t:
        s += 0.33
    if "numerator" in t or "add" in t:
        s += 0.33
    if "simplif" in t or "reduce" in t:
        s += 0.34
    return min(1.0, s)

def compute_reward(sample_text: str, gt_answer: str, rubric: List[str], gen_token_count: int) -> Dict[str, float]:
    parsed = parse_answer_frac(sample_text)
    correct = 1.0 if (parsed is not None and frac_equal(gt_answer, parsed)) else 0.0

    # base correctness reward (+1 / -1)
    base = 1.0 if correct == 1.0 else -1.0

    # pedagogy shaping (rubric coverage), only helps if not completely wrong
    rub = rubric_score(sample_text, rubric)
    shaped = base + (0.4 * rub)  # small bonus for teaching structure

    # length shaping: penalize going far beyond target
    extra = max(0, gen_token_count - cfg.target_tokens)
    shaped -= cfg.len_penalty_alpha * extra

    return {"reward": shaped, "correct": correct, "rubric": rub}

In [9]:
# =========================
# 4) Utilities: generation + logprobs on generated tokens
# =========================
@torch.no_grad()
def generate_group(model, prompts: List[str], G: int) -> List[Dict]:
    """
    Returns list of dicts for each completion:
      {prompt, completion_text, input_ids, attention_mask, gen_token_mask, gen_len, truncated}
    """
    model.eval()
    gen_cfg = GenerationConfig(
        max_new_tokens=cfg.gen_max_new_tokens,
        do_sample=True,
        temperature=cfg.temperature,
        top_p=cfg.top_p,
        pad_token_id=tokenizer.pad_token_id,
        eos_token_id=tokenizer.eos_token_id,
    )

    items = []
    for p in prompts:
        enc = tokenizer(p, return_tensors="pt", padding=False).to(cfg.device)
        prompt_len = enc["input_ids"].shape[1]

        # Generate G completions (separately for simplicity)
        for _ in range(G):
            out = model.generate(**enc, generation_config=gen_cfg)
            full_ids = out[0]  # [seq]
            gen_ids = full_ids[prompt_len:]
            truncated = (len(gen_ids) >= cfg.gen_max_new_tokens)

            completion_text = tokenizer.decode(full_ids[prompt_len:], skip_special_tokens=True)
            items.append({
                "prompt": p,
                "completion": completion_text,
                "full_ids": full_ids.unsqueeze(0),  # [1, seq]
                "prompt_len": prompt_len,
                "gen_len": int(gen_ids.shape[0]),
                "truncated": bool(truncated),
            })
    return items

def logprobs_on_generated_tokens(model, full_ids: torch.Tensor, prompt_len: int) -> Tuple[torch.Tensor, torch.Tensor]:
    """
    full_ids: [1, seq]
    Returns:
      token_logprobs: [T] log p(token_t | <t) for generated tokens only
      token_mask:     [T] mask=1 for generated tokens
    """
    # logits for positions 0..seq-2 predicting token 1..seq-1
    out = model(full_ids)
    logits = out.logits[:, :-1, :]  # [1, seq-1, V]
    target = full_ids[:, 1:]        # [1, seq-1]

    logp = F.log_softmax(logits, dim=-1)
    gathered = torch.gather(logp, dim=-1, index=target.unsqueeze(-1)).squeeze(-1)  # [1, seq-1]

    # generated token positions in target space:
    # generated tokens begin at position prompt_len in full_ids
    # in target indexing (shifted by 1), that corresponds to indices (prompt_len-1 ... seq-2)
    start = max(0, prompt_len - 1)
    token_logprobs = gathered[0, start:]  # [T]
    token_mask = torch.ones_like(token_logprobs, dtype=torch.float32)

    return token_logprobs, token_mask