# HouseGym RL Training with New Architecture

This notebook trains RL agents on **synthetic data** and evaluates on **real data**.

Key features:
- Pure random candidate selection (M=512)
- Batch arrival system
- Capacity ramp system
- Focus on robustness, not optimization

## Cell 1: Imports and Configuration

In [None]:
from __future__ import annotations

import os
import time
import multiprocessing as mp
from pathlib import Path
from functools import partial
from importlib import reload
from typing import Dict, Optional

import numpy as np
import pandas as pd
import torch

# Reload project modules to get latest changes
import evaluate, baseline, housegymrl, config
reload(evaluate); reload(baseline); reload(housegymrl); reload(config)

# Import new environment classes
from housegymrl import RLEnv, BaselineEnv
from baseline import create_baseline_env
from evaluate import make_synth_env
from config import M_CANDIDATES, MAX_STEPS

# SB3 imports
from stable_baselines3 import SAC
from stable_baselines3.common.monitor import Monitor
from stable_baselines3.common.vec_env import SubprocVecEnv, DummyVecEnv, VecMonitor, VecNormalize
from stable_baselines3.common.callbacks import EvalCallback, CheckpointCallback, BaseCallback
from stable_baselines3.common.utils import set_random_seed

print("="*60)
print("IMPORTS COMPLETE")
print("="*60)
print(f"M_CANDIDATES = {M_CANDIDATES}")
print(f"MAX_STEPS = {MAX_STEPS}")
print("Environment architecture: New (pure random candidate selection)")
print("Training data: Synthetic")
print("Testing data: Real regions")

## Cell 2: Training Parameters

In [None]:
# =========================== Training Parameters ===========================
# For quick testing, use small values. For real training, increase these.

# Quick test mode (set to False for real training)
QUICK_TEST = True

if QUICK_TEST:
    print("üîß QUICK TEST MODE - Using minimal settings for functionality verification")
    SEEDS = [42]           # Single seed
    N_ENVS = 2            # Just 2 parallel environments
    TOTAL_STEPS = 1000    # Just 1000 steps (1-2 episodes)
    EVAL_FREQ = 500       # Eval every 500 steps
    CKPT_FREQ = 500       # Save every 500 steps
else:
    print("üöÄ FULL TRAINING MODE")
    SEEDS = [42]          # Can add more seeds for robustness
    N_ENVS = 10           # 10 parallel environments
    TOTAL_STEPS = 300_000 # 300k steps
    EVAL_FREQ = 10_000    # Eval every 10k steps
    CKPT_FREQ = 50_000    # Save every 50k steps

# Synthetic environment parameters
H_RANGE = (10_000, 100_000)  # House count range
WORKER_RATIO = (0.10, 0.25)   # Contractor/house ratio range

# Device selection
DEVICE = "mps" if torch.backends.mps.is_available() else ("cuda" if torch.cuda.is_available() else "cpu")

# Logging
ROOT_RUNS = Path("runs")
ROOT_RUNS.mkdir(exist_ok=True, parents=True)

print(f"\nSettings:")
print(f"  Seeds: {SEEDS}")
print(f"  Parallel envs: {N_ENVS}")
print(f"  Total steps: {TOTAL_STEPS:,}")
print(f"  Device: {DEVICE}")
print(f"  House range: {H_RANGE}")
print(f"  Worker ratio: {WORKER_RATIO}")

## Cell 3: Callback Definitions

In [None]:
# ======================== Custom Callbacks ========================

class CompletionTBCallback(BaseCallback):
    """Log completion metrics to TensorBoard."""
    def __init__(self, tb_every: int = 200, verbose: int = 0):
        super().__init__(verbose)
        self.tb_every = int(tb_every)

    def _on_step(self) -> bool:
        if self.tb_every > 0 and (self.num_timesteps % self.tb_every != 0):
            return True
        infos = self.locals.get("infos", None)
        if not infos:
            return True
        vals = [i.get("completion") for i in infos if isinstance(i, dict) and ("completion" in i)]
        if not vals:
            return True
        v = np.asarray(vals, dtype=float)
        self.logger.record("env/completion_household_mean", float(np.nanmean(v)))
        self.logger.record("env/completion_household_min",  float(np.nanmin(v)))
        self.logger.record("env/completion_household_max",  float(np.nanmax(v)))
        return True

