# 🧪 TEST MODE - Config 3 (1-2 minutes)

This notebook runs a **QUICK INFRASTRUCTURE TEST** of SAPO Config 3 (heavy rollout sharing).

**⚠️ Important: This is NOT a full training test!**

TEST_MODE validates the **system infrastructure** only:
- ✅ Coordinator round advancement
- ✅ Worker startup and state polling
- ✅ Log streaming to GDrive
- ✅ Progress tracking
- ✅ Rollout sharing (Config 3 has J=3)
- ⊘ **NOT tested**: Actual model training or reward submission (skipped for speed)

For full training validation, run a short experiment (e.g., 10 rounds) in a main experiment notebook.

**Configuration:**
- **I=1** (local rollouts per round - 25%)
- **J=3** (external rollouts per round - 75% HEAVY SHARING)
- **G=4** (completions per question - REDUCED for speed)
- **Model**: GPT-2 (124M params)
- **Hardware**: 5 nodes (1 coordinator + 4 workers) on 1× A100 80GB
- **Rounds**: 3 only (for quick testing)

**Other Test Configurations:**
- **Baseline**: `TEST_MODE_Baseline.ipynb` - I=4, J=0 (no rollout sharing)
- **Config1**: `TEST_MODE_Config1.ipynb` - I=3, J=1 (light rollout sharing)
- **Config2**: `TEST_MODE_Config2.ipynb` - I=2, J=2 (balanced rollout sharing)

**Expected Results:**
- Completes in 1-2 minutes
- 3 rounds executed
- State file advanced to round 3
- Logs created for all nodes
- Progress files created for all nodes
- Rollouts published (J=3 - heavy sharing)

**Memory Usage:** ~33 GB peak VRAM (4 training nodes)

**Note:** This is an infrastructure test. Use main experiment notebooks for actual training.

## 1. Configuration

**This notebook is pre-configured for TEST MODE.**

Just run all cells - no changes needed!

In [None]:
# TEST MODE - Config3 (25% local / 75% external - heavy sharing)
# Quick validation run with minimal parameters

# ============================================
# TEST MODE SETTINGS
# ============================================
EXPERIMENT_NAME = 'test_config3_1loc3ext'
NUM_TRAIN_SAMPLES = 1        # I: Local rollouts per round
NUM_TRANSPLANT_TREES = 3     # J: External rollouts (75% sharing - heavy)
NUM_GENERATIONS = 4          # G: Completions per question (REDUCED for speed)
MAX_ROUNDS = 3               # Only 3 rounds for quick test

# ============================================
# FIXED SETTINGS
# ============================================
NUM_NODES = 5                # Run 5 nodes (1 coordinator + 4 workers)
MODEL_NAME = 'gpt2'          # GPT-2 (124M params, fits memory)
SEED = 42                    # For reproducibility

# Rollout Sharing Configuration
ROLLOUT_PUBLISH_FREQUENCY = 'stage'  # When to share rollouts
ROLLOUT_CLEANUP_ENABLED = True       # Enable cleanup to save space
ROLLOUT_KEEP_LAST_N_ROUNDS = 20      # Keep recent rollouts only
ROLLOUT_ARCHIVE_OLD = False          # Don't archive (saves space)

# Optional: HuggingFace Token
HUGGINGFACE_TOKEN = None  # Set to your token or keep None

print("="*60)
print(f"TEST MODE - Config3 (25% local / 75% external - heavy sharing)")
print("="*60)
print(f"✓ Nodes: {NUM_NODES} (1 coordinator + 4 workers on single A100 80GB)")
print(f"✓ Model: {MODEL_NAME}")
print(f"✓ Config: I={NUM_TRAIN_SAMPLES}, J={NUM_TRANSPLANT_TREES}, G={NUM_GENERATIONS}")
print(f"✓ Experiment: {EXPERIMENT_NAME}")
print(f"✓ Max Rounds: {MAX_ROUNDS} (TEST MODE)")
print()
print(f"Expected VRAM: ~33 GB (80 GB available)")
print(f"Expected Time: ~2 minutes")
print()
print("🧪 TEST MODE - Config3 Configuration")
print("   Tests: Coordinator, workers, submissions, logs")
print("   Tests heavy rollout sharing (25% local / 75% external)")

## 2. Mount Google Drive

In [None]:
from google.colab import drive
import os

# Mount Google Drive
drive.mount('/content/drive')

# Set base path (MUST BE SAME ACROSS ALL NODES)
GDRIVE_BASE_PATH = '/content/drive/MyDrive/rl-swarm'
os.makedirs(GDRIVE_BASE_PATH, exist_ok=True)

