In [None]:
# ====================== MAIN: Train + Evaluate(7 regions) + Report ======================
import os
import time
import json
import random
import multiprocessing as mp
from pathlib import Path
from importlib import reload

import numpy as np
import pandas as pd
import torch
import matplotlib.pyplot as plt

# ---- Project modules
import config
import evaluate
import housegymrl
from evaluate import create_unified_ramp, create_tasks_from_real_config
from housegymrl import RLEnv, BaselineEnv

# ---- SB3
from stable_baselines3 import SAC
from stable_baselines3.common.monitor import Monitor
from stable_baselines3.common.vec_env import SubprocVecEnv, DummyVecEnv, VecMonitor, VecNormalize
from stable_baselines3.common.callbacks import EvalCallback, CheckpointCallback, BaseCallback
from stable_baselines3.common.utils import set_random_seed
from torch.utils.tensorboard import SummaryWriter

# 热更新非config文件
reload(evaluate)
reload(housegymrl)

# ----------------- 全局参数 -----------------
SEEDS = [42]
N_ENVS = 1 # was 8
TOTAL_STEPS = 50_000 # was 300_000
EVAL_FREQ = 10_000
CKPT_FREQ = 50_000
TRAIN_REGION_KEYS = list(config.REGION_CONFIG.keys())
DEVICE = "mps" if torch.backends.mps.is_available() else ("cuda" if torch.cuda.is_available() else "cpu")
print("Device:", DEVICE)

ROOT_RUNS = Path("runs")
ROOT_RUNS.mkdir(exist_ok=True, parents=True)
try:
    mp.set_start_method("spawn", force=True)
except RuntimeError:
    pass



# ----------------- TensorBoard记录器 -----------------
class CompletionTBCallback(BaseCallback):
    def __init__(self, tb_every: int = 200, verbose: int = 0):
        super().__init__(verbose)
        self.tb_every = int(tb_every)

    def _on_step(self) -> bool:
        infos = self.locals.get("infos", None) or []
        
        # 记录makespan（回合级）
        done_days = [i.get("day") for i in infos if isinstance(i, dict) and i.get("done") is True]
        if done_days:
            self.logger.record("env/makespan_episode", float(np.min(done_days)))
            self.logger.record("env/episodes_finished", float(len(done_days)))

        if self.tb_every > 0 and (self.num_timesteps % self.tb_every != 0):
            return True
        if not infos:
            return True

        vals = [i.get("completion") for i in infos if isinstance(i, dict) and ("completion" in i)]
        if not vals:
            return True

        idle_vals = [i.get("idle_workers", 0) for i in infos if isinstance(i, dict)]
        K_vals = [i.get("K_effective", 0) for i in infos if isinstance(i, dict)]
        alloc_vals = [i.get("allocated_workers", 0) for i in infos if isinstance(i, dict)]
        util_vals = [a / max(1, k) for a, k in zip(alloc_vals, K_vals)] if K_vals else []

        v = np.asarray(vals, dtype=float)
        self.logger.record("env/completion_household_mean", float(np.nanmean(v)))
        self.logger.record("env/completion_household_min", float(np.nanmin(v)))
        self.logger.record("env/completion_household_max", float(np.nanmax(v)))
        self.logger.record("env/completion_household_p50", float(np.nanpercentile(v, 50)))
        self.logger.record("env/completion_household_p90", float(np.nanpercentile(v, 90)))

        if idle_vals:
            self.logger.record("env/idle_workers_mean", float(np.mean(idle_vals)))
        if util_vals:
            self.logger.record("env/utilization_mean", float(np.mean(util_vals)))
        if alloc_vals:
            alloc_arr = np.asarray(alloc_vals, dtype=float)
            self.logger.record("env/allocated_workers_p50", float(np.percentile(alloc_arr, 50)))
            self.logger.record("env/allocated_workers_p90", float(np.percentile(alloc_arr, 90)))
        if K_vals:
            K_arr = np.asarray(K_vals, dtype=float)
            self.logger.record("env/K_effective_p50", float(np.percentile(K_arr, 50)))
            self.logger.record("env/K_effective_p90", float(np.percentile(K_arr, 90)))

        if hasattr(self.model, "actor") and getattr(self.model.actor, "optimizer", None) is not None:
            lr = self.model.actor.optimizer.param_groups[0]["lr"]
            self.logger.record("train/lr", float(lr))
        if hasattr(self.model, "log_ent_coef"):
            ent_coef = float(self.model.log_ent_coef.exp().item())
            self.logger.record("train/ent_coef", ent_coef)
        return True

