In [None]:
# SAPO Config 2 Experiment Configuration (I=4, J=4) - BEST CONFIG
# This notebook runs 5 nodes with BALANCED swarm collaboration (50% external)

# ============================================
# PRE-CONFIGURED FOR CONFIG 2 (BEST)
# ============================================
EXPERIMENT_NAME = 'sapo_4loc4ext'
NUM_TRAIN_SAMPLES = 4        # I: Local rollouts per round
NUM_TRANSPLANT_TREES = 4     # J: External rollouts (50% external)

# ============================================
# TRAINING MODE: TESTING (default) or PRODUCTION
# ============================================
# TESTING MODE (default): Quick validation ~6 minutes
MAX_ROUNDS = 50              # Testing: 10 rounds to verify everything works

# PRODUCTION MODE: Full training ~21 hours
# Uncomment line below for production run:
# MAX_ROUNDS = 2000          # Production: Full 2000 rounds (like paper)

# ============================================
# FIXED SETTINGS (same for all experiments)
# ============================================
NUM_NODES = 5                # Run 5 nodes (1 coordinator + 4 workers)
MODEL_NAME = 'HuggingFaceTB/SmolLM-360M-Instruct'
NUM_GENERATIONS = 8          # G: Completions per question (like paper)
SEED = 42                    # For reproducibility


# Checkpoint Configuration
CHECKPOINT_INTERVAL = 100    # Save checkpoints every 100 rounds (~1 hour in production)
                             # Allows resume if training crashes
MAX_STAGES = 1               # Stages per round (1=default)

# Optional: HuggingFace Token
HUGGINGFACE_TOKEN = None  # Set to your token or keep None

# ============================================
# DISPLAY CONFIGURATION
# ============================================
mode = "TESTING" if MAX_ROUNDS <= 20 else "PRODUCTION"
estimated_time = "~6 minutes" if MAX_ROUNDS <= 20 else "~21 hours"

print("="*60)
print(f"SAPO Config 2 Experiment (BEST) - {mode} MODE")
print("="*60)
print(f"‚úì Mode: {mode}")
print(f"‚úì Nodes: {NUM_NODES} (1 coordinator + 4 workers on single A100 80GB)")
print(f"‚úì Model: {MODEL_NAME}")
print(f"‚úì Config: I={NUM_TRAIN_SAMPLES}, J={NUM_TRANSPLANT_TREES}, G={NUM_GENERATIONS}")
print(f"‚úì Experiment: {EXPERIMENT_NAME}")
print(f"‚úì Max Rounds: {MAX_ROUNDS}")
print(f"‚úì Checkpoints: Every {CHECKPOINT_INTERVAL} rounds")
print()
print(f"Expected VRAM: ~30 GB peak (4 workers train, coordinator doesn't)")
print(f"Expected Time: {estimated_time}")
print()

if mode == "TESTING":
    print("üß™ TESTING MODE ENABLED")
    print("   Quick validation run - verifies:")
    print("   ‚úì All nodes start successfully")
    print("   ‚úì Rollouts are published and shared")
    print("   ‚úì Training progresses through rounds")
    print("   ‚úì Logs are saved to Google Drive")
    print()
    print("   After validation succeeds, uncomment production line")
    print("   in this cell to run full 2000-round training.")
else:
    print("üìä PRODUCTION MODE - Config 2 (balanced swarm - 50% external)")
    print("   Expected cumulative reward: 500-700")
    print("   Expected improvement vs baseline: +110-150%")
    print("   Paper result (Qwen2.5): 1093 (+94%)")

print("="*60)


GDRIVE_BASE_PATH = '/content'

## 2. Clone Repository & Install Dependencies

In [None]:
%cd /content

# Remove existing directory if it exists
if os.path.exists('/content/rl-swarm'):
    print("Removing existing repository...")
    !rm -rf /content/rl-swarm

# Clone fresh copy
print("Cloning repository...")
!git clone https://github.com/Elrashid/rl-swarm.git /content/rl-swarm

