# RL Swarm - Experiment Monitoring Dashboard

This notebook provides monitoring and analysis tools for your RL Swarm experiments.

**Features:**
- 📊 Real-time experiment status
- 📈 Training metrics visualization
- 👥 Peer activity monitoring
- 🔍 Rollout inspection
- ⚠️ Error detection

**Usage:**
1. Mount Google Drive
2. Configure experiment name
3. Run monitoring cells

---

## 1. Setup

In [None]:
# Mount Google Drive
from google.colab import drive
drive.mount('/content/drive')

In [None]:
# Install dependencies
!pip install -q pandas matplotlib seaborn plotly

## 2. Configuration

In [None]:
import os

# =======================
# CONFIGURE YOUR EXPERIMENT
# =======================

# Must match the experiment you want to monitor
EXPERIMENT_NAME = 'my_first_experiment'

# Base path (usually don't need to change)
GDRIVE_BASE = '/content/drive/MyDrive/rl-swarm'

# =======================
# Derived paths
EXPERIMENT_PATH = f'{GDRIVE_BASE}/experiments/{EXPERIMENT_NAME}'

# Verify experiment exists
if not os.path.exists(EXPERIMENT_PATH):
    print(f"❌ Experiment '{EXPERIMENT_NAME}' not found!")
    print(f"   Path checked: {EXPERIMENT_PATH}")
    print("\nAvailable experiments:")
    exp_dir = f'{GDRIVE_BASE}/experiments'
    if os.path.exists(exp_dir):
        exps = [d for d in os.listdir(exp_dir) if os.path.isdir(os.path.join(exp_dir, d))]
        for exp in exps:
            print(f"  - {exp}")
    else:
        print("  (No experiments found)")
else:
    print(f"✅ Monitoring experiment: {EXPERIMENT_NAME}")
    print(f"   Path: {EXPERIMENT_PATH}")

## 3. Experiment Status

In [None]:
import json
from datetime import datetime

def get_experiment_status(exp_path):
    """Get current experiment status."""
    status = {}
    
    # Read current round/stage
    state_file = f"{exp_path}/state/current_state.json"
    if os.path.exists(state_file):
        with open(state_file) as f:
            state = json.load(f)
            status['round'] = state.get('round', 0)
            status['stage'] = state.get('stage', 0)
            status['last_updated'] = state.get('timestamp', 0)
    else:
        status['round'] = 0
        status['stage'] = 0
        status['last_updated'] = 0
    
    # Count active peers
    peers_dir = f"{exp_path}/peers"
    if os.path.exists(peers_dir):
        status['num_peers'] = len(os.listdir(peers_dir))
    else:
        status['num_peers'] = 0
    
    # List peer IDs
    if os.path.exists(peers_dir):
        peers = []
        for peer_file in os.listdir(peers_dir):
            with open(f"{peers_dir}/{peer_file}") as f:
                peer_data = json.load(f)
                peers.append({
                    'peer_id': peer_data.get('peer_id', 'unknown'),
                    'registered_at': peer_data.get('timestamp', 0)
                })
        status['peers'] = peers
    else:
        status['peers'] = []
    
    return status

# Get and display status
status = get_experiment_status(EXPERIMENT_PATH)

print("="*60)
print(f"📊 EXPERIMENT STATUS: {EXPERIMENT_NAME}")
print("="*60)
print(f"Current Round:  {status['round']}")
print(f"Current Stage:  {status['stage']}")
print(f"Active Peers:   {status['num_peers']}")

if status['last_updated'] > 0:
    last_update = datetime.fromtimestamp(status['last_updated'])
    print(f"Last Updated:   {last_update.strftime('%Y-%m-%d %H:%M:%S')}")

print("\n👥 Registered Peers:")
if status['peers']:
    for peer in status['peers']:
        peer_id = peer['peer_id']
        registered = datetime.fromtimestamp(peer['registered_at']).strftime('%Y-%m-%d %H:%M:%S')
        print(f"  - {peer_id[:16]}... (registered: {registered})")
else:
    print("  (No peers registered yet)")

