# 🧪 TEST MODE - Baseline Configuration (1-2 minutes)

This notebook runs a **QUICK TEST** of the **Baseline SAPO configuration** (100% local rollouts, no sharing).

**Configuration:**
- **I=4** (local rollouts per round)
- **J=0** (external rollouts per round - NO SHARING)
- **G=4** (completions per question - REDUCED for speed)
- **Model**: GPT-2 (124M params)
- **Hardware**: 5 nodes (1 coordinator + 4 workers) on 1× A100 80GB
- **Rounds**: 3 only (for quick testing)

**What This Tests:**
- ✅ Coordinator round advancement
- ✅ Worker training and reward submission
- ✅ Logging infrastructure
- ✅ GDrive state management
- ⊘ Rollout sharing (not tested - use Config1/2/3 for that)

**Other Test Configurations:**
- **Config1**: `TEST_MODE_Config1.ipynb` - I=3, J=1 (75%/25% split, tests rollout sharing)
- **Config2**: `TEST_MODE_Config2.ipynb` - I=2, J=2 (50%/50% split, balanced sharing)
- **Config3**: `TEST_MODE_Config3.ipynb` - I=1, J=3 (25%/75% split, heavy sharing)

**Expected Results:**
- Completes in 1-2 minutes
- 3 rounds executed
- All 4 workers submit rewards
- State persisted to GDrive
- No rollouts published (J=0)

**Memory Usage:** ~33 GB peak VRAM (4 training nodes)

**Note:** This is a TEST run with reduced parameters. Not for scientific experiments.

## 1. Configuration

**This notebook is pre-configured for TEST MODE.**

Just run all cells - no changes needed!

In [None]:
# TEST MODE - Baseline Configuration
# Quick validation run with minimal parameters

# ============================================
# TEST MODE SETTINGS
# ============================================
EXPERIMENT_NAME = 'test_baseline_4loc0ext'
NUM_TRAIN_SAMPLES = 4        # I: Local rollouts per round
NUM_TRANSPLANT_TREES = 0     # J: External rollouts (NONE - baseline config)
NUM_GENERATIONS = 4          # G: Completions per question (REDUCED for speed)
MAX_ROUNDS = 3               # Only 3 rounds for quick test

# ============================================
# FIXED SETTINGS
# ============================================
NUM_NODES = 5                # Run 5 nodes (1 coordinator + 4 workers)
MODEL_NAME = 'gpt2'          # GPT-2 (124M params, fits memory)
SEED = 42                    # For reproducibility

# Rollout Sharing Configuration
ROLLOUT_PUBLISH_FREQUENCY = 'stage'  # When to share rollouts
ROLLOUT_CLEANUP_ENABLED = True       # Enable cleanup to save space
ROLLOUT_KEEP_LAST_N_ROUNDS = 20      # Keep recent rollouts only
ROLLOUT_ARCHIVE_OLD = False          # Don't archive (saves space)

# Optional: HuggingFace Token
HUGGINGFACE_TOKEN = None  # Set to your token or keep None

print("="*60)
print(f"TEST MODE - Baseline Configuration (100% local)")
print("="*60)
print(f"✓ Nodes: {NUM_NODES} (1 coordinator + 4 workers on single A100 80GB)")
print(f"✓ Model: {MODEL_NAME}")
print(f"✓ Config: I={NUM_TRAIN_SAMPLES}, J={NUM_TRANSPLANT_TREES}, G={NUM_GENERATIONS}")
print(f"✓ Experiment: {EXPERIMENT_NAME}")
print(f"✓ Max Rounds: {MAX_ROUNDS} (TEST MODE)")
print()
print(f"Expected VRAM: ~33 GB (80 GB available)")
print(f"Expected Time: ~2 minutes")
print()
print("🧪 TEST MODE - Baseline Configuration")
print("   Tests: Coordinator, workers, submissions, logs")
print("   No rollout sharing (J=0)")

## 2. Mount Google Drive

In [None]:
from google.colab import drive
import os

# Mount Google Drive
drive.mount('/content/drive')

# Set base path (MUST BE SAME ACROSS ALL NODES)
GDRIVE_BASE_PATH = '/content/drive/MyDrive/rl-swarm'
os.makedirs(GDRIVE_BASE_PATH, exist_ok=True)

print(f"✓ Google Drive mounted at: {GDRIVE_BASE_PATH}")

## 3. System Setup & GPU Verification

In [None]:
import torch

print("="*60)
print("GPU Verification")
print("="*60)

