# 🚀 Quick Start: Notebook Training

This notebook provides a minimal working example to get you started with training in Jupyter.

**Goal**: Run a short 20k-step training session to verify everything works!

**Time**: ~5-10 minutes on Apple Silicon MPS

---

## Instructions:
1. Run each cell in order (Shift+Enter)
2. Modify the CONFIG cell to experiment
3. Check the results in the evaluation cells

Let's go! 🎮


In [None]:
# ============================================================================
# CELL 1: Setup (Run this first!)
# ============================================================================

# Set MPS fallback BEFORE importing torch (important for Apple Silicon!)
import os
os.environ['PYTORCH_ENABLE_MPS_FALLBACK'] = '1'

# Core imports
import sys
from pathlib import Path
import torch
import numpy as np
import matplotlib.pyplot as plt
from functools import partial

# Configure matplotlib for inline display
%matplotlib inline
plt.rcParams['figure.figsize'] = (12, 6)

# Add project root to Python path
PROJECT_ROOT = Path.cwd().parent  # Assumes notebook is in /guides/
if str(PROJECT_ROOT) not in sys.path:
    sys.path.insert(0, str(PROJECT_ROOT))
    sys.path.insert(0, str(PROJECT_ROOT / "user"))

print(f"✓ Imports complete")
print(f"✓ Project root: {PROJECT_ROOT}")
print(f"✓ Python: {sys.version.split()[0]}")


: 

In [None]:
# ============================================================================
# CELL 2: Import Training Components
# ============================================================================

# Import everything from your existing training script
# This avoids code duplication!
from train_agent import (
    TransformerStrategyAgent,
    BasedAgent,
    gen_reward_manager,
    build_self_play_components,
    run_training_loop,
    run_eval_match,
    get_torch_device,
    CameraResolution,
    TrainLogging,
    Result,
    SaveHandlerMode,
    SelfPlayRandom,
)

# Detect and configure device (MPS/CUDA/CPU)
TORCH_DEVICE = get_torch_device()

print(f"✓ Training components imported")
print(f"✓ Device: {TORCH_DEVICE}")


---
## 📋 Configuration

**This is where you experiment!** Modify these values and re-run this cell to try different settings.


In [None]:
# ============================================================================
# CELL 3: Configuration (MODIFY THIS TO EXPERIMENT!)
# ============================================================================

CONFIG = {
    # Training parameters
    "total_timesteps": 20_000,        # Short test run (~5-10 min on MPS)
    "save_freq": 5_000,               # Save checkpoint every 5k steps
    "eval_freq": 2_500,               # Evaluate every 2.5k steps
    "eval_episodes": 3,               # Number of evaluation matches
    
    # Transformer hyperparameters
    "latent_dim": 256,                # Strategy embedding dimension
    "num_heads": 8,                   # Number of attention heads
    "num_layers": 6,                  # Transformer depth
    "sequence_length": 90,            # Frames to analyze (3 sec at 30 FPS)
    
    # PPO hyperparameters  
    "n_steps": 30 * 90 * 20,          # Steps per rollout (54,000)
    "batch_size": 128,                # Batch size (powers of 2 for MPS)
    "learning_rate": 2.5e-4,          # Learning rate
    "ent_coef": 0.10,                 # Entropy coefficient (exploration)
    "lstm_hidden_size": 512,          # LSTM hidden units
    
    # Experiment tracking
    "run_name": "notebook_quick_test", # Name for this run
    "load_checkpoint": None,          # Path to checkpoint or None
}

# Print configuration
print("=" * 70)
print("📋 Training Configuration")
print("=" * 70)
for key, value in CONFIG.items():
    print(f"  {key:25s}: {value}")
print("=" * 70)
print(f"\n💡 TIP: Modify values above and re-run this cell to experiment!")


In [None]:
# ============================================================================
# CELL 4: Create Transformer Strategy Agent
# ============================================================================

print("Creating TransformerStrategyAgent...")

# Create agent with transformer-based strategy recognition
learning_agent = TransformerStrategyAgent(
    file_path=CONFIG["load_checkpoint"],
    latent_dim=CONFIG["latent_dim"],
    num_heads=CONFIG["num_heads"],
    num_layers=CONFIG["num_layers"],
    sequence_length=CONFIG["sequence_length"],
    opponent_obs_dim=None  # Will be auto-detected from environment
)