print(f"✓ Google Drive mounted at: {GDRIVE_BASE_PATH}")

## 3. System Setup & GPU Verification

In [None]:
import torch

print("="*60)
print("GPU Verification")
print("="*60)

if torch.cuda.is_available():
    gpu_name = torch.cuda.get_device_name(0)
    total_memory = torch.cuda.get_device_properties(0).total_memory / 1e9
    
    print(f"✓ GPU: {gpu_name}")
    print(f"✓ Total VRAM: {total_memory:.1f} GB")
    print()
    
    # Check if we have enough memory
    required_memory = NUM_NODES * 6.5  # 6.5 GB per GPT-2 node
    print(f"Memory Requirements:")
    print(f"  Required: {required_memory:.1f} GB ({NUM_NODES} nodes × 6.5 GB)")
    print(f"  Available: {total_memory:.1f} GB")
    print(f"  Margin: {total_memory - required_memory:.1f} GB")
    print()
    
    if total_memory < required_memory:
        print("⚠️  WARNING: Insufficient VRAM!")
        print(f"   Need at least {required_memory:.0f} GB, but have {total_memory:.1f} GB")
        print(f"   Consider reducing NUM_NODES to {int(total_memory / 6.5)}")
        raise RuntimeError("Insufficient GPU memory")
    elif total_memory < 75:
        print("⚠️  WARNING: Tight fit! Expected A100 80GB.")
        print(f"   Have {total_memory:.1f} GB. May still work, but monitor memory closely.")
    else:
        print(f"✅ Sufficient VRAM for {NUM_NODES} GPT-2 nodes")
else:
    raise RuntimeError("No GPU detected! Select A100 GPU runtime: Runtime > Change runtime type > A100 GPU")

## 4. Clone Repository & Install Dependencies

In [None]:
%cd /content

# Remove existing directory if it exists
if os.path.exists('/content/rl-swarm'):
    print("Removing existing repository...")
    !rm -rf /content/rl-swarm

# Clone fresh copy
print("Cloning repository...")
!git clone https://github.com/Elrashid/rl-swarm.git /content/rl-swarm

# Change to repo directory
%cd /content/rl-swarm

# Verify clone worked
if not os.path.exists('requirements.txt'):
    raise FileNotFoundError("Repository clone failed - requirements.txt not found")

print("✓ Repository cloned successfully")
print()

# Install dependencies
print("Installing dependencies (this may take 3-5 minutes)...")
print("Note: Warnings about protobuf versions can be ignored")
print()

# Install main dependencies (without -q to show errors)
!pip install -r requirements.txt

# Install GenRL explicitly
!pip install gensyn-genrl==0.1.9

# Fix protobuf version explicitly to avoid conflicts
!pip install 'protobuf>=4.25.0,<5.0'

# Verify reasoning-gym was installed
try:
    import reasoning_gym
    print()
    print("✓ Dependencies installed successfully")
    print("✓ reasoning-gym verified")
except ImportError as e:
    print()
    print("❌ ERROR: reasoning-gym failed to install!")
    print("   Please report this issue with the error above")
    raise

## 5. Initialize Experiment

**Note:** Only the coordinator (node_0) creates the experiment structure. Workers will join it.

In [None]:
from rgym_exp.utils.experiment_manager import init_experiment

# Initialize experiment structure in Google Drive (coordinator creates it)
config_overrides = {
    'training.max_round': MAX_ROUNDS,
    'training.num_generations': NUM_GENERATIONS,
    'training.num_transplant_trees': NUM_TRANSPLANT_TREES,
    'training.num_train_samples': NUM_TRAIN_SAMPLES,
    'training.seed': SEED,
    'TEST_MODE': 'True',
}

init_experiment(
    gdrive_base_path=GDRIVE_BASE_PATH,
    experiment_name=EXPERIMENT_NAME,
    config_overrides=config_overrides
)

print(f"✓ Experiment initialized: {EXPERIMENT_NAME}")
print(f"  Path: {GDRIVE_BASE_PATH}/experiments/{EXPERIMENT_NAME}")
print(f"  Config: I={NUM_TRAIN_SAMPLES}, J={NUM_TRANSPLANT_TREES}, G={NUM_GENERATIONS}")
print(f"  TEST_MODE: Enabled")

## 6. Launch 5-Node Swarm (KEY CELL)

**This cell:**
1. Spawns 5 separate Python processes
2. Each process runs `swarm_launcher.py` with unique NODE_ID
3. All processes share GPU 0 (CUDA_VISIBLE_DEVICES=0)
4. Coordinator (node_0) manages round progression
5. Workers (node_1-4) poll state and train

