# Step 3: Reinforcement Learning with GRPO (Advanced)

This notebook uses **Group Relative Policy Optimization (GRPO)** to further optimize the model.

**Key Idea**: Use formal verifiers as reward functions to encourage:
1. ✓ **Format Compliance**: Proper XML structure
2. ✓ **Feasibility**: Solutions satisfy constraints
3. ✓ **Optimality**: Solutions are optimal

**Prerequisites**:
- Complete `01_train_sft.ipynb` first (SFT baseline required)
- Model should be saved at `../models/constraint-reasoner-v1`

**Note**: This is an advanced optimization step. The SFT model from Step 1 is already functional.


In [None]:
import os
import jax

# Import from installed package (no sys.path hacks!)
from src.data_loader import OptimizationDataset
from src.format_utils import format_input
from src.rewards import (
    format_reward_func,
    feasibility_reward_func,
    optimality_reward_func,
    brevity_reward_func,
)

print("=" * 60)
print("STEP 3: REINFORCEMENT LEARNING (GRPO)")
print("=" * 60)
print(f"JAX Devices: {jax.devices()}")
print()

# Import Tunix (may not be available in all environments)
try:
    import tunix
    from tunix.rl import GRPOTrainer, RLConfig
    from tunix.config import ModelConfig
    print(f"✓ Tunix version: {tunix.__version__}")
    TUNIX_AVAILABLE = True
except ImportError as e:
    print(f"⚠️  Tunix not available: {e}")
    print("This notebook requires Tunix. Install with: pip install google-tunix[prod]")
    TUNIX_AVAILABLE = False

print()

## 1. Load Training Data

For RL, we only need prompts (not targets). The model generates completions and learns from rewards.


In [None]:
RL_DATASET_SIZE = 500  # Increase to 2000+ for production

print(f"Generating {RL_DATASET_SIZE} training prompts...")
dataset = OptimizationDataset(size=RL_DATASET_SIZE)
prompts = [format_input(item['problem']) for item in dataset]
print(f"✓ Generated {len(prompts)} prompts for RL training")
print()

# Show a sample prompt
print("Sample Prompt:")
print(prompts[0][:300] + "...")
print()


## 2. Configure Reward Functions

GRPO uses multiple reward functions to guide learning. Each reward function scores model outputs.

**Prioritized Reward Functions** (per judge recommendations):
1. `format_reward_func` (weight: 1.0): Checks XML structure (all tags present)
2. `feasibility_reward_func` (weight: 2.0): Verifies constraint satisfaction
3. `optimality_reward_func` (weight: 3.0): Verifies solution optimality
4. `brevity_reward_func` (weight: 0.5): Encourages concise outputs

The model learns to maximize the weighted sum of these rewards.
Total possible reward: 1.0 + 2.0 + 3.0 + 0.5 = 6.5


In [None]:
# Tunix RL expects a list of reward functions with prioritized weights
reward_funcs = [
    format_reward_func,      # Gate 1: Valid format (weight: 1.0)
    feasibility_reward_func, # Gate 2: Feasible solution (weight: 2.0)
    optimality_reward_func,  # Gate 3: Optimal solution (weight: 3.0)
    brevity_reward_func,     # Bonus: Concise output (weight: 0.5)
]

# Reward weights: higher weights for more critical objectives
reward_weights = [1.0, 2.0, 3.0, 0.5]

print("✓ Configured 4 prioritized reward functions:")
print("  1. Format compliance (weight: 1.0)")
print("  2. Feasibility verification (weight: 2.0)")
print("  3. Optimality verification (weight: 3.0)")
print("  4. Brevity bonus (weight: 0.5)")
print(f"  Total possible reward: {sum(reward_weights)}")
print()


## 3. Configure RL Training

Set up GRPO training configuration.

**Key Parameters**:
- `kl_coeff`: KL divergence penalty (keeps model close to SFT baseline)
- `num_generations`: Number of completions per prompt for group comparison
- `learning_rate`: Lower than SFT (fine-tuning an already good model)


In [None]:
if TUNIX_AVAILABLE:
    # RL training configuration
    rl_config = RLConfig(
        output_dir="../checkpoints/grpo_optimized",
        num_train_epochs=1,  # Increase to 2-3 for production
        per_device_train_batch_size=4,  # Adjust based on memory
        gradient_accumulation_steps=4,  # Effective batch size = 16
        learning_rate=1e-6,  # Much lower than SFT (fine-tuning)
        kl_coeff=0.01,  # Penalty for diverging from SFT model
        num_generations=4,  # Generate 4 completions per prompt
        max_prompt_length=256,
        max_completion_length=1024,
    )

    # Model configuration (start from SFT checkpoint)
    sft_model_path = "../models/constraint-reasoner-v1"

    if not os.path.exists(sft_model_path):
        print(f"⚠️  SFT model not found at: {sft_model_path}")
        print("You must run 01_train_sft.ipynb first!")
        TUNIX_AVAILABLE = False
    else:
        model_config = ModelConfig(
            base_model=sft_model_path,  # Start from SFT checkpoint
            dtype="bfloat16",
            use_flash_attention=True,
            lora_rank=8,
            lora_alpha=32
        )

        print("✓ RL configuration complete")
        print(f"  - Base model: {sft_model_path}")
        print(f"  - Epochs: {rl_config.num_train_epochs}")
        print(f"  - Learning rate: {rl_config.learning_rate}")
        print(f"  - KL coefficient: {rl_config.kl_coeff}")
        print(f"  - Generations per prompt: {rl_config.num_generations}")
else:
    print("⚠️  Skipping configuration (Tunix not available)")

print()


## 4. Train with GRPO

Run reinforcement learning. This optimizes the model to maximize reward scores.

**Expected Improvements**:
- Higher format compliance rate
- Better constraint satisfaction
- More optimal solutions

**Note**: This may take several hours depending on hardware.


In [None]:
if TUNIX_AVAILABLE:
    try:
        print("Starting GRPO training...")
        print("This may take several hours depending on hardware.")
        print()

        trainer = GRPOTrainer(
            model_config=model_config,
            rl_config=rl_config,
            reward_funcs=reward_funcs,
            train_dataset=prompts
        )

        trainer.train()

        # Save the RL-optimized model
        rl_model_path = "../models/constraint-reasoner-v2-rl"
        os.makedirs(os.path.dirname(rl_model_path), exist_ok=True)
        trainer.save_model(rl_model_path)

        print()
        print("=" * 60)
        print("✓ GRPO TRAINING COMPLETE")
        print("=" * 60)
        print(f"Model saved to: {rl_model_path}")
        print()
        print("Next steps:")
        print("  1. Run 02_verify_and_export.ipynb with the new model")
        print("  2. Compare performance: SFT (v1) vs GRPO (v2)")
        print("  3. Export the best model for Kaggle submission")

    except Exception as e:
        print(f"✗ GRPO training failed: {e}")
        print("Check the error message above for details.")
        raise
else:
    print("⚠️  Cannot train without Tunix")
    print("Install Tunix with: pip install google-tunix[prod]")