# Change to repo directory
%cd /content/rl-swarm

# Verify clone worked
if not os.path.exists('requirements.txt'):
    raise FileNotFoundError("Repository clone failed - requirements.txt not found")

print("‚úì Repository cloned successfully")
print()

# Install dependencies
print("Installing dependencies (this may take 3-5 minutes)...")
print("Note: Warnings about protobuf versions can be ignored")
print()

# Install main dependencies (without -q to show errors)
!pip install -r requirements.txt

# NOTE: GenRL is now vendored locally in rgym_exp/vendor/genrl/
# No need to install gensyn-genrl package separately

# Fix protobuf version explicitly to avoid conflicts
!pip install 'protobuf>=4.25.0,<5.0'

# Verify reasoning-gym was installed
try:
    import reasoning_gym
    print()
    print("‚úì Dependencies installed successfully")
    print("‚úì reasoning-gym verified")
    print("‚úì GenRL vendored locally (no separate install needed)")
except ImportError as e:
    print()
    print("‚ùå ERROR: reasoning-gym failed to install!")
    print("   Please report this issue with the error above")
    raise

## 3. Initialize Experiment

**Note:** Only the coordinator (node_0) creates the experiment structure. Workers will join it.

In [None]:
from rgym_exp.utils.experiment_manager import init_experiment

# Initialize experiment structure in Google Drive (coordinator creates it)
config_overrides = {
    'training.max_round': MAX_ROUNDS,
    'training.num_generations': NUM_GENERATIONS,
    'training.num_transplant_trees': NUM_TRANSPLANT_TREES,
    'training.num_train_samples': NUM_TRAIN_SAMPLES,
    'training.seed': SEED,
}

init_experiment(
    gdrive_base_path=GDRIVE_BASE_PATH,
    experiment_name=EXPERIMENT_NAME,
    config_overrides=config_overrides
)

print(f"‚úì Experiment initialized: {EXPERIMENT_NAME}")
print(f"  Path: {GDRIVE_BASE_PATH}/experiments/{EXPERIMENT_NAME}")
print(f"  Config: I={NUM_TRAIN_SAMPLES}, J={NUM_TRANSPLANT_TREES}, G={NUM_GENERATIONS}")

## 4. Launch 5-Node Swarm (KEY CELL)

**This cell:**
1. Spawns 5 separate Python processes
2. Each process runs `swarm_launcher.py` with unique NODE_ID
3. All processes share GPU 0 (CUDA_VISIBLE_DEVICES=0)
4. Coordinator (node_0) manages round progression
5. Workers (node_1-4) follow coordinator

**Logs:** Each node writes to Google Drive at `{GDRIVE_BASE_PATH}/experiments/{EXPERIMENT_NAME}/logs/`

**Monitor:** Use next cell (Cell 7) to track progress in real-time

In [None]:
import subprocess
import os
import time
from datetime import datetime

mode_label = "TESTING" if MAX_ROUNDS <= 20 else "PRODUCTION"
estimated_duration = "~6 minutes" if MAX_ROUNDS <= 20 else "~21 hours"

print("="*60)
print(f"Launching {NUM_NODES}-Node SAPO Swarm ({mode_label} MODE)")
print("="*60)
print(f"Experiment: {EXPERIMENT_NAME}")
print(f"Model: {MODEL_NAME}")
print(f"Config: I={NUM_TRAIN_SAMPLES}, J={NUM_TRANSPLANT_TREES}, G={NUM_GENERATIONS}")
print(f"Rounds: {MAX_ROUNDS} ({estimated_duration})")
print(f"Hardware: All {NUM_NODES} nodes on single GPU (A100 80GB)")
print("="*60)
print()

# =========================================
# CODE VERSION CHECK
# =========================================
print("Checking code version...")
result = subprocess.run(['git', 'log', '--oneline', '-5'], 
                       capture_output=True, text=True, cwd='/content/rl-swarm')
commits = result.stdout.strip().split('\n')
latest_commit = commits[0] if commits else "unknown"