print("="*60)

## 4. Training Metrics

In [None]:
import pandas as pd

def load_all_metrics(exp_path):
    """Load metrics from all nodes."""
    logs_dir = f"{exp_path}/logs"
    
    if not os.path.exists(logs_dir):
        return pd.DataFrame()
    
    all_metrics = []
    
    for node_id in os.listdir(logs_dir):
        metrics_file = f"{logs_dir}/{node_id}/metrics.jsonl"
        
        if not os.path.exists(metrics_file):
            continue
        
        # Read JSONL file
        with open(metrics_file) as f:
            for line in f:
                try:
                    metric = json.loads(line.strip())
                    metric['node_id'] = node_id
                    all_metrics.append(metric)
                except json.JSONDecodeError:
                    continue
    
    if not all_metrics:
        return pd.DataFrame()
    
    df = pd.DataFrame(all_metrics)
    
    # Convert timestamp to datetime
    if 'timestamp' in df.columns:
        df['datetime'] = pd.to_datetime(df['timestamp'], unit='s')
    
    return df

# Load metrics
print("Loading metrics...")
df = load_all_metrics(EXPERIMENT_PATH)

if df.empty:
    print("❌ No metrics found. Training may not have started yet.")
else:
    print(f"✅ Loaded {len(df)} metric records from {df['node_id'].nunique()} nodes")
    print(f"   Rounds: {df['round'].min()} - {df['round'].max()}")
    print(f"\nSample metrics:")
    print(df.head())

In [None]:
import matplotlib.pyplot as plt
import seaborn as sns

if not df.empty:
    # Plot average reward per round
    fig, axes = plt.subplots(2, 2, figsize=(15, 10))
    
    # 1. Average reward per round (all nodes)
    ax = axes[0, 0]
    if 'my_reward' in df.columns:
        round_rewards = df.groupby('round')['my_reward'].agg(['mean', 'std', 'count'])
        ax.plot(round_rewards.index, round_rewards['mean'], marker='o', label='Mean Reward')
        ax.fill_between(
            round_rewards.index, 
            round_rewards['mean'] - round_rewards['std'],
            round_rewards['mean'] + round_rewards['std'],
            alpha=0.3
        )
        ax.set_xlabel('Round')
        ax.set_ylabel('Reward')
        ax.set_title('Average Reward per Round (All Nodes)')
        ax.legend()
        ax.grid(True, alpha=0.3)
    
    # 2. Reward per node
    ax = axes[0, 1]
    if 'my_reward' in df.columns:
        node_rewards = df.groupby('node_id')['my_reward'].mean().sort_values()
        node_rewards.plot(kind='barh', ax=ax)
        ax.set_xlabel('Average Reward')
        ax.set_ylabel('Node ID')
        ax.set_title('Average Reward by Node')
        ax.grid(True, alpha=0.3, axis='x')
    
    # 3. Total agents per round
    ax = axes[1, 0]
    if 'total_agents' in df.columns:
        agents_per_round = df.groupby('round')['total_agents'].mean()
        ax.plot(agents_per_round.index, agents_per_round.values, marker='o', color='green')
        ax.set_xlabel('Round')
        ax.set_ylabel('Number of Agents')
        ax.set_title('Active Agents per Round')
        ax.grid(True, alpha=0.3)
    
    # 4. Reward distribution
    ax = axes[1, 1]
    if 'my_reward' in df.columns:
        df['my_reward'].hist(bins=30, ax=ax, edgecolor='black')
        ax.set_xlabel('Reward')
        ax.set_ylabel('Frequency')
        ax.set_title('Reward Distribution')
        ax.grid(True, alpha=0.3, axis='y')
    
    plt.tight_layout()
    plt.show()
else:
    print("⚠️ No metrics to plot")

