### Loading

In [None]:
from unsloth import FastLanguageModel
import torch
# max_seq_length = 4096 # Can increase for longer reasoning traces
max_seq_length = 6144 # Can increase for longer reasoning traces
lora_rank = 8 # Larger rank = smarter, but slower

# Use v0 Engine
import os
os.environ["VLLM_USE_V1"] = "0"

In [None]:
model, tokenizer = FastLanguageModel.from_pretrained(
    model_name = "<MODEL>", 
    fix_tokenizer    = False,           # <— let it use the HF tokenizer directly
    max_seq_length = max_seq_length,
    load_in_4bit = False, # False for LoRA 16bit
    fast_inference = True, # Enable vLLM fast inference
    max_lora_rank = lora_rank,
    gpu_memory_utilization = 0.6, # Reduce if out of memory
)


In [None]:
# lora_adapter_path = "./YOUR_LORA_PATH"

# if lora_adapter_path is not None and os.path.exists(lora_adapter_path):
#     from peft import PeftModel
#     print(f"Loading LoRA adapter from: {lora_adapter_path}")
#     model = PeftModel.from_pretrained(base_model, lora_adapter_path)
# else:
#     print("No valid LoRA adapter found. Using base model.")
#     model = base_model


In [None]:
model = FastLanguageModel.get_peft_model(
    model,
    r = lora_rank, # Choose any number > 0 ! Suggested 8, 16, 32, 64, 128
    target_modules = [
        "q_proj", "k_proj", "v_proj", "o_proj",
        "gate_proj", "up_proj", "down_proj",
    ], # Remove QKVO if out of memory
    lora_alpha = lora_rank,
    use_gradient_checkpointing = "unsloth", # Enable long context finetuning
    random_state = 3407,
)

### Data Prep
<a name="Data"></a>

In [None]:
from huggingface_hub import login

import wandb

wandb.login(key="<WANDB_TOKEN>") 

login(token="<HF_TOKEN>")

In [None]:
from datasets import load_dataset

# Using a local CSV file
dataset = load_dataset("csv", data_files="data/train_dataset.csv")['train']  

from ast import literal_eval

def parse_prompt(example):
    try:
        example["prompt"] = literal_eval(example["prompt"])  # str -> list[dict]
    except:
        example["prompt"] = []
    return example

dataset = dataset.map(parse_prompt)

In [None]:
print(dataset[0]["prompt"])

In [None]:
import re
# ====================================================================== #
# 1.  CONSTANTS & REGEXES                                                #
# ====================================================================== #
BOXED_RE  = re.compile(r"\\boxed\{(.*?)\}")
OPTION_RE = re.compile(r"Option\s+([0-9]+)", re.I)

# FORMAT_PENALTY = -2.0    # Stage‑A fail
# PARSE_PENALTY  = -3.0    # Stage‑B fail
# CORRECT_BOX    = +3.0    # Stage‑C success, boxed
# PARTIAL        =  0.0    # Stage‑C success, unboxed
# WRONG          = -3.0    # Stage‑C wrong answer

FORMAT_PENALTY = -4.0
PARSE_PENALTY = -5.0
CORRECT_BOX = +5.0
PARTIAL = +2.0
WRONG = -3.0

In [None]:
# =================================================================== #
# 2.  FORMAT CHECK  (Stage‑A)                                          #
# =================================================================== #
def format_score(text: str) -> float:
    """
    Minimal contract:

      • exactly one </think>
      • if a \\boxed{…} exists, its *first* occurrence is after </think>

    +0.5  correct </think> count
    +0.25 order OK  (*or* no box at all)
    ---------------------------------
    -1    per violated rule
    """
    score = 0.0

    # 1) </think> count
    score += 0.5 if text.count("</think>") == 1 else -1.0

    # 2) ordering (only if count is correct)
    try:
        think_end = text.index("</think>")
        first_box = BOXED_RE.search(text)
        if first_box and first_box.start() < think_end:
            score -= 1.0                       # box appears before </think>
        else:
            score += 0.25                      # either no box or order OK
    except ValueError:
        score -= 1.0                           # shouldn’t happen, tag missing

    return round(score, 2)   # >0 → passes Stage‑A

