# SAPO Config 3: 5 GPT-2 Nodes (I=1, J=3) on Single A100 80GB

This notebook runs **SAPO Config 3** with **heavy swarm collaboration** (75% external rollouts).

**Configuration:**
- **I=1** (local rollouts per round)
- **J=3** (external rollouts per round - 75% external)
- **G=8** (completions per question)
- **Model**: GPT-2 (124M params)
- **Hardware**: 5 nodes (1 coordinator + 4 workers) on 1× A100 80GB

**Purpose:** Test heavy swarm collaboration effect.

**Expected Results:**
- Cumulative reward: **450-650**
- Improvement vs baseline: **+100-130%**
- Paper (Qwen2.5-0.5B): 946 (+68% vs their baseline)

**Memory Usage:** ~33 GB peak VRAM (safe on A100 80GB)

**Timeline:** ~21 hours (2000 rounds)

**Scientific Justification:** See `EXPERIMENTAL_DESIGN_JUSTIFICATION.md`

**Paper Reference:** arXiv:2509.08721 - SAPO (Gensyn AI Team, 2025)

## 1. Configuration

**This notebook is pre-configured for Config 3 (I=1, J=3).**

Just run all cells - no changes needed!

In [None]:
# SAPO Config 3 Experiment Configuration (I=1, J=3)
# This notebook runs 5 nodes with HEAVY swarm collaboration (75% external)

# ============================================
# PRE-CONFIGURED FOR CONFIG 3
# ============================================
EXPERIMENT_NAME = 'sapo_gpt2_config3_1loc3ext'
NUM_TRAIN_SAMPLES = 1        # I: Local rollouts per round
NUM_TRANSPLANT_TREES = 3     # J: External rollouts (75% external)

# ============================================
# FIXED SETTINGS (same for all experiments)
# ============================================
NUM_NODES = 5                # Run 5 nodes (1 coordinator + 4 workers)
MODEL_NAME = 'gpt2'          # GPT-2 (124M params, fits memory)
NUM_GENERATIONS = 8          # G: Completions per question (like paper)
MAX_ROUNDS = 2000            # Train for 2000 rounds (like paper)
SEED = 42                    # For reproducibility

# Rollout Sharing Configuration
ROLLOUT_PUBLISH_FREQUENCY = 'stage'  # When to share rollouts
ROLLOUT_CLEANUP_ENABLED = True       # Enable cleanup to save space
ROLLOUT_KEEP_LAST_N_ROUNDS = 20      # Keep recent rollouts only
ROLLOUT_ARCHIVE_OLD = False          # Don't archive (saves space)

# Optional: HuggingFace Token
HUGGINGFACE_TOKEN = None  # Set to your token or keep None

print("="*60)
print(f"SAPO Config 3 Experiment")
print("="*60)
print(f"✓ Nodes: {NUM_NODES} (1 coordinator + 4 workers on single A100 80GB)")
print(f"✓ Model: {MODEL_NAME}")
print(f"✓ Config: I={NUM_TRAIN_SAMPLES}, J={NUM_TRANSPLANT_TREES}, G={NUM_GENERATIONS}")
print(f"✓ Experiment: {EXPERIMENT_NAME}")
print(f"✓ Max Rounds: {MAX_ROUNDS}")
print()
print(f"Expected VRAM: ~33 GB (80 GB available)")
print(f"Expected Time: ~21 hours")
print()
print("📊 Config 3 (heavy swarm collaboration - 75% external)")
print("   Expected reward: 450-650 (+100-130% vs baseline)")
print("   Paper (Qwen2.5): 946 (+68%)")

## 2. Mount Google Drive

In [None]:
from google.colab import drive
import os

# Mount Google Drive
drive.mount('/content/drive')

# Set base path (MUST BE SAME ACROSS ALL NODES)
GDRIVE_BASE_PATH = '/content/drive/MyDrive/rl-swarm'
os.makedirs(GDRIVE_BASE_PATH, exist_ok=True)

