<div align="center">

# 🏴‍☠️ MAROONED: Process Reward Modeling

### LLM-as-Judge: Teacher-Guided RL Training

**OpenEnv Hackathon 2025**

[![OpenEnv](https://img.shields.io/badge/Framework-OpenEnv-blue)](https://github.com/openenv)
[![Llama](https://img.shields.io/badge/Model-Llama_3.1_8B-green)](https://huggingface.co/meta-llama/Meta-Llama-3.1-8B-Instruct)
[![Hardware](https://img.shields.io/badge/Hardware-AMD_MI300X-red)](https://www.amd.com/en/products/accelerators/instinct/mi300.html)

</div>

---

## 🔬 Innovation: Teacher LLM for Action Validation

**Problem:** Student LLM outputs invalid actions (30-40% parse failure rate)

**Solution:** Teacher LLM validates, corrects, and critiques student outputs in real-time

```
Student LLM (training) → Generates action
    ↓
Teacher LLM (frozen) → Validates + Corrects + Critiques
    ↓
Environment → Executes corrected action
    ↓
Student receives:
    - Environment reward (ship progress)
    - Process penalty (format/quality)
    - Critique (for learning)
```

---

## ⚙️ Prerequisites

**Before running this notebook, start the vLLM teacher server in a separate terminal:**

```bash
# Install vLLM
pip install vllm

# Start teacher inference server
vllm serve unsloth/Meta-Llama-3.1-8B-Instruct \
    --dtype bfloat16 \
    --max-model-len 8192 \
    --port 8000
```

**Verify it's running:**
```bash
curl http://localhost:8000/v1/models
```

Expected output:
```json
{"data": [{"id": "unsloth/Meta-Llama-3.1-8B-Instruct", "object": "model"}]}
```

The teacher runs as a **separate inference server** for:
- ⚡ Fast parallel validation (10-20x faster than transformers)
- 🔄 Non-blocking student training
- 📊 Scalable to multiple teacher instances

---

## 1️⃣ Environment Setup

Installing dependencies optimized for AMD MI300X.

In [None]:
%%capture
import os, importlib.util
!pip install --upgrade -qqq uv
if importlib.util.find_spec("torch") is None or "COLAB_" in "".join(os.environ.keys()):
    try: import numpy; get_numpy = f"numpy=={numpy.__version__}"
    except: get_numpy = "numpy"
    !uv pip install -qqq \
        "torch>=2.8.0" "triton>=3.4.0" {get_numpy} torchvision bitsandbytes "transformers==4.56.2" trackio \
        "unsloth_zoo[base] @ git+https://github.com/unslothai/unsloth-zoo" \
        "unsloth[base] @ git+https://github.com/unslothai/unsloth" \
        git+https://github.com/triton-lang/triton.git@05b2c186c1b6c9a08375389d5efe9cb4c401c075#subdirectory=python/triton_kernels
elif importlib.util.find_spec("unsloth") is None:
    !uv pip install -qqq unsloth trackio
!uv pip install --upgrade --no-deps transformers==4.56.2 tokenizers trl==0.22.2 unsloth unsloth_zoo

## 2️⃣ Load MAROONED Environment

Import custom environment and verify game configuration.

In [None]:
import sys
import json
import random
from typing import Dict, Any, List

# Clear cached modules to reload changes
modules_to_clear = [m for m in list(sys.modules.keys()) 
                   if 'marooned' in m or m in ['environment', 'config', 'models', 'game_state', 'view_map', 'llm_interface']]
for module in modules_to_clear:
    if module in sys.modules:
        del sys.modules[module]

sys.path.insert(0, '../marooned_env')

from environment import MaroonedEnv
from llm_interface import (
    get_system_prompt,
    observation_to_prompt,
    teacher_validate_student_output,  # ⭐ Main function for process reward modeling
)
from config import ActionType, ResourceType, MapLevel, ShipComponent
from models import Action, Position, Observation

print("✅ MAROONED environment loaded")
print("✅ Process Reward Modeling API imported")
print("\nTeacher-Student Training Flow:")
print("  1. Student generates action")
print("  2. Teacher validates via vLLM (http://localhost:8000)")
print("  3. Environment executes corrected action")
print("  4. Student receives: env_reward + process_penalty")

MAROONED environment successfully loaded.

Environment Reward Configuration:
  Colonist - Resource Gathering: +0.1
  Colonist - Resource Deposit: +0.2
  Colonist - Ship Construction: +0.5
  Colonist - Mission Success: +100.0
  Traitor - Sabotage: +2.0
  Traitor - Elimination: +10.0
  Traitor - Mission Success: +100.0


## 3️⃣ Load Student Model (Llama 3.1 8B)

Student LLM that will learn to play the game through PPO training.

In [2]:
from unsloth import FastLanguageModel
import torch
import os

# ROCm optimizations for MI300X
os.environ["PYTORCH_ROCM_ARCH"] = "gfx942"
os.environ["HSA_FORCE_FINE_GRAIN_PCIE"] = "1"
os.environ["ATTN_BACKEND"] = "triton"
torch.backends.cudnn.benchmark = True

max_seq_length = 16384
lora_rank = 16

student_model, tokenizer = FastLanguageModel.from_pretrained(
    model_name = "unsloth/Meta-Llama-3.1-8B-Instruct",
    load_in_4bit = False,
    max_seq_length = max_seq_length,
    dtype = torch.bfloat16,
    device_map = "auto",
)

print(f"✅ Student Model Loaded: Llama 3.1 8B (BF16)")
print(f"   GPU: {torch.cuda.get_device_name(0)}")
print(f"   VRAM: {torch.cuda.get_device_properties(0).total_memory / 1024**3:.1f} GB")

bitsandbytes library load error: Configured ROCm binary not found at /root/AIAC/colony-collapse/.venv/lib/python3.12/site-packages/bitsandbytes/libbitsandbytes_rocm64.so
Traceback (most recent call last):
  File "/root/AIAC/colony-collapse/.venv/lib/python3.12/site-packages/bitsandbytes/cextension.py", line 313, in <module>
    lib = get_native_library()
          ^^^^^^^^^^^^^^^^^^^^
  File "/root/AIAC/colony-collapse/.venv/lib/python3.12/site-packages/bitsandbytes/cextension.py", line 282, in get_native_library
    raise RuntimeError(f"Configured {BNB_BACKEND} binary not found at {cuda_binary_path}")
RuntimeError: Configured ROCm binary not found at /root/AIAC/colony-collapse/.venv/lib/python3.12/site-packages/bitsandbytes/libbitsandbytes_rocm64.so


🦥 Unsloth: Will patch your computer to enable 2x faster free finetuning.


  from .autonotebook import tqdm as notebook_tqdm
    PyTorch 2.8.0+cu128 with CUDA 1208 (you have 2.9.0+rocm6.4)
    Python  3.9.23 (you have 3.12.3)
  Please reinstall xformers (see https://github.com/facebookresearch/xformers#installing-xformers)
  Memory-efficient attention, SwiGLU, sparse and more won't be available.
  Set XFORMERS_MORE_DETAILS=1 for more details
    PyTorch 2.8.0+cu128 with CUDA 1208 (you have 2.9.0+rocm6.4)
    Python  3.9.23 (you have 3.12.3)
  Please reinstall xformers (see https://github.com/facebookresearch/xformers#installing-xformers)
  Memory-efficient attention, SwiGLU, sparse and more won't be available.
  Set XFORMERS_MORE_DETAILS=1 for more details


Switching to PyTorch attention since your Xformers is broken.

Unsloth: Xformers was not installed correctly.
Please install xformers separately first.
Then confirm if it's correctly installed by running:
python -m xformers.info

Longer error message:
xFormers can't load C++/CUDA extensions. xFormers was built for:
    PyTorch 2.8.0+cu128 with CUDA 1208 (you have 2.9.0+rocm6.4)
    Python  3.9.23 (you have 3.12.3)
  Please reinstall xformers (see https://github.com/facebookresearch/xformers#installing-xformers)
  Memory-efficient attention, SwiGLU, sparse and more won't be available.
🦥 Unsloth Zoo will now patch everything to make training faster!
Unsloth: AMD currently is not stable with 4bit bitsandbytes. Disabling for now.
Unsloth: AMD currently is not stable with 4bit bitsandbytes. Disabling for now.
==((====))==  Unsloth 2025.10.9: Fast Llama patching. Transformers: 4.56.2.
   \\   /|    AMD Instinct MI300X VF. Num GPUs = 1. Max memory: 191.688 GB. Platform: Linux.
O^O/ \_/ \    T

INFO:accelerate.utils.modeling: We will use 90% of the memory on device 0 for storing the model, and 10% for the buffer to avoid OOM. You can set `max_memory` in to a higher value to use more memory (at your own risk).
Loading checkpoint shards: 100%|██████████| 4/4 [00:10<00:00,  2.66s/it]



✅ Student Model Loaded: Llama 3.1 8B (BF16)
   GPU: AMD Instinct MI300X VF
   VRAM: 191.7 GB


## 4️⃣ Configure LoRA for Student

Parameter-efficient fine-tuning with rank-16 adapters.

In [3]:
student_model = FastLanguageModel.get_peft_model(
    student_model,
    r = lora_rank,
    target_modules = [
        "q_proj", "k_proj", "v_proj", "o_proj",
        "gate_proj", "up_proj", "down_proj",
    ],
    lora_alpha = lora_rank * 2,
    lora_dropout = 0.0,
    use_gradient_checkpointing = "unsloth",
    random_state = 3407,
    use_rslora = True,
)

print("✅ LoRA adapters configured (rank=16)")

Unsloth 2025.10.9 patched 32 layers with 32 QKV layers, 32 O layers and 32 MLP layers.


✅ LoRA adapters configured (rank=16)


## 5️⃣ Start Teacher Model (vLLM Server)

**Run this in a separate terminal BEFORE running training:**

```bash
pip install vllm
vllm serve unsloth/Meta-Llama-3.1-8B-Instruct \
    --dtype bfloat16 \
    --max-model-len 8192 \
    --port 8000
```

Then verify it's running:
```bash
curl http://localhost:8000/v1/models
```

The teacher runs as a separate inference server for fast, parallel validation.

In [None]:
import requests
import time

# Verify vLLM server is running
VLLM_API_URL = "http://localhost:8000/v1/chat/completions"
VLLM_MODELS_URL = "http://localhost:8000/v1/models"

print("Checking vLLM server...")
try:
    response = requests.get(VLLM_MODELS_URL, timeout=5)
    if response.status_code == 200:
        models = response.json()
        print(f"✅ vLLM server running!")
        print(f"   Available models: {[m['id'] for m in models.get('data', [])]}")
    else:
        print(f"⚠️  vLLM server responded with status {response.status_code}")
except requests.exceptions.RequestException as e:
    print(f"❌ vLLM server not reachable!")
    print(f"   Error: {e}")
    print(f"\n   Please start vLLM server in a separate terminal:")
    print(f"   vllm serve unsloth/Meta-Llama-3.1-8B-Instruct --dtype bfloat16 --max-model-len 8192 --port 8000")

📏 Prompt length: 8852 tokens

🤖 Generating response...



REASONING: As the traitor, my goal is to prevent the group from escaping the island. Since we are currently at 0% total ship progress, sabotaging any part of the ship would significantly hinder our chances of completing it within the given timeframe. Additionally, poisoning Alice could weaken her and make her more vulnerable to elimination, which might not be ideal as she seems relatively strong. Therefore, I should focus on delaying progress rather than directly targeting individuals.

ACTION: SABOTAGE hull


## 5️⃣ Start vLLM Teacher Server

**Prerequisites**: Teacher LLM must be running before training.

Open a new terminal and run:
```bash
pip install vllm
vllm serve unsloth/Meta-Llama-3.1-8B-Instruct \
    --dtype bfloat16 \
    --max-model-len 8192 \
    --port 8000
```

Verify it's running:
```bash
curl http://localhost:8000/v1/models
```

The teacher API is now available at `http://localhost:8000/v1/chat/completions`

In [None]:
print("="*80)
print("🧪 TESTING TEACHER VALIDATION")
print("="*80)

# Create sample environment
env = MaroonedEnv(render_mode="ansi")
observations = env.reset(seed=42)
alice_obs = observations["Alice"]

# Simulate bad student responses
test_cases = [
    "REASONING: I should move\nACTION: MOVING",
    "REASONING: Need to check\nACTION: CHECK_STATUS",
    "REASONING: Gather resources\nACTION: GATHER wood",
    "REASONING: Moving north to explore\nACTION: MOVE NORTH",
]

print("\nSending test cases to teacher API...\n")

for i, student_response in enumerate(test_cases):
    print(f"Test {i+1}: Student says: '{student_response.split('ACTION:')[1].strip()}'")
    
    result = teacher_validate_student_output(student_response, alice_obs, "Alice")
    
    print(f"   ✓ Teacher corrected to: {result['action'].action_type.value}")
    print(f"   ✓ Valid: {result['valid']}")
    print(f"   ✓ Penalty: {result['penalty']}")
    print(f"   ✓ Critique: {result['critique'][:80]}...")
    print()

print("="*80)
print("✅ Teacher API working correctly!")

## 🔟 Save Trained Model

Save final student model and tokenizer.

In [None]:
print("="*80)
print("🧪 TESTING TEACHER VALIDATION (via vLLM API)")
print("="*80)

# Create a sample observation
env = MaroonedEnv(render_mode="ansi")
observations = env.reset(seed=42)
alice_obs = observations["Alice"]

# Simulate bad student responses
test_cases = [
    "REASONING: I should move northeast\nACTION: MOVING",
    "REASONING: Check my status\nACTION: CHECK_STATUS",
    "REASONING: Gather wood\nACTION: GATHER wood",
    "REASONING: Moving north to explore\nACTION: MOVE NORTH",  # Good one
]

print("\nSending 4 test cases to vLLM teacher...\n")

for i, student_response in enumerate(test_cases):
    print(f"{'─'*80}")
    print(f"Test {i+1}: Student says: {student_response.split('ACTION:')[1].strip()}")
    
    result = teacher_parse_and_critique(student_response, alice_obs, "Alice")
    
    print(f"   ✓ Teacher corrected to: {result['action'].action_type.value}")
    print(f"   ✓ Valid: {result['valid']}")
    print(f"   ✓ Penalty: {result['penalty']}")
    print(f"   ✓ Critique: {result['critique'][:80]}...")

print(f"\n{'='*80}")
print("✅ vLLM teacher validation working!")

## 6️⃣ Teacher System Prompt

Expert validator that parses student outputs into valid game actions.

In [None]:
import unsloth
from trl import PPOConfig, PPOTrainer, AutoModelForCausalLMWithValueHead
from datasets import Dataset

ppo_config = PPOConfig(
    output_dir="outputs_marooned_rl",
    learning_rate=1e-5,
    batch_size=4,
    mini_batch_size=1,
    gradient_accumulation_steps=4,
    seed=42,
    num_ppo_epochs=4,
    kl_coef=0.2,
    vf_coef=0.1,
    cliprange=0.2,
    gamma=0.99,
    lam=0.95,
    temperature=0.3,
    response_length=256,
)

# Wrap student model with value head
model_with_value = AutoModelForCausalLMWithValueHead.from_pretrained(student_model)

# Compatibility patches
base_model = model_with_value.pretrained_model
if not hasattr(model_with_value, "base_model_prefix"):
    model_with_value.base_model_prefix = getattr(base_model, "base_model_prefix", "model")
setattr(model_with_value, model_with_value.base_model_prefix, base_model)
if not hasattr(model_with_value, "config"):
    model_with_value.config = base_model.config
if not hasattr(model_with_value, "generation_config"):
    model_with_value.generation_config = base_model.generation_config
if hasattr(base_model, "is_gradient_checkpointing"):
    model_with_value.is_gradient_checkpointing = base_model.is_gradient_checkpointing
else:
    model_with_value.is_gradient_checkpointing = False

# Minimal train dataset
train_dataset = Dataset.from_dict({
    "prompt": ["stub"],
    "response": ["stub"],
    "reward": [0.0],
})

ppo_trainer = PPOTrainer(
    args=ppo_config,
    model=model_with_value,
    ref_model=None,
    reward_model=model_with_value,
    value_model=model_with_value,
    train_dataset=train_dataset,
    tokenizer=tokenizer,
)

print("✅ PPO Trainer initialized")

## 9️⃣ PPO Configuration

Proximal Policy Optimization setup for student model fine-tuning.

## 🔟 Training Loop with vLLM Teacher

**Process Reward Modeling Flow:**
1. Student generates action
2. **Teacher validates via vLLM API** (fast, parallel)
3. Environment executes corrected action
4. Student receives: env_reward + process_penalty + critique

**Advantages of vLLM:**
- 10-20x faster than transformers
- Non-blocking parallel requests
- Automatic batching and caching

In [None]:
import torch, time, numpy as np
from llm_interface import get_system_prompt, observation_to_prompt

NUM_TRAINING_STEPS = 100
EPISODE_MAX_SEQ_LENGTH = 16384

def generate_episode_with_teacher(max_turns=100, verbose=False):
    """
    Play episode with teacher guidance:
    - Student generates action
    - Teacher validates and corrects
    - Environment executes corrected action
    - Student receives combined reward
    """
    env = MaroonedEnv(render_mode="ansi")
    observations = env.reset()
    sailor_ids = list(env.agents)
    
    query_tensors, response_tensors, rewards_list = [], [], []
    
    # Enable inference mode
    FastLanguageModel.for_inference(student_model)
    
    for turn in range(max_turns):
        for sailor_id in sailor_ids:
            if not env.state.sailors[sailor_id].alive:
                continue
            
            obs = observations[sailor_id]
            role = env.state.sailors[sailor_id].role.value
            
            # === STEP 1: Student generates response ===
            system_prompt = get_system_prompt(role)
            user_prompt = observation_to_prompt(obs)
            
            messages = [
                {"role": "system", "content": system_prompt},
                {"role": "user", "content": user_prompt}
            ]
            
            text = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
            inputs = tokenizer(text, return_tensors="pt", truncation=True, max_length=EPISODE_MAX_SEQ_LENGTH).to("cuda")
            query_tensor = inputs["input_ids"][0]
            
            with torch.no_grad():
                outputs = student_model.generate(
                    **inputs,
                    max_new_tokens=256,
                    temperature=0.3,
                    do_sample=True,
                    top_p=0.9,
                    top_k=40,
                    repetition_penalty=1.2,
                    pad_token_id=tokenizer.eos_token_id,
                )
                response_tensor = outputs[0]
            
            student_response = tokenizer.decode(response_tensor[len(query_tensor):], skip_special_tokens=True).strip()
            
            # === STEP 2: Teacher validates and corrects ===
            teacher_result = teacher_parse_and_critique(student_response, obs, sailor_id)
            action = teacher_result["action"]
            process_penalty = teacher_result["penalty"]
            
            # === STEP 3: Environment executes corrected action ===
            actions_dict = {sid: Action(sailor_id=sid, action_type=ActionType.WAIT) for sid in env.agents}
            actions_dict[sailor_id] = action
            
            observations, rewards_dict, dones, truncated, info = env.step(actions_dict)
            env_reward = rewards_dict[sailor_id]
            
            # === STEP 4: Combined reward ===
            total_reward = env_reward + process_penalty
            
            # Store experience
            query_tensors.append(query_tensor)
            response_tensors.append(response_tensor[len(query_tensor):])
            rewards_list.append(torch.tensor(total_reward, dtype=torch.float32))
            
            if verbose and turn % 20 == 0:
                print(f"Turn {turn:03d} | {sailor_id}: {action.action_type.value:<12} | "
                      f"Env={env_reward:+.1f} Penalty={process_penalty:+.1f} Total={total_reward:+.1f}")
            
            if dones[sailor_id]:
                return query_tensors, response_tensors, rewards_list
    
    return query_tensors, response_tensors, rewards_list


# === TRAINING LOOP ===
print("🚀 Starting Process Reward Modeling Training\n")

stats_rewards = []
stats_parse_success = []
stats_corrections = []

for step in range(NUM_TRAINING_STEPS):
    start_time = time.time()
    batch_queries, batch_responses, batch_rewards = [], [], []
    
    # Generate episodes
    for _ in range(ppo_config.batch_size):
        queries, responses, rewards = generate_episode_with_teacher(
            max_turns=100,
            verbose=(step % 50 == 0 and _ == 0)
        )
        batch_queries.extend(queries)
        batch_responses.extend(responses)
        batch_rewards.extend(rewards)
    
    # PPO step
    stats = ppo_trainer.step(batch_queries, batch_responses, batch_rewards)
    
    # Track metrics
    episode_reward = sum([r.item() for r in batch_rewards])
    stats_rewards.append(episode_reward)
    
    elapsed = time.time() - start_time
    avg_reward = np.mean(stats_rewards[-10:]) if len(stats_rewards) >= 10 else np.mean(stats_rewards)
    
    print(f"Step {step+1}/{NUM_TRAINING_STEPS} | "
          f"Reward: {episode_reward:+.1f} | "
          f"Avg(10): {avg_reward:+.1f} | "
          f"Turns: {len(batch_rewards)} | "
          f"Time: {elapsed:.1f}s")
    
    # Checkpoint
    if (step + 1) % 50 == 0:
        checkpoint_path = f"outputs_marooned_rl/checkpoint_step{step+1}"
        ppo_trainer.save_pretrained(checkpoint_path)
        print(f"   💾 Checkpoint saved → {checkpoint_path}")

print("\n✅ Training complete!")

🚀 Starting PPO training...


🎮 Starting episode...
   - Alice: TRAITOR
   - Bob: HONEST
   - Charlie: HONEST
   - Diana: HONEST
   - Eve: HONEST

   📏 Prompt length: 8912 tokens (max: 16384)



   📝 Sample response for Alice:
      REASONING: As the traitor, my goal is to maintain secrecy and gain an upper hand in eliminating potential threats. Since there are four other sailors who could potentially discover my true identity, I...
   Turn 000 | Alice: move_east  | Reward = -0.01

   📝 Sample response for Bob:
      REASONING: I should move towards the nearby woods to collect more resources and contribute to our shared goal of escaping the island as soon as possible. The visible resources include several types su...
   Turn 000 | Bob: move_north | Reward = +0.04

   📝 Sample response for Bob:
      REASONING: I should move towards the nearby woods to collect more resources and contribute to our shared goal of escaping the island as soon as possible. The visible resources include several types su...
   Turn 000 | Bob: move_north | Reward = +0.04

   📝 Sample response for Charlie:
      REASONING: I need to start gathering resources as soon as possible since our ship progress 

## 8️⃣ Training Loop with Teacher Guidance

**Process Reward Modeling Flow:**
1. Student generates action
2. Teacher validates via vLLM API & corrects
3. Environment executes corrected action
4. Student receives: env_reward + process_penalty

In [None]:
import torch, time, numpy as np

NUM_TRAINING_STEPS = 100
EPISODE_MAX_SEQ_LENGTH = 16384

def generate_episode_with_teacher(max_turns=100, verbose=False):
    """
    Play episode with teacher guidance (vLLM API):
    - Student generates action
    - Teacher validates via vLLM and corrects
    - Environment executes corrected action
    - Student receives combined reward (env + process penalty)
    """
    env = MaroonedEnv(render_mode="ansi")
    observations = env.reset()
    sailor_ids = list(env.agents)
    
    query_tensors, response_tensors, rewards_list = [], [], []
    
    # Enable inference mode
    FastLanguageModel.for_inference(student_model)
    
    for turn in range(max_turns):
        for sailor_id in sailor_ids:
            if not env.state.sailors[sailor_id].alive:
                continue
            
            obs = observations[sailor_id]
            role = env.state.sailors[sailor_id].role.value
            
            # === STEP 1: Student generates response ===
            system_prompt = get_system_prompt(role)
            user_prompt = observation_to_prompt(obs)
            
            messages = [
                {"role": "system", "content": system_prompt},
                {"role": "user", "content": user_prompt}
            ]
            
            text = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
            inputs = tokenizer(text, return_tensors="pt", truncation=True, max_length=EPISODE_MAX_SEQ_LENGTH).to("cuda")
            query_tensor = inputs["input_ids"][0]
            
            with torch.no_grad():
                outputs = student_model.generate(
                    **inputs,
                    max_new_tokens=256,
                    temperature=0.3,
                    do_sample=True,
                    top_p=0.9,
                    top_k=40,
                    repetition_penalty=1.2,
                    pad_token_id=tokenizer.eos_token_id,
                )
                response_tensor = outputs[0]
            
            student_response = tokenizer.decode(response_tensor[len(query_tensor):], skip_special_tokens=True).strip()
            
            # === STEP 2: Teacher validates via vLLM API ===
            teacher_result = teacher_validate_student_output(student_response, obs, sailor_id)
            action = teacher_result["action"]
            process_penalty = teacher_result["penalty"]
            
            # === STEP 2.5: Collect correction examples for SFT ===
            add_correction_example(student_response, teacher_result, obs)
            
            # === STEP 3: Environment executes corrected action ===
            actions_dict = {sid: Action(sailor_id=sid, action_type=ActionType.WAIT) for sid in env.agents}
            actions_dict[sailor_id] = action
            
            observations, rewards_dict, dones, truncated, info = env.step(actions_dict)
            env_reward = rewards_dict[sailor_id]
            
            # === STEP 4: Combined reward ===
            total_reward = env_reward + process_penalty
            
            # Store experience
            query_tensors.append(query_tensor)
            response_tensors.append(response_tensor[len(query_tensor):])
            rewards_list.append(torch.tensor(total_reward, dtype=torch.float32))
            
            if verbose and turn % 20 == 0:
                print(f"Turn {turn:03d} | {sailor_id}: {action.action_type.value:<12} | "
                      f"Env={env_reward:+.1f} Penalty={process_penalty:+.1f} Total={total_reward:+.1f}")
            
            if dones[sailor_id]:
                return query_tensors, response_tensors, rewards_list
    
    return query_tensors, response_tensors, rewards_list


# === TRAINING LOOP ===
print("🚀 Starting Process Reward Modeling Training\n")

stats_rewards = []

for step in range(NUM_TRAINING_STEPS):
    start_time = time.time()
    batch_queries, batch_responses, batch_rewards = [], [], []
    
    # Generate episodes
    for _ in range(ppo_config.batch_size):
        queries, responses, rewards = generate_episode_with_teacher(
            max_turns=100,
            verbose=(step % 50 == 0 and _ == 0)
        )
        batch_queries.extend(queries)
        batch_responses.extend(responses)
        batch_rewards.extend(rewards)
    
    # PPO step
    stats = ppo_trainer.step(batch_queries, batch_responses, batch_rewards)
    
    # Track metrics
    episode_reward = sum([r.item() for r in batch_rewards])
    stats_rewards.append(episode_reward)
    
    elapsed = time.time() - start_time
    avg_reward = np.mean(stats_rewards[-10:]) if len(stats_rewards) >= 10 else np.mean(stats_rewards)
    
    print(f"Step {step+1}/{NUM_TRAINING_STEPS} | "
          f"Reward: {episode_reward:+.1f} | "
          f"Avg(10): {avg_reward:+.1f} | "
          f"Turns: {len(batch_rewards)} | "
          f"Time: {elapsed:.1f}s")
    
    # Checkpoint
    if (step + 1) % 50 == 0:
        checkpoint_path = f"outputs_marooned_rl/checkpoint_step{step+1}"
        ppo_trainer.save_pretrained(checkpoint_path)
        print(f"   💾 Checkpoint saved → {checkpoint_path}")

print("\n✅ Training complete!")

## 7️⃣ Correction Dataset for SFT

**Option 2: Supervised Fine-Tuning Pass**

We collect student errors and teacher corrections during training:
- Invalid student outputs → Teacher corrections
- Build correction dataset: `(student_wrong, teacher_correct + critique)`
- Periodically run SFT to teach format directly

This accelerates format learning compared to RL penalties alone.

In [None]:
# Correction dataset storage
correction_dataset = []

def add_correction_example(student_response: str, teacher_result: dict, observation: Observation):
    """
    Store correction examples for SFT training.
    
    Format:
    - Input: Student's wrong output + observation context
    - Output: Teacher's corrected action + critique explanation
    """
    if not teacher_result["valid"]:  # Only store corrections for invalid outputs
        # Extract corrected action string
        action = teacher_result["action"]
        action_str = f"{action.action_type.value}"
        
        # Add parameters based on action type
        if action.target_position:
            # For movement actions, include direction
            if "NORTH" in action.action_type.value:
                action_str = "MOVE NORTH"
            elif "SOUTH" in action.action_type.value:
                action_str = "MOVE SOUTH"
            elif "EAST" in action.action_type.value:
                action_str = "MOVE EAST"
            elif "WEST" in action.action_type.value:
                action_str = "MOVE WEST"
        elif action.target_resource_id:
            action_str = f"GATHER {action.target_resource_id}"
        elif action.resource_type and action.quantity:
            action_str = f"DEPOSIT {action.resource_type.value} {action.quantity}"
        elif action.ship_component:
            action_str = f"BUILD {action.ship_component.value}"
        elif action.target_sailor:
            if action.action_type == ActionType.OFFER_FOOD:
                action_str = f"POISON {action.target_sailor}"
            else:
                action_str = f"{action.action_type.value} {action.target_sailor}"
        
        # Build correction example
        correction = {
            "input": student_response,  # What student said (wrong)
            "output": f"REASONING: {teacher_result['critique']}\nACTION: {action_str}",  # Correct format + explanation
            "penalty": teacher_result["penalty"],
            "critique": teacher_result["critique"]
        }
        
        correction_dataset.append(correction)

print("✅ Correction dataset collector initialized")
print("   Format: (student_wrong_output) → (teacher_correct_action + critique)")
print("   Usage: Collect errors during RL, then run SFT pass")

In [None]:
from trl import SFTTrainer, SFTConfig
from datasets import Dataset

def run_sft_correction_pass(correction_examples: list, num_epochs: int = 1):
    """
    Run supervised fine-tuning on collected correction examples.
    
    This teaches the student model the correct format directly:
    - Input: Student's wrong output
    - Target: Teacher's corrected action + critique
    
    Args:
        correction_examples: List of {input, output, penalty, critique} dicts
        num_epochs: Number of SFT epochs (default 1)
    
    Returns:
        SFT trainer statistics
    """
    if len(correction_examples) == 0:
        print("⚠️  No corrections to train on")
        return
    
    print(f"\n{'='*80}")
    print(f"🎓 Running SFT Correction Pass")
    print(f"{'='*80}")
    print(f"   Corrections collected: {len(correction_examples)}")
    print(f"   SFT epochs: {num_epochs}")
    
    # Convert to dataset format
    sft_data = []
    for example in correction_examples:
        # Create chat format
        messages = [
            {"role": "user", "content": f"Fix this invalid action:\n{example['input']}"},
            {"role": "assistant", "content": example['output']}
        ]
        text = tokenizer.apply_chat_template(messages, tokenize=False)
        sft_data.append({"text": text})
    
    sft_dataset = Dataset.from_list(sft_data)
    
    # Configure SFT
    sft_config = SFTConfig(
        output_dir="outputs_marooned_rl/sft_corrections",
        num_train_epochs=num_epochs,
        per_device_train_batch_size=2,
        gradient_accumulation_steps=2,
        learning_rate=1e-5,
        logging_steps=10,
        save_steps=100,
        max_seq_length=2048,
        packing=False,
    )
    
    # Run SFT
    sft_trainer = SFTTrainer(
        model=student_model,
        args=sft_config,
        train_dataset=sft_dataset,
        tokenizer=tokenizer,
    )
    
    sft_result = sft_trainer.train()
    
    print(f"\n✅ SFT correction pass complete!")
    print(f"   Examples trained: {len(correction_examples)}")
    print(f"   Final loss: {sft_result.training_loss:.4f}")
    print(f"{'='*80}\n")
    
    return sft_result

print("✅ SFT correction trainer defined")
print("   Usage: run_sft_correction_pass(correction_dataset, num_epochs=1)")

## 8️⃣ Hybrid RL + SFT Training Loop

**Training Flow:**
1. **RL Phase (PPO)**: Student learns via rewards (env + process penalties)
2. **Collect Corrections**: Store student errors → teacher corrections
3. **SFT Phase (every 25 steps)**: Directly teach correct format via supervised learning
4. **Repeat**: Clear dataset, continue RL

This hybrid approach combines:
- **RL**: Strategic decision-making, long-term planning
- **SFT**: Fast format learning, explicit error correction

In [None]:
import torch, time, numpy as np

NUM_TRAINING_STEPS = 100
EPISODE_MAX_SEQ_LENGTH = 16384
SFT_INTERVAL = 25  # Run SFT every N steps

def generate_episode_with_teacher(max_turns=100, verbose=False):
    """
    Play episode with teacher guidance (vLLM API):
    - Student generates action
    - Teacher validates via vLLM and corrects
    - Collect corrections for SFT
    - Environment executes corrected action
    - Student receives combined reward (env + process penalty)
    """
    env = MaroonedEnv(render_mode="ansi")
    observations = env.reset()
    sailor_ids = list(env.agents)
    
    query_tensors, response_tensors, rewards_list = [], [], []
    
    # Enable inference mode
    FastLanguageModel.for_inference(student_model)
    
    for turn in range(max_turns):
        for sailor_id in sailor_ids:
            if not env.state.sailors[sailor_id].alive:
                continue
            
            obs = observations[sailor_id]
            role = env.state.sailors[sailor_id].role.value
            
            # === STEP 1: Student generates response ===
            system_prompt = get_system_prompt(role)
            user_prompt = observation_to_prompt(obs)
            
            messages = [
                {"role": "system", "content": system_prompt},
                {"role": "user", "content": user_prompt}
            ]
            
            text = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
            inputs = tokenizer(text, return_tensors="pt", truncation=True, max_length=EPISODE_MAX_SEQ_LENGTH).to("cuda")
            query_tensor = inputs["input_ids"][0]
            
            with torch.no_grad():
                outputs = student_model.generate(
                    **inputs,
                    max_new_tokens=256,
                    temperature=0.3,
                    do_sample=True,
                    top_p=0.9,
                    top_k=40,
                    repetition_penalty=1.2,
                    pad_token_id=tokenizer.eos_token_id,
                )
                response_tensor = outputs[0]
            
            student_response = tokenizer.decode(response_tensor[len(query_tensor):], skip_special_tokens=True).strip()
            
            # === STEP 2: Teacher validates via vLLM API ===
            teacher_result = teacher_validate_student_output(student_response, obs, sailor_id)
            action = teacher_result["action"]
            process_penalty = teacher_result["penalty"]
            
            # === STEP 2.5: Collect correction examples for SFT ===
            add_correction_example(student_response, teacher_result, obs)
            
            # === STEP 3: Environment executes corrected action ===
            actions_dict = {sid: Action(sailor_id=sid, action_type=ActionType.WAIT) for sid in env.agents}
            actions_dict[sailor_id] = action
            
            observations, rewards_dict, dones, truncated, info = env.step(actions_dict)
            env_reward = rewards_dict[sailor_id]
            
            # === STEP 4: Combined reward ===
            total_reward = env_reward + process_penalty
            
            # Store experience
            query_tensors.append(query_tensor)
            response_tensors.append(response_tensor[len(query_tensor):])
            rewards_list.append(torch.tensor(total_reward, dtype=torch.float32))
            
            if verbose and turn % 20 == 0:
                print(f"Turn {turn:03d} | {sailor_id}: {action.action_type.value:<12} | "
                      f"Env={env_reward:+.1f} Penalty={process_penalty:+.1f} Total={total_reward:+.1f}")
            
            if dones[sailor_id]:
                return query_tensors, response_tensors, rewards_list
    
    return query_tensors, response_tensors, rewards_list


# === HYBRID RL + SFT TRAINING LOOP ===
print("🚀 Starting Hybrid RL + SFT Training")
print(f"   Total steps: {NUM_TRAINING_STEPS}")
print(f"   SFT interval: Every {SFT_INTERVAL} steps")
print(f"   Strategy: PPO for strategy, SFT for format learning\n")

stats_rewards = []
stats_sft_runs = 0

for step in range(NUM_TRAINING_STEPS):
    start_time = time.time()
    batch_queries, batch_responses, batch_rewards = [], [], []
    
    # === RL PHASE: Generate episodes with PPO ===
    for _ in range(ppo_config.batch_size):
        queries, responses, rewards = generate_episode_with_teacher(
            max_turns=100,
            verbose=(step % 50 == 0 and _ == 0)
        )
        batch_queries.extend(queries)
        batch_responses.extend(responses)
        batch_rewards.extend(rewards)
    
    # PPO step
    stats = ppo_trainer.step(batch_queries, batch_responses, batch_rewards)
    
    # Track metrics
    episode_reward = sum([r.item() for r in batch_rewards])
    stats_rewards.append(episode_reward)
    
    elapsed = time.time() - start_time
    avg_reward = np.mean(stats_rewards[-10:]) if len(stats_rewards) >= 10 else np.mean(stats_rewards)
    
    print(f"Step {step+1:03d}/{NUM_TRAINING_STEPS} | "
          f"Reward: {episode_reward:+6.1f} | "
          f"Avg(10): {avg_reward:+6.1f} | "
          f"Corrections: {len(correction_dataset):4d} | "
          f"Time: {elapsed:4.1f}s")
    
    # === SFT PHASE: Periodic correction pass ===
    if (step + 1) % SFT_INTERVAL == 0 and len(correction_dataset) >= 10:
        print(f"\n{'─'*80}")
        print(f"🎓 SFT CORRECTION PASS #{stats_sft_runs + 1}")
        print(f"{'─'*80}")
        run_sft_correction_pass(correction_dataset, num_epochs=1)
        stats_sft_runs += 1
        # Clear dataset after training
        correction_dataset.clear()
        print(f"{'─'*80}\n")
    
    # Checkpoint
    if (step + 1) % 50 == 0:
        checkpoint_path = f"outputs_marooned_rl/checkpoint_step{step+1}"
        ppo_trainer.save_pretrained(checkpoint_path)
        print(f"   💾 Checkpoint saved → {checkpoint_path}")

print("\n" + "="*80)
print("✅ TRAINING COMPLETE!")
print("="*80)
print(f"   Total steps: {NUM_TRAINING_STEPS}")
print(f"   SFT passes run: {stats_sft_runs}")
print(f"   Average reward: {np.mean(stats_rewards):.2f}")
print(f"   Final reward (last 10): {np.mean(stats_rewards[-10:]):.2f}")
print("="*80)

## 9️⃣ Visualize Training Progress

Analyze how SFT corrections improved format learning.

In [None]:
import matplotlib.pyplot as plt

# Plot rewards over time
plt.figure(figsize=(12, 5))

plt.subplot(1, 2, 1)
plt.plot(stats_rewards, alpha=0.3, label='Raw Reward')
# Moving average
window = 10
if len(stats_rewards) >= window:
    ma_rewards = np.convolve(stats_rewards, np.ones(window)/window, mode='valid')
    plt.plot(range(window-1, len(stats_rewards)), ma_rewards, label=f'MA({window})', linewidth=2)
plt.xlabel('Training Step')
plt.ylabel('Episode Reward')
plt.title('Training Progress (RL + SFT)')
plt.legend()
plt.grid(True, alpha=0.3)

# Mark SFT passes
for i in range(stats_sft_runs):
    sft_step = (i + 1) * SFT_INTERVAL
    if sft_step < len(stats_rewards):
        plt.axvline(x=sft_step, color='red', linestyle='--', alpha=0.5, linewidth=1)

plt.subplot(1, 2, 2)
# Show improvement rate
improvement = np.diff(stats_rewards)
plt.hist(improvement, bins=30, alpha=0.7, edgecolor='black')
plt.xlabel('Reward Change (step to step)')
plt.ylabel('Frequency')
plt.title('Reward Improvement Distribution')
plt.grid(True, alpha=0.3)

plt.tight_layout()
plt.savefig('outputs_marooned_rl/training_progress.png', dpi=150, bbox_inches='tight')
plt.show()

print(f"\n📊 Training Statistics:")
print(f"   Initial avg reward (first 10): {np.mean(stats_rewards[:10]):.2f}")
print(f"   Final avg reward (last 10): {np.mean(stats_rewards[-10:]):.2f}")
print(f"   Improvement: {np.mean(stats_rewards[-10:]) - np.mean(stats_rewards[:10]):.2f}")
print(f"   SFT passes executed: {stats_sft_runs}")
print(f"\n✅ Plot saved to outputs_marooned_rl/training_progress.png")

## 🔟 Save Trained Model

Save the final student model with improved format learning.

In [None]:
output_dir = "outputs_marooned_rl/final_model"

# Save student model
student_model.save_pretrained(output_dir)
tokenizer.save_pretrained(output_dir)

print(f"✅ Model saved to {output_dir}")
print(f"\nTo load the model:")
print(f"```python")
print(f"from unsloth import FastLanguageModel")
print(f"model, tokenizer = FastLanguageModel.from_pretrained('{output_dir}')")
print(f"```")
print(f"\nModel training complete with hybrid RL + SFT approach!")
print(f"   - Process reward modeling (teacher validates student)")
print(f"   - Periodic SFT passes (direct format learning)")
print(f"   - Combined strategy: RL for game strategy, SFT for format correctness")

## 9️⃣ Visualize Training Progress

Plot reward progression over training steps.

In [None]:
save_path = "outputs_marooned_rl/final_ppo_model"

ppo_trainer.save_pretrained(save_path)
tokenizer.save_pretrained(save_path)

print(f"✅ Model saved to {save_path}")
print(f"\n🎉 Process Reward Modeling Training Complete!")
print(f"\n📈 Key Results:")
print(f"   - Teacher-guided action parsing")
print(f"   - Process penalties for invalid outputs")
print(f"   - Real-time correction and feedback")
print(f"   - {NUM_TRAINING_STEPS} training steps completed")