<div align="center">

# 🏴‍☠️ MAROONED: Hybrid RL + SFT Training

### Process Reward Modeling with Supervised Correction

**OpenEnv Hackathon 2025**

[![OpenEnv](https://img.shields.io/badge/Framework-OpenEnv-blue)](https://github.com/openenv)
[![Llama](https://img.shields.io/badge/Model-Llama_3.1_8B-green)](https://huggingface.co/meta-llama/Meta-Llama-3.1-8B-Instinct)
[![Hardware](https://img.shields.io/badge/Hardware-AMD_MI300X-red)](https://www.amd.com/en/products/accelerators/instinct/mi300.html)

</div>

---

## 🔬 Training Architecture

**Hybrid Approach: RL (strategy) + SFT (format learning)**

```
┌─────────────────────────────────────────────────┐
│  RL PHASE: PPO Training (Episodes 1-4)          │
│  Student → Teacher (vLLM) → Env → Rewards       │
│  Collect corrections: wrong → correct           │
└──────────────────┬──────────────────────────────┘
                   │ Every 25 steps
                   ▼
┌─────────────────────────────────────────────────┐
│  SFT PHASE: Supervised Fine-Tuning              │
│  Train on corrections: mimic teacher format     │
│  Clear dataset, continue RL                     │
└─────────────────────────────────────────────────┘
```

**Key Innovation:** Student learns format directly from teacher critiques via periodic SFT passes.

---

## ⚙️ Prerequisites

**Start vLLM teacher server in a separate terminal:**

```bash
pip install vllm

vllm serve unsloth/Meta-Llama-3.1-8B-Instruct \
    --dtype bfloat16 \
    --max-model-len 8192 \
    --port 8000
```

**Verify:**
```bash
curl http://localhost:8000/v1/models
```

Expected: `{"data": [{"id": "unsloth/Meta-Llama-3.1-8B-Instruct", ...}]}`

---

## 1️⃣ Install Dependencies

In [2]:
import os, importlib.util
!pip install --upgrade -qqq uv
if importlib.util.find_spec("torch") is None or "COLAB_" in "".join(os.environ.keys()):
    try: import numpy; get_numpy = f"numpy=={numpy.__version__}"
    except: get_numpy = "numpy"
    !uv pip install -qqq \
        "torch>=2.8.0" "triton>=3.4.0" {get_numpy} torchvision bitsandbytes "transformers==4.56.2" trackio \
        "unsloth_zoo[base] @ git+https://github.com/unslothai/unsloth-zoo" \
        "unsloth[base] @ git+https://github.com/unslothai/unsloth" \
        git+https://github.com/triton-lang/triton.git@05b2c186c1b6c9a08375389d5efe9cb4c401c075#subdirectory=python/triton_kernels
elif importlib.util.find_spec("unsloth") is None:
    !uv pip install -qqq unsloth trackio
!uv pip install --upgrade --no-deps transformers==4.56.2 tokenizers trl==0.22.2 unsloth unsloth_zoo

print("✅ Dependencies installed")

[1m[31merror[39m[0m: No virtual environment found; run `[32muv venv[39m` to create an environment, or pass `[32m--system[39m` to install into a non-virtual environment
✅ Dependencies installed


In [1]:
import torch
print(torch.version.hip)
print(torch.cuda.is_available())
print(torch.cuda.get_device_name(0))


7.0.51831-a3e329ad8
True



## 2️⃣ Load MAROONED Environment

In [2]:
import sys
import json
import random
from typing import Dict, Any, List

# Clear cached modules
modules_to_clear = [m for m in list(sys.modules.keys()) 
                   if 'marooned' in m or m in ['environment', 'config', 'models', 'game_state', 'view_map', 'llm_interface']]
for module in modules_to_clear:
    if module in sys.modules:
        del sys.modules[module]

sys.path.insert(0, '../marooned_env')

from environment import MaroonedEnv
from llm_interface import (
    get_system_prompt,
    observation_to_prompt,
    teacher_validate_student_output,
)
from config import ActionType, ResourceType, MapLevel, ShipComponent
from models import Action, Position, Observation

print("✅ MAROONED environment loaded")
print("✅ Teacher validation API imported")

✅ MAROONED environment loaded
✅ Teacher validation API imported


## 3️⃣ Load Student Model (Llama 3.1 8B with LoRA)

In [3]:
import os
os.environ["UNSLOTH_NO_TQDM"] = "1"
from unsloth import FastLanguageModel

import torch

# ROCm optimizations
os.environ["PYTORCH_ROCM_ARCH"] = "gfx942"
os.environ["HSA_FORCE_FINE_GRAIN_PCIE"] = "1"
os.environ["ATTN_BACKEND"] = "triton"
torch.backends.cudnn.benchmark = True

max_seq_length = 16384
lora_rank = 16

student_model, tokenizer = FastLanguageModel.from_pretrained(
    model_name = "/root/llama3_8b",  # your local path
    load_in_4bit = False,
    dtype = torch.bfloat16,
    max_seq_length = 16384,
    device_map = "auto",
)

# Add LoRA adapters
student_model = FastLanguageModel.get_peft_model(
    student_model,
    r = lora_rank,
    target_modules = [
        "q_proj", "k_proj", "v_proj", "o_proj",
        "gate_proj", "up_proj", "down_proj",
    ],
    lora_alpha = lora_rank * 2,
    lora_dropout = 0.0,
    use_gradient_checkpointing = "unsloth",
    random_state = 3407,
    use_rslora = True,
)

print(f"✅ Student Model: Llama 3.1 8B (BF16, LoRA rank={lora_rank})")
print(f"   GPU: {torch.cuda.get_device_name(0)}")
print(f"   VRAM: {torch.cuda.get_device_properties(0).total_memory / 1024**3:.1f} GB")

🦥 Unsloth: Will patch your computer to enable 2x faster free finetuning.




INFO 10-29 20:31:41 [__init__.py:225] Automatically detected platform rocm.
🦥 Unsloth Zoo will now patch everything to make training faster!
==((====))==  Unsloth 2025.10.11: Fast Llama patching. Transformers: 4.56.2. vLLM: 0.11.1rc3.dev39+gf417746ad.rocm700.
   \\   /|    AMD GPU Device. Num GPUs = 1. Max memory: 191.688 GB. Platform: Linux.
O^O/ \_/ \    Torch: 2.9.0a0+git1c57644. ROCm Toolkit: 7.0.51831-a3e329ad8. Triton: 3.4.0
\        /    Bfloat16 = TRUE. FA [Xformers = None. FA2 = True]
 "-____-"     Free license: http://github.com/unslothai/unsloth
Unsloth: Fast downloading is enabled - ignore downloading bars which are red colored!


Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

Unsloth 2025.10.11 patched 32 layers with 32 QKV layers, 32 O layers and 32 MLP layers.


✅ Student Model: Llama 3.1 8B (BF16, LoRA rank=16)
   GPU: 
   VRAM: 191.7 GB


## 4️⃣ Verify vLLM Teacher Server

In [None]:
import requests

VLLM_API_URL = "http://localhost:8001/v1/chat/completions"
VLLM_MODELS_URL = "http://localhost:8001/v1/models"

print("Checking vLLM teacher server...")
try:
    response = requests.get(VLLM_MODELS_URL, timeout=5)
    if response.status_code == 200:
        models = response.json()
        print(f"✅ vLLM server running!")
        print(f"   Model: {[m['id'] for m in models.get('data', [])]}")
    else:
        print(f"⚠️  Server responded with status {response.status_code}")
except requests.exceptions.RequestException as e:
    print(f"❌ vLLM server not reachable!")
    print(f"   Error: {e}")
    print(f"\n   Start server in separate terminal:")
    print(f"   vllm serve unsloth/Meta-Llama-3.1-8B-Instruct --dtype bfloat16 --port 8001")
    raise SystemExit("Teacher server required for training")


Checking vLLM teacher server...
✅ vLLM server running!
   Model: ['Unsloth/Llama-3.3-70B-Instruct']


## 5️⃣ Test Teacher Validation

In [None]:
from config import MAX_ENERGY

print("="*80)
print("🧪 TESTING TEACHER VALIDATION WITH FULL CONTEXT")
print("="*80)

# Create environment and get real observation
env = MaroonedEnv(render_mode="ansi", seed=42)
observations = env.reset(seed=42)
alice_obs = observations["Alice"]
alice_role = env.state.sailors["Alice"].role.value

print(f"\n📋 Test Setup:")
print(f"   Sailor: Alice")
print(f"   Role: {alice_role.upper()}")
print(f"   Position: {alice_obs.position}")
print(f"   Energy: {alice_obs.energy}/{MAX_ENERGY}")
print(f"   Visible resources: {len(alice_obs.visible_resources) if hasattr(alice_obs, 'visible_resources') else 'N/A'}")

# Get proper system and user prompts
system_prompt = get_system_prompt(alice_role)
user_prompt = observation_to_prompt(alice_obs)

print(f"\n📏 Prompt sizes:")
print(f"   System prompt: {len(system_prompt)} chars")
print(f"   User prompt: {len(user_prompt)} chars")
print(f"   Total context: {len(system_prompt) + len(user_prompt)} chars")

# Test cases - simulating what an untrained student LLM might output
test_cases = [
    {
        "name": "Format Error (MOVING instead of MOVE)",
        "output": "REASONING: I should move northeast to explore\nACTION: MOVING NORTH"
    },
    {
        "name": "Invalid Command (CHECK_STATUS)",
        "output": "REASONING: Let me check my status\nACTION: CHECK_STATUS"
    },
    {
        "name": "Missing Resource ID",
        "output": "REASONING: I see wood nearby, gathering it\nACTION: GATHER wood"
    },
    {
        "name": "Truncated Output",
        "output": "REASONING: As the traitor, I should sabotagin"
    },
    {
        "name": "Correct Format",
        "output": "REASONING: Moving north to explore the area\nACTION: MOVE NORTH"
    }
]

print(f"\n{'='*80}")
print("🔬 RUNNING TEACHER VALIDATION TESTS")
print(f"{'='*80}\n")

for i, test in enumerate(test_cases, 1):
    print(f"Test {i}: {test['name']}")
    print(f"   Student output: {test['output'][:60]}...")
    
    # Call teacher validation with full context
    result = teacher_validate_student_output(
        student_response=test['output'],
        observation=alice_obs,
        sailor_id="Alice"
    )
    
    # Display results
    validity_icon = "✅" if result['valid'] else "❌"
    print(f"   {validity_icon} Valid: {result['valid']}")
    print(f"   🔧 Corrected action: {result['action'].action_type.value}")
    print(f"   💰 Process penalty: {result['penalty']}")
    print(f"   💬 Critique: {result['critique'][:80]}")
    
    # Show what would happen
    if result['valid']:
        print(f"   ✅ Action executes as-is (no penalty)")
    else:
        print(f"   ⚠️  Teacher corrected → student gets penalty {result['penalty']}")
    print()

print("="*80)
print("✅ TEACHER VALIDATION API WORKING!")
print("="*80)
print("\nKey Points:")
print("  • Teacher receives full game context (observation + system prompt)")
print("  • Invalid formats get corrected automatically")
print("  • Process penalties guide student learning")
print("  • Student focuses on strategy, not syntax")


🧪 TESTING TEACHER VALIDATION

Sending test cases to teacher...

Test 1: MOVING NORTH
⚠️  Teacher API error: HTTPConnectionPool(host='localhost', port=8000): Max retries exceeded with url: /v1/chat/completions (Caused by NewConnectionError('<urllib3.connection.HTTPConnection object at 0x7bff7f7ce180>: Failed to establish a new connection: [Errno 111] Connection refused'))
   Valid: False
   Corrected: wait
   Penalty: -2.0
   Critique: Teacher API unavailable - defaulting to WAIT...

Test 2: CHECK_STATUS
⚠️  Teacher API error: HTTPConnectionPool(host='localhost', port=8000): Max retries exceeded with url: /v1/chat/completions (Caused by NewConnectionError('<urllib3.connection.HTTPConnection object at 0x7bff7f7ce720>: Failed to establish a new connection: [Errno 111] Connection refused'))
   Valid: False
   Corrected: wait
   Penalty: -2.0
   Critique: Teacher API unavailable - defaulting to WAIT...

Test 3: GATHER wood
⚠️  Teacher API error: HTTPConnectionPool(host='localhost', port=800

## 6️⃣ Setup Correction Dataset for SFT

Collect student errors and teacher corrections during RL training.

In [None]:
correction_dataset = []

def add_correction_example(student_response: str, teacher_result: dict, observation: Observation):
    """Store invalid outputs for SFT training."""
    if not teacher_result["valid"]:
        action = teacher_result["action"]
        action_str = f"{action.action_type.value}"
        
        # Format action string with parameters
        if action.target_position:
            if "NORTH" in action.action_type.value:
                action_str = "MOVE NORTH"
            elif "SOUTH" in action.action_type.value:
                action_str = "MOVE SOUTH"
            elif "EAST" in action.action_type.value:
                action_str = "MOVE EAST"
            elif "WEST" in action.action_type.value:
                action_str = "MOVE WEST"
        elif action.target_resource_id:
            action_str = f"GATHER {action.target_resource_id}"
        elif action.resource_type and action.quantity:
            action_str = f"DEPOSIT {action.resource_type.value} {action.quantity}"
        elif action.ship_component:
            action_str = f"BUILD {action.ship_component.value}"
        elif action.target_sailor:
            if action.action_type == ActionType.OFFER_FOOD:
                action_str = f"POISON {action.target_sailor}"
            else:
                action_str = f"{action.action_type.value} {action.target_sailor}"
        
        correction = {
            "input": student_response,
            "output": f"REASONING: {teacher_result['critique']}\nACTION: {action_str}",
            "penalty": teacher_result["penalty"],
            "critique": teacher_result["critique"]
        }
        
        correction_dataset.append(correction)

print("✅ Correction collector initialized")
print("   Format: (student_wrong) → (teacher_correct + critique)")

## 7️⃣ Define SFT Correction Trainer

In [None]:
from trl import SFTTrainer, SFTConfig
from datasets import Dataset

def run_sft_correction_pass(correction_examples: list, num_epochs: int = 1):
    """
    Run supervised fine-tuning on collected corrections.
    Teaches student to mimic teacher's correct format.
    """
    if len(correction_examples) == 0:
        print("⚠️  No corrections to train on")
        return
    
    print(f"\n{'='*80}")
    print(f"🎓 SFT CORRECTION PASS")
    print(f"{'='*80}")
    print(f"   Examples: {len(correction_examples)}")
    print(f"   Epochs: {num_epochs}")
    
    # Convert to chat format
    sft_data = []
    for example in correction_examples:
        messages = [
            {"role": "user", "content": f"Fix this invalid action:\n{example['input']}"},
            {"role": "assistant", "content": example['output']}
        ]
        text = tokenizer.apply_chat_template(messages, tokenize=False)
        sft_data.append({"text": text})
    
    sft_dataset = Dataset.from_list(sft_data)
    
    # SFT configuration
    sft_config = SFTConfig(
        output_dir="outputs_marooned_rl/sft_corrections",
        num_train_epochs=num_epochs,
        per_device_train_batch_size=2,
        gradient_accumulation_steps=2,
        learning_rate=1e-5,
        logging_steps=10,
        save_steps=100,
        max_seq_length=2048,
        packing=False,
    )
    
    # Train
    sft_trainer = SFTTrainer(
        model=student_model,
        args=sft_config,
        train_dataset=sft_dataset,
        tokenizer=tokenizer,
    )
    
    result = sft_trainer.train()
    
    print(f"\n✅ SFT complete! Loss: {result.training_loss:.4f}")
    print(f"{'='*80}\n")
    
    return result

print("✅ SFT trainer defined")

## 8️⃣ Setup PPO Trainer

In [None]:
from trl import PPOConfig, PPOTrainer, AutoModelForCausalLMWithValueHead
import numpy as np

ppo_config = PPOConfig(
    output_dir="outputs_marooned_rl",
    learning_rate=1e-5,
    batch_size=4,
    mini_batch_size=1,
    gradient_accumulation_steps=4,
    seed=42,
    num_ppo_epochs=4,
    kl_coef=0.2,
    vf_coef=0.1,
    cliprange=0.2,
    gamma=0.99,
    lam=0.95,
    temperature=0.3,
    response_length=256,
)

# Wrap student model with value head
model_with_value = AutoModelForCausalLMWithValueHead.from_pretrained(student_model)

# Compatibility patches
base_model = model_with_value.pretrained_model
if not hasattr(model_with_value, "base_model_prefix"):
    model_with_value.base_model_prefix = getattr(base_model, "base_model_prefix", "model")
setattr(model_with_value, model_with_value.base_model_prefix, base_model)
if not hasattr(model_with_value, "config"):
    model_with_value.config = base_model.config
if not hasattr(model_with_value, "generation_config"):
    model_with_value.generation_config = base_model.generation_config
if hasattr(base_model, "is_gradient_checkpointing"):
    model_with_value.is_gradient_checkpointing = base_model.is_gradient_checkpointing
else:
    model_with_value.is_gradient_checkpointing = False

# Minimal dataset
train_dataset = Dataset.from_dict({
    "prompt": ["stub"],
    "response": ["stub"],
    "reward": [0.0],
})

ppo_trainer = PPOTrainer(
    args=ppo_config,
    model=model_with_value,
    ref_model=None,
    reward_model=model_with_value,
    value_model=model_with_value,
    train_dataset=train_dataset,
    tokenizer=tokenizer,
)

print("✅ PPO Trainer initialized")

## 9️⃣ Hybrid RL + SFT Training Loop

**Training Flow:**
1. **RL Phase:** Student plays episodes, teacher validates, collect corrections
2. **SFT Phase (every 25 steps):** Train on corrections, clear dataset
3. **Repeat:** Continue RL with improved format knowledge

In [None]:
import torch
import time

NUM_TRAINING_STEPS = 100
EPISODE_MAX_SEQ_LENGTH = 16384
SFT_INTERVAL = 25

def generate_episode_with_teacher(max_turns=100, verbose=False):
    """Play episode with teacher validation and correction collection."""
    env = MaroonedEnv(render_mode="ansi")
    observations = env.reset()
    sailor_ids = list(env.agents)
    
    query_tensors, response_tensors, rewards_list = [], [], []
    
    FastLanguageModel.for_inference(student_model)
    
    for turn in range(max_turns):
        for sailor_id in sailor_ids:
            if not env.state.sailors[sailor_id].alive:
                continue
            
            obs = observations[sailor_id]
            role = env.state.sailors[sailor_id].role.value
            
            # Student generates
            system_prompt = get_system_prompt(role)
            user_prompt = observation_to_prompt(obs)
            
            messages = [
                {"role": "system", "content": system_prompt},
                {"role": "user", "content": user_prompt}
            ]
            
            text = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
            inputs = tokenizer(text, return_tensors="pt", truncation=True, max_length=EPISODE_MAX_SEQ_LENGTH).to("cuda")
            query_tensor = inputs["input_ids"][0]
            
            with torch.no_grad():
                outputs = student_model.generate(
                    **inputs,
                    max_new_tokens=256,
                    temperature=0.3,
                    do_sample=True,
                    top_p=0.9,
                    top_k=40,
                    repetition_penalty=1.2,
                    pad_token_id=tokenizer.eos_token_id,
                )
                response_tensor = outputs[0]
            
            student_response = tokenizer.decode(response_tensor[len(query_tensor):], skip_special_tokens=True).strip()
            
            # Teacher validates
            teacher_result = teacher_validate_student_output(student_response, obs, sailor_id)
            action = teacher_result["action"]
            process_penalty = teacher_result["penalty"]
            
            # Collect correction
            add_correction_example(student_response, teacher_result, obs)
            
            # Environment executes
            actions_dict = {sid: Action(sailor_id=sid, action_type=ActionType.WAIT) for sid in env.agents}
            actions_dict[sailor_id] = action
            
            observations, rewards_dict, dones, truncated, info = env.step(actions_dict)
            env_reward = rewards_dict[sailor_id]
            
            # Combined reward
            total_reward = env_reward + process_penalty
            
            query_tensors.append(query_tensor)
            response_tensors.append(response_tensor[len(query_tensor):])
            rewards_list.append(torch.tensor(total_reward, dtype=torch.float32))
            
            if verbose and turn % 20 == 0:
                print(f"Turn {turn:03d} | {sailor_id}: {action.action_type.value:<12} | "
                      f"Env={env_reward:+.1f} Penalty={process_penalty:+.1f} Total={total_reward:+.1f}")
            
            if dones[sailor_id]:
                return query_tensors, response_tensors, rewards_list
    
    return query_tensors, response_tensors, rewards_list


# TRAINING LOOP
print("🚀 Starting Hybrid RL + SFT Training")
print(f"   Steps: {NUM_TRAINING_STEPS}")
print(f"   SFT interval: Every {SFT_INTERVAL} steps\n")

stats_rewards = []
stats_sft_runs = 0

for step in range(NUM_TRAINING_STEPS):
    start_time = time.time()
    batch_queries, batch_responses, batch_rewards = [], [], []
    
    # RL Phase
    for _ in range(ppo_config.batch_size):
        queries, responses, rewards = generate_episode_with_teacher(
            max_turns=100,
            verbose=(step % 50 == 0 and _ == 0)
        )
        batch_queries.extend(queries)
        batch_responses.extend(responses)
        batch_rewards.extend(rewards)
    
    # PPO update
    stats = ppo_trainer.step(batch_queries, batch_responses, batch_rewards)
    
    episode_reward = sum([r.item() for r in batch_rewards])
    stats_rewards.append(episode_reward)
    
    elapsed = time.time() - start_time
    avg_reward = np.mean(stats_rewards[-10:]) if len(stats_rewards) >= 10 else np.mean(stats_rewards)
    
    print(f"Step {step+1:03d}/{NUM_TRAINING_STEPS} | "
          f"Reward: {episode_reward:+6.1f} | "
          f"Avg(10): {avg_reward:+6.1f} | "
          f"Corrections: {len(correction_dataset):4d} | "
          f"Time: {elapsed:4.1f}s")
    
    # SFT Phase
    if (step + 1) % SFT_INTERVAL == 0 and len(correction_dataset) >= 10:
        print(f"\n{'─'*80}")
        print(f"🎓 SFT PASS #{stats_sft_runs + 1}")
        print(f"{'─'*80}")
        run_sft_correction_pass(correction_dataset, num_epochs=1)
        stats_sft_runs += 1
        correction_dataset.clear()
        print(f"{'─'*80}\n")
    
    # Checkpoint
    if (step + 1) % 50 == 0:
        checkpoint_path = f"outputs_marooned_rl/checkpoint_step{step+1}"
        ppo_trainer.save_pretrained(checkpoint_path)
        print(f"   💾 Checkpoint → {checkpoint_path}")

print("\n" + "="*80)
print("✅ TRAINING COMPLETE!")
print("="*80)
print(f"   Total steps: {NUM_TRAINING_STEPS}")
print(f"   SFT passes: {stats_sft_runs}")
print(f"   Avg reward: {np.mean(stats_rewards):.2f}")
print(f"   Final (10): {np.mean(stats_rewards[-10:]):.2f}")
print("="*80)

## 🔟 Visualize Training Progress

In [None]:
import matplotlib.pyplot as plt

plt.figure(figsize=(12, 5))

# Rewards over time
plt.subplot(1, 2, 1)
plt.plot(stats_rewards, alpha=0.3, label='Raw Reward')
window = 10
if len(stats_rewards) >= window:
    ma_rewards = np.convolve(stats_rewards, np.ones(window)/window, mode='valid')
    plt.plot(range(window-1, len(stats_rewards)), ma_rewards, label=f'MA({window})', linewidth=2)
plt.xlabel('Training Step')
plt.ylabel('Episode Reward')
plt.title('Hybrid RL + SFT Training Progress')
plt.legend()
plt.grid(True, alpha=0.3)

# Mark SFT passes
for i in range(stats_sft_runs):
    sft_step = (i + 1) * SFT_INTERVAL
    if sft_step < len(stats_rewards):
        plt.axvline(x=sft_step, color='red', linestyle='--', alpha=0.5, linewidth=1)

# Improvement distribution
plt.subplot(1, 2, 2)
improvement = np.diff(stats_rewards)
plt.hist(improvement, bins=30, alpha=0.7, edgecolor='black')
plt.xlabel('Reward Change')
plt.ylabel('Frequency')
plt.title('Step-to-Step Improvement')
plt.grid(True, alpha=0.3)

plt.tight_layout()
plt.savefig('outputs_marooned_rl/training_progress.png', dpi=150, bbox_inches='tight')
plt.show()

print(f"\n📊 Statistics:")
print(f"   Initial avg (10): {np.mean(stats_rewards[:10]):.2f}")
print(f"   Final avg (10): {np.mean(stats_rewards[-10:]):.2f}")
print(f"   Improvement: {np.mean(stats_rewards[-10:]) - np.mean(stats_rewards[:10]):.2f}")
print(f"   SFT passes: {stats_sft_runs}")

## 1️⃣1️⃣ Save Trained Model

In [None]:
output_dir = "outputs_marooned_rl/final_model"

student_model.save_pretrained(output_dir)
tokenizer.save_pretrained(output_dir)

print(f"✅ Model saved to {output_dir}")
print(f"\nTo load:")
print(f"```python")
print(f"from unsloth import FastLanguageModel")
print(f"model, tokenizer = FastLanguageModel.from_pretrained('{output_dir}')")
print(f"```")
print(f"\n🎉 Training complete with hybrid RL + SFT approach!")