print(f"✓ Google Drive mounted at: {GDRIVE_BASE_PATH}")

## 3. System Setup & GPU Verification

In [None]:
import torch

print("="*60)
print("GPU Verification")
print("="*60)

if torch.cuda.is_available():
    gpu_name = torch.cuda.get_device_name(0)
    total_memory = torch.cuda.get_device_properties(0).total_memory / 1e9
    
    print(f"✓ GPU: {gpu_name}")
    print(f"✓ Total VRAM: {total_memory:.1f} GB")
    print()
    
    # Check if we have enough memory
    required_memory = NUM_NODES * 6.5  # 6.5 GB per GPT-2 node
    print(f"Memory Requirements:")
    print(f"  Required: {required_memory:.1f} GB ({NUM_NODES} nodes × 6.5 GB)")
    print(f"  Available: {total_memory:.1f} GB")
    print(f"  Margin: {total_memory - required_memory:.1f} GB")
    print()
    
    if total_memory < required_memory:
        print("⚠️  WARNING: Insufficient VRAM!")
        print(f"   Need at least {required_memory:.0f} GB, but have {total_memory:.1f} GB")
        print(f"   Consider reducing NUM_NODES to {int(total_memory / 6.5)}")
        raise RuntimeError("Insufficient GPU memory")
    elif total_memory < 75:
        print("⚠️  WARNING: Tight fit! Expected A100 80GB.")
        print(f"   Have {total_memory:.1f} GB. May still work, but monitor memory closely.")
    else:
        print(f"✅ Sufficient VRAM for {NUM_NODES} GPT-2 nodes")
else:
    raise RuntimeError("No GPU detected! Select A100 GPU runtime: Runtime > Change runtime type > A100 GPU")

## 4. Clone Repository & Install Dependencies

In [None]:
%cd /content

# Remove existing directory if it exists
if os.path.exists('/content/rl-swarm'):
    print("Removing existing repository...")
    !rm -rf /content/rl-swarm

# Clone fresh copy
print("Cloning repository...")
!git clone https://github.com/Elrashid/rl-swarm.git /content/rl-swarm

# Change to repo directory
%cd /content/rl-swarm

# Verify clone worked
if not os.path.exists('requirements.txt'):
    raise FileNotFoundError("Repository clone failed - requirements.txt not found")

print("✓ Repository cloned successfully")
print()

# Install dependencies
print("Installing dependencies (this may take 3-5 minutes)...")
print("Note: Warnings about protobuf versions can be ignored")
print()

# Install main dependencies (without -q to show errors)
!pip install -r requirements.txt

# Install GenRL explicitly
!pip install gensyn-genrl==0.1.9

# Fix protobuf version explicitly to avoid conflicts
!pip install 'protobuf>=4.25.0,<5.0'

# Verify reasoning-gym was installed
try:
    import reasoning_gym
    print()
    print("✓ Dependencies installed successfully")
    print("✓ reasoning-gym verified")
except ImportError as e:
    print()
    print("❌ ERROR: reasoning-gym failed to install!")
    print("   Please report this issue with the error above")
    raise

## 5. Initialize Experiment

**Note:** Only the coordinator (node_0) creates the experiment structure. Workers will join it.

In [None]:
from rgym_exp.utils.experiment_manager import init_experiment

# Initialize experiment structure in Google Drive (coordinator creates it)
config_overrides = {
    'training.max_round': MAX_ROUNDS,
    'training.num_generations': NUM_GENERATIONS,
    'training.num_transplant_trees': NUM_TRANSPLANT_TREES,
    'training.num_train_samples': NUM_TRAIN_SAMPLES,
    'training.seed': SEED,
}

init_experiment(
    gdrive_base_path=GDRIVE_BASE_PATH,
    experiment_name=EXPERIMENT_NAME,
    config_overrides=config_overrides
)

