# Parallel Multi-Map Training on Google Colab (L4 GPU Optimized)

This notebook trains a **generalized policy** that works across **all maps** using parallel environment execution.

## Key Features:
- **Parallel Training**: Multiple environments run simultaneously
- **Multi-Map Learning**: Trains on all maps at once for generalization
- **Curriculum Learning**: Starts easy, gradually increases difficulty
- **L4 GPU Optimized**: Batched operations for maximum GPU utilization
- **Easy to Extend**: Simple to add new maps

## Setup Instructions:
1. **Enable L4 GPU**: Runtime → Change runtime type → Hardware accelerator → GPU → L4
2. **Run all cells in order**
3. **Training saves checkpoints** automatically
4. **Download checkpoints** when complete

---

## 1. Check GPU and Clone Repository

In [None]:
import torch

# Check GPU
print("PyTorch version:", torch.__version__)
print("CUDA available:", torch.cuda.is_available())

if torch.cuda.is_available():
    print("\n✅ GPU DETECTED")
    print("GPU Device:", torch.cuda.get_device_name(0))
    total_mem = torch.cuda.get_device_properties(0).total_memory / 1e9
    print(f"GPU Memory: {total_mem:.1f} GB")
    
    # Check if L4
    gpu_name = torch.cuda.get_device_name(0)
    if 'L4' in gpu_name:
        print("\n🚀 L4 GPU DETECTED - Optimal configuration!")
    else:
        print(f"\n⚠️  GPU is {gpu_name}, not L4")
        print("For best performance, select L4 GPU in Runtime settings")
else:
    print("\n❌ NO GPU DETECTED")
    print("Enable GPU: Runtime → Change runtime type → Hardware accelerator → GPU")

In [None]:
# Clone repository
!git clone https://github.com/Ben-jpg-del/CalHacks.git
%cd CalHacks

print("\n✅ Repository cloned!")

## 2. Install Dependencies

In [None]:
# Install required packages
!pip install torch numpy pygame wandb -q

print("✅ Dependencies installed!")

## 3. Verify Environment and Maps

In [None]:
from parallel_multi_map_env import MapRegistry, ParallelMultiMapEnv
import torch

# Check available maps
print("Available Maps:")
print("=" * 60)
for map_name in MapRegistry.get_map_names():
    level = MapRegistry.get_map(map_name)
    print(f"  {map_name.upper()}")
    print(f"    - Dimensions: {level.width}x{level.height}")
    print(f"    - Platforms: {len(level.base_solids)}")
    print(f"    - Hazards: {list(level.get_hazards().keys())}")
    print()

# Test parallel environment
print("Testing Parallel Environment:")
print("=" * 60)
device = 'cuda' if torch.cuda.is_available() else 'cpu'
env = ParallelMultiMapEnv(num_envs_per_map=4, device=device)
print(env.get_statistics())

# Test reset and step
fire_obs, water_obs = env.reset()
print(f"\nObservation shape: {fire_obs.shape}")
print(f"Observation device: {fire_obs.device}")
print("\n✅ Environment test passed!")

env.close()

## 4. Configure Training

**Adjust these parameters as needed:**

In [None]:
# ============================================
# TRAINING CONFIGURATION
# ============================================

# Training settings
NUM_EPISODES = 5000          # Total episodes to train
NUM_ENVS_PER_MAP = 8         # Parallel environments per map (L4 can handle 8-16)
USE_CURRICULUM = True        # Start easy, gradually add harder maps
REWARD_TYPE = 'dense'        # Options: 'sparse', 'dense', 'cooperation', 'safety'

# Curriculum schedule (episode -> map distribution)
# Customize this to control learning difficulty progression
CURRICULUM_SCHEDULE = {
    0: {'tutorial': 1.0},                          # Start: 100% tutorial
    500: {'tutorial': 0.7, 'tower': 0.3},         # Episode 500: Add 30% tower
    1000: {'tutorial': 0.5, 'tower': 0.5},        # Episode 1000: Equal split
    2000: {'tutorial': 0.3, 'tower': 0.7},        # Episode 2000: Focus on tower
    # Add new maps here:
    # 3000: {'tutorial': 0.2, 'tower': 0.3, 'new_map': 0.5},
}