print(f"Latest commit: {latest_commit}")
print(f"Recent commits:")
for commit in commits[:3]:
    print(f"  {commit}")
print()

# Check for rollout publishing fix (critical for rollout sharing)
has_rollout_fix = any('027fb2d' in commit or 'rollout publishing' in commit.lower() 
                      for commit in commits)

if has_rollout_fix:
    print("‚úì Has rollout publishing fix (commit 027fb2d)")
else:
    print("‚ö†Ô∏è  WARNING: Missing rollout publishing fix!")
    print("   Rollout sharing may not work correctly")
    print("   Expected: commit 027fb2d 'fix: Implement proper rollout publishing'")

print()
print("="*60)
print("Starting node processes...")
print("="*60)
print()

processes = []
process_stderr = []  # Store stderr for each process
start_time = time.time()  # For ETA calculation

for node_id in range(NUM_NODES):
    # Environment variables for this node
    env = os.environ.copy()
    env['NODE_ID'] = f'node_{node_id}'
    env['NODE_ROLE'] = 'coordinator' if node_id == 0 else 'worker'
    env['MODEL_NAME'] = MODEL_NAME
    env['NUM_TRAIN_SAMPLES'] = str(NUM_TRAIN_SAMPLES)
    env['NUM_TRANSPLANT_TREES'] = str(NUM_TRANSPLANT_TREES)
    env['NUM_GENERATIONS'] = str(NUM_GENERATIONS)
    env['MAX_ROUNDS'] = str(MAX_ROUNDS)
    env['EXPERIMENT_NAME'] = EXPERIMENT_NAME
    env['GDRIVE_PATH'] = GDRIVE_BASE_PATH
    env['CUDA_VISIBLE_DEVICES'] = '0'  # All nodes share GPU 0
    env['SEED'] = str(SEED + node_id)  # Different seed per node (diversity)
    env['CHECKPOINT_INTERVAL'] = str(CHECKPOINT_INTERVAL)
    env['MAX_STAGES'] = str(MAX_STAGES)
    
    # Enable debug logging for state.trees tracking
    env['DEBUG_TREES'] = 'true'
    
    if HUGGINGFACE_TOKEN:
        env['HUGGINGFACE_ACCESS_TOKEN'] = HUGGINGFACE_TOKEN
    
    # Launch process with stderr/stdout capture
    import sys
    process = subprocess.Popen(
        [sys.executable, '-m', 'rgym_exp.runner.swarm_launcher'],
        env=env,
        cwd='/content/rl-swarm',
        stderr=subprocess.PIPE,  # Capture errors
        stdout=subprocess.PIPE,  # Capture output
        text=True
    )
    processes.append(process)
    process_stderr.append(process.stderr)
    
    role = "COORDINATOR" if node_id == 0 else "WORKER     "
    print(f"‚úì Started node_{node_id} ({role}) - PID: {process.pid:5d}")
    
    # Stagger startup: coordinator gets 10s, workers get 5s
    delay = 10 if node_id == 0 else 5
    time.sleep(delay)

print()
print("="*60)
print("PROCESS HEALTH CHECK")
print("="*60)
print("Waiting 30 seconds for processes to initialize...")
time.sleep(30)

crashed_nodes = []
for i, p in enumerate(processes):
    returncode = p.poll()
    if returncode is not None:
        crashed_nodes.append(i)
        print(f"‚ùå node_{i}: CRASHED (exit code {returncode})")
        # Try to read stderr
        try:
            stderr_output = process_stderr[i].read()
            if stderr_output:
                print(f"   Error output:")
                for line in stderr_output.split('\n')[-10:]:  # Last 10 lines
                    if line.strip():
                        print(f"     {line}")
        except:
            print(f"   (Could not read error output)")
    else:
        print(f"‚úì node_{i}: RUNNING")

print()