class SyncVecNormCallback(BaseCallback):
    """Sync VecNormalize statistics between train and eval environments."""
    def __init__(self, src_vecnorm: VecNormalize, tgt_vecnorm: VecNormalize):
        super().__init__(verbose=0)
        self.src = src_vecnorm
        self.tgt = tgt_vecnorm

    def _on_step(self) -> bool:
        if hasattr(self.src, "obs_rms") and hasattr(self.tgt, "obs_rms"):
            self.tgt.obs_rms = self.src.obs_rms
        if hasattr(self.src, "ret_rms") and hasattr(self.tgt, "ret_rms"):
            self.tgt.ret_rms = self.src.ret_rms
        return True

print("‚úì Callbacks defined")
print("  - CompletionTBCallback: Tracks completion metrics")
print("  - SyncVecNormCallback: Syncs normalization between train/eval")

## Cell 4: Diagnostic Test - Candidate Selection & Queue Ordering

This cell creates a test environment to verify:
1. Pure random candidate selection
2. Different ordering strategies (LJF, SJF, Random)

In [None]:
print("="*60)
print("DIAGNOSTIC: Testing Candidate Selection & Queue Ordering")
print("="*60)

# Create a small test environment
test_env = make_synth_env(
    H_min=1000, 
    H_max=2000, 
    worker_ratio=(0.15, 0.20),
    seed=42,
    verbose=True,
    use_batch_arrival=True,
    use_capacity_ramp=True
)

# Reset and get initial state
obs, info = test_env.reset()

print(f"\nüìä Initial State:")
print(f"  Day: {test_env.day}")
print(f"  Queue size: {test_env.waiting_queue.size()}")
print(f"  M candidates: {test_env.M}")
print(f"  Current capacity: {test_env._effective_capacity()}")

# Advance to day 50 to have some capacity
print("\n‚è© Advancing to day 50...")
for _ in range(50):
    action = test_env.action_space.sample()
    obs, r, done, trunc, info = test_env.step(action)
    if done or trunc:
        break

print(f"\nüìä Day 50 State:")
print(f"  Day: {test_env.day}")
print(f"  Queue size: {test_env.waiting_queue.size()}")
print(f"  Current capacity: {test_env._effective_capacity()}")
print(f"  Completion: {info['completion']:.2%}")

# Get candidates
queue_ids = test_env.waiting_queue.get_all()
if len(queue_ids) > 0:
    candidates = test_env._select_candidates(queue_ids)
    
    print(f"\nüé≤ Candidate Selection (Pure Random):")
    print(f"  Queue has {len(queue_ids)} houses")
    print(f"  Selected {len(candidates)} candidates")
    
    # Show first 10 candidates
    print("\n  First 10 candidates:")
    for i in range(min(10, len(candidates))):
        cid = candidates[i]
        total = test_env._arr_total[cid]
        remain = test_env._arr_rem[cid]
        damage = test_env._arr_dmg[cid]
        print(f"    {i:2d}: ID={cid:4d}, Total={total:3.0f}, Remain={remain:3.0f}, Damage={damage}")
    
    # Test different orderings
    print("\n" + "="*60)
    print("Testing Different Policy Orderings")
    print("="*60)
    
    # LJF ordering
    ljf_sorted = sorted(candidates, key=lambda h: -test_env._arr_total[h])
    print("\nüìâ LJF Ordering (should be DESCENDING by total work):")
    for i in range(min(5, len(ljf_sorted))):
        cid = ljf_sorted[i]
        print(f"    {i}: ID={cid:4d}, Total={test_env._arr_total[cid]:3.0f} days")
    
    # SJF ordering
    sjf_sorted = sorted(candidates, key=lambda h: test_env._arr_total[h])
    print("\nüìà SJF Ordering (should be ASCENDING by total work):")
    for i in range(min(5, len(sjf_sorted))):
        cid = sjf_sorted[i]
        print(f"    {i}: ID={cid:4d}, Total={test_env._arr_total[cid]:3.0f} days")
    
    # Random ordering (shuffle)
    import random
    random_sorted = candidates.copy()
    random.shuffle(random_sorted)
    print("\nüîÄ Random Ordering (should be RANDOM):")
    for i in range(min(5, len(random_sorted))):
        cid = random_sorted[i]
        print(f"    {i}: ID={cid:4d}, Total={test_env._arr_total[cid]:3.0f} days")
    
    # RL would score these and sort by score
    print("\nü§ñ RL Ordering (would sort by learned scores):")
    print("    RL assigns scores to each candidate and sorts by score")
    print("    The scoring function is learned to maximize long-term reward")