# =================================================================== #
# 3.  ANSWER EXTRACTION  (Stage‑B)   — AFTER </think> ONLY            #
# =================================================================== #
def _tail_after_think(text: str) -> str:
    """
    Return the substring that starts *immediately after* the one and only
    </think> tag.  (Stage-A guarantees that tag exists exactly once.)
    """
    end = text.index("</think>") + len("</think>")
    return text[end:]


def extract_boxed(text: str) -> str:
    """
    Find the first \\boxed{…} *after* </think>; return its inner text, or ''.
    """
    tail = _tail_after_think(text)
    m = BOXED_RE.search(tail)
    return m.group(1).strip() if m else ""


def extract_unboxed(text: str) -> str:
    """
    Look for the first 'Option X' label *after* </think>.  Return
    'Option X' if found, otherwise '' (no fallback to earlier text).
    """
    tail = _tail_after_think(text)
    m = OPTION_RE.search(tail)
    return f"Option {m.group(1)}" if m else ""

def same_option(gold: str, pred: str) -> bool:
    """Does gold and pred contain the same option number?"""
    g = OPTION_RE.search(gold)
    p = OPTION_RE.search(pred)
    return g and p and g.group(1) == p.group(1)

# =================================================================== #
# 4.  HIERARCHICAL REWARD FUNCTION  (with debug)                       #
# =================================================================== #
def hierarchical_reward(prompts, completions, answer, **kw):
    """
    Stage‑A : structure check via format_score
    Stage‑B : extract first boxed (or unboxed fallback)
    Stage‑C : grade correctness
    """
    gold = answer[0].strip()
    rewards = []

    for comp in completions:
        resp = comp[0]["content"]

        # -------- A. structure gate ----------------------------------
        if format_score(resp) <= 0:
            reward = FORMAT_PENALTY
            boxed = unboxed = parsed = ""
            rewards.append(reward)

        else:
            # -------- B. parse gate ----------------------------------
            boxed   = extract_boxed(resp)
            unboxed = "" if boxed else extract_unboxed(resp)
            parsed  = boxed or unboxed

            if not parsed:
                reward = PARSE_PENALTY
                rewards.append(reward)
            else:
                # -------- C. correctness -----------------------------
                if parsed == gold:
                    reward = CORRECT_BOX        # exact match (boxed or fallback)
                elif same_option(gold, parsed):
                    reward = PARTIAL            # only option number matches
                else:
                    reward = WRONG
                rewards.append(reward)

        # -------- debug output --------------------------------------
        print("-" * 40)
        print(f"Question:\n{prompts[0][-1]['content']}\n")
        print(f"Target   : {gold}")
        print(f"Response :\n{resp}\n")
        print(f"Extracted boxed   : {boxed or '—'}")
        print(f"Extracted unboxed : {unboxed or '—'}")
        print(f"Final parsed ans. : {parsed or '—'}")
        print(f"Reward   : {reward}")

    return rewards

<a name="Train"></a>
### Train the model


In [None]:
# max_prompt_length = 512 
max_prompt_length = 1024
output_dir = "<OUTPUT_DIR>"

from trl import GRPOConfig, GRPOTrainer
training_args = GRPOConfig(
    learning_rate = 5e-6,
    adam_beta1 = 0.9,
    adam_beta2 = 0.99,
    weight_decay = 0.1,
    warmup_ratio = 0.1,
    lr_scheduler_type = "cosine",
    optim = "paged_adamw_8bit",
    logging_steps = 1,
    per_device_train_batch_size = 8,
    gradient_accumulation_steps = 4, # Increase to 4 for smoother training
    num_generations = 8, # Decrease if out of memory
    max_prompt_length = max_prompt_length,
    max_completion_length = max_seq_length - max_prompt_length,
    num_train_epochs = 5, # Set to 1 for a full training run
    # max_steps = 350,
    save_steps = 50,
    max_grad_norm = 0.1,
    report_to = "wandb", # Can use Weights & Biases
    output_dir = output_dir,
)