print(f"✓ Experiment initialized: {EXPERIMENT_NAME}")
print(f"  Path: {GDRIVE_BASE_PATH}/experiments/{EXPERIMENT_NAME}")
print(f"  Config: I={NUM_TRAIN_SAMPLES}, J={NUM_TRANSPLANT_TREES}, G={NUM_GENERATIONS}")

## 6. Launch 5-Node Swarm (KEY CELL)

**This cell:**
1. Spawns 5 separate Python processes
2. Each process runs `swarm_launcher.py` with unique NODE_ID
3. All processes share GPU 0 (CUDA_VISIBLE_DEVICES=0)
4. Coordinator (node_0) manages round progression
5. Workers (node_1-4) follow coordinator

**Logs:** Each node writes to Google Drive at `{GDRIVE_BASE_PATH}/experiments/{EXPERIMENT_NAME}/logs/`

**Monitor:** Use next cell (Cell 7) to track progress in real-time

In [None]:
import subprocess
import os
import time
from datetime import datetime

print("="*60)
print(f"Launching {NUM_NODES}-Node SAPO Swarm")
print("="*60)
print(f"Experiment: {EXPERIMENT_NAME}")
print(f"Model: {MODEL_NAME}")
print(f"Config: I={NUM_TRAIN_SAMPLES}, J={NUM_TRANSPLANT_TREES}, G={NUM_GENERATIONS}")
print(f"Hardware: All {NUM_NODES} nodes on single GPU (A100 80GB)")
print("="*60)
print()

processes = []
start_time = time.time()  # For ETA calculation

for node_id in range(NUM_NODES):
    # Environment variables for this node
    env = os.environ.copy()
    env['NODE_ID'] = f'node_{node_id}'
    env['NODE_ROLE'] = 'coordinator' if node_id == 0 else 'worker'
    env['MODEL_NAME'] = MODEL_NAME
    env['NUM_TRAIN_SAMPLES'] = str(NUM_TRAIN_SAMPLES)
    env['NUM_TRANSPLANT_TREES'] = str(NUM_TRANSPLANT_TREES)
    env['NUM_GENERATIONS'] = str(NUM_GENERATIONS)
    env['MAX_ROUNDS'] = str(MAX_ROUNDS)
    env['EXPERIMENT_NAME'] = EXPERIMENT_NAME
    env['GDRIVE_PATH'] = GDRIVE_BASE_PATH
    env['CUDA_VISIBLE_DEVICES'] = '0'  # All nodes share GPU 0
    env['SEED'] = str(SEED + node_id)  # Different seed per node (diversity)
    env['ROLLOUT_PUBLISH_FREQUENCY'] = ROLLOUT_PUBLISH_FREQUENCY
    env['ROLLOUT_CLEANUP_ENABLED'] = str(ROLLOUT_CLEANUP_ENABLED)
    env['ROLLOUT_KEEP_LAST_N_ROUNDS'] = str(ROLLOUT_KEEP_LAST_N_ROUNDS)
    env['ROLLOUT_ARCHIVE_OLD'] = str(ROLLOUT_ARCHIVE_OLD)
    
    if HUGGINGFACE_TOKEN:
        env['HUGGINGFACE_ACCESS_TOKEN'] = HUGGINGFACE_TOKEN
    
    # Launch process
    import sys
    process = subprocess.Popen(
        [sys.executable, '-m', 'rgym_exp.runner.swarm_launcher'],
        env=env,
        cwd='/content/rl-swarm'
    )
    processes.append(process)
    
    role = "COORDINATOR" if node_id == 0 else "WORKER     "
    print(f"✓ Started node_{node_id} ({role}) - PID: {process.pid:5d}")
    
    # Stagger startup to avoid race conditions
    time.sleep(10)

