<div align="center">

# 🏴‍☠️ MAROONED: Hybrid RL + SFT Training

### Process Reward Modeling with Supervised Correction

**OpenEnv Hackathon 2025**

[![OpenEnv](https://img.shields.io/badge/Framework-OpenEnv-blue)](https://github.com/openenv)
[![Llama](https://img.shields.io/badge/Model-Llama_3.1_8B-green)](https://huggingface.co/meta-llama/Meta-Llama-3.1-8B-Instinct)
[![Hardware](https://img.shields.io/badge/Hardware-AMD_MI300X-red)](https://www.amd.com/en/products/accelerators/instinct/mi300.html)
[![Teacher](https://img.shields.io/badge/Teacher-Mixtral_8x7B-orange)](https://huggingface.co/mistralai/Mixtral-8x7B-Instruct-v0.1)

</div>

---

## 🔬 Training Architecture

**Hybrid Approach: RL (strategy) + SFT (format learning)**

```
┌─────────────────────────────────────────────────┐
│  RL PHASE: PPO Training (Episodes 1-4)          │
│  Student (Llama 3.1 8B) → Teacher (vLLM         │
│  Mixtral-8x7B) → Env → Rewards                  │
│  Collect corrections: wrong → correct           │
└──────────────────┬──────────────────────────────┘
                   │ Every 25 steps
                   ▼
┌─────────────────────────────────────────────────┐
│  SFT PHASE: Supervised Fine-Tuning              │
│  Train on corrections: mimic teacher format     │
│  Clear dataset, continue RL                     │
└─────────────────────────────────────────────────┘
```

**Key Innovation:** Student learns format directly from teacher critiques via periodic SFT passes.

**Teacher Model:** vLLM server running Mixtral-8x7B-Instruct-v0.1 at `localhost:8000`

---

## ⚙️ Prerequisites

**Ensure vLLM teacher server is running:**

```bash
# Start vLLM server with Mixtral-8x7B-Instruct-v0.1
vllm serve mistralai/Mixtral-8x7B-Instruct-v0.1 \
  --port 8000 \
  --gpu-memory-utilization 0.9 \
  --max-num-batched-tokens 8192 \
  --dtype float16 \
  --tokenizer-mode mistral
```

**Test the model:**
```bash
curl http://localhost:8000/v1/models
```

Expected: JSON response listing `mistralai/Mixtral-8x7B-Instruct-v0.1` in the models array.

---

## 1️⃣ Install Dependencies

In [None]:
import os, importlib.util
!pip install --upgrade -qqq uv
if importlib.util.find_spec("torch") is None or "COLAB_" in "".join(os.environ.keys()):
    try: import numpy; get_numpy = f"numpy=={numpy.__version__}"
    except: get_numpy = "numpy"
    !uv pip install -qqq \
        "torch>=2.8.0" "triton>=3.4.0" {get_numpy} torchvision bitsandbytes "transformers==4.56.2" trackio \
        "unsloth_zoo[base] @ git+https://github.com/unslothai/unsloth-zoo" \
        "unsloth[base] @ git+https://github.com/unslothai/unsloth" \
        git+https://github.com/triton-lang/triton.git@05b2c186c1b6c9a08375389d5efe9cb4c401c075#subdirectory=python/triton_kernels
elif importlib.util.find_spec("unsloth") is None:
    !uv pip install -qqq unsloth trackio
!uv pip install --upgrade --no-deps transformers==4.56.2 tokenizers trl==0.22.2 unsloth unsloth_zoo

print("✅ Dependencies installed")

In [1]:
import torch
print(torch.version.hip)
print(torch.cuda.is_available())
print(torch.cuda.get_device_name(0))


6.4.43484-123eb5128
True
AMD Instinct MI300X VF


## 2️⃣ Load MAROONED Environment

In [6]:
import sys
import json
import random
from typing import Dict, Any, List

# Clear cached modules
modules_to_clear = [m for m in list(sys.modules.keys()) 
                   if 'marooned' in m or m in ['environment', 'config', 'models', 'game_state', 'view_map', 'llm_interface']]
for module in modules_to_clear:
    if module in sys.modules:
        del sys.modules[module]

sys.path.insert(0, '../marooned_env')

from environment import MaroonedEnv
from llm_interface import (
    get_system_prompt,
    observation_to_prompt,
    teacher_validate_student_output,
)
from config import ActionType, ResourceType, MapLevel, ShipComponent
from models import Action, Position, Observation

print("✅ MAROONED environment loaded")
print("✅ Teacher validation API imported")

✅ MAROONED environment loaded
✅ Teacher validation API imported


## 3️⃣ Load Student Model (Llama 3.1 8B with LoRA)

In [7]:
import os
os.environ["UNSLOTH_NO_TQDM"] = "1"
from unsloth import FastLanguageModel

import torch

# ROCm optimizations
os.environ["PYTORCH_ROCM_ARCH"] = "gfx942"
os.environ["HSA_FORCE_FINE_GRAIN_PCIE"] = "1"
os.environ["ATTN_BACKEND"] = "triton"
torch.backends.cudnn.benchmark = True

max_seq_length = 16384
lora_rank = 16

student_model, tokenizer = FastLanguageModel.from_pretrained(
    model_name = "unsloth/Llama-3.1-8B-bnb-4bit",  # your local path
    load_in_4bit = False,
    dtype = torch.bfloat16,
    max_seq_length = 16384,
    device_map = "auto",
)

# Set chat template for Llama 3.1
if tokenizer.chat_template is None:
    tokenizer.chat_template = "{% if messages[0]['role'] == 'system' %}{% set system_message = messages[0]['content'] %}{% set messages = messages[1:] %}{% else %}{% set system_message = '' %}{% endif %}{% if system_message != '' %}{{ '<|start_header_id|>system<|end_header_id|>\n\n' + system_message + '<|eot_id|>' }}{% endif %}{% for message in messages %}{{ '<|start_header_id|>' + message['role'] + '<|end_header_id|>\n\n' + message['content'] + '<|eot_id|>' }}{% endfor %}{% if add_generation_prompt %}{{ '<|start_header_id|>assistant<|end_header_id|>\n\n' }}{% endif %}"

# Add LoRA adapters
student_model = FastLanguageModel.get_peft_model(
    student_model,
    r = lora_rank,
    target_modules = [
        "q_proj", "k_proj", "v_proj", "o_proj",
        "gate_proj", "up_proj", "down_proj",
    ],
    lora_alpha = lora_rank * 2,
    lora_dropout = 0.0,
    use_gradient_checkpointing = "unsloth",
    random_state = 3407,
    use_rslora = True,
)

print(f"✅ Student Model: Llama 3.1 8B (BF16, LoRA rank={lora_rank})")
print(f"   GPU: {torch.cuda.get_device_name(0)}")
print(f"   VRAM: {torch.cuda.get_device_properties(0).total_memory / 1024**3:.1f} GB")
print(f"✅ Chat template configured for Llama 3.1")

==((====))==  Unsloth 2025.10.11: Fast Llama patching. Transformers: 4.57.1.
   \\   /|    AMD Instinct MI300X VF. Num GPUs = 1. Max memory: 191.688 GB. Platform: Linux.
O^O/ \_/ \    Torch: 2.9.0+rocm6.4. ROCm Toolkit: 6.4.43484-123eb5128. Triton: 3.5.0
\        /    Bfloat16 = TRUE. FA [Xformers = None. FA2 = False]
 "-____-"     Free license: http://github.com/unslothai/unsloth
Unsloth: Fast downloading is enabled - ignore downloading bars which are red colored!


Loading checkpoint shards: 100%|██████████| 4/4 [00:06<00:00,  1.51s/it]



✅ Student Model: Llama 3.1 8B (BF16, LoRA rank=16)
   GPU: AMD Instinct MI300X VF
   VRAM: 191.7 GB
✅ Chat template configured for Llama 3.1


## 4️⃣ Verify vLLM Teacher Server

In [10]:
import requests

VLLM_API_URL = "http://localhost:8000/v1/chat/completions"
VLLM_MODELS_URL = "http://localhost:8000/v1/models"

print("Checking vLLM teacher server...")
try:
    response = requests.get(VLLM_MODELS_URL, timeout=5)
    if response.status_code == 200:
        models = response.json()
        model_list = models.get('data', [])
        model_names = [m['id'] for m in model_list]
        print(f"✅ vLLM server running!")
        print(f"   Available models: {model_names}")
        
        if 'mistralai/Mixtral-8x7B-Instruct-v0.1' in model_names:
            print(f"   ✅ Mixtral-8x7B model ready for training")
        else:
            print(f"   ⚠️  Mixtral-8x7B model not found!")
            print(f"   Start server with:")
            print(f"   vllm serve mistralai/Mixtral-8x7B-Instruct-v0.1 --port 8000 --gpu-memory-utilization 0.9")
    else:
        print(f"⚠️  Server responded with status {response.status_code}")
except requests.exceptions.RequestException as e:
    print(f"❌ vLLM server not reachable!")
    print(f"   Error: {e}")
    print(f"\n   Start server:")
    print(f"   vllm serve mistralai/Mixtral-8x7B-Instruct-v0.1 \\")
    print(f"     --port 8000 \\")
    print(f"     --gpu-memory-utilization 0.9 \\")
    print(f"     --max-num-batched-tokens 8192 \\")
    print(f"     --dtype float16 \\")
    print(f"     --tokenizer-mode mistral")
    raise SystemExit("Teacher server required for training")


Checking vLLM teacher server...
✅ vLLM server running!
   Available models: ['mistralai/Mixtral-8x7B-Instruct-v0.1']
   ✅ Mixtral-8x7B model ready for training


## 5️⃣ Test Teacher Validation (vLLM Mixtral)

In [20]:
# Reload LLM interface to pick up latest changes
import importlib
import llm_interface
importlib.reload(llm_interface)

from llm_interface import teacher_validate_student_output

print("✅ LLM interface reloaded")

✅ LLM interface reloaded


In [21]:
from config import MAX_ENERGY

print("="*80)
print("🧪 TESTING TEACHER VALIDATION WITH FULL CONTEXT")
print("="*80)

# Create environment and get real observation
env = MaroonedEnv(render_mode="ansi", seed=42)
observations = env.reset(seed=42)
alice_obs = observations["Alice"]
alice_role = env.state.sailors["Alice"].role.value

print(f"\n📋 Test Setup:")
print(f"   Sailor: Alice")
print(f"   Role: {alice_role.upper()}")
print(f"   Position: {alice_obs.position}")
print(f"   Energy: {alice_obs.energy}/{MAX_ENERGY}")
print(f"   Visible resources: {len(alice_obs.visible_resources) if hasattr(alice_obs, 'visible_resources') else 'N/A'}")

# Get proper system and user prompts
system_prompt = get_system_prompt(alice_role)
user_prompt = observation_to_prompt(alice_obs)

print(f"\n📏 Prompt sizes:")
print(f"   System prompt: {len(system_prompt)} chars")
print(f"   User prompt: {len(user_prompt)} chars")
print(f"   Total context: {len(system_prompt) + len(user_prompt)} chars")

# Test cases - simulating what an untrained student LLM might output
test_cases = [
    {
        "name": "Format Error (MOVING instead of MOVE)",
        "output": "REASONING: I should move northeast to explore\nACTION: MOVING NORTH"
    },
    {
        "name": "Invalid Command (CHECK_STATUS)",
        "output": "REASONING: Let me check my status\nACTION: CHECK_STATUS"
    },
    {
        "name": "Missing Resource ID",
        "output": "REASONING: I see wood nearby, gathering it\nACTION: GATHER wood"
    },
    {
        "name": "Truncated Output",
        "output": "REASONING: As the traitor, I should sabotagin"
    },
    {
        "name": "Correct Format",
        "output": "REASONING: Moving north to explore the area\nACTION: MOVE NORTH"
    }
]

print(f"\n{'='*80}")
print("🔬 RUNNING TEACHER VALIDATION TESTS")
print(f"{'='*80}\n")

for i, test in enumerate(test_cases, 1):
    print(f"Test {i}: {test['name']}")
    print(f"   Student output: {test['output'][:60]}...")
    
    # Call teacher validation with full context
    result = teacher_validate_student_output(
        student_response=test['output'],
        observation=alice_obs,
        sailor_id="Alice"
    )
    
    # Display results
    validity_icon = "✅" if result['valid'] else "❌"
    print(f"   {validity_icon} Valid: {result['valid']}")
    print(f"   🔧 Corrected action: {result['action'].action_type.value}")
    print(f"   💰 Process penalty: {result['penalty']}")
    print(f"   💬 Critique: {result['critique'][:80]}")
    
    # Show what would happen
    if result['valid']:
        print(f"   ✅ Action executes as-is (no penalty)")
    else:
        print(f"   ⚠️  Teacher corrected → student gets penalty {result['penalty']}")
    print()

print("="*80)
print("✅ TEACHER VALIDATION API WORKING!")
print("="*80)
print("\nKey Points:")
print("  • Teacher receives full game context (observation + system prompt)")
print("  • Invalid formats get corrected automatically")
print("  • Process penalties guide student learning")
print("  • Student focuses on strategy, not syntax")


🧪 TESTING TEACHER VALIDATION WITH FULL CONTEXT

📋 Test Setup:
   Sailor: Alice
   Role: TRAITOR
   Position: Position(x=15, y=15, level=<MapLevel.GROUND: 0>)
   Energy: 100/100
   Visible resources: N/A

📏 Prompt sizes:
   System prompt: 9401 chars
   User prompt: 9127 chars
   Total context: 18528 chars

🔬 RUNNING TEACHER VALIDATION TESTS

Test 1: Format Error (MOVING instead of MOVE)
   Student output: REASONING: I should move northeast to explore
ACTION: MOVING...


   ❌ Valid: False
   🔧 Corrected action: move_north
   💰 Process penalty: -0.5
   💬 Critique: Use MOVE NORTH not MOVING NORTH - verb must be MOVE, and you should move east in
   ⚠️  Teacher corrected → student gets penalty -0.5

Test 2: Invalid Command (CHECK_STATUS)
   Student output: REASONING: Let me check my status
ACTION: CHECK_STATUS...
   ❌ Valid: False
   🔧 Corrected action: wait
   💰 Process penalty: -1.0
   💬 Critique: CHECK_STATUS doesn't exist - use WAIT for no-operation. The student's energy is 
   ⚠️  Teacher corrected → student gets penalty -1.0

Test 3: Missing Resource ID
   Student output: REASONING: I see wood nearby, gathering it
ACTION: GATHER wo...
   ❌ Valid: False
   🔧 Corrected action: wait
   💰 Process penalty: -1.0
   💬 Critique: CHECK_STATUS doesn't exist - use WAIT for no-operation. The student's energy is 
   ⚠️  Teacher corrected → student gets penalty -1.0

Test 3: Missing Resource ID
   Student output: REASONING: I see wood nearby, gathering it
ACTION: 

## 6️⃣ Setup Correction Dataset for SFT

Collect student errors and teacher corrections during RL training.

In [22]:
correction_dataset = []

def add_correction_example(student_response: str, teacher_result: dict, observation: Observation):
    """Store invalid outputs for SFT training."""
    if not teacher_result["valid"]:
        action = teacher_result["action"]
        action_str = f"{action.action_type.value}"
        
        # Format action string with parameters
        if action.target_position:
            if "NORTH" in action.action_type.value:
                action_str = "MOVE NORTH"
            elif "SOUTH" in action.action_type.value:
                action_str = "MOVE SOUTH"
            elif "EAST" in action.action_type.value:
                action_str = "MOVE EAST"
            elif "WEST" in action.action_type.value:
                action_str = "MOVE WEST"
        elif action.target_resource_id:
            action_str = f"GATHER {action.target_resource_id}"
        elif action.resource_type and action.quantity:
            action_str = f"DEPOSIT {action.resource_type.value} {action.quantity}"
        elif action.ship_component:
            action_str = f"BUILD {action.ship_component.value}"
        elif action.target_sailor:
            if action.action_type == ActionType.OFFER_FOOD:
                action_str = f"POISON {action.target_sailor}"
            else:
                action_str = f"{action.action_type.value} {action.target_sailor}"
        
        correction = {
            "input": student_response,
            "output": f"REASONING: {teacher_result['critique']}\nACTION: {action_str}",
            "penalty": teacher_result["penalty"],
            "critique": teacher_result["critique"]
        }
        
        correction_dataset.append(correction)

print("✅ Correction collector initialized")
print("   Format: (student_wrong) → (teacher_correct + critique)")

✅ Correction collector initialized
   Format: (student_wrong) → (teacher_correct + critique)


## 7️⃣ Define SFT Correction Trainer

In [23]:
from trl import SFTTrainer, SFTConfig
from datasets import Dataset

def run_sft_correction_pass(correction_examples: list, num_epochs: int = 1):
    """
    Run supervised fine-tuning on collected corrections.
    Teaches student to mimic teacher's correct format.
    """
    if len(correction_examples) == 0:
        print("⚠️  No corrections to train on")
        return
    
    print(f"\n{'='*80}")
    print(f"🎓 SFT CORRECTION PASS")
    print(f"{'='*80}")
    print(f"   Examples: {len(correction_examples)}")
    print(f"   Epochs: {num_epochs}")
    
    # Convert to chat format
    sft_data = []
    for example in correction_examples:
        messages = [
            {"role": "user", "content": f"Fix this invalid action:\n{example['input']}"},
            {"role": "assistant", "content": example['output']}
        ]
        text = tokenizer.apply_chat_template(messages, tokenize=False)
        sft_data.append({"text": text})
    
    sft_dataset = Dataset.from_list(sft_data)
    
    # SFT configuration
    sft_config = SFTConfig(
        output_dir="outputs_marooned_rl/sft_corrections",
        num_train_epochs=num_epochs,
        per_device_train_batch_size=2,
        gradient_accumulation_steps=2,
        learning_rate=1e-5,
        logging_steps=10,
        save_steps=100,
        max_seq_length=2048,
        packing=False,
    )
    
    # Train
    sft_trainer = SFTTrainer(
        model=student_model,
        args=sft_config,
        train_dataset=sft_dataset,
        tokenizer=tokenizer,
    )
    
    result = sft_trainer.train()
    
    print(f"\n✅ SFT complete! Loss: {result.training_loss:.4f}")
    print(f"{'='*80}\n")
    
    return result

print("✅ SFT trainer defined")

✅ SFT trainer defined


## 8️⃣ Setup PPO Trainer

In [24]:
from trl import PPOConfig, PPOTrainer, AutoModelForCausalLMWithValueHead
import numpy as np

ppo_config = PPOConfig(
    output_dir="outputs_marooned_rl",
    learning_rate=1e-5,
    batch_size=4,
    mini_batch_size=1,
    gradient_accumulation_steps=4,
    seed=42,
    num_ppo_epochs=4,
    kl_coef=0.2,
    vf_coef=0.1,
    cliprange=0.2,
    gamma=0.99,
    lam=0.95,
    temperature=0.3,
    response_length=256,
)

# Wrap student model with value head
model_with_value = AutoModelForCausalLMWithValueHead.from_pretrained(student_model)

# Compatibility patches
base_model = model_with_value.pretrained_model
if not hasattr(model_with_value, "base_model_prefix"):
    model_with_value.base_model_prefix = getattr(base_model, "base_model_prefix", "model")
setattr(model_with_value, model_with_value.base_model_prefix, base_model)
if not hasattr(model_with_value, "config"):
    model_with_value.config = base_model.config
if not hasattr(model_with_value, "generation_config"):
    model_with_value.generation_config = base_model.generation_config
if hasattr(base_model, "is_gradient_checkpointing"):
    model_with_value.is_gradient_checkpointing = base_model.is_gradient_checkpointing
else:
    model_with_value.is_gradient_checkpointing = False

# Minimal dataset
train_dataset = Dataset.from_dict({
    "prompt": ["stub"],
    "response": ["stub"],
    "reward": [0.0],
})

ppo_trainer = PPOTrainer(
    args=ppo_config,
    model=model_with_value,
    ref_model=None,
    reward_model=model_with_value,
    value_model=model_with_value,
    train_dataset=train_dataset,
    tokenizer=tokenizer,
)

print("✅ PPO Trainer initialized")



✅ PPO Trainer initialized


## 9️⃣ Hybrid RL + SFT Training Loop

**Training Flow:**
1. **RL Phase:** Student plays episodes, teacher validates, collect corrections
2. **SFT Phase (every 25 steps):** Train on corrections, clear dataset
3. **Repeat:** Continue RL with improved format knowledge

## ⚡ Quick Test Configuration (Reduce Training Load)

**For initial testing, use these reduced parameters:**

In [29]:
# TEMPORARY: Reduce load for testing
# Comment these out once you confirm training works

NUM_TRAINING_STEPS = 5          # Was 100 - test with just 5 steps
EPISODE_MAX_TURNS = 10          # Was 100 - shorter episodes
BATCH_SIZE = 1                  # Was 4 - single episode per step
SFT_INTERVAL = 10               # Was 25 - faster SFT testing

print("⚡ REDUCED CONFIGURATION FOR TESTING:")
print(f"   Training steps: {NUM_TRAINING_STEPS}")
print(f"   Episode max turns: {EPISODE_MAX_TURNS}")
print(f"   Batch size: {BATCH_SIZE}")
print(f"   SFT interval: {SFT_INTERVAL}")
print(f"\n   Estimated time: ~{NUM_TRAINING_STEPS * EPISODE_MAX_TURNS * 5 * 2 / 60:.1f} minutes")
print(f"   (assumes ~2sec per teacher API call)")
print(f"\n⚠️  Run this cell, then RESTART the training loop cell!")

⚡ REDUCED CONFIGURATION FOR TESTING:
   Training steps: 5
   Episode max turns: 10
   Batch size: 1
   SFT interval: 10

   Estimated time: ~8.3 minutes
   (assumes ~2sec per teacher API call)

⚠️  Run this cell, then RESTART the training loop cell!


In [None]:
from IPython.display import clear_output
from config import MAX_ENERGY, MapLevel

def visualize_game_state(env, turn_num, sailor_actions=None, sailor_reasoning=None):
    """
    Display comprehensive game state visualization with map, sailor status, and actions.
    
    Args:
        env: MaroonedEnv instance
        turn_num: Current turn number
        sailor_actions: Dict of sailor_id -> action string
        sailor_reasoning: Dict of sailor_id -> reasoning string
    """
    output = []
    
    # Header
    output.append("="*100)
    output.append(f"🏴‍☠️  TURN {turn_num} | DAY {env.state.current_day} | GAME STATE")
    output.append("="*100)
    
    # Sailor Status Table
    output.append("\n📊 SAILOR STATUS:")
    output.append("─"*100)
    output.append(f"{'Name':<8} | {'Role':<10} | {'HP':<10} | {'Energy':<12} | {'Position':<15} | {'Status':<10}")
    output.append("─"*100)
    
    for sailor_id in env.agents:
        sailor = env.state.sailors[sailor_id]
        role = sailor.role.value
        
        # Health bar
        hp_icon = "💚" if sailor.alive else "💀"
        hp_str = f"{hp_icon} {'ALIVE' if sailor.alive else 'DEAD'}"
        
        # Energy bar (visual representation)
        energy_pct = sailor.energy / MAX_ENERGY
        energy_bars = int(energy_pct * 10)
        energy_visual = "█" * energy_bars + "░" * (10 - energy_bars)
        energy_str = f"{energy_visual} {sailor.energy}/{MAX_ENERGY}"
        
        # Position
        pos = sailor.position
        pos_str = f"({pos.x},{pos.y}) {pos.level.name}"
        
        # Status indicators
        status_parts = []
        if not sailor.alive:
            status_parts.append("DEAD")
        if sailor_id == env.state.traitor_id:
            status_parts.append("🔪TRAITOR")
        if sailor.poisoned:
            status_parts.append("☠️ POISON")
        status_str = " ".join(status_parts) if status_parts else "OK"
        
        output.append(f"{sailor_id:<8} | {role:<10} | {hp_str:<10} | {energy_str:<12} | {pos_str:<15} | {status_str:<10}")
    
    output.append("─"*100)
    
    # Ship Progress
    ship = env.state.ship_progress
    ship_pct = ship.total_percentage
    ship_bars = int(ship_pct / 10)
    ship_visual = "▓" * ship_bars + "░" * (10 - ship_bars)
    output.append(f"\n🚢 SHIP PROGRESS: {ship_visual} {ship_pct:.1f}%")
    output.append(f"   Hull: {ship.hull}/{ship.hull_max} | Sail: {ship.sail}/{ship.sail_max} | Engine: {ship.engine}/{ship.engine_max}")
    
    # Base Storage
    storage = env.state.base_storage
    output.append(f"\n📦 BASE STORAGE:")
    output.append(f"   🌲 Wood: {storage.wood} | ⚙️ Metal: {storage.metal} | 🍎 Food: {storage.food} | 🌿 Antidote: {storage.antidote_herbs}")
    
    # Actions this turn (if provided)
    if sailor_actions:
        output.append(f"\n⚔️  ACTIONS THIS TURN:")
        output.append("─"*100)
        for sailor_id, action in sailor_actions.items():
            sailor = env.state.sailors[sailor_id]
            if not sailor.alive:
                continue
            
            reasoning = sailor_reasoning.get(sailor_id, "N/A") if sailor_reasoning else "N/A"
            # Truncate long reasoning
            reasoning_short = (reasoning[:70] + "...") if len(reasoning) > 70 else reasoning
            
            output.append(f"  [{sailor_id}] {action}")
            output.append(f"          💭 {reasoning_short}")
        output.append("─"*100)
    
    # Map visualization (all 3 levels side by side)
    output.append(f"\n🗺️  ISLAND MAP (Day {env.state.current_day}):")
    output.append("")
    
    # Render all three levels
    for level in [MapLevel.GROUND, MapLevel.MOUNTAIN, MapLevel.CAVE]:
        map_str = env.render_map(level, use_emoji=True)
        output.append(map_str)
    
    output.append("\n" + "="*100)
    
    # Print everything
    print("\n".join(output))


print("✅ Game state visualization function loaded")
print("   Use: visualize_game_state(env, turn_num, sailor_actions, sailor_reasoning)")


## 🎮 Test Game Visualization (Optional)

Before training, you can test the visualization with a quick demo episode:

In [None]:
# TEST VISUALIZATION - Run a quick 5-turn demo episode with visualization
# This will show you what the training visualization looks like

print("🎬 Running demo episode with visualization...")
print("   This will show the full game state for 5 turns\n")

demo_env = MaroonedEnv(seed=42, render_mode="ansi")
demo_obs = demo_env.reset(seed=42)

for demo_turn in range(5):
    # Collect actions for all sailors
    demo_actions = {}
    demo_reasoning = {}
    
    for sailor_id in demo_env.agents:
        sailor = demo_env.state.sailors[sailor_id]
        if not sailor.alive:
            continue
        
        # Random action for demo
        action = Action(sailor_id=sailor_id, action_type=ActionType.WAIT)
        demo_actions[sailor_id] = f"{action.action_type.value} (demo)"
        demo_reasoning[sailor_id] = "This is a demo - just waiting"
    
    # Show visualization
    clear_output(wait=True)
    visualize_game_state(demo_env, demo_turn, demo_actions, demo_reasoning)
    
    # Execute actions
    actions_dict = {sid: Action(sailor_id=sid, action_type=ActionType.WAIT) for sid in demo_env.agents}
    demo_obs, _, dones, _, _ = demo_env.step(actions_dict)
    
    time.sleep(1.5)  # Pause to see each turn

print("\n✅ Demo complete! You can now run the training loop.")
print("   The first episode will show this same visualization.")


In [None]:
import torch
import time
from IPython.display import clear_output

# Use variables from config cell above if defined, otherwise defaults
if 'NUM_TRAINING_STEPS' not in dir():
    NUM_TRAINING_STEPS = 100
    EPISODE_MAX_TURNS = 100
    BATCH_SIZE = 4
    SFT_INTERVAL = 25
    print("⚠️  Using default configuration - run config cell above to reduce load!")

EPISODE_MAX_SEQ_LENGTH = 16384

def generate_episode_with_teacher(max_turns=None, verbose=False, visualize=False):
    """
    Play one episode with teacher validation.
    Returns training data (prompts, responses) and rewards.
    
    Args:
        max_turns: Maximum turns per episode
        verbose: Print detailed action logs
        visualize: Show full game state visualization every turn
    """
    if max_turns is None:
        max_turns = EPISODE_MAX_TURNS
        
    env = MaroonedEnv(render_mode="ansi")
    observations = env.reset()
    sailor_ids = list(env.agents)
    
    # Collect episode data
    episode_data = []
    total_reward = 0.0
    
    FastLanguageModel.for_inference(student_model)
    
    if verbose:
        print(f"\n🎮 Starting episode (max {max_turns} turns)...")
    
    for turn in range(max_turns):
        # Collect turn data for visualization
        turn_actions = {}
        turn_reasoning = {}
        turn_actions_count = 0
        
        for sailor_id in sailor_ids:
            sailor = env.state.sailors[sailor_id]
            
            # Skip dead sailors
            if not sailor.alive:
                continue
            
            obs = observations[sailor_id]
            role = sailor.role.value
            
            # Student generates action
            system_prompt = get_system_prompt(role)
            user_prompt = observation_to_prompt(obs)
            
            messages = [
                {"role": "system", "content": system_prompt},
                {"role": "user", "content": user_prompt}
            ]
            
            full_prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
            inputs = tokenizer(full_prompt, return_tensors="pt", truncation=True, max_length=EPISODE_MAX_SEQ_LENGTH).to("cuda")
            
            with torch.no_grad():
                outputs = student_model.generate(
                    **inputs,
                    max_new_tokens=256,
                    temperature=0.3,
                    do_sample=True,
                    top_p=0.9,
                    top_k=40,
                    repetition_penalty=1.2,
                    pad_token_id=tokenizer.eos_token_id,
                )
            
            student_response = tokenizer.decode(outputs[0][len(inputs["input_ids"][0]):], skip_special_tokens=True).strip()
            
            # Parse reasoning from student response
            reasoning = "N/A"
            if "REASONING:" in student_response:
                try:
                    reasoning = student_response.split("REASONING:")[1].split("ACTION:")[0].strip()
                except:
                    reasoning = student_response[:100]
            
            # Teacher validates and potentially corrects
            teacher_result = teacher_validate_student_output(student_response, obs, sailor_id)
            action = teacher_result["action"]
            process_penalty = teacher_result["penalty"]
            
            # Collect correction if needed
            add_correction_example(student_response, teacher_result, obs)
            
            # Execute action in environment (only this sailor acts, others WAIT)
            actions_dict = {sid: Action(sailor_id=sid, action_type=ActionType.WAIT) for sid in env.agents}
            actions_dict[sailor_id] = action
            
            observations, rewards_dict, dones, truncated, info = env.step(actions_dict)
            env_reward = rewards_dict[sailor_id]
            
            # Combined reward (environment + process penalty)
            step_reward = env_reward + process_penalty
            total_reward += step_reward
            
            # Store for visualization
            action_str = action.action_type.value
            if action.target_position:
                action_str = f"{action.action_type.value}"
            elif action.target_resource_id:
                action_str = f"{action.action_type.value} (resource {action.target_resource_id})"
            elif action.target_sailor:
                action_str = f"{action.action_type.value} {action.target_sailor}"
            
            turn_actions[sailor_id] = f"{action_str} | Reward: {step_reward:+.1f}"
            turn_reasoning[sailor_id] = reasoning
            
            # Store training example
            episode_data.append({
                "prompt": full_prompt,
                "response": student_response,
                "reward": step_reward,
                "sailor_id": sailor_id,
                "action": action.action_type.value,
            })
            
            turn_actions_count += 1
        
        # Visualize game state after all sailors acted this turn
        if visualize and turn_actions_count > 0:
            clear_output(wait=True)
            visualize_game_state(env, turn, turn_actions, turn_reasoning)
            time.sleep(0.5)  # Brief pause to see the state
        
        if verbose and not visualize:
            print(f"--- Turn {turn}: {turn_actions_count} sailors acted ---")
        
        # Check if episode is over
        if env.state.game_over or all(dones.values()):
            if verbose or visualize:
                print(f"\n✅ Episode ended at turn {turn}: game_over={env.state.game_over}")
            break
        
        # Early exit if no one acted (all dead)
        if turn_actions_count == 0:
            if verbose or visualize:
                print(f"\n⚠️  No sailors acted at turn {turn} (all dead)")
            break
    
    if verbose or visualize:
        print(f"\n📊 Episode complete: {len(episode_data)} actions, total reward: {total_reward:.1f}")
    
    return episode_data, total_reward


# ============================================================================
# SIMPLIFIED TRAINING LOOP (Without PPO.step)
# ============================================================================
print("🚀 Starting Hybrid RL + SFT Training (SIMPLIFIED)")
print(f"   Steps: {NUM_TRAINING_STEPS}")
print(f"   Episode max turns: {EPISODE_MAX_TURNS}")
print(f"   Batch size: {BATCH_SIZE}")
print(f"   SFT interval: Every {SFT_INTERVAL} steps\n")
print("⚠️  NOTE: This is a simplified training loop focused on SFT corrections")
print("   PPO updates are disabled due to UnslothPPOTrainer API limitations")
print("   The student learns primarily through teacher corrections\n")

stats_rewards = []
stats_sft_runs = 0
stats_corrections = []

for step in range(NUM_TRAINING_STEPS):
    start_time = time.time()
    batch_data = []
    batch_reward = 0.0
    
    # RL Phase: Collect episodes
    for episode_num in range(BATCH_SIZE):
        print(f"\n📍 Step {step+1}/{NUM_TRAINING_STEPS} - Episode {episode_num+1}/{BATCH_SIZE}")
        
        # Only visualize the first episode of the first step
        should_visualize = (step == 0 and episode_num == 0)
        
        episode_data, episode_reward = generate_episode_with_teacher(
            max_turns=EPISODE_MAX_TURNS,
            verbose=False,  # Disable verbose when visualizing
            visualize=should_visualize
        )
        
        batch_data.extend(episode_data)
        batch_reward += episode_reward
        
        print(f"   ✓ Episode complete: {len(episode_data)} actions, reward: {episode_reward:.1f}")
    
    stats_rewards.append(batch_reward)
    stats_corrections.append(len(correction_dataset))
    
    elapsed = time.time() - start_time
    avg_reward = np.mean(stats_rewards[-10:]) if len(stats_rewards) >= 10 else np.mean(stats_rewards)
    
    print(f"\n{'='*80}")
    print(f"Step {step+1:03d}/{NUM_TRAINING_STEPS} | "
          f"Reward: {batch_reward:+6.1f} | "
          f"Avg(10): {avg_reward:+6.1f} | "
          f"Corrections: {len(correction_dataset):4d} | "
          f"Time: {elapsed:4.1f}s")
    print(f"{'='*80}")
    
    # SFT Phase: Train on corrections
    if (step + 1) % SFT_INTERVAL == 0 and len(correction_dataset) >= 10:
        print(f"\n{'─'*80}")
        print(f"🎓 SFT PASS #{stats_sft_runs + 1}")
        print(f"{'─'*80}")
        run_sft_correction_pass(correction_dataset, num_epochs=1)
        stats_sft_runs += 1
        correction_dataset.clear()
        print(f"{'─'*80}\n")
    
    # Checkpoint
    if (step + 1) % 50 == 0:
        checkpoint_path = f"outputs_marooned_rl/checkpoint_step{step+1}"
        student_model.save_pretrained(checkpoint_path)
        tokenizer.save_pretrained(checkpoint_path)
        print(f"   💾 Checkpoint → {checkpoint_path}")

print("\n" + "="*80)
print("✅ TRAINING COMPLETE!")
print("="*80)
print(f"   Total steps: {NUM_TRAINING_STEPS}")
print(f"   SFT passes: {stats_sft_runs}")
print(f"   Avg reward: {np.mean(stats_rewards):.2f}")
print(f"   Final avg (10): {np.mean(stats_rewards[-10:]):.2f}")
print(f"   Total corrections: {sum(stats_corrections)}")
print("="*80)


🚀 Starting Hybrid RL + SFT Training (SIMPLIFIED)
   Steps: 5
   Episode max turns: 10
   Batch size: 1
   SFT interval: Every 10 steps

⚠️  NOTE: This is a simplified training loop focused on SFT corrections
   PPO updates are disabled due to UnslothPPOTrainer API limitations
   The student learns primarily through teacher corrections


📍 Step 1/5 - Episode 1/1

🎮 Starting episode (max 10 turns)...

--- Turn 0 START ---


  [Alice] wait            | Env=+0.0 Penalty=-0.1 Total=-0.1
  [Bob] move_east       | Env=+0.0 Penalty=+0.0 Total=+0.0
  [Bob] move_east       | Env=+0.0 Penalty=+0.0 Total=+0.0
  [Charlie] wait            | Env=+0.0 Penalty=+0.0 Total=+0.0
  [Charlie] wait            | Env=+0.0 Penalty=+0.0 Total=+0.0
  [Diana] move_north      | Env=-0.0 Penalty=+0.0 Total=-0.0
  [Diana] move_north      | Env=-0.0 Penalty=+0.0 Total=-0.0
  [Eve] wait            | Env=+0.0 Penalty=+0.0 Total=+0.0
--- Turn 0 COMPLETE: 5 sailors acted ---

--- Turn 1 START ---
  [Eve] wait            | Env=+0.0 Penalty=+0.0 Total=+0.0
--- Turn 0 COMPLETE: 5 sailors acted ---

--- Turn 1 START ---
  [Alice] wait            | Env=+0.0 Penalty=-0.1 Total=-0.1
  [Alice] wait            | Env=+0.0 Penalty=-0.1 Total=-0.1
  [Bob] wait            | Env=+0.0 Penalty=-0.1 Total=-0.1
  [Bob] wait            | Env=+0.0 Penalty=-0.1 Total=-0.1
  [Charlie] wait            | Env=+0.0 Penalty=+0.0 Total=+0.0
  [Charlie] wait          

## 🔟 Visualize Training Progress

In [None]:
import matplotlib.pyplot as plt

plt.figure(figsize=(12, 5))

# Rewards over time
plt.subplot(1, 2, 1)
plt.plot(stats_rewards, alpha=0.3, label='Raw Reward')
window = 10
if len(stats_rewards) >= window:
    ma_rewards = np.convolve(stats_rewards, np.ones(window)/window, mode='valid')
    plt.plot(range(window-1, len(stats_rewards)), ma_rewards, label=f'MA({window})', linewidth=2)
plt.xlabel('Training Step')
plt.ylabel('Episode Reward')
plt.title('Hybrid RL + SFT Training Progress')
plt.legend()
plt.grid(True, alpha=0.3)

# Mark SFT passes
for i in range(stats_sft_runs):
    sft_step = (i + 1) * SFT_INTERVAL
    if sft_step < len(stats_rewards):
        plt.axvline(x=sft_step, color='red', linestyle='--', alpha=0.5, linewidth=1)

# Improvement distribution
plt.subplot(1, 2, 2)
improvement = np.diff(stats_rewards)
plt.hist(improvement, bins=30, alpha=0.7, edgecolor='black')
plt.xlabel('Reward Change')
plt.ylabel('Frequency')
plt.title('Step-to-Step Improvement')
plt.grid(True, alpha=0.3)

plt.tight_layout()
plt.savefig('outputs_marooned_rl/training_progress.png', dpi=150, bbox_inches='tight')
plt.show()

print(f"\n📊 Statistics:")
print(f"   Initial avg (10): {np.mean(stats_rewards[:10]):.2f}")
print(f"   Final avg (10): {np.mean(stats_rewards[-10:]):.2f}")
print(f"   Improvement: {np.mean(stats_rewards[-10:]) - np.mean(stats_rewards[:10]):.2f}")
print(f"   SFT passes: {stats_sft_runs}")

## 1️⃣1️⃣ Save Trained Model

In [None]:
output_dir = "outputs_marooned_rl/final_model"

student_model.save_pretrained(output_dir)
tokenizer.save_pretrained(output_dir)

print(f"✅ Model saved to {output_dir}")
print(f"\nTo load:")
print(f"```python")
print(f"from unsloth import FastLanguageModel")
print(f"model, tokenizer = FastLanguageModel.from_pretrained('{output_dir}')")
print(f"```")
print(f"\n🎉 Training complete with hybrid RL + SFT approach!")