if torch.cuda.is_available():
    gpu_name = torch.cuda.get_device_name(0)
    total_memory = torch.cuda.get_device_properties(0).total_memory / 1e9
    
    print(f"✓ GPU: {gpu_name}")
    print(f"✓ Total VRAM: {total_memory:.1f} GB")
    print()
    
    # Check if we have enough memory
    required_memory = NUM_NODES * 6.5  # 6.5 GB per GPT-2 node
    print(f"Memory Requirements:")
    print(f"  Required: {required_memory:.1f} GB ({NUM_NODES} nodes × 6.5 GB)")
    print(f"  Available: {total_memory:.1f} GB")
    print(f"  Margin: {total_memory - required_memory:.1f} GB")
    print()
    
    if total_memory < required_memory:
        print("⚠️  WARNING: Insufficient VRAM!")
        print(f"   Need at least {required_memory:.0f} GB, but have {total_memory:.1f} GB")
        print(f"   Consider reducing NUM_NODES to {int(total_memory / 6.5)}")
        raise RuntimeError("Insufficient GPU memory")
    elif total_memory < 75:
        print("⚠️  WARNING: Tight fit! Expected A100 80GB.")
        print(f"   Have {total_memory:.1f} GB. May still work, but monitor memory closely.")
    else:
        print(f"✅ Sufficient VRAM for {NUM_NODES} GPT-2 nodes")
else:
    raise RuntimeError("No GPU detected! Select A100 GPU runtime: Runtime > Change runtime type > A100 GPU")

## 4. Clone Repository & Install Dependencies

In [None]:
%cd /content

# Remove existing directory if it exists
if os.path.exists('/content/rl-swarm'):
    print("Removing existing repository...")
    !rm -rf /content/rl-swarm

# Clone fresh copy
print("Cloning repository...")
!git clone https://github.com/Elrashid/rl-swarm.git /content/rl-swarm

# Change to repo directory
%cd /content/rl-swarm

# Verify clone worked
if not os.path.exists('requirements.txt'):
    raise FileNotFoundError("Repository clone failed - requirements.txt not found")

print("✓ Repository cloned successfully")
print()

# Install dependencies
print("Installing dependencies (this may take 3-5 minutes)...")
print("Note: Warnings about protobuf versions can be ignored")
print()

# Install main dependencies (without -q to show errors)
!pip install -r requirements.txt

# Install GenRL explicitly
!pip install gensyn-genrl==0.1.9

# Fix protobuf version explicitly to avoid conflicts
!pip install 'protobuf>=4.25.0,<5.0'

# Verify reasoning-gym was installed
try:
    import reasoning_gym
    print()
    print("✓ Dependencies installed successfully")
    print("✓ reasoning-gym verified")
except ImportError as e:
    print()
    print("❌ ERROR: reasoning-gym failed to install!")
    print("   Please report this issue with the error above")
    raise

## 5. Initialize Experiment

**Note:** Only the coordinator (node_0) creates the experiment structure. Workers will join it.

In [None]:
from rgym_exp.utils.experiment_manager import init_experiment

# Initialize experiment structure in Google Drive (coordinator creates it)
config_overrides = {
    'training.max_round': MAX_ROUNDS,
    'training.num_generations': NUM_GENERATIONS,
    'training.num_transplant_trees': NUM_TRANSPLANT_TREES,
    'training.num_train_samples': NUM_TRAIN_SAMPLES,
    'training.seed': SEED,
    'TEST_MODE': 'True',
}

init_experiment(
    gdrive_base_path=GDRIVE_BASE_PATH,
    experiment_name=EXPERIMENT_NAME,
    config_overrides=config_overrides
)

print(f"✓ Experiment initialized: {EXPERIMENT_NAME}")
print(f"  Path: {GDRIVE_BASE_PATH}/experiments/{EXPERIMENT_NAME}")
print(f"  Config: I={NUM_TRAIN_SAMPLES}, J={NUM_TRANSPLANT_TREES}, G={NUM_GENERATIONS}")
print(f"  TEST_MODE: Enabled")

## 6. Launch 5-Node Swarm (KEY CELL)

**This cell:**
1. Spawns 5 separate Python processes
2. Each process runs `swarm_launcher.py` with unique NODE_ID
3. All processes share GPU 0 (CUDA_VISIBLE_DEVICES=0)
4. Coordinator (node_0) manages round progression
5. Workers (node_1-4) poll state and train

**Logs:** Each node writes to GDrive at `{GDRIVE_BASE_PATH}/experiments/{EXPERIMENT_NAME}/logs/node_<id>/`

**Monitor:** Use next cell (Cell 7) to track progress in real-time

**TEST MODE:** Only 3 rounds will be executed for quick validation (30s per round)

