# LOWO (Leave-One-World-Out) Training - Google Colab

This notebook trains all 6 LOWO models for Sprint 5 generalization experiments.

**Requirements:**
- Enable GPU: Runtime > Change runtime type > T4 GPU (or A100 for faster training)
- Mount Google Drive for checkpoint persistence

**Estimated Time:**
- T4 GPU: ~45-60 min per model (~5-6 hours total)
- A100 GPU: ~15-20 min per model (~2 hours total)

## 1. Setup Environment

In [None]:
# Check GPU availability
!nvidia-smi

In [None]:
# Mount Google Drive for persistent storage
from google.colab import drive
drive.mount('/content/drive')

# Create project directory in Drive
!mkdir -p /content/drive/MyDrive/macro_emulator/runs
!mkdir -p /content/drive/MyDrive/macro_emulator/datasets

In [None]:
# Clone the repository (UPDATE THIS URL to your repo)
%cd /content

# Option 1: Clone from GitHub (if public or with token)
# !git clone https://github.com/YOUR_USERNAME/macro_simulator.git

# Option 2: Upload from local machine
# Use the file browser on the left to upload a zip of your repo
# Then uncomment:
# !unzip macro_simulator.zip

# Option 3: Clone from Drive (if you've synced your repo there)
# !cp -r /content/drive/MyDrive/macro_simulator /content/macro_simulator

print("Choose one of the options above and uncomment it")

In [None]:
# Install dependencies
%cd /content/macro_simulator
!pip install -e ".[dev]" -q

# Verify installation
import torch
print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")

In [None]:
# Quick sanity check - run fast tests
!python -m pytest -m "fast" -v --tb=line -q 2>&1 | tail -20

## 2. Generate or Load Dataset

You have two options:
1. Generate fresh dataset (takes ~30-60 min)
2. Load pre-generated dataset from Google Drive

In [None]:
import os

DATASET_PATH = "datasets/v1.0"
DRIVE_DATASET_PATH = "/content/drive/MyDrive/macro_emulator/datasets/v1.0"

# Check if dataset exists in Drive
if os.path.exists(DRIVE_DATASET_PATH) and os.path.exists(os.path.join(DRIVE_DATASET_PATH, "manifest.json")):
    print("Found existing dataset in Google Drive!")
    print("Symlinking to local path...")
    !ln -sf {DRIVE_DATASET_PATH} {DATASET_PATH}
    print("Dataset ready.")
else:
    print("No dataset found in Drive. Will generate new dataset.")
    print("This will take ~30-60 minutes...")

In [None]:
# Generate dataset if not found (skip if already exists)
import os

if not os.path.exists(os.path.join(DATASET_PATH, "manifest.json")):
    print("Generating dataset with 10k samples per world...")
    print("This will take ~30-60 minutes. Go grab a coffee!")
    
    !python -m data.scripts.generate_dataset \
        --world all \
        --n_samples 10000 \
        --seed 42 \
        --output {DATASET_PATH}
    
    # Copy to Drive for persistence
    print("\nCopying dataset to Google Drive for future use...")
    !cp -r {DATASET_PATH} {DRIVE_DATASET_PATH}
    print("Dataset saved to Drive!")
else:
    print("Dataset already exists. Skipping generation.")

In [None]:
# Validate dataset
!python -m data.scripts.validate_dataset --path {DATASET_PATH}

## 3. LOWO Training

Train 6 models, each with one simulator family held out.

In [None]:
# Define LOWO experiments
LOWO_WORLDS = ['lss', 'var', 'nk', 'rbc', 'switching', 'zlb']

# Symlink runs directory to Drive for persistence
DRIVE_RUNS_PATH = "/content/drive/MyDrive/macro_emulator/runs"
!rm -rf runs  # Remove local runs if exists
!ln -sf {DRIVE_RUNS_PATH} runs

print(f"Runs will be saved to: {DRIVE_RUNS_PATH}")
print(f"\nWill train {len(LOWO_WORLDS)} LOWO models:")
for world in LOWO_WORLDS:
    print(f"  - lowo_exclude_{world}")

In [None]:
# Check which models are already trained
import os

remaining_worlds = []
for world in LOWO_WORLDS:
    checkpoint_path = f"runs/lowo_exclude_{world}/best_checkpoint.pt"
    if os.path.exists(checkpoint_path):
        print(f"[DONE] lowo_exclude_{world}")
    else:
        print(f"[TODO] lowo_exclude_{world}")
        remaining_worlds.append(world)

print(f"\n{len(remaining_worlds)} models remaining to train.")

In [None]:
# Train all remaining LOWO models
import subprocess
import time

for i, world in enumerate(remaining_worlds):
    config_path = f"configs/lowo_exclude_{world}.yaml"
    
    print(f"\n{'='*60}")
    print(f"Training LOWO model {i+1}/{len(remaining_worlds)}: exclude_{world}")
    print(f"Config: {config_path}")
    print(f"{'='*60}\n")
    
    start_time = time.time()
    
    # Run training
    !python -m emulator.training.trainer --config {config_path}
    
    elapsed = time.time() - start_time
    print(f"\nCompleted exclude_{world} in {elapsed/60:.1f} minutes")
    print(f"Checkpoint saved to: runs/lowo_exclude_{world}/")