# Alternative: Fixed distribution (if USE_CURRICULUM = False)
MAP_DISTRIBUTION = {
    'tutorial': 0.5,
    'tower': 0.5,
}

# Checkpoint settings
SAVE_DIR = 'checkpoints_multimap'
SAVE_FREQ = 100              # Save every N episodes
LOG_FREQ = 10                # Log every N episodes

# Resume training (optional)
RESUME_FROM = None           # e.g., 'checkpoints_multimap/checkpoint_ep1000'

# Weights & Biases (optional)
USE_WANDB = False
WANDB_PROJECT = 'firewater-multimap'

# ============================================

print("Training Configuration:")
print("=" * 60)
print(f"Total Episodes: {NUM_EPISODES}")
print(f"Environments per Map: {NUM_ENVS_PER_MAP}")
print(f"Curriculum Learning: {USE_CURRICULUM}")
print(f"Reward Type: {REWARD_TYPE}")
print(f"Save Directory: {SAVE_DIR}")
if RESUME_FROM:
    print(f"\n⚠️  RESUMING from {RESUME_FROM}")
else:
    print(f"\nStarting fresh training")
print("=" * 60)

# Estimate parallel environments
num_maps = len(MapRegistry.get_map_names())
total_envs = NUM_ENVS_PER_MAP * num_maps
print(f"\nEstimated parallel environments: ~{total_envs}")
print(f"GPU utilization: {'High ✅' if total_envs >= 8 else 'Medium ⚠️'}")

## 5. Upload Existing Checkpoints (OPTIONAL - For Resuming)

In [None]:
from google.colab import files
import os
import zipfile

print("📤 Upload checkpoints.zip to resume training")
print("(Skip this cell if starting fresh)\n")

uploaded = files.upload()

if 'checkpoints.zip' in uploaded or 'checkpoints_multimap.zip' in uploaded:
    zip_name = 'checkpoints.zip' if 'checkpoints.zip' in uploaded else 'checkpoints_multimap.zip'
    
    # Extract
    with zipfile.ZipFile(zip_name, 'r') as zip_ref:
        zip_ref.extractall('.')
    
    print(f"\n✅ Extracted {zip_name}")
    
    # List checkpoints
    if os.path.exists(SAVE_DIR):
        checkpoints = [d for d in os.listdir(SAVE_DIR) if d.startswith('checkpoint_ep')]
        if checkpoints:
            print(f"\nFound {len(checkpoints)} checkpoint(s)")
            for cp in sorted(checkpoints)[-5:]:
                print(f"  - {cp}")
            
            # Suggest latest
            latest = sorted(checkpoints)[-1]
            print(f"\n💡 Latest checkpoint: {latest}")
            print(f"💡 Set RESUME_FROM = '{SAVE_DIR}/{latest}' in config cell")
else:
    print("No checkpoints uploaded. Starting fresh.")

## 6. Setup Weights & Biases (OPTIONAL)

In [None]:
if USE_WANDB:
    import wandb
    wandb.login()
    print("✅ Logged in to W&B")
else:
    print("W&B disabled. Set USE_WANDB=True to enable.")

## 7. Train Agents

This will train a **generalized policy** across all maps using parallel environments.

**Training Features:**
- Parallel environment execution
- Multi-map simultaneous training
- Curriculum learning (optional)
- Automatic checkpointing
- Per-map success tracking

**Note**: Training may take several hours depending on NUM_EPISODES and GPU.

In [None]:
# Reload module to get latest changes (in case it was imported before)
import importlib
import sys
if 'train_parallel_multimap' in sys.modules:
    import train_parallel_multimap
    importlib.reload(train_parallel_multimap)
    
from train_parallel_multimap import train_parallel_multimap

print("🚀 Starting Parallel Multi-Map Training...")
print("=" * 80)