In [None]:
# Plot rewards over time for each node
if not df.empty and 'my_reward' in df.columns:
    plt.figure(figsize=(14, 6))
    
    for node_id in df['node_id'].unique():
        node_data = df[df['node_id'] == node_id].sort_values('round')
        plt.plot(node_data['round'], node_data['my_reward'], 
                marker='o', label=node_id[:16], alpha=0.7)
    
    plt.xlabel('Round')
    plt.ylabel('Reward')
    plt.title('Reward Over Time by Node')
    plt.legend(bbox_to_anchor=(1.05, 1), loc='upper left')
    plt.grid(True, alpha=0.3)
    plt.tight_layout()
    plt.show()
else:
    print("⚠️ No reward data to plot")

## 5. Rollout Inspection

In [None]:
def count_rollouts_by_round(exp_path):
    """Count rollout files by round."""
    rollouts_dir = f"{exp_path}/rollouts"
    
    if not os.path.exists(rollouts_dir):
        return {}
    
    rollout_counts = {}
    
    for round_dir in os.listdir(rollouts_dir):
        if not round_dir.startswith('round_'):
            continue
        
        round_num = int(round_dir.split('_')[1])
        round_path = f"{rollouts_dir}/{round_dir}"
        
        total_files = 0
        for stage_dir in os.listdir(round_path):
            stage_path = f"{round_path}/{stage_dir}"
            if os.path.isdir(stage_path):
                total_files += len([f for f in os.listdir(stage_path) if f.endswith('.json')])
        
        rollout_counts[round_num] = total_files
    
    return rollout_counts

# Count and display rollouts
rollout_counts = count_rollouts_by_round(EXPERIMENT_PATH)

if rollout_counts:
    print("📁 Rollout Files by Round:")
    print("="*40)
    for round_num in sorted(rollout_counts.keys()):
        count = rollout_counts[round_num]
        print(f"  Round {round_num:3d}: {count:3d} files")
    print("="*40)
    print(f"Total: {sum(rollout_counts.values())} rollout files")
    
    # Plot rollout counts
    plt.figure(figsize=(10, 4))
    rounds = sorted(rollout_counts.keys())
    counts = [rollout_counts[r] for r in rounds]
    plt.bar(rounds, counts, color='steelblue', edgecolor='black')
    plt.xlabel('Round')
    plt.ylabel('Number of Rollout Files')
    plt.title('Rollout Files per Round')
    plt.grid(True, alpha=0.3, axis='y')
    plt.tight_layout()
    plt.show()
else:
    print("❌ No rollout files found")

In [None]:
# Inspect a specific rollout file
def get_latest_rollout(exp_path):
    """Get the most recent rollout file."""
    rollouts_dir = f"{exp_path}/rollouts"
    
    if not os.path.exists(rollouts_dir):
        return None
    
    latest_file = None
    latest_time = 0
    
    for root, dirs, files in os.walk(rollouts_dir):
        for file in files:
            if file.endswith('.json'):
                filepath = os.path.join(root, file)
                mtime = os.path.getmtime(filepath)
                if mtime > latest_time:
                    latest_time = mtime
                    latest_file = filepath
    
    return latest_file

latest_rollout = get_latest_rollout(EXPERIMENT_PATH)

if latest_rollout:
    print(f"📄 Latest Rollout File:")
    print(f"   {latest_rollout}")
    print("\nContent preview:")
    
    with open(latest_rollout) as f:
        rollout_data = json.load(f)
    
    print(f"\n  Peer ID: {rollout_data.get('peer_id', 'unknown')}")
    print(f"  Round: {rollout_data.get('round', '?')}")
    print(f"  Stage: {rollout_data.get('stage', '?')}")
    
    if 'rollouts' in rollout_data:
        total_rollouts = sum(len(v) for v in rollout_data['rollouts'].values())
        print(f"  Total rollouts: {total_rollouts}")
        print(f"  Batches: {list(rollout_data['rollouts'].keys())}")
else:
    print("❌ No rollout files found")

## 6. Error Detection