In [None]:
tokenizer.chat_template = \
"{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}{% set ns = namespace(is_first=false, is_tool=false, is_output_first=true, system_prompt='') %}{% for message in messages %}{% if message['role'] == 'system' %}{% set ns.system_prompt = message['content'] %}{% endif %}{% endfor %}{{bos_token}}{{ns.system_prompt}}{% for message in messages %}{% if message['role'] == 'user' %}{% set ns.is_tool = false %}{{'<｜User｜>' + message['content']}}{% endif %}{% if message['role'] == 'assistant' and message['content'] is none %}{% set ns.is_tool = false %}{% for tool in message['tool_calls']%}{% if not ns.is_first %}{{'<｜Assistant｜><｜tool▁calls▁begin｜><｜tool▁call▁begin｜>' + tool['type'] + '<｜tool▁sep｜>' + tool['function']['name'] + '\\n' + '```json' + '\\n' + tool['function']['arguments'] + '\\n' + '```' + '<｜tool▁call▁end｜>'}}{% set ns.is_first = true %}{% else %}{{'\\n' + '<｜tool▁call▁begin｜>' + tool['type'] + '<｜tool▁sep｜>' + tool['function']['name'] + '\\n' + '```json' + '\\n' + tool['function']['arguments'] + '\\n' + '```' + '<｜tool▁call▁end｜>'}}{{'<｜tool▁calls▁end｜><｜end▁of▁sentence｜>'}}{% endif %}{% endfor %}{% endif %}{% if message['role'] == 'assistant' and message['content'] is not none %}{% if ns.is_tool %}{{'<｜tool▁outputs▁end｜>' + message['content'] + '<｜end▁of▁sentence｜>'}}{% set ns.is_tool = false %}{% else %}{% set content = message['content'] %}{{'<｜Assistant｜>' + content + '<｜end▁of▁sentence｜>'}}{% endif %}{% endif %}{% if message['role'] == 'tool' %}{% set ns.is_tool = true %}{% if ns.is_output_first %}{{'<｜tool▁outputs▁begin｜><｜tool▁output▁begin｜>' + message['content'] + '<｜tool▁output▁end｜>'}}{% set ns.is_output_first = false %}{% else %}{{'\\n<｜tool▁output▁begin｜>' + message['content'] + '<｜tool▁output▁end｜>'}}{% endif %}{% endif %}{% endfor %}{% if ns.is_tool %}{{'<｜tool▁outputs▁end｜>'}}{% endif %}{% if add_generation_prompt and not ns.is_tool %}{{'<｜Assistant｜></think>\\n'}}{% endif %}"

In [None]:

from contextlib import redirect_stdout, redirect_stderr
import torch
import os

_original_load = torch.load
def _unsafe_load(*args, **kwargs):
    kwargs["weights_only"] = False      # override whatever the caller passes
    return _original_load(*args, **kwargs)

torch.load = _unsafe_load     
# patch in
# def formatting_prompts_func(examples):
#    convos = examples["messages"]
#    texts = [tokenizer.apply_chat_template(convo, tokenize = False, add_generation_prompt = False) for convo in convos]
#    return 
    
# ───── build trainer as before ───────────────────────────────────────
trainer = GRPOTrainer(
    model            = model,
    processing_class = tokenizer,
    reward_funcs     = [hierarchical_reward],
    args             = training_args,
    train_dataset    = dataset,
    # formatting_func = formatting_prompts_func,
)

# checkpoint = ""
# TODO: change the path of .log and .err
with open(os.path.join(output_dir, "training.log"), "w") as out, open(os.path.join(output_dir, "training.err"),"w") as err:
    with redirect_stdout(out), redirect_stderr(err):
        # trainer.train(resume_from_checkpoint=checkpoint)
        trainer.train()


In [None]:
model.push_to_hub("")
tokenizer.push_to_hub("")
model.save_lora("") 