# Set PPO hyperparameters
learning_agent.default_policy_kwargs = {
    'activation_fn': torch.nn.ReLU,
    'lstm_hidden_size': CONFIG["lstm_hidden_size"],
    'net_arch': dict(pi=[96, 96], vf=[96, 96]),
    'shared_lstm': True,
    'enable_critic_lstm': False,
    'share_features_extractor': True,
}
learning_agent.default_n_steps = CONFIG["n_steps"]
learning_agent.default_batch_size = CONFIG["batch_size"]
learning_agent.default_ent_coef = CONFIG["ent_coef"]

print("=" * 70)
print("✓ Agent Created Successfully!")
print("=" * 70)
print(f"  Architecture: Transformer Strategy Agent")
print(f"  Latent dimension: {CONFIG['latent_dim']}")
print(f"  Attention heads: {CONFIG['num_heads']}")
print(f"  Transformer layers: {CONFIG['num_layers']}")
print(f"  Sequence length: {CONFIG['sequence_length']} frames")
print(f"  Device: {TORCH_DEVICE}")
print("=" * 70)


In [None]:
# ============================================================================
# CELL 5: Setup Reward Manager
# ============================================================================

print("Setting up reward functions...")

# Create reward manager with all reward terms
reward_manager = gen_reward_manager()

# Display active reward functions
print("=" * 70)
print("📊 Active Reward Functions")
print("=" * 70)

if reward_manager.reward_functions:
    for name, term in reward_manager.reward_functions.items():
        print(f"  {name:35s} weight: {term.weight:+.3f}")

print("\n🎯 Signal-Based Rewards")
if reward_manager.signal_subscriptions:
    for name, (signal_name, term) in reward_manager.signal_subscriptions.items():
        print(f"  {name:35s} weight: {term.weight:+.3f}")

print("=" * 70)
print("✓ Reward manager configured")
print("=" * 70)


In [None]:
# ============================================================================
# CELL 6: Setup Opponents & Checkpointing
# ============================================================================

print("Configuring self-play and opponents...")

# Define opponent mix (who the agent trains against)
opponent_mix = {
    'based_agent': (1.0, partial(BasedAgent)),  # Scripted heuristic opponent
}

# Build self-play infrastructure
selfplay_handler, save_handler, opponent_cfg = build_self_play_components(
    learning_agent,
    run_name=CONFIG["run_name"],
    save_freq=CONFIG["save_freq"],
    max_saved=10,  # Keep last 10 checkpoints
    mode=SaveHandlerMode.FORCE,
    opponent_mix=opponent_mix,
    selfplay_handler_cls=SelfPlayRandom,
)

print("=" * 70)
print("✓ Training Infrastructure Ready")
print("=" * 70)
print(f"  Run name: {CONFIG['run_name']}")
print(f"  Checkpoint frequency: every {CONFIG['save_freq']:,} steps")
print(f"  Checkpoint directory: checkpoints/{CONFIG['run_name']}/")
print(f"  Opponents: {list(opponent_mix.keys())}")
print("=" * 70)


---
## 🚀 Training

**Run this cell to start training!** 

**Note**: You can interrupt training anytime by clicking the ⏹ stop button.


In [None]:
# ============================================================================
# CELL 7: START TRAINING! 🚀
# ============================================================================

print("\n" + "=" * 70)
print("🚀 STARTING TRAINING")
print("=" * 70)
print(f"  Total timesteps: {CONFIG['total_timesteps']:,}")
print(f"  Device: {TORCH_DEVICE}")
print(f"  Run name: {CONFIG['run_name']}")
print(f"  Estimated time: ~{CONFIG['total_timesteps'] / 2000:.0f}-{CONFIG['total_timesteps'] / 1000:.0f} minutes")
print("=" * 70)
print("\n💡 TIP: You can interrupt training with the ⏹ button\n")

# Run training loop!
result = run_training_loop(
    agent=learning_agent,
    reward_manager=reward_manager,
    save_handler=save_handler,
    opponent_cfg=opponent_cfg,
    resolution=CameraResolution.LOW,
    train_timesteps=CONFIG["total_timesteps"],
    train_logging=TrainLogging.PLOT,
)