**Logs:** Each node writes to GDrive at `{GDRIVE_BASE_PATH}/experiments/{EXPERIMENT_NAME}/logs/node_<id>/`

**Monitor:** Use next cell (Cell 7) to track progress in real-time

**TEST MODE:** Only 3 rounds will be executed for quick validation (30s per round)

In [None]:
import subprocess
import os
import time
from datetime import datetime

print("="*60)
print(f"Launching {NUM_NODES}-Node TEST MODE Swarm")
print("="*60)
print(f"Experiment: {EXPERIMENT_NAME}")
print(f"Model: {MODEL_NAME}")
print(f"Config: I={NUM_TRAIN_SAMPLES}, J={NUM_TRANSPLANT_TREES}, G={NUM_GENERATIONS}")
print(f"Hardware: All {NUM_NODES} nodes on single GPU (A100 80GB)")
print(f"TEST MODE: {MAX_ROUNDS} rounds only")
print("="*60)
print()

processes = []
start_time = time.time()  # For ETA calculation

for node_id in range(NUM_NODES):
    # Environment variables for this node
    env = os.environ.copy()
    env['NODE_ID'] = f'node_{node_id}'
    env['NODE_ROLE'] = 'coordinator' if node_id == 0 else 'worker'
    env['MODEL_NAME'] = MODEL_NAME
    env['NUM_TRAIN_SAMPLES'] = str(NUM_TRAIN_SAMPLES)
    env['NUM_TRANSPLANT_TREES'] = str(NUM_TRANSPLANT_TREES)
    env['NUM_GENERATIONS'] = str(NUM_GENERATIONS)
    env['MAX_ROUNDS'] = str(MAX_ROUNDS)
    env['EXPERIMENT_NAME'] = EXPERIMENT_NAME
    env['GDRIVE_PATH'] = GDRIVE_BASE_PATH
    env['CUDA_VISIBLE_DEVICES'] = '0'  # All nodes share GPU 0
    env['SEED'] = str(SEED + node_id)  # Different seed per node (diversity)
    env['ROLLOUT_PUBLISH_FREQUENCY'] = ROLLOUT_PUBLISH_FREQUENCY
    env['ROLLOUT_CLEANUP_ENABLED'] = str(ROLLOUT_CLEANUP_ENABLED)
    env['ROLLOUT_KEEP_LAST_N_ROUNDS'] = str(ROLLOUT_KEEP_LAST_N_ROUNDS)
    env['ROLLOUT_ARCHIVE_OLD'] = str(ROLLOUT_ARCHIVE_OLD)
    env['TEST_MODE'] = 'True'  # Enable test mode
    env['COORDINATOR_ROUND_INTERVAL'] = '30'  # Faster round advancement for test mode
    
    if HUGGINGFACE_TOKEN:
        env['HUGGINGFACE_ACCESS_TOKEN'] = HUGGINGFACE_TOKEN
    
    # Launch process (logs will be written to GDrive by swarm_launcher)
    import sys
    process = subprocess.Popen(
        [sys.executable, '-m', 'rgym_exp.runner.swarm_launcher'],
        env=env,
        cwd='/content/rl-swarm'
    )
    processes.append(process)
    
    role = "COORDINATOR" if node_id == 0 else "WORKER     "
    print(f"✓ Started node_{node_id} ({role}) - PID: {process.pid:5d}")
    
    # Stagger startup to avoid race conditions
    time.sleep(10)

print()
print(f"✅ All {NUM_NODES} nodes launched successfully!")
print(f"✓ TEST MODE: Will run for ~1-2 minutes ({MAX_ROUNDS} rounds)")
print(f"✓ Logs location: {GDRIVE_BASE_PATH}/experiments/{EXPERIMENT_NAME}/logs/node_*/")
print()
print("⚠️  Keep this notebook open (browser tab active)")
print()
print("Monitor progress in Cell 7 below...")

## 7. Monitor Training Progress

**This cell:**
- Shows real-time status of all 5 nodes
- Displays GPU memory usage
- Shows current round/stage progress
- Updates every 30 seconds (faster for test mode)

**To stop training:** Click "Stop" button or press Ctrl+C

**Note:** You can re-run this cell anytime to check status

## 7.5. Check Real-Time Progress (Optional)

**You can run this anytime** to check training progress from GDrive:
- Shows current round for each node
- Displays elapsed time
- Updates independently of notebook state

This is useful if your notebook disconnects - progress is always saved to GDrive!