print()
print(f"✅ All {NUM_NODES} nodes launched successfully!")
print(f"✓ Training will run for approximately 21 hours ({MAX_ROUNDS} rounds)")
print(f"✓ Logs location: {GDRIVE_BASE_PATH}/experiments/{EXPERIMENT_NAME}/logs/")
print()
print("⚠️  Keep this notebook open (browser tab active)")
print("⚠️  Colab may disconnect after 12-24 hours")
print("⚠️  Training will continue, but use Cell 7 to monitor")
print()
print("Monitor progress in Cell 7 below...")

## 7. Monitor Training Progress

**This cell:**
- Shows real-time status of all 5 nodes
- Displays GPU memory usage
- Shows current round/stage progress
- Estimates time remaining (ETA)
- Updates every 60 seconds

**To stop training:** Click "Stop" button or press Ctrl+C

**Note:** You can re-run this cell anytime to check status

In [None]:
import time
from IPython.display import clear_output
import pandas as pd

print("Starting training monitor...")
print("Press 'Stop' button or Ctrl+C to interrupt\n")

monitor_start_time = time.time()

try:
    while True:
        clear_output(wait=True)
        
        # Check process status
        running = sum(1 for p in processes if p.poll() is None)
        completed = NUM_NODES - running
        
        current_time = datetime.now().strftime('%Y-%m-%d %H:%M:%S')
        elapsed_hours = (time.time() - start_time) / 3600
        
        print("="*70)
        print(f" SAPO Training Monitor - {EXPERIMENT_NAME}")
        print(f" Time: {current_time} | Elapsed: {elapsed_hours:.1f}h")
        print("="*70)
        print()
        
        # Node status
        print(f"Nodes:")
        print(f"  Running:   {running}/{NUM_NODES}")
        print(f"  Completed: {completed}/{NUM_NODES}")
        print()
        
        # GPU memory
        if torch.cuda.is_available():
            allocated = torch.cuda.memory_allocated(0) / 1e9
            reserved = torch.cuda.memory_reserved(0) / 1e9
            total = torch.cuda.get_device_properties(0).total_memory / 1e9
            
            utilization = (reserved / total) * 100
            
            print(f"GPU Memory ({torch.cuda.get_device_name(0)}):")
            print(f"  Allocated: {allocated:5.1f} GB")
            print(f"  Reserved:  {reserved:5.1f} GB / {total:.1f} GB ({utilization:.1f}%)")
            print(f"  Free:      {total - reserved:5.1f} GB")
            
            # Warning if memory is high
            if utilization > 90:
                print(f"  ⚠️  WARNING: High memory usage! May OOM soon.")
            elif utilization > 75:
                print(f"  ⚠️  Memory usage elevated. Monitoring closely.")
            
            print()
        
        # Training progress
        try:
            from rgym_exp.utils.experiment_manager import get_experiment_status
            status = get_experiment_status(GDRIVE_BASE_PATH, EXPERIMENT_NAME)
            
            if status:
                current_round = status.get('current_round', 0)
                progress_pct = (current_round / MAX_ROUNDS) * 100
                
                print(f"Training Progress:")
                print(f"  Round:     {current_round:4d} / {MAX_ROUNDS} ({progress_pct:5.1f}%)")
                print(f"  Stage:     {status.get('current_stage', 0)}")
                print(f"  Active peers: {status.get('active_peers', 0)}")
                
                # ETA calculation
                if current_round > 10:  # Wait for stable estimate
                    hours_per_round = elapsed_hours / current_round
                    remaining_rounds = MAX_ROUNDS - current_round
                    eta_hours = remaining_rounds * hours_per_round
                    
                    print(f"  ETA:       {eta_hours:.1f} hours (~{eta_hours/24:.1f} days)")
                    
                    # Progress bar
                    bar_length = 40
                    filled = int(bar_length * progress_pct / 100)
                    bar = '█' * filled + '░' * (bar_length - filled)
                    print(f"  [{bar}]")
                
                print()
                
                # Recent performance
                try:
                    from rgym_exp.utils.experiment_manager import get_experiment_metrics
                    df = get_experiment_metrics(GDRIVE_BASE_PATH, EXPERIMENT_NAME)
                    
                    if not df.empty:
                        cumulative_reward = df['my_reward'].sum()
                        recent_reward = df.tail(10)['my_reward'].mean()
                        
                        print(f"Rewards:")
                        print(f"  Cumulative: {cumulative_reward:6.2f}")
                        print(f"  Recent avg: {recent_reward:6.2f} (last 10 rounds)")
                        print()
                except Exception:
                    pass  # Metrics not available yet
                    
        except Exception as e:
            print(f"Progress: Unable to load status ({e})")
            print()
        
        # Instructions
        print("-"*70)
        print("Press 'Stop' button or Ctrl+C to halt training")
        print(f"Next update in 60 seconds...")
        
        # Exit if all completed
        if running == 0:
            print()
            print("="*70)
            print("✅ All nodes completed successfully!")
            print("="*70)
            break
        
        time.sleep(60)  # Update every minute