In [None]:
def check_for_issues(exp_path):
    """Check for common issues."""
    issues = []
    
    # Check if state file exists
    state_file = f"{exp_path}/state/current_state.json"
    if not os.path.exists(state_file):
        issues.append("⚠️ State file missing - coordinator may not be running")
    
    # Check for peers
    peers_dir = f"{exp_path}/peers"
    if not os.path.exists(peers_dir) or not os.listdir(peers_dir):
        issues.append("⚠️ No peers registered - nodes may not have started")
    
    # Check for recent activity
    if os.path.exists(state_file):
        with open(state_file) as f:
            state = json.load(f)
            last_update = state.get('timestamp', 0)
            if last_update > 0:
                time_since_update = datetime.now().timestamp() - last_update
                if time_since_update > 600:  # 10 minutes
                    issues.append(f"⚠️ No activity for {int(time_since_update/60)} minutes - nodes may be stuck")
    
    # Check for logs
    logs_dir = f"{exp_path}/logs"
    if not os.path.exists(logs_dir) or not os.listdir(logs_dir):
        issues.append("⚠️ No log files - training may not have started")
    
    # Check for rollouts
    rollouts_dir = f"{exp_path}/rollouts"
    if not os.path.exists(rollouts_dir) or not os.listdir(rollouts_dir):
        issues.append("⚠️ No rollout files - nodes may not be sharing data")
    
    return issues

print("🔍 Checking for issues...\n")
issues = check_for_issues(EXPERIMENT_PATH)

if issues:
    print("Issues found:")
    for issue in issues:
        print(f"  {issue}")
else:
    print("✅ No issues detected - experiment running normally!")

## 7. Summary Statistics

In [None]:
if not df.empty:
    print("="*60)
    print("📈 TRAINING SUMMARY")
    print("="*60)
    
    print(f"\nExperiment: {EXPERIMENT_NAME}")
    print(f"Total Rounds: {df['round'].max() + 1}")
    print(f"Total Nodes: {df['node_id'].nunique()}")
    print(f"Total Metrics: {len(df)}")
    
    if 'my_reward' in df.columns:
        print(f"\nReward Statistics:")
        print(f"  Mean:   {df['my_reward'].mean():.4f}")
        print(f"  Median: {df['my_reward'].median():.4f}")
        print(f"  Std:    {df['my_reward'].std():.4f}")
        print(f"  Min:    {df['my_reward'].min():.4f}")
        print(f"  Max:    {df['my_reward'].max():.4f}")
    
    if 'datetime' in df.columns:
        duration = df['datetime'].max() - df['datetime'].min()
        print(f"\nTraining Duration: {duration}")
    
    print("\nPer-Node Statistics:")
    if 'my_reward' in df.columns:
        node_stats = df.groupby('node_id')['my_reward'].agg(['count', 'mean', 'std'])
        node_stats.columns = ['Metrics', 'Avg Reward', 'Std Dev']
        print(node_stats.to_string())
    
    print("="*60)
else:
    print("⚠️ No metrics available for summary")

## 8. Auto-Refresh Status (Optional)

In [None]:
# Auto-refresh status every 30 seconds
# WARNING: This will run indefinitely until you stop it manually

import time
from IPython.display import clear_output

REFRESH_INTERVAL = 30  # seconds

try:
    while True:
        clear_output(wait=True)
        
        status = get_experiment_status(EXPERIMENT_PATH)
        
        print("="*60)
        print(f"📊 LIVE STATUS: {EXPERIMENT_NAME}")
        print(f"🔄 Auto-refreshing every {REFRESH_INTERVAL}s (Press ■ to stop)")
        print("="*60)
        print(f"Current Round:  {status['round']}")
        print(f"Current Stage:  {status['stage']}")
        print(f"Active Peers:   {status['num_peers']}")
        
        if status['last_updated'] > 0:
            last_update = datetime.fromtimestamp(status['last_updated'])
            print(f"Last Updated:   {last_update.strftime('%Y-%m-%d %H:%M:%S')}")
        
        print(f"\nNext refresh:   {datetime.now().strftime('%H:%M:%S')} + {REFRESH_INTERVAL}s")
        print("="*60)
        
        time.sleep(REFRESH_INTERVAL)
        
except KeyboardInterrupt:
    print("\n✋ Auto-refresh stopped")