# ----------------- Ramp -----------------
ramp_fn = create_unified_ramp()

# ----------------- 训练用sampler -----------------
USE_SYNTHETIC_TRAIN = getattr(config, "USE_SYNTHETIC_TRAIN", False)

def make_sampler(seed: int):
    rng = np.random.default_rng(seed)
    def sampler():
        region_key = rng.choice(TRAIN_REGION_KEYS)
        cfg = config.REGION_CONFIG[region_key]
        if USE_SYNTHETIC_TRAIN and hasattr(evaluate, "create_tasks_from_synthetic"):
            tasks_df = evaluate.create_tasks_from_synthetic(cfg, rng)
        else:
            tasks_df = create_tasks_from_real_config(cfg, rng)
        resources = {"workers": int(cfg["num_contractors"]), "region_name": region_key}
        return tasks_df, resources, {"region": region_key}
    return sampler

def make_train_env(seed: int):
    sampler = make_sampler(seed)
    def _init():
        env = RLEnv(
            scenario_sampler=sampler,
            
            max_steps=config.MAX_STEPS,
            seed=seed,
            k_ramp=ramp_fn,
            batch_arrival=True,
        )
        return Monitor(env)
    return _init

def make_eval_env(region_key: str, seed: int, env_cls):
    def _init():
        env = evaluate.make_region_env(region_key, env_cls, k_ramp=ramp_fn, batch_arrival=True)
        return Monitor(env)
    return _init

# =================== Train + Eval ===================
saved_models = []