In [None]:
# === Real-time Progress Viewer (Optional) ===
# Run this cell anytime to check progress from GDrive

import sys
sys.path.append('/content/rl-swarm')

from rgym_exp.utils.progress_tracker import get_experiment_progress

progress = get_experiment_progress(GDRIVE_BASE_PATH, EXPERIMENT_NAME)

print("="*60)
print("📊 REAL-TIME PROGRESS")
print("="*60)
print(f"Experiment: {progress.get('experiment')}")
print()

for node_id, node_data in progress.get('nodes', {}).items():
    if 'error' in node_data:
        print(f"  {node_id}: {node_data['error']}")
    else:
        print(f"  {node_id}:")
        print(f"    Event: {node_data.get('latest_event')}")
        print(f"    Round: {node_data.get('latest_round')}")
        print(f"    Elapsed: {node_data.get('elapsed_seconds', 0)/60:.1f}m")
        print()

print("="*60)
print("Note: Progress updates every round. Logs flush every 30s.")

In [None]:
# === Post-Test Validation ===
print("\n" + "="*60)
print("🔍 TEST MODE VALIDATION")
print("="*60)

import os
import json
from pathlib import Path
import sys
sys.path.append('/content/rl-swarm')

from rgym_exp.utils.test_results import save_test_results

gdrive_path = '/content/drive/MyDrive/rl-swarm'
exp_name = EXPERIMENT_NAME
exp_path = f"{gdrive_path}/experiments/{exp_name}"

# Check 1: GDrive state file
state_check = False
state_file = f"{exp_path}/state/current_state.json"
if os.path.exists(state_file):
    with open(state_file) as f:
        state = json.load(f)
    print(f"✓ State file exists: Round {state.get('round', '?')}")
    if state.get('round') == 3:
        print("  ✓ Reached round 3 as expected")
        state_check = True
    else:
        print(f"  ⚠ Expected round 3, got {state.get('round')}")
else:
    print("✗ State file not found")

# Check 2: Worker submissions
submissions_check = True
for round_num in range(3):
    submissions_dir = f"{exp_path}/rewards/round_{round_num}/stage_0"
    if os.path.exists(submissions_dir):
        submissions = [f for f in os.listdir(submissions_dir) if f.endswith('.json')]
        print(f"✓ Round {round_num}: {len(submissions)} worker submissions")
        if len(submissions) == 0:
            submissions_check = False
    else:
        print(f"✗ Round {round_num}: No submissions directory")
        submissions_check = False

# Check 3: Logs
logs_check = True
log_dirs = [f"{exp_path}/logs/node_{i}" for i in range(5)]
for i, log_dir in enumerate(log_dirs):
    if os.path.exists(f"{log_dir}/stdout.log"):
        print(f"✓ Node {i}: Logs present")
    else:
        print(f"✗ Node {i}: Logs missing")
        logs_check = False

# Check 4: Rollouts (Config3 has J=3)
rollouts_check = False
rollouts_dir = f"{exp_path}/rollouts"
if os.path.exists(rollouts_dir):
    # Check if any rollouts were published
    rollout_files = []
    for root, dirs, files in os.walk(rollouts_dir):
        rollout_files.extend([f for f in files if f.endswith('.pkl')])
    if rollout_files:
        print(f"✓ Rollouts: {len(rollout_files)} rollout files found")
        rollouts_check = True
    else:
        print(f"⚠ Rollouts: Directory exists but no rollout files found")
else:
    print(f"✗ Rollouts: Directory not found")

# Save results to GDrive
save_test_results(
    gdrive_path=gdrive_path,
    experiment_name=exp_name,
    state_check=state_check,
    submissions_check=submissions_check,
    logs_check=logs_check,
    coordinator_check=True,
    rollouts_check=rollouts_check,
    num_nodes=5,
    num_rounds=3,
    num_train_samples=NUM_TRAIN_SAMPLES,
    num_transplants=NUM_TRANSPLANT_TREES
)

print("="*60)
print("✅ TEST MODE VALIDATION COMPLETE")
print("="*60)

# === Post-Test Validation ===
print("\n" + "="*60)
print("🔍 TEST MODE VALIDATION (Infrastructure Only)")
print("="*60)
print()
print("ℹ️  TEST MODE validates infrastructure, not actual training")
print("   Reward submissions are intentionally skipped for speed")
print("   Use a full experiment notebook to test actual training")
print()

import os
import json
from pathlib import Path
import sys
sys.path.append('/content/rl-swarm')

from rgym_exp.utils.test_results import save_test_results