In [None]:
import subprocess
import os
import time
from datetime import datetime

print("="*60)
print(f"Launching {NUM_NODES}-Node TEST MODE Swarm")
print("="*60)
print(f"Experiment: {EXPERIMENT_NAME}")
print(f"Model: {MODEL_NAME}")
print(f"Config: I={NUM_TRAIN_SAMPLES}, J={NUM_TRANSPLANT_TREES}, G={NUM_GENERATIONS}")
print(f"Hardware: All {NUM_NODES} nodes on single GPU (A100 80GB)")
print(f"TEST MODE: {MAX_ROUNDS} rounds only")
print("="*60)
print()

processes = []
start_time = time.time()  # For ETA calculation

for node_id in range(NUM_NODES):
    # Environment variables for this node
    env = os.environ.copy()
    env['NODE_ID'] = f'node_{node_id}'
    env['NODE_ROLE'] = 'coordinator' if node_id == 0 else 'worker'
    env['MODEL_NAME'] = MODEL_NAME
    env['NUM_TRAIN_SAMPLES'] = str(NUM_TRAIN_SAMPLES)
    env['NUM_TRANSPLANT_TREES'] = str(NUM_TRANSPLANT_TREES)
    env['NUM_GENERATIONS'] = str(NUM_GENERATIONS)
    env['MAX_ROUNDS'] = str(MAX_ROUNDS)
    env['EXPERIMENT_NAME'] = EXPERIMENT_NAME
    env['GDRIVE_PATH'] = GDRIVE_BASE_PATH
    env['CUDA_VISIBLE_DEVICES'] = '0'  # All nodes share GPU 0
    env['SEED'] = str(SEED + node_id)  # Different seed per node (diversity)
    env['ROLLOUT_PUBLISH_FREQUENCY'] = ROLLOUT_PUBLISH_FREQUENCY
    env['ROLLOUT_CLEANUP_ENABLED'] = str(ROLLOUT_CLEANUP_ENABLED)
    env['ROLLOUT_KEEP_LAST_N_ROUNDS'] = str(ROLLOUT_KEEP_LAST_N_ROUNDS)
    env['ROLLOUT_ARCHIVE_OLD'] = str(ROLLOUT_ARCHIVE_OLD)
    env['TEST_MODE'] = 'True'  # Enable test mode
    env['COORDINATOR_ROUND_INTERVAL'] = '30'  # Faster round advancement for test mode
    
    if HUGGINGFACE_TOKEN:
        env['HUGGINGFACE_ACCESS_TOKEN'] = HUGGINGFACE_TOKEN
    
    # Launch process (logs will be written to GDrive by swarm_launcher)
    import sys
    process = subprocess.Popen(
        [sys.executable, '-m', 'rgym_exp.runner.swarm_launcher'],
        env=env,
        cwd='/content/rl-swarm'
    )
    processes.append(process)
    
    role = "COORDINATOR" if node_id == 0 else "WORKER     "
    print(f"✓ Started node_{node_id} ({role}) - PID: {process.pid:5d}")
    
    # Stagger startup to avoid race conditions
    time.sleep(10)

print()
print(f"✅ All {NUM_NODES} nodes launched successfully!")
print(f"✓ TEST MODE: Will run for ~1-2 minutes ({MAX_ROUNDS} rounds)")
print(f"✓ Logs location: {GDRIVE_BASE_PATH}/experiments/{EXPERIMENT_NAME}/logs/node_*/")
print()
print("⚠️  Keep this notebook open (browser tab active)")
print()
print("Monitor progress in Cell 7 below...")

## 7. Monitor Training Progress

**This cell:**
- Shows real-time status of all 5 nodes
- Displays GPU memory usage
- Shows current round/stage progress
- Updates every 30 seconds (faster for test mode)

**To stop training:** Click "Stop" button or press Ctrl+C

**Note:** You can re-run this cell anytime to check status

## 7.5. Check Real-Time Progress (Optional)

**You can run this anytime** to check training progress from GDrive:
- Shows current round for each node
- Displays elapsed time
- Updates independently of notebook state

This is useful if your notebook disconnects - progress is always saved to GDrive!

In [None]:
# === Real-time Progress Viewer (Optional) ===
# Run this cell anytime to check progress from GDrive

import sys
sys.path.append('/content/rl-swarm')

from rgym_exp.utils.progress_tracker import get_experiment_progress

progress = get_experiment_progress(GDRIVE_BASE_PATH, EXPERIMENT_NAME)

print("="*60)
print("📊 REAL-TIME PROGRESS")
print("="*60)
print(f"Experiment: {progress.get('experiment')}")
print()