print("\n" + "="*60)
print("ALL LOWO MODELS TRAINED!")
print("="*60)

## 4. Evaluate LOWO Transfer

Evaluate each model on its held-out world to measure generalization.

In [None]:
# Evaluate all LOWO models
# For each model, evaluate on BOTH the held-out world (transfer) and in-domain worlds
import json
import os

results = {}

for world in LOWO_WORLDS:
    checkpoint_path = f"runs/lowo_exclude_{world}/best_checkpoint.pt"
    
    if not os.path.exists(checkpoint_path):
        print(f"[SKIP] No checkpoint for lowo_exclude_{world}")
        continue
    
    print(f"\n{'='*50}")
    print(f"Evaluating lowo_exclude_{world}")
    print(f"{'='*50}")
    
    # Evaluate on held-out world (transfer performance)
    held_out_output = f"runs/lowo_exclude_{world}/eval_held_out.json"
    print(f"\n1. Evaluating on HELD-OUT world ({world})...")
    !python -m emulator.eval.evaluate \
        --checkpoint {checkpoint_path} \
        --dataset {DATASET_PATH} \
        --worlds {world} \
        --output {held_out_output}
    
    # Evaluate on in-domain worlds (trained worlds)
    in_domain_worlds = [w for w in LOWO_WORLDS if w != world]
    in_domain_str = ",".join(in_domain_worlds)
    in_domain_output = f"runs/lowo_exclude_{world}/eval_in_domain.json"
    print(f"\n2. Evaluating on IN-DOMAIN worlds ({in_domain_str})...")
    !python -m emulator.eval.evaluate \
        --checkpoint {checkpoint_path} \
        --dataset {DATASET_PATH} \
        --worlds {in_domain_str} \
        --output {in_domain_output}
    
    # Load and combine results
    results[world] = {}
    if os.path.exists(held_out_output):
        with open(held_out_output) as f:
            results[world]['held_out'] = json.load(f)
    if os.path.exists(in_domain_output):
        with open(in_domain_output) as f:
            results[world]['in_domain'] = json.load(f)
    
    print(f"Results saved for lowo_exclude_{world}")

In [None]:
# Generate LOWO comparison table
import pandas as pd

if results:
    rows = []
    for world, res in results.items():
        held_out_nrmse = "N/A"
        in_domain_nrmse = "N/A"
        transfer_gap = "N/A"
        
        # Extract NRMSE from results
        if 'held_out' in res and 'overall' in res['held_out']:
            held_out_nrmse = res['held_out']['overall'].get('nrmse', 'N/A')
        if 'in_domain' in res and 'overall' in res['in_domain']:
            in_domain_nrmse = res['in_domain']['overall'].get('nrmse', 'N/A')
        
        # Compute transfer gap (ratio of held-out to in-domain error)
        if isinstance(held_out_nrmse, (int, float)) and isinstance(in_domain_nrmse, (int, float)):
            if in_domain_nrmse > 0:
                transfer_gap = f"{held_out_nrmse / in_domain_nrmse:.2f}x"
        
        rows.append({
            'Held-Out World': world.upper(),
            'NRMSE (held-out)': f"{held_out_nrmse:.4f}" if isinstance(held_out_nrmse, float) else held_out_nrmse,
            'NRMSE (in-domain)': f"{in_domain_nrmse:.4f}" if isinstance(in_domain_nrmse, float) else in_domain_nrmse,
            'Transfer Gap': transfer_gap,
        })
    
    df = pd.DataFrame(rows)
    print("\n" + "="*60)
    print("LOWO Transfer Results Summary")
    print("="*60)
    print("\nTransfer Gap = (held-out NRMSE) / (in-domain NRMSE)")
    print("  - Gap ~ 1.0: Good generalization")
    print("  - Gap > 1.5: Poor generalization (model overfits to training worlds)")
    print()
    print(df.to_string(index=False))
    
    # Save to CSV
    csv_path = 'runs/lowo_transfer_results.csv'
    df.to_csv(csv_path, index=False)
    print(f"\nResults saved to {csv_path}")
    
    # Also save full results as JSON
    json_path = 'runs/lowo_all_results.json'
    with open(json_path, 'w') as f:
        json.dump(results, f, indent=2, default=str)
    print(f"Full results saved to {json_path}")
else:
    print("No results to display yet. Run the evaluation cell first.")

## 5. Download Results

All checkpoints and results are automatically saved to Google Drive.

Location: `/content/drive/MyDrive/macro_emulator/runs/`

In [None]:
# List all saved checkpoints
!ls -la {DRIVE_RUNS_PATH}/lowo_*/

print("\n" + "="*60)
print("All results saved to Google Drive!")
print(f"Path: {DRIVE_RUNS_PATH}")
print("="*60)

In [None]:
# Optional: Create a zip of all LOWO results for download
!cd {DRIVE_RUNS_PATH} && zip -r lowo_results.zip lowo_exclude_*

print(f"\nZip file created: {DRIVE_RUNS_PATH}/lowo_results.zip")
print("You can download this file from Google Drive.")