train_parallel_multimap(
    num_episodes=NUM_EPISODES,
    num_envs_per_map=NUM_ENVS_PER_MAP,
    map_distribution=MAP_DISTRIBUTION if not USE_CURRICULUM else None,
    use_curriculum=USE_CURRICULUM,
    curriculum_schedule=CURRICULUM_SCHEDULE if USE_CURRICULUM else None,
    reward_type=REWARD_TYPE,
    save_dir=SAVE_DIR,
    save_freq=SAVE_FREQ,
    log_freq=LOG_FREQ,
    device='cuda' if torch.cuda.is_available() else 'cpu',
    resume_from=RESUME_FROM,
    use_wandb=USE_WANDB,
    wandb_project=WANDB_PROJECT
)

print("\n✅ Training complete!")

## 8. List Saved Checkpoints

In [None]:
import os

if os.path.exists(SAVE_DIR):
    checkpoints = sorted([d for d in os.listdir(SAVE_DIR) if d.startswith('checkpoint_')])
    
    print(f"Checkpoints in {SAVE_DIR}:")
    print("=" * 60)
    
    for cp in checkpoints:
        cp_dir = os.path.join(SAVE_DIR, cp)
        if os.path.isdir(cp_dir):
            files = os.listdir(cp_dir)
            fire_exists = 'fire_agent.pth' in files
            water_exists = 'water_agent.pth' in files
            status = "✅" if (fire_exists and water_exists) else "⚠️"
            print(f"{status} {cp}")
    
    # Check final
    final_dir = os.path.join(SAVE_DIR, 'final')
    if os.path.exists(final_dir):
        print(f"\n✅ Final checkpoint saved")
    
    print("\n" + "=" * 60)
    print(f"Total checkpoints: {len(checkpoints)}")
else:
    print(f"No checkpoints found in {SAVE_DIR}")

## 9. Evaluate Trained Agent on All Maps

Test how well the generalized policy works on each map individually:

In [None]:
from train_parallel_multimap import ParallelDQNAgent
from parallel_multi_map_env import MapRegistry
from game_environment import FireWaterEnv
import numpy as np
import torch

# Select checkpoint to evaluate
EVAL_CHECKPOINT = 'final'  # or 'checkpoint_ep1000', etc.
NUM_EVAL_EPISODES = 100    # Episodes per map

print(f"Evaluating checkpoint: {EVAL_CHECKPOINT}")
print("=" * 80)

# Load agents
device = 'cuda' if torch.cuda.is_available() else 'cpu'
fire_agent = ParallelDQNAgent(device=device)
water_agent = ParallelDQNAgent(device=device)

checkpoint_dir = os.path.join(SAVE_DIR, EVAL_CHECKPOINT)
fire_agent.load(os.path.join(checkpoint_dir, 'fire_agent.pth'))
water_agent.load(os.path.join(checkpoint_dir, 'water_agent.pth'))

# Set to evaluation mode
fire_agent.epsilon = 0.0
water_agent.epsilon = 0.0

print(f"✅ Agents loaded from {checkpoint_dir}\n")

# Evaluate on each map
all_results = {}

for map_name in MapRegistry.get_map_names():
    print(f"Testing on {map_name.upper()} map...")
    
    # Create environment for this map
    level = MapRegistry.get_map(map_name)
    env = FireWaterEnv(level=level)
    
    successes = 0
    total_rewards = []
    episode_lengths = []
    
    for ep in range(NUM_EVAL_EPISODES):
        fire_obs, water_obs = env.reset()
        fire_obs = torch.FloatTensor(fire_obs).unsqueeze(0).to(device)
        water_obs = torch.FloatTensor(water_obs).unsqueeze(0).to(device)
        
        done = False
        episode_reward = 0
        steps = 0
        
        while not done and steps < 3000:
            # Get actions
            fire_action = fire_agent.select_actions(fire_obs, training=False).item()
            water_action = water_agent.select_actions(water_obs, training=False).item()
            
            # Step
            (fire_obs_np, water_obs_np), (fire_reward, water_reward), \
            (fire_done, water_done), info = env.step(fire_action, water_action)
            
            fire_obs = torch.FloatTensor(fire_obs_np).unsqueeze(0).to(device)
            water_obs = torch.FloatTensor(water_obs_np).unsqueeze(0).to(device)
            
            episode_reward += fire_reward + water_reward
            steps += 1
            done = fire_done or water_done
        
        if info.get('both_won', False):
            successes += 1
        
        total_rewards.append(episode_reward)
        episode_lengths.append(steps)
    
    # Store results
    all_results[map_name] = {
        'success_rate': successes / NUM_EVAL_EPISODES * 100,
        'avg_reward': np.mean(total_rewards),
        'avg_length': np.mean(episode_lengths),
        'successes': successes
    }
    
    print(f"  Success Rate: {all_results[map_name]['success_rate']:.1f}%")
    print(f"  Avg Reward: {all_results[map_name]['avg_reward']:.2f}")
    print(f"  Avg Length: {all_results[map_name]['avg_length']:.1f} steps")
    print()
    
    env.close()

