# Total Cumulative Reward Monitor - SAPO Experiments

This notebook provides **real-time monitoring of total_cumulative_reward** (the primary SAPO paper metric) across all your experiments.

**Features:**
- Track total cumulative reward (sum across all nodes)
- Per-node breakdown
- Comparison to baseline
- Multi-experiment comparison
- Live progress graphs

**Usage:**
1. Mount Google Drive (Cell 1)
2. Configure experiment names (Cell 2)
3. Run monitoring cells (Cell 3-6)

## 1. Mount Google Drive & Setup

In [None]:
from google.colab import drive
import os
import sys

# Mount Google Drive
drive.mount('/content/drive')

# Set base path (MUST MATCH YOUR EXPERIMENT PATH)
GDRIVE_BASE_PATH = '/content/drive/MyDrive/rl-swarm'

# Clone repository if needed
if not os.path.exists('/content/rl-swarm'):
    print("Cloning repository...")
    !git clone https://github.com/Elrashid/rl-swarm.git /content/rl-swarm
    
# Add to path
sys.path.append('/content/rl-swarm')

print(f"✓ Setup complete!")
print(f"  Base path: {GDRIVE_BASE_PATH}")

## 2. Configure Experiments to Monitor

In [None]:
# List your experiment names here
EXPERIMENTS = [
    'sapo_gpt2_baseline_4loc0ext',     # Baseline
    'sapo_gpt2_config2_4loc4ext',      # Config 2 (best)
    'sapo_gpt2_adaptive_ij',           # Adaptive I/J
]

# Primary experiment to monitor (for detailed view)
PRIMARY_EXPERIMENT = 'sapo_gpt2_config2_4loc4ext'

# Maximum rounds (for progress %)
MAX_ROUNDS = 2000

# Baseline cumulative reward (for improvement calculation)
# Set to None if you don't have baseline yet
BASELINE_CUMULATIVE_REWARD = None  # e.g., 250.0

print("Configuration:")
print(f"  Primary experiment: {PRIMARY_EXPERIMENT}")
print(f"  Monitoring {len(EXPERIMENTS)} experiments total")
print(f"  Max rounds: {MAX_ROUNDS}")

## 3. Quick Status - Total Cumulative Reward

**Run this cell anytime to see current total cumulative reward!**

In [None]:
from rgym_exp.utils.cumulative_reward_monitor import display_cumulative_progress

display_cumulative_progress(
    gdrive_base_path=GDRIVE_BASE_PATH,
    experiment_name=PRIMARY_EXPERIMENT,
    max_rounds=MAX_ROUNDS,
    baseline_reward=BASELINE_CUMULATIVE_REWARD
)

## 4. Compare All Experiments

**See total cumulative rewards across all your experiments side-by-side**

In [None]:
from rgym_exp.utils.cumulative_reward_monitor import display_experiment_comparison

display_experiment_comparison(
    gdrive_base_path=GDRIVE_BASE_PATH,
    experiment_names=EXPERIMENTS
)

## 5. Plot Cumulative Reward Over Time

**Visualize how total cumulative reward grows during training**

In [None]:
from rgym_exp.utils.cumulative_reward_monitor import get_cumulative_history
import matplotlib.pyplot as plt

# Get history for primary experiment
history = get_cumulative_history(GDRIVE_BASE_PATH, PRIMARY_EXPERIMENT)

if history:
    rounds = [h['round'] for h in history]
    totals = [h['total_cumulative_reward'] for h in history]
    
    plt.figure(figsize=(12, 6))
    plt.plot(rounds, totals, linewidth=2, marker='o', markersize=4)
    plt.xlabel('Round', fontsize=12)
    plt.ylabel('Total Cumulative Reward', fontsize=12)
    plt.title(f'Total Cumulative Reward Progress - {PRIMARY_EXPERIMENT}', 
              fontsize=14, fontweight='bold')
    plt.grid(True, alpha=0.3)
    
    # Add current value annotation
    if totals:
        plt.annotate(f'{totals[-1]:.2f}', 
                    xy=(rounds[-1], totals[-1]),
                    xytext=(10, 10), textcoords='offset points',
                    fontsize=10, fontweight='bold',
                    bbox=dict(boxstyle='round,pad=0.5', facecolor='yellow', alpha=0.7))
    
    plt.tight_layout()
    plt.show()
    
    print(f"✓ Plotted {len(history)} rounds")
    print(f"  Latest total cumulative reward: {totals[-1]:.2f}")
else:
    print("⚠️  No history data available yet")
    print("   Training may not have started or logs are empty")

## 6. Compare Multiple Experiments (Plot)

**Side-by-side comparison of total cumulative reward growth**

In [None]:
from rgym_exp.utils.cumulative_reward_monitor import get_cumulative_history
import matplotlib.pyplot as plt

plt.figure(figsize=(14, 7))

colors = ['blue', 'green', 'red', 'orange', 'purple']

for i, exp_name in enumerate(EXPERIMENTS):
    history = get_cumulative_history(GDRIVE_BASE_PATH, exp_name)
    
    if history:
        rounds = [h['round'] for h in history]
        totals = [h['total_cumulative_reward'] for h in history]
        
        # Short label for legend
        label = exp_name.replace('sapo_gpt2_', '').replace('_', ' ')
        
        plt.plot(rounds, totals, linewidth=2, 
                label=f"{label} ({totals[-1]:.0f})",
                color=colors[i % len(colors)],
                alpha=0.8)

plt.xlabel('Round', fontsize=12)
plt.ylabel('Total Cumulative Reward', fontsize=12)
plt.title('Total Cumulative Reward Comparison (SAPO Paper Metric)', 
          fontsize=14, fontweight='bold')
plt.legend(loc='upper left', fontsize=10)
plt.grid(True, alpha=0.3)
plt.tight_layout()
plt.show()

print("✓ Comparison plot complete")

## 7. Live Monitoring Loop (Optional)

**Run this cell for continuous monitoring with auto-refresh**

Updates every 60 seconds. Press 'Stop' button to halt.

In [None]:
from IPython.display import clear_output
from rgym_exp.utils.cumulative_reward_monitor import get_live_cumulative_rewards
import time
from datetime import datetime

print("Starting live monitoring...")
print("Press 'Stop' button to halt\n")

try:
    while True:
        clear_output(wait=True)
        
        current_time = datetime.now().strftime('%Y-%m-%d %H:%M:%S')
        
        print("="*70)
        print(f"LIVE TOTAL CUMULATIVE REWARD MONITOR - {current_time}")
        print("="*70)
        print()
        
        for exp_name in EXPERIMENTS:
            total, per_node, current_round = get_live_cumulative_rewards(
                GDRIVE_BASE_PATH, exp_name
            )
            
            # Short name
            short_name = exp_name.replace('sapo_gpt2_', '')
            
            print(f"{short_name:25s}:")
            print(f"  Total Cumulative: {total:8.2f}")
            print(f"  Round: {current_round:4d} / {MAX_ROUNDS}")
            print(f"  Nodes: {len(per_node)}")
            
            # Progress bar
            if MAX_ROUNDS > 0:
                progress = min(1.0, current_round / MAX_ROUNDS)
                bar_len = 30
                filled = int(bar_len * progress)
                bar = '█' * filled + '░' * (bar_len - filled)
                print(f"  Progress: [{bar}] {progress*100:.1f}%")
            
            print()
        
        print("-"*70)
        print("Refreshing in 60 seconds... (Press 'Stop' to halt)")
        
        time.sleep(60)
        
except KeyboardInterrupt:
    print("\n✓ Monitoring stopped")