print("\n" + "=" * 70)
print("✓ TRAINING COMPLETE!")
print("=" * 70)
print(f"  Checkpoints saved to: checkpoints/{CONFIG['run_name']}/")
print("=" * 70)


---
## 📊 Evaluation & Results

Let's see how well your agent performs!


In [None]:
# ============================================================================
# CELL 8: Run Evaluation Matches
# ============================================================================

print("🎯 Running evaluation matches...")
print("=" * 70)

results = []
for i in range(CONFIG["eval_episodes"]):
    print(f"  Match {i+1}/{CONFIG['eval_episodes']}...", end=" ", flush=True)
    
    # Run match against BasedAgent
    match_stats = run_eval_match(
        learning_agent,
        partial(BasedAgent),
        max_timesteps=30*90,  # 90 seconds
        resolution=CameraResolution.LOW,
        train_mode=True
    )
    
    # Extract results
    won = match_stats.player1_result == Result.WIN
    damage_dealt = match_stats.player2.total_damage  # Damage to opponent
    damage_taken = match_stats.player1.total_damage
    
    results.append({
        "won": won,
        "damage_dealt": damage_dealt,
        "damage_taken": damage_taken
    })
    
    print(f"{'✓ WIN' if won else '✗ LOSS'} "
          f"(Damage: {damage_dealt:.0f} dealt / {damage_taken:.0f} taken)")

# Calculate summary statistics
win_rate = np.mean([r["won"] for r in results]) * 100
avg_damage_dealt = np.mean([r["damage_dealt"] for r in results])
avg_damage_taken = np.mean([r["damage_taken"] for r in results])
damage_ratio = avg_damage_dealt / max(avg_damage_taken, 1)

print("=" * 70)
print("📈 EVALUATION RESULTS")
print("=" * 70)
print(f"  Win Rate:        {win_rate:.1f}%")
print(f"  Avg Damage Dealt: {avg_damage_dealt:.1f}")
print(f"  Avg Damage Taken: {avg_damage_taken:.1f}")
print(f"  Damage Ratio:    {damage_ratio:.2f}x")
print("=" * 70)

# Interpretation
if win_rate >= 70:
    print("🎉 Excellent! Agent is performing very well!")
elif win_rate >= 50:
    print("👍 Good! Agent is learning effectively.")
elif win_rate >= 30:
    print("📈 Progress! Agent needs more training.")
else:
    print("⚠️  Agent needs more training or config tuning.")


In [None]:
# ============================================================================
# CELL 9: Plot Training Curves
# ============================================================================

import pandas as pd

# Load training logs
log_path = f"checkpoints/{CONFIG['run_name']}/monitor.csv"