if crashed_nodes:
    print(f"‚ö†Ô∏è  WARNING: {len(crashed_nodes)}/{NUM_NODES} nodes crashed!")
    print(f"   Crashed: {', '.join(f'node_{i}' for i in crashed_nodes)}")
    print(f"   Check logs in: {GDRIVE_BASE_PATH}/experiments/{EXPERIMENT_NAME}/logs/")
    print()
    user_input = input("Continue with remaining nodes? (yes/no): ")
    if user_input.lower() != 'yes':
        print("Terminating all processes...")
        for p in processes:
            if p.poll() is None:
                p.terminate()
        raise RuntimeError(f"{len(crashed_nodes)} nodes crashed - see errors above")
else:
    print(f"‚úÖ All {NUM_NODES} nodes launched successfully!")

print()
print(f"‚úì Training will run for {estimated_duration} ({MAX_ROUNDS} rounds)")
print(f"‚úì Logs location: {GDRIVE_BASE_PATH}/experiments/{EXPERIMENT_NAME}/logs/")
print(f"‚úì Checkpoints: Every {CHECKPOINT_INTERVAL} rounds")
print()
if MAX_ROUNDS <= 20:
    print("üß™ TESTING MODE: After completion, check Cell 8 for results")
    print("   Then uncomment production line in Cell 2 for full training")
else:
    print("‚ö†Ô∏è  Keep this notebook open (browser tab active)")
    print("‚ö†Ô∏è  Colab may disconnect after 12-24 hours")
    print("‚ö†Ô∏è  Training will continue - use Cell 7.5 to check progress after reconnect")
print()
print("Monitor progress in Cell 7 below...")

## 5. Monitor Training Progress

**This cell:**
- Shows real-time status of all 5 nodes
- Displays GPU memory usage
- Shows current round/stage progress
- Estimates time remaining (ETA)
- Updates every 60 seconds

**To stop training:** Click "Stop" button or press Ctrl+C

**Note:** You can re-run this cell anytime to check status

In [None]:
import time
from IPython.display import clear_output
import pandas as pd

print("Starting training monitor...")
print("Press 'Stop' button or Ctrl+C to interrupt\n")

monitor_start_time = time.time()