gdrive_path = '/content/drive/MyDrive/rl-swarm'
exp_name = EXPERIMENT_NAME
exp_path = f"{gdrive_path}/experiments/{exp_name}"

# Check 1: GDrive state file (coordinator works)
state_check = False
state_file = f"{exp_path}/state/current_state.json"
if os.path.exists(state_file):
    with open(state_file) as f:
        state = json.load(f)
    print(f"✓ State file exists: Round {state.get('round', '?')}")
    if state.get('round') == 3:
        print("  ✓ Reached round 3 as expected")
        state_check = True
    else:
        print(f"  ⚠ Expected round 3, got {state.get('round')}")
else:
    print("✗ State file not found")

# Check 2: Logs (log streaming works)
logs_check = True
log_dirs = [f"{exp_path}/logs/node_{i}" for i in range(5)]
for i, log_dir in enumerate(log_dirs):
    if os.path.exists(f"{log_dir}/stdout.log"):
        print(f"✓ Node {i}: Logs present")
    else:
        print(f"✗ Node {i}: Logs missing")
        logs_check = False

# Check 3: Progress tracking (monitoring works)
progress_check = True
for i in range(5):
    progress_file = f"{exp_path}/progress_node_{i}.jsonl"
    if os.path.exists(progress_file):
        print(f"✓ Node {i}: Progress tracking working")
    else:
        print(f"✗ Node {i}: Progress tracking failed")
        progress_check = False

# Check 4: Rollouts (Config3 has J=3)
rollouts_check = False
rollouts_dir = f"{exp_path}/rollouts"
if os.path.exists(rollouts_dir):
    # Check if any rollouts were published
    rollout_files = []
    for root, dirs, files in os.walk(rollouts_dir):
        rollout_files.extend([f for f in files if f.endswith('.pkl')])
    if rollout_files:
        print(f"✓ Rollouts: {len(rollout_files)} rollout files found")
        rollouts_check = True
    else:
        print(f"⚠ Rollouts: Directory exists but no rollout files found")
else:
    print(f"✗ Rollouts: Directory not found")

print()
print("⊘ Reward submissions: Not tested (infrastructure test only)")
print()

# Save results to GDrive
save_test_results(
    gdrive_path=gdrive_path,
    experiment_name=exp_name,
    state_check=state_check,
    submissions_check=None,  # Not tested in TEST_MODE
    logs_check=logs_check,
    coordinator_check=True,
    rollouts_check=rollouts_check,
    num_nodes=5,
    num_rounds=3,
    num_train_samples=NUM_TRAIN_SAMPLES,
    num_transplants=NUM_TRANSPLANT_TREES
)

print("="*60)
if state_check and logs_check and progress_check and rollouts_check:
    print("✅ INFRASTRUCTURE TEST PASSED")
    print("   Coordinator, workers, logging, rollout sharing working!")
else:
    print("❌ INFRASTRUCTURE TEST FAILED")
    print("   Check the errors above")
print("="*60)

In [None]:
# === Post-Test Validation ===
print("\n" + "="*60)
print("🔍 TEST MODE VALIDATION")
print("="*60)

import os
import json
from pathlib import Path

gdrive_path = '/content/drive/MyDrive/rl-swarm'
exp_name = 'test_mode_validation'
exp_path = f"{gdrive_path}/experiments/{exp_name}"

# Check 1: GDrive state file
state_file = f"{exp_path}/state/current_state.json"
if os.path.exists(state_file):
    with open(state_file) as f:
        state = json.load(f)
    print(f"✓ State file exists: Round {state.get('round', '?')}")
    if state.get('round') == 3:
        print("  ✓ Reached round 3 as expected")
    else:
        print(f"  ⚠ Expected round 3, got {state.get('round')}")
else:
    print("✗ State file not found")

# Check 2: Worker submissions
for round_num in range(3):
    submissions_dir = f"{exp_path}/rewards/round_{round_num}/stage_0"
    if os.path.exists(submissions_dir):
        submissions = [f for f in os.listdir(submissions_dir) if f.endswith('.json')]
        print(f"✓ Round {round_num}: {len(submissions)} worker submissions")
    else:
        print(f"✗ Round {round_num}: No submissions directory")

# Check 3: Logs
log_dirs = [f"{exp_path}/logs/node_{i}" for i in range(5)]
for i, log_dir in enumerate(log_dirs):
    if os.path.exists(f"{log_dir}/stdout.log"):
        print(f"✓ Node {i}: Logs present")
    else:
        print(f"✗ Node {i}: Logs missing")

print("="*60)
print("✅ TEST MODE VALIDATION COMPLETE")
print("="*60)