# 04 — GRPO Training & Monitoring

Runs **Group Relative Policy Optimization (GRPO)** training on
Qwen2.5-3B-Instruct with a LoRA adapter.

**Requirements**: GPU with >= 12 GB VRAM (e.g. Colab T4 / A100).
If running on Colab, make sure to select **Runtime → Change runtime type → GPU**.

In [None]:
# Setup — install deps if needed (e.g. on Colab)
import subprocess, sys
try:
    import trl
except ImportError:
    subprocess.check_call([sys.executable, "-m", "pip", "install", "-q",
        "torch", "transformers>=4.40", "datasets", "accelerate",
        "peft>=0.10", "trl>=0.15", "pandas", "matplotlib"])

import os, sys, torch
sys.path.insert(0, os.path.abspath(".."))

print(f"Python: {sys.version.split()[0]}")
print(f"PyTorch: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"VRAM: {torch.cuda.get_device_properties(0).total_mem / 1e9:.1f} GB")

## Quick Sanity Check (5 steps)

Run a short training pass to confirm reward functions produce non-zero values
and training doesn't crash.

In [None]:
from src.train import train

# Quick check — 5 steps only
train(max_steps=5)

## Full Training (100 steps)

Run the full training session. The `correct_answer_reward_func` mean
should trend **upward** over time as the model learns to count letters accurately.

Expected time: ~15–30 min on a T4, ~5–10 min on an A100.

In [None]:
train(max_steps=100)

## Plot Training Rewards

Visualise the mean correctness reward over training steps.
A clear upward trend confirms that GRPO is working.

In [None]:
import pandas as pd
import matplotlib.pyplot as plt

csv_path = "../outputs/logs/mean_correctness_over_time.csv"
df = pd.read_csv(csv_path)

print("Reward log (last 10 rows):")
print(df.tail(10).to_string(index=False))

plt.figure(figsize=(8, 4))
plt.plot(df["step"], df["mean_correctness_reward"], marker="o", linewidth=2)
plt.xlabel("Training Step")
plt.ylabel("Mean Correctness Reward")
plt.title("Correctness Reward Over GRPO Training")
plt.grid(True, alpha=0.3)
plt.tight_layout()
plt.show()