except KeyboardInterrupt:
    print("\n" + "="*70)
    print("⚠️  Training interrupted by user")
    print("="*70)
    print("\nTerminating all node processes...")
    
    for i, p in enumerate(processes):
        if p.poll() is None:
            print(f"  Stopping node_{i}... (PID: {p.pid})")
            p.terminate()
    
    time.sleep(5)
    
    # Force kill if still running
    for i, p in enumerate(processes):
        if p.poll() is None:
            print(f"  Force killing node_{i}... (PID: {p.pid})")
            p.kill()
    
    print("\n✓ All processes terminated")
    print("\n💾 Note: Training state is checkpointed.")
    print("   Re-run this notebook to resume from last checkpoint.")

finally:
    # Close log files
    for log_file in log_files:
        try:
            log_file.close()
        except:
            pass

## 8. View Results & Analysis

**After training completes, run this cell to:**
- Load all metrics
- Calculate cumulative rewards per node
- Compare to paper's results
- Generate plots

In [None]:
from rgym_exp.utils.experiment_manager import get_experiment_metrics
import pandas as pd
import matplotlib.pyplot as plt

print("="*70)
print(f"Results: {EXPERIMENT_NAME}")
print("="*70)
print(f"Config: I={NUM_TRAIN_SAMPLES}, J={NUM_TRANSPLANT_TREES}, G={NUM_GENERATIONS}")
print(f"Model: {MODEL_NAME}")
print(f"Nodes: {NUM_NODES}")
print()

# Load metrics
df = get_experiment_metrics(GDRIVE_BASE_PATH, EXPERIMENT_NAME)