try:
    while True:
        clear_output(wait=True)

        # Check process status
        running = sum(1 for p in processes if p.poll() is None)
        completed = NUM_NODES - running

        current_time = datetime.now().strftime('%Y-%m-%d %H:%M:%S')
        elapsed_hours = (time.time() - start_time) / 3600

        print("="*70)
        print(f" SAPO Training Monitor - {EXPERIMENT_NAME}")
        print(f" Time: {current_time} | Elapsed: {elapsed_hours:.1f}h")
        print("="*70)
        print()

        # Node status
        print(f"Nodes: {running}/{NUM_NODES} running, {completed} completed")
        print()

        # GPU memory
        if torch.cuda.is_available():
            reserved = torch.cuda.memory_reserved(0) / 1e9
            total = torch.cuda.get_device_properties(0).total_memory / 1e9
            utilization = (reserved / total) * 100

            print(f"GPU: {reserved:.1f} / {total:.1f} GB ({utilization:.1f}%)")

            if utilization > 90:
                print(f"  ‚ö†Ô∏è  WARNING: High memory usage!")
            elif utilization > 75:
                print(f"  ‚ö†Ô∏è  Memory usage elevated")
            print()

        # Training progress
        try:
            from rgym_exp.utils.experiment_manager import get_experiment_status
            status = get_experiment_status(GDRIVE_BASE_PATH, EXPERIMENT_NAME)

            if status:
                current_round = status.get('current_round', 0)
                progress_pct = (current_round / MAX_ROUNDS) * 100

                print(f"Progress:")
                print(f"  Round: {current_round:4d} / {MAX_ROUNDS} ({progress_pct:5.1f}%)")

                # Progress bar
                bar_length = 50
                filled = int(bar_length * progress_pct / 100)
                bar = '‚ñà' * filled + '‚ñë' * (bar_length - filled)
                print(f"  [{bar}]")

                # ETA
                if current_round > 10:
                    hours_per_round = elapsed_hours / current_round
                    remaining_rounds = MAX_ROUNDS - current_round
                    eta_hours = remaining_rounds * hours_per_round
                    print(f"  ETA: {eta_hours:.1f} hours (~{eta_hours/24:.1f} days)")

                print()

                # ==========================================
                # TOTAL CUMULATIVE REWARD (PRIMARY METRIC)
                # ==========================================
                if 'total_cumulative_reward' in status:
                    cumulative_rewards = status.get('cumulative_rewards', {})
                    total_cumulative = status['total_cumulative_reward']

                    print("‚îå" + "‚îÄ"*68 + "‚îê")
                    print("‚îÇ" + " TOTAL CUMULATIVE REWARD (SAPO Paper Metric)".center(68) + "‚îÇ")
                    print("‚îú" + "‚îÄ"*68 + "‚î§")
                    print(f"‚îÇ  TOTAL: {total_cumulative:8.2f}".ljust(69) + "‚îÇ")

                    if cumulative_rewards:
                        print("‚îÇ".ljust(69) + "‚îÇ")
                        print("‚îÇ  Per-Node Breakdown:".ljust(69) + "‚îÇ")
                        for node_id in sorted(cumulative_rewards.keys()):
                            node_reward = cumulative_rewards[node_id]
                            pct = (node_reward / total_cumulative * 100) if total_cumulative > 0 else 0
                            line = f"‚îÇ    {node_id:10s}: {node_reward:7.2f} ({pct:5.1f}%)"
                            print(line.ljust(69) + "‚îÇ")

                    # Average per round
                    avg_per_round = total_cumulative / current_round if current_round > 0 else 0
                    print("‚îÇ".ljust(69) + "‚îÇ")
                    print(f"‚îÇ  Avg per round: {avg_per_round:6.4f}".ljust(69) + "‚îÇ")

                    # Comparison to paper benchmarks
                    print("‚îÇ".ljust(69) + "‚îÇ")
                    print("‚îÇ  Paper Benchmarks (Qwen2.5, 2000 rounds):".ljust(69) + "‚îÇ")
                    print("‚îÇ    Baseline (8/0):    562".ljust(69) + "‚îÇ")
                    print("‚îÇ    Config 2 (4/4):  1,093 (+94%) ‚≠ê".ljust(69) + "‚îÇ")
                    print("‚îî" + "‚îÄ"*68 + "‚îò")
                    print()

                # Recent performance
                try:
                    from rgym_exp.utils.experiment_manager import get_experiment_metrics
                    df = get_experiment_metrics(GDRIVE_BASE_PATH, EXPERIMENT_NAME)

                    if not df.empty:
                        recent_reward = df.tail(10)['my_reward'].mean()
                        print(f"Recent Performance:")
                        print(f"  Avg reward: {recent_reward:6.4f} (last 10 rounds)")
                        print()
                except Exception:
                    pass

        except Exception as e:
            print(f"Progress: Unable to load status ({e})")
            print()

        # Instructions
        print("-"*70)
        print("Press 'Stop' button or Ctrl+C to halt training")
        print(f"Next update in 60 seconds...")

        # Exit if all completed
        if running == 0:
            print()
            print("="*70)
            print("‚úÖ All nodes completed successfully!")
            print("="*70)
            break

        time.sleep(60)

except KeyboardInterrupt:
    print("\n" + "="*70)
    print("‚ö†Ô∏è  Training interrupted by user")
    print("="*70)
    print("\nTerminating all node processes...")

    for i, p in enumerate(processes):
        if p.poll() is None:
            print(f"  Stopping node_{i}... (PID: {p.pid})")
            p.terminate()

    time.sleep(5)

    for i, p in enumerate(processes):
        if p.poll() is None:
            print(f"  Force killing node_{i}... (PID: {p.pid})")
            p.kill()

    print("\n‚úì All processes terminated")
    print("\nüíæ Note: Training state is checkpointed.")
    print("   Re-run this notebook to resume from last checkpoint.")