else:
    print("\n‚ö†Ô∏è Queue is empty, cannot test ordering")

print("\n" + "="*60)
print("‚úÖ Diagnostic test complete")
print("="*60)

## Cell 5: Environment Factory Functions

In [None]:
# ======================== Environment Factories ========================

# Set multiprocessing start method (for macOS compatibility)
try:
    mp.set_start_method("spawn", force=True)
except RuntimeError:
    pass

def _make_env_worker(H_min: int, H_max: int, worker_ratio, seed: int, rank: int):
    """Create a single synthetic environment for training."""
    def _init():
        env = make_synth_env(
            H_min=H_min, 
            H_max=H_max,
            worker_ratio=worker_ratio,
            seed=seed + rank,
            verbose=False,
            use_batch_arrival=True,  # Use new features
            use_capacity_ramp=True,
        )
        return Monitor(env)
    return _init

def make_train_vec_env(seed: int, n_envs: int) -> VecNormalize:
    """Create vectorized training environments with normalization."""
    env_fns = [
        _make_env_worker(
            H_min=H_RANGE[0], 
            H_max=H_RANGE[1],
            worker_ratio=WORKER_RATIO,
            seed=seed,
            rank=i
        )
        for i in range(n_envs)
    ]
    
    # Try SubprocVecEnv first, fallback to DummyVecEnv
    try:
        # Smoke test
        _tmp = env_fns[0]()
        _tmp.close()
        vec = SubprocVecEnv(env_fns, start_method="spawn")
        print(f"  ‚úì Using SubprocVecEnv with {n_envs} workers")
    except Exception as e:
        print(f"  ‚ö†Ô∏è SubprocVecEnv failed, using DummyVecEnv: {e}")
        vec = DummyVecEnv(env_fns)
    
    vec = VecMonitor(vec)
    vec = VecNormalize(vec, norm_obs=True, norm_reward=True, clip_obs=10.0, clip_reward=10.0)
    return vec

def make_eval_vec_env(seed: int) -> VecNormalize:
    """Create evaluation environment."""
    eval_env_fns = [
        _make_env_worker(
            H_min=H_RANGE[0], 
            H_max=H_RANGE[1],
            worker_ratio=WORKER_RATIO,
            seed=seed + 9999,
            rank=0
        )
    ]
    eval_vec = DummyVecEnv(eval_env_fns)
    eval_vec = VecMonitor(eval_vec)
    eval_vec = VecNormalize(
        eval_vec, 
        norm_obs=True, 
        norm_reward=True, 
        clip_obs=10.0, 
        clip_reward=10.0, 
        training=False
    )
    return eval_vec

print("‚úì Environment factory functions defined")
print("  - Training: Synthetic environments with batch arrival & capacity ramp")
print("  - Evaluation: Same setup for consistent comparison")

## Cell 6: Training Setup & Execution

In [None]:
# ======================== Main Training Loop ========================