# Summary
print("=" * 80)
print("EVALUATION SUMMARY")
print("=" * 80)
overall_success = np.mean([r['success_rate'] for r in all_results.values()])
print(f"Overall Success Rate: {overall_success:.1f}%")
print(f"\nPer-Map Results:")
for map_name, results in all_results.items():
    print(f"  {map_name.upper()}: {results['success_rate']:.1f}% ({results['successes']}/{NUM_EVAL_EPISODES})")
print("=" * 80)

## 10. Download Checkpoints

In [None]:
import shutil

# Zip all checkpoints
shutil.make_archive('checkpoints_multimap', 'zip', SAVE_DIR)

print("✅ Checkpoints zipped!")
print(f"\nFile: checkpoints_multimap.zip")
print(f"Size: {os.path.getsize('checkpoints_multimap.zip') / 1e6:.1f} MB")
print("\nDownload from Files panel (left sidebar) →→→")

# Alternative: Auto-download
from google.colab import files
# files.download('checkpoints_multimap.zip')  # Uncomment to auto-download

## 11. Visualize Locally (Instructions)

To visualize your trained generalized policy:

1. **Download** `checkpoints_multimap.zip` from Colab
2. **Extract** to your local repository
3. **Run visualization** on any map:

```bash
# On your local machine:

# Tutorial map
python visualize.py trained checkpoints_multimap/final/fire_agent.pth checkpoints_multimap/final/water_agent.pth --map tutorial

# Tower map
python visualize.py trained checkpoints_multimap/final/fire_agent.pth checkpoints_multimap/final/water_agent.pth --map tower

# Any new map you add
python visualize.py trained checkpoints_multimap/final/fire_agent.pth checkpoints_multimap/final/water_agent.pth --map your_new_map
```

**Note**: The policy is trained on all maps, so it should generalize well!

## 12. Adding New Maps (Instructions)

To add a new map to the training:

### Step 1: Create Map File
Create `map_3.py` (or any name) with your level definition:

```python
# map_3.py
from physics_engine import Rect

class LevelConfig:
    def __init__(self, level_name="My New Map"):
        self.name = level_name
        self.width = 960
        self.height = 540
        # ... define platforms, hazards, etc.

class LevelLibrary:
    @staticmethod
    def get_my_map():
        return LevelConfig("My New Map")
```

### Step 2: Register Map
Edit `parallel_multi_map_env.py`, add to `MapRegistry.get_all_maps()`:

```python
from map_3 import LevelLibrary as Map3Library

@staticmethod
def get_all_maps():
    return {
        'tutorial': LevelLibrary.get_tutorial_level(),
        'tower': Map1Library.get_tower_level(),
        'my_new_map': Map3Library.get_my_map(),  # Add this line
    }
```

### Step 3: Update Curriculum (Optional)
In the config cell above, add to `CURRICULUM_SCHEDULE`:

```python
CURRICULUM_SCHEDULE = {
    0: {'tutorial': 1.0},
    500: {'tutorial': 0.7, 'tower': 0.3},
    1000: {'tutorial': 0.5, 'tower': 0.5},
    2000: {'tutorial': 0.3, 'tower': 0.4, 'my_new_map': 0.3},  # Add new map
}
```

### Step 4: Re-run Training
The system will automatically include your new map!

That's it! The parallel training system handles everything else automatically.