for node_id, node_data in progress.get('nodes', {}).items():
    if 'error' in node_data:
        print(f"  {node_id}: {node_data['error']}")
    else:
        print(f"  {node_id}:")
        print(f"    Event: {node_data.get('latest_event')}")
        print(f"    Round: {node_data.get('latest_round')}")
        print(f"    Elapsed: {node_data.get('elapsed_seconds', 0)/60:.1f}m")
        print()

print("="*60)
print("Note: Progress updates every round. Logs flush every 30s.")

In [None]:
import time
from IPython.display import clear_output
import pandas as pd

print("Starting training monitor (TEST MODE)...")
print("Press 'Stop' button or Ctrl+C to interrupt\n")

monitor_start_time = time.time()

try:
    while True:
        clear_output(wait=True)
        
        # Check process status
        running = sum(1 for p in processes if p.poll() is None)
        completed = NUM_NODES - running
        
        current_time = datetime.now().strftime('%Y-%m-%d %H:%M:%S')
        elapsed_minutes = (time.time() - start_time) / 60
        
        print("="*70)
        print(f" TEST MODE Monitor - {EXPERIMENT_NAME}")
        print(f" Time: {current_time} | Elapsed: {elapsed_minutes:.1f}m")
        print("="*70)
        print()
        
        # Node status
        print(f"Nodes:")
        print(f"  Running:   {running}/{NUM_NODES}")
        print(f"  Completed: {completed}/{NUM_NODES}")
        print()
        
        # GPU memory
        if torch.cuda.is_available():
            allocated = torch.cuda.memory_allocated(0) / 1e9
            reserved = torch.cuda.memory_reserved(0) / 1e9
            total = torch.cuda.get_device_properties(0).total_memory / 1e9
            
            utilization = (reserved / total) * 100
            
            print(f"GPU Memory ({torch.cuda.get_device_name(0)}):")
            print(f"  Allocated: {allocated:5.1f} GB")
            print(f"  Reserved:  {reserved:5.1f} GB / {total:.1f} GB ({utilization:.1f}%)")
            print(f"  Free:      {total - reserved:5.1f} GB")
            
            # Warning if memory is high
            if utilization > 90:
                print(f"  ⚠️  WARNING: High memory usage! May OOM soon.")
            elif utilization > 75:
                print(f"  ⚠️  Memory usage elevated. Monitoring closely.")
            
            print()
        
        # Training progress
        try:
            from rgym_exp.utils.experiment_manager import get_experiment_status
            status = get_experiment_status(GDRIVE_BASE_PATH, EXPERIMENT_NAME)
            
            if status:
                current_round = status.get('current_round', 0)
                progress_pct = (current_round / MAX_ROUNDS) * 100
                
                print(f"Training Progress:")
                print(f"  Round:     {current_round:4d} / {MAX_ROUNDS} ({progress_pct:5.1f}%)")
                print(f"  Stage:     {status.get('current_stage', 0)}")
                print(f"  Active peers: {status.get('active_peers', 0)}")
                
                # Progress bar
                bar_length = 40
                filled = int(bar_length * progress_pct / 100)
                bar = '█' * filled + '░' * (bar_length - filled)
                print(f"  [{bar}]")
                
                print()
                
                # Recent performance
                try:
                    from rgym_exp.utils.experiment_manager import get_experiment_metrics
                    df = get_experiment_metrics(GDRIVE_BASE_PATH, EXPERIMENT_NAME)
                    
                    if not df.empty:
                        cumulative_reward = df['my_reward'].sum()
                        recent_reward = df.tail(5)['my_reward'].mean()
                        
                        print(f"Rewards:")
                        print(f"  Cumulative: {cumulative_reward:6.2f}")
                        print(f"  Recent avg: {recent_reward:6.2f} (last 5 rounds)")
                        print()
                except Exception:
                    pass  # Metrics not available yet
                    
        except Exception as e:
            print(f"Progress: Unable to load status ({e})")
            print()
        
        # Instructions
        print("-"*70)
        print("Press 'Stop' button or Ctrl+C to halt training")
        print(f"Next update in 30 seconds...")
        
        # Exit if all completed
        if running == 0:
            print()
            print("="*70)
            print("✅ All nodes completed successfully!")
            print("="*70)
            break
        
        time.sleep(30)  # Update every 30 seconds (faster for test mode)