try:
    # Read CSV (skip metadata row)
    df = pd.read_csv(log_path, skiprows=1)
    
    # Create plots
    fig, axes = plt.subplots(2, 2, figsize=(15, 10))
    
    # 1. Episode Reward
    axes[0, 0].plot(df['r'], alpha=0.3, label='Raw')
    axes[0, 0].plot(df['r'].rolling(10).mean(), linewidth=2, label='10-ep MA')
    axes[0, 0].set_title('Episode Reward', fontsize=14, fontweight='bold')
    axes[0, 0].set_xlabel('Episode')
    axes[0, 0].set_ylabel('Reward')
    axes[0, 0].legend()
    axes[0, 0].grid(True, alpha=0.3)
    
    # 2. Episode Length
    axes[0, 1].plot(df['l'], alpha=0.3, color='orange', label='Raw')
    axes[0, 1].plot(df['l'].rolling(10).mean(), linewidth=2, color='darkorange', label='10-ep MA')
    axes[0, 1].set_title('Episode Length', fontsize=14, fontweight='bold')
    axes[0, 1].set_xlabel('Episode')
    axes[0, 1].set_ylabel('Steps')
    axes[0, 1].legend()
    axes[0, 1].grid(True, alpha=0.3)
    
    # 3. Reward Distribution (last 50 episodes)
    axes[1, 0].hist(df['r'].tail(50), bins=15, alpha=0.7, color='green', edgecolor='black')
    axes[1, 0].axvline(df['r'].tail(50).mean(), color='red', linestyle='--', 
                       linewidth=2, label=f'Mean: {df["r"].tail(50).mean():.2f}')
    axes[1, 0].set_title('Recent Reward Distribution (Last 50 Episodes)', 
                         fontsize=14, fontweight='bold')
    axes[1, 0].set_xlabel('Reward')
    axes[1, 0].set_ylabel('Frequency')
    axes[1, 0].legend()
    axes[1, 0].grid(True, alpha=0.3, axis='y')
    
    # 4. Training Time
    axes[1, 1].plot(df['t'].cumsum() / 60, color='purple', linewidth=2)
    axes[1, 1].set_title('Cumulative Training Time', fontsize=14, fontweight='bold')
    axes[1, 1].set_xlabel('Episode')
    axes[1, 1].set_ylabel('Time (minutes)')
    axes[1, 1].grid(True, alpha=0.3)
    
    plt.tight_layout()
    plt.show()
    
    # Print summary statistics
    print("=" * 70)
    print("📊 TRAINING STATISTICS")
    print("=" * 70)
    print(f"  Total episodes: {len(df)}")
    print(f"  Latest reward: {df['r'].iloc[-1]:.2f}")
    print(f"  Average (last 20): {df['r'].tail(20).mean():.2f}")
    print(f"  Best reward: {df['r'].max():.2f}")
    print(f"  Total training time: {df['t'].sum() / 60:.1f} minutes")
    print("=" * 70)
    
except FileNotFoundError:
    print(f"⚠️  Training logs not found at: {log_path}")
    print("   Make sure training has completed (run Cell 7 first)")


---
## 💾 Save Model

Your model was automatically saved during training, but you can manually save here too.


In [None]:
# ============================================================================
# CELL 10: Manual Save (Optional)
# ============================================================================

# Training auto-saves, but you can manually save here
final_save_path = f"checkpoints/{CONFIG['run_name']}/final_model.zip"

learning_agent.save(final_save_path)

print("=" * 70)
print("✓ Model saved successfully!")
print("=" * 70)
print(f"  Path: {final_save_path}")
print(f"  Includes: RecurrentPPO policy + Transformer encoder")
print("=" * 70)

# Show all saved checkpoints
import os
checkpoint_dir = f"checkpoints/{CONFIG['run_name']}/"
if os.path.exists(checkpoint_dir):
    checkpoints = [f for f in os.listdir(checkpoint_dir) if f.endswith('.zip')]
    print(f"\n📁 All checkpoints in {checkpoint_dir}:")
    for cp in sorted(checkpoints):
        print(f"  - {cp}")


---
## 🎓 Next Steps

Congratulations! You've successfully trained an AI agent in a notebook! 🎉

### What to try next:

1. **Experiment with hyperparameters**:
   - Go back to Cell 3 (CONFIG)
   - Modify values (e.g., `learning_rate`, `ent_coef`, `batch_size`)
   - Re-run training and compare results

2. **Longer training**:
   - Increase `total_timesteps` to 50,000 or 100,000
   - Better performance needs more training!

3. **Load and continue training**:
   ```python
   CONFIG["load_checkpoint"] = "checkpoints/notebook_quick_test/rl_model_20000_steps.zip"
   CONFIG["total_timesteps"] = 50_000  # Train for 30k more steps
   ```

4. **Visualize attention patterns**:
   ```python
   # See what frames the transformer focuses on
   attention_info = learning_agent.visualize_attention(obs)
   ```

5. **Compare experiments**:
   - Run multiple training sessions with different configs
   - Plot them together to compare

6. **Read the full guide**:
   - Check out `NOTEBOOK_TRAINING_GUIDE.md` for advanced tips

---

### Key Benefits You Just Experienced:

✅ **Easy Configuration**: Changed settings in one cell  
✅ **Interactive**: Ran evaluation and plotting without restarting  
✅ **Fast Iteration**: Quick feedback loop for experimentation  
✅ **Visual Feedback**: Saw training curves immediately  

### When to use Scripts vs Notebooks:

- **Notebooks**: Experimentation, debugging, short runs, learning
- **Scripts** (`train_agent.py`): Long overnight runs, production, automation

Happy training! 🚀