def train_agent(seed: int, quick_test: bool = False):
    """Train a single SAC agent."""
    print(f"\n{'='*60}")
    print(f"Training with seed={seed}")
    print(f"{'='*60}")
    
    set_random_seed(seed)
    ts_tag = time.strftime("%Y-%m-%d_%H-%M-%S")
    run_name = f"sac_synth_seed{seed}_{ts_tag}"
    if quick_test:
        run_name = f"TEST_{run_name}"
    
    RUN_DIR = ROOT_RUNS / run_name
    TB_LOG  = RUN_DIR / "tensorboard_logs" / "SAC_1"
    RUN_DIR.mkdir(parents=True, exist_ok=True)
    
    print(f"üìÅ Run directory: {RUN_DIR}")
    
    # Create environments
    print(f"\nüèóÔ∏è Creating environments...")
    train_vec = make_train_vec_env(seed, N_ENVS)
    eval_vec  = make_eval_vec_env(seed)
    
    # Learning rate schedule
    TOTAL = int(TOTAL_STEPS)
    if quick_test:
        # Simple constant learning rate for quick test
        def lr_schedule(progress_remaining: float) -> float:
            return 3e-4
    else:
        # Staged learning rate for full training
        B1, B2 = int(0.60 * TOTAL), int(0.85 * TOTAL)
        LR1, LR2, LR3 = 3e-4, 1e-4, 5e-5
        def lr_schedule(progress_remaining: float) -> float:
            step_done = int((1.0 - progress_remaining) * TOTAL)
            if step_done < B1:   return LR1
            elif step_done < B2: return LR2
            else:                return LR3
    
    # Policy network architecture
    # Increased size for M=512 (2054-dim observations)
    policy_kwargs = dict(
        net_arch=dict(
            pi=[512, 512],  # Actor network
            qf=[512, 512]   # Critic network
        )
    )
    
    batch_size = 256 if quick_test else (1024 if DEVICE == "mps" else 512)
    
    print(f"\nü§ñ Creating SAC model...")
    print(f"  Observation space: {train_vec.observation_space.shape}")
    print(f"  Action space: {train_vec.action_space.shape}")
    print(f"  Policy network: {policy_kwargs}")
    print(f"  Batch size: {batch_size}")
    print(f"  Device: {DEVICE}")
    
    model = SAC(
        "MlpPolicy",
        train_vec,
        verbose=1,
        device=DEVICE,
        policy_kwargs=policy_kwargs,
        learning_rate=lr_schedule,
        buffer_size=min(10_000, TOTAL) if quick_test else max(300_000, TOTAL),
        batch_size=batch_size,
        gamma=0.95,
        tau=0.01,
        train_freq=(1, "step"),
        gradient_steps=1,
        ent_coef="auto",
        tensorboard_log=str(TB_LOG),
        seed=seed,
    )
    
    # Setup callbacks
    sync_cb = SyncVecNormCallback(src_vecnorm=train_vec, tgt_vecnorm=eval_vec)
    eval_cb = EvalCallback(
        eval_vec,
        best_model_save_path=str(RUN_DIR / "best"),
        log_path=str(RUN_DIR / "eval"),
        eval_freq=max(1, EVAL_FREQ // max(1, N_ENVS)),
        deterministic=True,
        render=False,
    )
    ckpt_cb = CheckpointCallback(
        save_freq=max(1, CKPT_FREQ // max(1, N_ENVS)),
        save_path=str(RUN_DIR / "ckpt"),
        name_prefix="sac"
    )
    completion_cb = CompletionTBCallback(tb_every=200)
    
    # Start training
    print(f"\nüöÄ Starting training for {TOTAL_STEPS:,} steps...")
    if quick_test:
        print("  (Quick test mode - just verifying functionality)")
    
    model.learn(
        total_timesteps=TOTAL_STEPS,
        callback=[sync_cb, eval_cb, ckpt_cb, completion_cb],
        progress_bar=True
    )
    
    # Save model and stats
    model_path = RUN_DIR / "sac_model.zip"
    model.save(str(model_path))
    
    vecnorm_path = RUN_DIR / "vecnormalize.pkl"
    train_vec.save(str(vecnorm_path))
    
    print(f"\n‚úÖ Training complete!")
    print(f"  Model saved: {model_path}")
    print(f"  VecNormalize stats: {vecnorm_path}")
    print(f"  Best model: {RUN_DIR / 'best'}")
    
    # Cleanup
    try:
        train_vec.close()
        eval_vec.close()
    except:
        pass
    
    return model_path

# Run training
print("Ready to train!")
print(f"\nSettings:")
print(f"  Quick test mode: {QUICK_TEST}")
print(f"  Seeds: {SEEDS}")
print(f"  Total steps per seed: {TOTAL_STEPS:,}")
print(f"  Parallel environments: {N_ENVS}")
print("\nCall train_agent(seed) to start training.")

if QUICK_TEST:
    print("\nüí° TIP: Since QUICK_TEST=True, training will be very fast (1-2 episodes)")
    print("     This is just to verify the system works correctly.")

## Cell 7: Run Training

In [None]:
# ======================== Execute Training ========================

saved_models = []

for seed in SEEDS:
    model_path = train_agent(seed, quick_test=QUICK_TEST)
    saved_models.append(model_path)

print("\n" + "="*60)
print("ALL TRAINING COMPLETE")
print("="*60)
print(f"\nSaved {len(saved_models)} model(s):")
for p in saved_models:
    print(f"  - {p}")

if QUICK_TEST:
    print("\n‚ö†Ô∏è Note: This was a quick test run.")
    print("   Set QUICK_TEST=False in Cell 2 for full training.")
else:
    print("\nüí° Next steps:")
    print("   1. Check tensorboard logs: tensorboard --logdir runs/")
    print("   2. Evaluate on real data: python evaluate.py")
    print("   3. Load model for inference:")
    print("      model = SAC.load('runs/.../sac_model.zip')")
    print("      vecnorm = VecNormalize.load('runs/.../vecnormalize.pkl')")

## Cell 8: Quick Evaluation Test

Test the trained model on a synthetic environment to verify it learned something.

In [None]:
# ======================== Quick Evaluation ========================

if saved_models:
    print("="*60)
    print("QUICK EVALUATION TEST")
    print("="*60)
    
    # Load the last trained model
    model_path = saved_models[-1]
    print(f"\nüìÇ Loading model: {model_path}")
    
    try:
        model = SAC.load(str(model_path))
        print("‚úì Model loaded successfully")
        
        # Create test environment
        test_env = make_synth_env(
            H_min=2000,
            H_max=3000,
            worker_ratio=0.15,
            seed=999,
            verbose=True
        )
        
        # Compare RL vs baselines
        from evaluate import rollout
        
        print("\nüèÉ Running rollouts (100 days each)...")
        
        policies = ["SAC", "LJF", "SJF", "Random"]
        results = {}
        
        for policy in policies:
            if policy == "SAC":
                traj = rollout(test_env, model=model, max_days=100)
            else:
                baseline_env = create_baseline_env(
                    region_key=test_env.region_key,
                    policy=policy,
                    num_contractors=test_env.num_contractors,
                    seed=999
                )
                traj = rollout(baseline_env, max_days=100)
            
            results[policy] = traj
            final = traj[-1] if len(traj) > 0 else 0.0
            print(f"  {policy:8s}: Day 100 completion = {final:.2%}")
        
        # Plot comparison
        import matplotlib.pyplot as plt
        
        plt.figure(figsize=(10, 6))
        for policy, traj in results.items():
            plt.plot(traj, label=policy, alpha=0.8)
        
        plt.xlabel("Day")
        plt.ylabel("Completion")
        plt.title("Policy Comparison on Synthetic Test")
        plt.legend()
        plt.grid(True, alpha=0.3)
        plt.show()
        
        print("\n‚úÖ Evaluation complete!")
        
        if QUICK_TEST:
            print("\n‚ö†Ô∏è Note: Model was only trained for 1000 steps.")
            print("   Don't expect good performance yet!")
            print("   This just verifies the system works.")
        
    except Exception as e:
        print(f"‚ùå Error during evaluation: {e}")
else:
    print("No models to evaluate. Run training first!")

## Summary

This notebook demonstrates the new HouseGym RL architecture with:

‚úÖ **Pure random candidate selection** (M=512)  
‚úÖ **Batch arrival system** (houses revealed over time)  
‚úÖ **Capacity ramp system** (contractors mobilize gradually)  
‚úÖ **Training on synthetic data**  
‚úÖ **Testing on real data** (via evaluate.py)  

The diagnostic tests show that:
- Candidates are selected randomly (no bias)
- LJF sorts by descending total work
- SJF sorts by ascending total work  
- RL learns to score and rank candidates

The focus is on **robustness**, not optimization!