except KeyboardInterrupt:
    print("\n" + "="*70)
    print("⚠️  Training interrupted by user")
    print("="*70)
    print("\nTerminating all node processes...")
    
    for i, p in enumerate(processes):
        if p.poll() is None:
            print(f"  Stopping node_{i}... (PID: {p.pid})")
            p.terminate()
    
    time.sleep(5)
    
    # Force kill if still running
    for i, p in enumerate(processes):
        if p.poll() is None:
            print(f"  Force killing node_{i}... (PID: {p.pid})")
            p.kill()
    
    print("\n✓ All processes terminated")
    print("\n💾 Note: Training state is checkpointed.")
    print("   Re-run this notebook to resume from last checkpoint.")

# === Post-Test Validation ===
print("\n" + "="*60)
print("🔍 TEST MODE VALIDATION")
print("="*60)

import os
import json
from pathlib import Path
import sys
sys.path.append('/content/rl-swarm')

from rgym_exp.utils.test_results import save_test_results

gdrive_path = '/content/drive/MyDrive/rl-swarm'
exp_name = EXPERIMENT_NAME
exp_path = f"{gdrive_path}/experiments/{exp_name}"

# Check 1: GDrive state file
state_check = False
state_file = f"{exp_path}/state/current_state.json"
if os.path.exists(state_file):
    with open(state_file) as f:
        state = json.load(f)
    print(f"✓ State file exists: Round {state.get('round', '?')}")
    if state.get('round') == 3:
        print("  ✓ Reached round 3 as expected")
        state_check = True
    else:
        print(f"  ⚠ Expected round 3, got {state.get('round')}")
else:
    print("✗ State file not found")

# Check 2: Worker submissions
submissions_check = True
for round_num in range(3):
    submissions_dir = f"{exp_path}/rewards/round_{round_num}/stage_0"
    if os.path.exists(submissions_dir):
        submissions = [f for f in os.listdir(submissions_dir) if f.endswith('.json')]
        print(f"✓ Round {round_num}: {len(submissions)} worker submissions")
        if len(submissions) == 0:
            submissions_check = False
    else:
        print(f"✗ Round {round_num}: No submissions directory")
        submissions_check = False

# Check 3: Logs
logs_check = True
log_dirs = [f"{exp_path}/logs/node_{i}" for i in range(5)]
for i, log_dir in enumerate(log_dirs):
    if os.path.exists(f"{log_dir}/stdout.log"):
        print(f"✓ Node {i}: Logs present")
    else:
        print(f"✗ Node {i}: Logs missing")
        logs_check = False

# Save results to GDrive
save_test_results(
    gdrive_path=gdrive_path,
    experiment_name=exp_name,
    state_check=state_check,
    submissions_check=submissions_check,
    logs_check=logs_check,
    coordinator_check=True,
    rollouts_check=None,  # Baseline has no rollout sharing (J=0)
    num_nodes=5,
    num_rounds=3,
    num_train_samples=NUM_TRAIN_SAMPLES,
    num_transplants=NUM_TRANSPLANT_TREES
)

print("="*60)
print("✅ TEST MODE VALIDATION COMPLETE")
print("="*60)

In [None]:
# === Post-Test Validation ===
print("\n" + "="*60)
print("🔍 TEST MODE VALIDATION")
print("="*60)

import os
import json
from pathlib import Path

gdrive_path = '/content/drive/MyDrive/rl-swarm'
exp_name = 'test_mode_validation'
exp_path = f"{gdrive_path}/experiments/{exp_name}"

# Check 1: GDrive state file
state_file = f"{exp_path}/state/current_state.json"
if os.path.exists(state_file):
    with open(state_file) as f:
        state = json.load(f)
    print(f"✓ State file exists: Round {state.get('round', '?')}")
    if state.get('round') == 3:
        print("  ✓ Reached round 3 as expected")
    else:
        print(f"  ⚠ Expected round 3, got {state.get('round')}")
else:
    print("✗ State file not found")

# Check 2: Worker submissions
for round_num in range(3):
    submissions_dir = f"{exp_path}/rewards/round_{round_num}/stage_0"
    if os.path.exists(submissions_dir):
        submissions = [f for f in os.listdir(submissions_dir) if f.endswith('.json')]
        print(f"✓ Round {round_num}: {len(submissions)} worker submissions")
    else:
        print(f"✗ Round {round_num}: No submissions directory")

# Check 3: Logs
log_dirs = [f"{exp_path}/logs/node_{i}" for i in range(5)]
for i, log_dir in enumerate(log_dirs):
    if os.path.exists(f"{log_dir}/stdout.log"):
        print(f"✓ Node {i}: Logs present")
    else:
        print(f"✗ Node {i}: Logs missing")

print("="*60)
print("✅ TEST MODE VALIDATION COMPLETE")
print("="*60)