# Step 3: Advanced Optimization with GRPO

Uses Group Relative Policy Optimization (RL) to refine the model. 
We use the **Verifiers** as reward functions to encourage:
1. Correct Format
2. Feasibility (Constraint Satisfaction)
3. Optimality (Finding the best solution)

In [None]:
import sys
import os
import jax

sys.path.append(os.path.abspath("../src"))

from data_loader import OptimizationDataset
from format_utils import format_input
from rewards import format_reward_func, feasibility_reward_func, optimality_reward_func

import tunix
from tunix.rl import GRPOTrainer, RLConfig
from tunix.config import ModelConfig

print(f"JAX Devices: {jax.devices()}")

In [None]:
# 1. Load Data (Prompts only for RL)
dataset = OptimizationDataset(size=500)
prompts = [format_input(item['problem']) for item in dataset]
print(f"Loaded {len(prompts)} prompts for RL.")

In [None]:
# 2. Configure Rewards
# Tunix RL expects a list of reward functions
reward_funcs = [
    format_reward_func,      # Basic gate
    feasibility_reward_func, # Constraint check
    optimality_reward_func   # Optimality check
]

# 3. Configure RL Training
rl_config = RLConfig(
    output_dir="../checkpoints/grpo_optimized",
    num_train_epochs=1,
    per_device_train_batch_size=4,
    gradient_accumulation_steps=4,
    learning_rate=1e-6, # Lower LR for RL fine-tuning
    kl_coeff=0.01,       # Stay close to SFT model
    num_generations=4,   # Generate 4 completions per prompt to compare group
    max_prompt_length=256,
    max_completion_length=1024,
)

model_config = ModelConfig(
    base_model="../models/constraint-reasoner-v1", # Start from SFT checkpoint
    dtype="bfloat16",
    use_flash_attention=True,
    lora_rank=8,
    lora_alpha=32
)

In [None]:
# 4. Train
trainer = GRPOTrainer(
    model_config=model_config,
    rl_config=rl_config,
    reward_funcs=reward_funcs,
    train_dataset=prompts
)

trainer.train()
trainer.save_model("../models/constraint-reasoner-v2-rl")
print("GRPO Training complete.")