if not df.empty:
    # Calculate cumulative reward per node
    node_rewards = df.groupby('node_id')['my_reward'].sum().sort_values(ascending=False)
    total_reward = node_rewards.sum()
    
    print("Cumulative Rewards by Node:")
    for node_id, reward in node_rewards.items():
        print(f"  {node_id:10s}: {reward:7.2f}")
    print(f"  {'TOTAL':10s}: {total_reward:7.2f}")
    print()
    
    # Compare to paper's results
    print("Comparison to Paper (Qwen2.5-0.5B):")
    print("-"*70)
    
    if NUM_TRANSPLANT_TREES == 0:
        paper_reward = 562
        config_name = "Baseline (8/0)"
        paper_improvement = "—"
    elif NUM_TRAIN_SAMPLES == 6 and NUM_TRANSPLANT_TREES == 2:
        paper_reward = 854
        config_name = "Config 1 (6/2)"
        paper_improvement = "+52%"
    elif NUM_TRAIN_SAMPLES == 4 and NUM_TRANSPLANT_TREES == 4:
        paper_reward = 1093
        config_name = "Config 2 (4/4) **BEST**"
        paper_improvement = "+94%"
    elif NUM_TRAIN_SAMPLES == 2 and NUM_TRANSPLANT_TREES == 6:
        paper_reward = 946
        config_name = "Config 3 (2/6)"
        paper_improvement = "+68%"
    else:
        paper_reward = None
        config_name = f"Custom ({NUM_TRAIN_SAMPLES}/{NUM_TRANSPLANT_TREES})"
        paper_improvement = "N/A"
    
    print(f"  Configuration: {config_name}")
    if paper_reward:
        print(f"  Paper (Qwen2.5):  {paper_reward:7.2f} ({paper_improvement})")
        print(f"  Ours (GPT-2):     {total_reward:7.2f} (~{total_reward/paper_reward*100:.1f}% of paper)")
    else:
        print(f"  Ours (GPT-2):     {total_reward:7.2f}")
    print()
    
    # If we have baseline results, calculate improvement
    baseline_path = GDRIVE_BASE_PATH + '/experiments/sapo_gpt2_baseline_8loc0ext'
    try:
        baseline_df = get_experiment_metrics(GDRIVE_BASE_PATH, 'sapo_gpt2_baseline_8loc0ext')
        if not baseline_df.empty and NUM_TRANSPLANT_TREES > 0:
            baseline_reward = baseline_df['my_reward'].sum()
            improvement = ((total_reward - baseline_reward) / baseline_reward) * 100
            print(f"Improvement vs Our Baseline:")
            print(f"  Baseline (GPT-2): {baseline_reward:7.2f}")
            print(f"  This config:      {total_reward:7.2f}")
            print(f"  Improvement:      {improvement:+7.1f}%")
            print()
            
            if improvement > 94:
                print("✅ HYPOTHESIS CONFIRMED! GPT-2 shows >94% improvement")
                print("   Weaker models benefit MORE from swarm (as predicted)")
            elif improvement > 50:
                print("✅ Strong swarm effect demonstrated (+{:.1f}%)".format(improvement))
            else:
                print("⚠️  Lower improvement than expected")
            print()
    except:
        pass  # Baseline not run yet
    
    # Plot rewards over time
    print("Generating plots...")
    
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(16, 6))
    
    # Plot 1: Cumulative reward per node
    for node_id in df['node_id'].unique():
        node_df = df[df['node_id'] == node_id].sort_values('round')
        ax1.plot(node_df['round'], node_df['my_reward'].cumsum(), 
                 label=node_id, alpha=0.7)
    
    ax1.set_xlabel('Round', fontsize=12)
    ax1.set_ylabel('Cumulative Reward', fontsize=12)
    ax1.set_title(f'{config_name} - Cumulative Rewards Over Time', fontsize=14, fontweight='bold')
    ax1.legend(loc='upper left', fontsize=8)
    ax1.grid(True, alpha=0.3)
    
    # Plot 2: Average reward per round (smoothed)
    round_avg = df.groupby('round')['my_reward'].mean()
    # Apply moving average
    window_size = 100
    smoothed = round_avg.rolling(window=window_size, center=True).mean()
    
    ax2.plot(round_avg.index, round_avg.values, alpha=0.3, color='gray', label='Raw')
    ax2.plot(smoothed.index, smoothed.values, linewidth=2, color='blue', label=f'Smoothed ({window_size}-round MA)')
    
    ax2.set_xlabel('Round', fontsize=12)
    ax2.set_ylabel('Average Reward', fontsize=12)
    ax2.set_title(f'{config_name} - Average Reward per Round', fontsize=14, fontweight='bold')
    ax2.legend(fontsize=10)
    ax2.grid(True, alpha=0.3)
    
    plt.tight_layout()
    
    # Save plot
    plot_path = f'/content/drive/MyDrive/rl-swarm/results_{EXPERIMENT_NAME}.png'
    plt.savefig(plot_path, dpi=150, bbox_inches='tight')
    print(f"✓ Plot saved to: {plot_path}")
    
    plt.show()
    
else:
    print("❌ No metrics available yet")
    print("   Training may not have started or metrics file is missing")