for SEED in SEEDS:
    print(f"========== Training seed={SEED} ==========")
    set_random_seed(SEED)

    ts_tag = time.strftime("%Y-%m-%d_%H-%M-%S")
    run_dir = ROOT_RUNS / f"sac_seed{SEED}_{ts_tag}"
    tb_dir = run_dir / "tb"
    run_dir.mkdir(parents=True, exist_ok=True)
    
    # 关键：为这次运行设置专属的输出目录，覆盖config中的默认路径
    eval_output_dir = Path("output") / f"exp_{ts_tag}"
    (eval_output_dir / "fig").mkdir(parents=True, exist_ok=True)
    (eval_output_dir / "tab").mkdir(parents=True, exist_ok=True)
    config.OUTPUT_DIR = eval_output_dir
    config.FIG_DIR = eval_output_dir / "fig"
    config.TAB_DIR = eval_output_dir / "tab"


    # ---- Build train vec env ----
    env_fns = [make_train_env(SEED + i) for i in range(N_ENVS)]
    _tmp = env_fns[0]()
    _tmp.close()

    try:
        vec_env = SubprocVecEnv(env_fns, start_method="spawn")
    except Exception as exc:
        print("[WARN] SubprocVecEnv spawn failed, fallback to DummyVecEnv:", repr(exc))
        vec_env = DummyVecEnv(env_fns)
    
    vec_env = VecMonitor(vec_env)
    vec_env = VecNormalize(vec_env, norm_obs=True, norm_reward=True, clip_obs=10.0)

    # ---- Eval env for callback (multi region, RL) ----
    def make_eval_envs(seed: int):
        envs = []
        for i, region_key in enumerate(TRAIN_REGION_KEYS):
            envs.append(make_eval_env(region_key, seed + 9000 + i, env_cls=RLEnv))
        vec = DummyVecEnv(envs)
        vec = VecMonitor(vec)
        vec = VecNormalize(vec, norm_obs=True, norm_reward=False, training=False, clip_obs=10.0)
        vec.training = False
        vec.norm_reward = False
        return vec

    eval_env = make_eval_envs(SEED)
    # 对齐统计
    eval_env.obs_rms = vec_env.obs_rms
    eval_env.clip_obs = vec_env.clip_obs
    eval_env.training = False
    eval_env.norm_reward = False
    eval_env.ret_rms = vec_env.ret_rms

    # ---- manifest ----
    manifest = {
        "global_seed": SEED,
        "train_env_seeds": [SEED + i for i in range(N_ENVS)],
        "eval_seeds": {region: SEED + 9000 + idx for idx, region in enumerate(TRAIN_REGION_KEYS)},
        "train_regions": TRAIN_REGION_KEYS,
        "timestamp": ts_tag,
        "output_dir": str(config.OUTPUT_DIR),
        "tensorboard_dir": str(tb_dir),
    }
    (run_dir / "seed_manifest.json").write_text(json.dumps(manifest, indent=2))

    # ---- Policy/LR ----
    total = int(TOTAL_STEPS)
    b1, b2 = int(0.60 * total), int(0.85 * total)
    lr1, lr2, lr3 = 3e-4, 1e-4, 5e-5
    
    def lr_schedule(progress_remaining: float) -> float:
        step_done = int((1.0 - progress_remaining) * total)
        if step_done < b1:
            return lr1
        if step_done < b2:
            return lr2
        return lr3

    # 全观察架构：增大网络容量以处理 400k 维观察
    policy_kwargs = dict(
        net_arch=dict(
            pi=[1024, 512],  # 策略网络：400k → 1024 → 512 → 80k
            qf=[1024, 512],  # Q函数网络：400k → 1024 → 512 → 1
        )
    )
    # 资源受限时可降低：pi/qf=[768] 或 [512]
    batch_size = 1024 if DEVICE == "mps" else 512
    
    model = SAC(
        "MlpPolicy",
        vec_env,
        verbose=1,
        device="cpu", # was DEVICE
        policy_kwargs=dict(net_arch=dict(pi=[512], qf=[512])),  # 最小网络
        # policy_kwargs=policy_kwargs,
        learning_rate=lr_schedule,
        buffer_size= 5000, # was max(300_000, TOTAL_STEPS),
        batch_size=64, # was batch_size
        gamma=0.95,
        tau=0.01,
        train_freq=(1, "step"),
        gradient_steps=1,
        ent_coef="auto",
        tensorboard_log=str(tb_dir),
        seed=SEED,
    )

    eval_cb = EvalCallback(
        eval_env,
        best_model_save_path=str(run_dir / "best"),
        log_path=str(run_dir / "eval"),
        eval_freq=max(1, EVAL_FREQ // max(1, N_ENVS)),
        deterministic=True,
        render=False,
    )
    
    ckpt_cb = CheckpointCallback(
        save_freq=max(1, CKPT_FREQ // max(1, N_ENVS)),
        save_path=str(run_dir / "ckpt"),
        name_prefix="sac",
    )
    
    completion_cb = CompletionTBCallback(tb_every=200)

    print(f"[Train] Starting training for {TOTAL_STEPS} steps...")
    model.learn(
        total_timesteps=TOTAL_STEPS,
        callback=[eval_cb, ckpt_cb, completion_cb],
        progress_bar=True
    )

    # ---- Save artifacts ----
    model_path = run_dir / "sac_model.zip"
    vecnorm_path = run_dir / "vecnormalize.pkl"
    model.save(str(model_path))
    vec_env.save(str(vecnorm_path))
    saved_models.append(model_path)
    
    print("Saved:", model_path)
    print("VecNormalize stats:", vecnorm_path)
    print("Best checkpoint dir:", run_dir / "best")

    # =====================================================================
    # 训练后评估：使用evaluate.py中更新后的evaluate_region函数
    # 这个函数自动处理baseline（简化路径）和RL（环境step路径）
    # =====================================================================
    
    print("\n" + "="*70)
    print("开始训练后评估（使用新的baseline评估路径）")
    print("="*70 + "\n")
    
    # 强校验动作维度
    model_M = model.policy.action_space.shape[0]
    assert model_M == config.MAX_HOUSES, \
        f"Model action dim {model_M} != MAX_HOUSES {config.MAX_HOUSES}"

    # 确保输出目录存在
    OUT_DIR = config.OUTPUT_DIR
    (OUT_DIR / "tab").mkdir(parents=True, exist_ok=True)
    (OUT_DIR / "fig").mkdir(parents=True, exist_ok=True)
    
    # 创建TensorBoard writer
    eval_tb_dir = OUT_DIR / "tb"
    eval_tb_dir.mkdir(parents=True, exist_ok=True)
    writer = SummaryWriter(log_dir=str(eval_tb_dir))
    
    # 尝试加载观测数据
    try:
        observed = evaluate.load_observed(config.OBSERVED_DATA_PATH)
        print(f"[Eval] 成功加载观测数据，包含 {len(observed)} 个地区")
    except Exception as e:
        print(f"[Eval] 无法加载观测数据: {e}")
        print("[Eval] 将继续评估但不计算RMSE指标")
        observed = {}
    
    # 评估所有region
    all_metrics = []
    
    for region_key in TRAIN_REGION_KEYS:
        print(f"\n[Eval] Region = {region_key}")
        
        obs_series = observed.get(region_key, pd.Series(dtype=float))
        
        if obs_series.size > 0:
            print(f"  观测数据: {len(obs_series)} 天")
        else:
            print(f"  无观测数据")
        
        # 调用evaluate.py中更新后的evaluate_region函数
        # 关键：传入训练时的vec_env作为vecnorm_src参数
        region_metrics = evaluate.evaluate_region(
            region=region_key,
            obs_series=obs_series,
            ramp_fn=ramp_fn,
            model=model,
            writer=writer,
            vecnorm_src=vec_env,  # 新增：传入训练时的VecNormalize环境
        )
        all_metrics.extend(region_metrics)
        
        # 打印这个region的结果预览
        for metric in region_metrics:
            strategy = metric['strategy']
            makespan = metric['makespan']
            util = metric['utilization']
            final = metric['final_completion']
            print(f"    {strategy:>8s}: makespan={makespan:>6.1f}, "
                  f"util={util:.3f}, final={final:.3f}")
    
    # 将结果转换为DataFrame并保存
    df = pd.DataFrame(all_metrics)
    df.sort_values(["region", "strategy"], inplace=True)
    
    metrics_path = config.TAB_DIR / "metrics_eval.csv"
    df.to_csv(metrics_path, index=False)
    print(f"\n[Eval] 评估指标已保存到: {metrics_path}")
    
    # 生成汇总报告
    summary_lines = [
        "="*70,
        "训练后评估汇总",
        "="*70,
        "",
    ]
    
    for region in df["region"].unique():
        summary_lines.append(f"\n地区: {region}")
        summary_lines.append("-" * 50)
        sub = df[df["region"] == region]
        
        for _, row in sub.iterrows():
            line = (
                f"  {row['strategy']:>8s} | "
                f"makespan={row['makespan']:>6.1f} | "
                f"util={row['utilization']:.3f} | "
                f"final={row['final_completion']:.3f}"
            )
            
            if 'rmse_aligned' in row and pd.notna(row['rmse_aligned']):
                line += f" | rmse_aligned={row['rmse_aligned']:.4f}"
            
            summary_lines.append(line)
    
    summary_lines.append("\n" + "="*70)
    
    summary_path = config.TAB_DIR / "evaluation_summary.txt"
    summary_path.write_text("\n".join(summary_lines))
    print(f"[Eval] 汇总报告已保存到: {summary_path}")
    
    # 生成策略对比可视化
    print("\n[Eval] 生成策略对比可视化...")
    
    # 按策略计算平均makespan
    strategy_means = df.groupby("strategy")["makespan"].mean().sort_values()
    
    plt.figure(figsize=(10, 6))
    strategy_means.plot.bar()
    plt.ylabel("Average Makespan (days)")
    plt.title("平均Makespan对比（所有地区）")
    plt.xlabel("策略")
    plt.grid(axis="y", alpha=0.3)
    plt.tight_layout()
    
    overall_fig_path = config.FIG_DIR / "overall_makespan_comparison.png"
    plt.savefig(overall_fig_path, dpi=200)
    plt.close()
    print(f"  保存: {overall_fig_path}")
    
    # 关闭TensorBoard writer
    writer.close()
    
    # ---- Cleanup ----
    try:
        vec_env.close()
        eval_env.close()
    except Exception:
        pass
    
    print(f"\n{'='*70}")
    print(f"种子 {SEED} 的训练和评估完成！")
    print(f"{'='*70}\n")

# =================== 最终总结 ===================
print("\n" + "="*70)
print("所有种子完成！")
print("="*70)

print("\n已保存的模型:")
for path in saved_models:
    print(f"  - {path}")

print(f"\n评估报告保存位置: {config.OUTPUT_DIR}")
print(f"  - 指标CSV: {config.TAB_DIR / 'metrics_eval.csv'}")
print(f"  - 汇总文本: {config.TAB_DIR / 'evaluation_summary.txt'}")
print(f"  - 图表: {config.FIG_DIR}")
print(f"  - TensorBoard日志: {config.OUTPUT_DIR / 'tb'}")

print("\n" + "="*70)
print("训练和评估流程全部完成！")
print("="*70)