In [None]:
# ====================== MAIN: Train + Evaluate(7 regions) + Report ======================
import os
import time
import json
import random
import multiprocessing as mp
from pathlib import Path
from importlib import reload

import numpy as np
import pandas as pd
import torch
import matplotlib.pyplot as plt

# ---- Project modules (不要 reload(config)，避免 OUTPUT_DIR 时间戳漂移)
import config
import evaluate
import housegymrl
from evaluate import create_unified_ramp, create_tasks_from_real_config
from housegymrl import RLEnv, BaselineEnv  # 需已按我们约定实现：两个环境类

# ---- SB3
from stable_baselines3 import SAC
from stable_baselines3.common.monitor import Monitor
from stable_baselines3.common.vec_env import SubprocVecEnv, DummyVecEnv, VecMonitor, VecNormalize
from stable_baselines3.common.callbacks import EvalCallback, CheckpointCallback, BaseCallback
from stable_baselines3.common.utils import set_random_seed

# 热更新非 config 文件
reload(evaluate)
reload(housegymrl)

# ----------------- 全局参数 -----------------
SEEDS = [42]
N_ENVS = 8
TOTAL_STEPS = 300_000
EVAL_FREQ = 10_000
CKPT_FREQ = 50_000
TRAIN_REGION_KEYS = list(config.REGION_CONFIG.keys())  # 覆盖全部区域
DEVICE = "mps" if torch.backends.mps.is_available() else ("cuda" if torch.cuda.is_available() else "cpu")
print("Device:", DEVICE)

ROOT_RUNS = Path("runs"); ROOT_RUNS.mkdir(exist_ok=True, parents=True)
try:
    mp.set_start_method("spawn", force=True)
except RuntimeError:
    pass

# ----------------- TensorBoard 记录器 -----------------
class CompletionTBCallback(BaseCallback):
    def __init__(self, tb_every: int = 200, verbose: int = 0):
        super().__init__(verbose); self.tb_every = int(tb_every)

    def _on_step(self) -> bool:
        infos = self.locals.get("infos", None) or []
        # 记录 makespan（回合级，env 内已在 done=true 时写入 info['day']）
        done_days = [i.get("day") for i in infos if isinstance(i, dict) and i.get("done") is True]
        if done_days:
            self.logger.record("env/makespan_episode", float(np.min(done_days)))
            self.logger.record("env/episodes_finished", float(len(done_days)))

        if self.tb_every > 0 and (self.num_timesteps % self.tb_every != 0):
            return True
        if not infos:
            return True

        vals = [i.get("completion") for i in infos if isinstance(i, dict) and ("completion" in i)]
        if not vals:
            return True

        idle_vals = [i.get("idle_workers", 0) for i in infos if isinstance(i, dict)]
        K_vals    = [i.get("K_effective", 0) for i in infos if isinstance(i, dict)]
        alloc_vals= [i.get("allocated_workers", 0) for i in infos if isinstance(i, dict)]
        util_vals = [a / max(1, k) for a, k in zip(alloc_vals, K_vals)] if K_vals else []

        v = np.asarray(vals, dtype=float)
        self.logger.record("env/completion_household_mean", float(np.nanmean(v)))
        self.logger.record("env/completion_household_min",  float(np.nanmin(v)))
        self.logger.record("env/completion_household_max",  float(np.nanmax(v)))
        self.logger.record("env/completion_household_p50",  float(np.nanpercentile(v, 50)))
        self.logger.record("env/completion_household_p90",  float(np.nanpercentile(v, 90)))

        if idle_vals: self.logger.record("env/idle_workers_mean", float(np.mean(idle_vals)))
        if util_vals: self.logger.record("env/utilization_mean", float(np.mean(util_vals)))
        if alloc_vals:
            alloc_arr = np.asarray(alloc_vals, dtype=float)
            self.logger.record("env/allocated_workers_p50", float(np.percentile(alloc_arr, 50)))
            self.logger.record("env/allocated_workers_p90", float(np.percentile(alloc_arr, 90)))
        if K_vals:
            K_arr = np.asarray(K_vals, dtype=float)
            self.logger.record("env/K_effective_p50", float(np.percentile(K_arr, 50)))
            self.logger.record("env/K_effective_p90", float(np.percentile(K_arr, 90)))

        if hasattr(self.model, "actor") and getattr(self.model.actor, "optimizer", None) is not None:
            lr = self.model.actor.optimizer.param_groups[0]["lr"]
            self.logger.record("train/lr", float(lr))
        if hasattr(self.model, "log_ent_coef"):
            ent_coef = float(self.model.log_ent_coef.exp().item())
            self.logger.record("train/ent_coef", ent_coef)
        return True

# ----------------- Ramp -----------------
ramp_fn = create_unified_ramp()

# ----------------- 训练用 sampler -----------------
# ----------------- 训练用 sampler -----------------

USE_SYNTHETIC_TRAIN = getattr(config, "USE_SYNTHETIC_TRAIN", False)

def make_sampler(seed: int):
    rng = np.random.default_rng(seed)
    def sampler():
        region_key = rng.choice(TRAIN_REGION_KEYS)
        cfg = config.REGION_CONFIG[region_key]
        if USE_SYNTHETIC_TRAIN and hasattr(evaluate, "create_tasks_from_synthetic"):
            tasks_df = evaluate.create_tasks_from_synthetic(cfg, rng)
        else:
            tasks_df = create_tasks_from_real_config(cfg, rng)
        resources = {"workers": int(cfg["num_contractors"]), "region_name": region_key}
        return tasks_df, resources, {"region": region_key}
    return sampler


def make_train_env(seed: int):
    sampler = make_sampler(seed)
    def _init():
        env = RLEnv(
            scenario_sampler=sampler,
            M=config.M_CANDIDATES,
            max_steps=config.MAX_STEPS,
            seed=seed,
            k_ramp=ramp_fn,
            batch_arrival=True,
        )
        return Monitor(env)
    return _init

def make_eval_env(region_key: str, seed: int, env_cls):
    def _init():
        env = evaluate.make_region_env(region_key, env_cls, k_ramp=ramp_fn, batch_arrival=True)
        return Monitor(env)
    return _init

# ----------------- 使用训练统计包裹单环境（推理） -----------------
def wrap_single_env_with_vecnorm(env, src_vecnorm):
    dummy = DummyVecEnv([lambda: env])
    v = VecNormalize(dummy, norm_obs=True, norm_reward=False, training=False, clip_obs=src_vecnorm.clip_obs)
    v.obs_rms = src_vecnorm.obs_rms
    v.clip_obs = src_vecnorm.clip_obs
    v.training = False
    v.norm_reward = False
    v.ret_rms = src_vecnorm.ret_rms
    return v

# ----------------- Baseline 分配策略（整数动作） -----------------
def _baseline_order(view, policy: str, rng: random.Random):
    """
    view: env.last_candidate_view() 的返回，期望包含：
      idx, cmax, mask, damage_level(0/1/2), arr_rem(剩余工日) 等
    返回：候选在 slate 中的访问顺序（np.ndarray of indices）
    """
    idx  = view.get("idx"); cmax = view.get("cmax"); mask = view.get("mask")
    valid = (mask > 0) & (idx >= 0) if (idx is not None and mask is not None) else None
    order = np.where(valid)[0] if valid is not None else np.arange(len(idx) if idx is not None else 0)

    # 需要的属性
    damage = view.get("damage_level", None)
    # [DEBUG FIX] 避免对 ndarray 使用 or 带来“数组真值不明确”
    rem = view.get("arr_rem", None)
    if rem is None:
        rem = view.get("rem_days", None)
    if rem is None:
        rem = view.get("remaining_days", None)
    if rem is None:
        # 若缺失剩余工日，用 cmax 替代一个 proxy
        rem = cmax if cmax is not None else np.ones_like(order)

    if policy == "FIFO":
        return order

    if policy == "RANDOM":
        arr = order.copy()
        rng.shuffle(arr)
        return arr

    # SJF：剩余工日升序；LJF：重灾优先 + 同级按剩余工日降序
    if policy == "SJF":
        if rem is None: return order
        # 有些 candidate 可能 rem 为负/NaN，先处理
        remv = np.asarray(rem, dtype=float)
        remv = np.where(np.isfinite(remv), remv, np.inf)
        return order[np.argsort(remv[order])]

    if policy == "LJF":
        # 「重灾(LJF)先、其余(中/轻)随机」；重灾内部按剩余工日降序
        if damage is None:  # 若没 damage_level，则退化为全体按剩余工日降序
            remv = np.asarray(rem, dtype=float)
            remv = np.where(np.isfinite(remv), remv, -np.inf)
            return order[np.argsort(-remv[order])]
        damage = np.asarray(damage)
        major = order[damage[order] == 2]
        others = order[damage[order] != 2]
        # major: 剩余工日降序
        remv = np.asarray(rem, dtype=float)
        remv = np.where(np.isfinite(remv), remv, -np.inf)
        major_sorted = major[np.argsort(-remv[major])] if major.size else major
        # others: 随机
        others_shuf = others.copy()
        rng.shuffle(others_shuf)
        return np.concatenate([major_sorted, others_shuf])

    # fallback
    return order

def _baseline_alloc(env, policy: str, rng: random.Random):
    # NOTE: BaselineEnv 下 action=每个候选的整数分配
    c = env.last_candidate_view()
    idx, cmax, mask = c["idx"], c["cmax"], c["mask"]
    order = _baseline_order(c, policy, rng)
    K_eff = int(env.k_ramp(env.day) * env.K) if getattr(env, "k_ramp", None) else env.K
    alloc = np.zeros_like(idx, dtype=np.int32)
    for i in order:
        if K_eff <= 0: break
        give = min(int(cmax[i]), K_eff)
        alloc[i] = give
        K_eff -= give
    return alloc

# ----------------- 指标函数 -----------------
def compute_auc(curve: np.ndarray, T: int) -> float:
    T = int(T)
    if T <= 0 or curve.size == 0: return 0.0
    c = curve[:min(T, len(curve))]
    return float(np.clip(np.trapz(c, dx=1) / T, 0.0, 1.0))

def compute_t_percentile(curve: np.ndarray, p: float) -> int:
    if curve.size == 0: return np.inf
    idx = np.where(curve >= p)[0]
    return int(idx[0]) if idx.size > 0 else np.inf

def compute_utilization(idle_list, K_total) -> float:
    if not idle_list or K_total <= 0: return 0.0
    return float(np.clip(1.0 - (np.mean(idle_list) / K_total), 0.0, 1.0))

def rmse_pair(a: np.ndarray, b: np.ndarray) -> float:
    n = min(len(a), len(b))
    if n == 0: return np.nan
    aa = a[:n]; bb = b[:n]
    m = np.isfinite(bb)
    if m.any():
        aa = aa[m[:len(aa)]]
        bb = bb[m]
    if len(aa) == 0: return np.nan
    return float(np.sqrt(np.mean((aa - bb)**2)))

# ----------------- 新增：逐日 K_effective 的利用率（与训练端口径一致） -----------------
def compute_utilization_from_series(idle_list, k_eff_list, fallback_K: int | None = None) -> float:
    if not idle_list or not k_eff_list:
        return 0.0 if fallback_K is None else compute_utilization(idle_list, fallback_K)
    n = min(len(idle_list), len(k_eff_list))
    if n == 0:
        return 0.0 if fallback_K is None else compute_utilization(idle_list, fallback_K)
    # 若全部 k_eff<=0，回退到常数 K 口径（保证结果稳定而不是硬 0）
    if all((int(k) if np.isfinite(k) else 0) <= 0 for k in k_eff_list[:n]):
        return 0.0 if fallback_K is None else compute_utilization(idle_list[:n], fallback_K)
    vals = []
    for i in range(n):
        k = int(k_eff_list[i]) if np.isfinite(k_eff_list[i]) else 0
        k = max(1, k)
        vals.append(1.0 - (float(idle_list[i]) / k))
    return float(np.clip(np.mean(vals), 0.0, 1.0))


# ----------------- rollout -----------------
def rollout_region(env, model=None, policy="SAC", max_days=2000, seed=0):
    """
    兼容两种 API：
    - Gymnasium 单环境：reset()->(obs, info), step()->(obs, reward, terminated, truncated, info)
    - VecEnv（DummyVecEnv/VecNormalize）：reset()->obs, step()->(obs, rewards, dones, infos)
    """
    rng = random.Random(seed)
    # ---- 是否是 VecEnv（如 DummyVecEnv/VecNormalize）----
    vec_mode = hasattr(env, "num_envs") or hasattr(env, "venv")

    # ---- reset 兼容 ----
    _res = env.reset()
    if isinstance(_res, tuple) and len(_res) == 2:
        obs, info = _res
    else:
        obs, info = _res, {}

    curve, idle_hist, k_eff_hist = [], [], []
    done_day = None

    for _ in range(max_days):
        # 动作
        if policy == "SAC":
            act, _ = model.predict(obs, deterministic=True)
            # VecEnv 需要 batch 维度
            if vec_mode and isinstance(act, np.ndarray) and act.ndim == 1:
                act = act[None, :]
        else:
            # Baseline 只在非 VecEnv 路径跑（此处也做了兜底以防万一）
            base_env = getattr(env, "unwrapped", env)
            alloc = _baseline_alloc(base_env, policy, rng)
            act = alloc
            if vec_mode:
                act = np.asarray(act)[None, :]

        if vec_mode:
            # VecEnv: step -> (obs, rewards, dones, infos)
            o, r, d, infos = env.step(act)
            info = infos[0] if isinstance(infos, (list, tuple)) and len(infos) else {}
            obs = o  # 仍保持 (1, obs_dim) 的 batch 形状

            curve.append(info.get("completion", 0.0))
            idle_hist.append(info.get("idle_workers", 0))
            k_eff_hist.append(info.get("K_effective", 0))

            done_flag = bool(np.asarray(d)[0])
            if info.get("done") is True or done_flag:
                done_day = info.get("day", len(curve) - 1)
                break
            continue

        # 单环境（Gymnasium）路径
        obs, reward, terminated, truncated, info = env.step(act)
        curve.append(info.get("completion", 0.0))
        idle_hist.append(info.get("idle_workers", 0))
        k_eff_hist.append(info.get("K_effective", 0))

        
        if info.get("done") is True:
            done_day = info.get("day")
            break
        if terminated or truncated:
            break

    done_day = (done_day if done_day is not None else len(curve) - 1)
    return np.array(curve, dtype=float), idle_hist, k_eff_hist, int(done_day)


# ----------------- [DEBUG FIX] 环境属性安全读取工具 -----------------
def _safe_total_houses(e):
    ue = getattr(e, "unwrapped", e)
    # 尝试常见“标量”属性名
    for name in ("H_total", "n_tasks", "num_tasks", "num_houses", "n_houses", "N", "H", "total_houses"):
        if hasattr(ue, name):
            try:
                return int(getattr(ue, name))
            except Exception:
                pass
    # 再尝试容器/表：tasks/houses/DataFrame
    for name in ("tasks", "houses", "tasks_df", "houses_df"):
        if hasattr(ue, name):
            obj = getattr(ue, name)
            try:
                return int(obj.shape[0]) if hasattr(obj, "shape") else int(len(obj))
            except Exception:
                pass
    return np.nan  # 实在拿不到就 NaN（不会影响后续 CSV 输出）

def _safe_K(e):
    ue = getattr(e, "unwrapped", e)
    for name in ("K", "K_base", "num_workers", "num_contractors", "workers"):
        if hasattr(ue, name):
            try:
                return int(getattr(ue, name))
            except Exception:
                pass
    # 正确的 resources fallback（ue.resources 是 dict）
    if hasattr(ue, "resources"):
        res = getattr(ue, "resources")
        if isinstance(res, dict) and "workers" in res:
            try:
                return int(res["workers"])
            except Exception:
                pass

    return 0

# ----------------- 新增：评估时核对户数与配置一致性（只告警） -----------------
def _warn_h_total_mismatch(region_key: str, H_env: int):
    try:
        exp = int(np.sum(config.REGION_CONFIG[region_key].get("damage_dist", [])))
        if H_env == H_env and exp > 0 and H_env != exp:
            print(f"[WARN] {region_key}: env houses={H_env} != sum(config.damage_dist)={exp}")
    except Exception:
        pass



# ----------------- 单 region 评估（含 observed 曲线的 RMSE，若可用） -----------------
def evaluate_region(region_key, model, vecnorm_src, max_days=2000, seed=1234):
    result_rows = []

    # 1) RL (RLEnv + VecNorm 统计)
    rl_env = make_eval_env(region_key, seed=seed, env_cls=RLEnv)()
    rl_env_v = wrap_single_env_with_vecnorm(rl_env, vecnorm_src)
    curve, idle, k_eff, done_day = rollout_region(rl_env_v, model=model, policy="SAC", max_days=max_days, seed=seed)
    # [DEBUG FIX] 安全获取 K 和总户数
    rl_K = _safe_K(rl_env)
    rl_H = _safe_total_houses(rl_env)
    _warn_h_total_mismatch(region_key, rl_H)
    result_rows.append(("SAC", curve, idle, k_eff, done_day, rl_K, rl_H))


    # 2) Baselines：SJF / LJF / RANDOM
    for pol in ("SJF", "LJF", "RANDOM"):
        bl_env = make_eval_env(region_key, seed=seed+111, env_cls=BaselineEnv)()
        c2, idle2, k2, done2 = rollout_region(bl_env, model=None, policy=pol, max_days=max_days, seed=seed+111)
        # [DEBUG FIX] 安全获取 K 和总户数
        bl_K = _safe_K(bl_env)
        bl_H = _safe_total_houses(bl_env)
        _warn_h_total_mismatch(region_key, bl_H)
        result_rows.append((pol, c2, idle2, k2, done2, bl_K, bl_H))


    # 3) 若 evaluate.py 提供 observed 曲线处理，则加载它用于 RMSE
    obs_curve = None
    try:
        if hasattr(evaluate, "process_observed_with_nan"):
            obs_curve = evaluate.process_observed_with_nan(region_key)
            # 允许返回 dict 或 ndarray
            if isinstance(obs_curve, dict):
                # 常见字段名猜测：'curve', 'completion', 'y'
                for k in ("curve", "completion", "y"):
                    if k in obs_curve:
                        obs_curve = obs_curve[k]; break
            obs_curve = np.asarray(obs_curve, dtype=float)
    except Exception as e:
        print(f"[WARN] observed load failed for {region_key}: {e}")
        obs_curve = None

    # 4) 汇总指标
    out = []
    for policy, c, idle_hist, k_eff_hist, done, K, H_total in result_rows:
        row = {
            "region": region_key,
            "strategy": policy,
            "makespan": int(done),
            "auc@200": compute_auc(c, 200),
            "auc@300": compute_auc(c, 300),
            "t80": compute_t_percentile(c, 0.80),
            "t90": compute_t_percentile(c, 0.90),
            "t95": compute_t_percentile(c, 0.95),
            "utilization": compute_utilization_from_series(idle_hist, k_eff_hist, fallback_K=K),
            "final_completion": float(c[-1]) if len(c)>0 else 0.0,
            "rmse_aligned": np.nan,
            "rmse_overlap": np.nan,
            "H_total": int(H_total) if (H_total == H_total) else np.nan,  # NaN 保护
            "K": int(K),
            "obs_days": np.nan,
            "obs_final": np.nan,
        }
        if obs_curve is not None and len(obs_curve)>0:
            row["rmse_aligned"] = rmse_pair(c, obs_curve)
            # overlap: 仅在 obs 非 NaN 的区间评估（保持你现有逻辑，不做非必要改动）
            mask = np.isfinite(obs_curve)
            if mask.any():
                row["rmse_overlap"] = rmse_pair(c, obs_curve[mask])
            row["obs_days"] = int(np.sum(np.isfinite(obs_curve)))
            row["obs_final"] = float(obs_curve[np.isfinite(obs_curve)][-1]) if np.any(np.isfinite(obs_curve)) else np.nan
        out.append(row)
    return out

# ----------------- 报表输出 -----------------
def write_summary_and_plots(df: pd.DataFrame, OUT_DIR: Path):
    TAB_DIR = OUT_DIR / "tab"
    FIG_DIR = OUT_DIR / "fig"
    TAB_DIR.mkdir(parents=True, exist_ok=True)
    FIG_DIR.mkdir(parents=True, exist_ok=True)

    metrics_path = TAB_DIR / "metrics_eval.csv"
    df.to_csv(metrics_path, index=False)
    print("Saved:", metrics_path)

    agg_cols = ["makespan","auc@200","auc@300","t80","t90","t95","utilization","rmse_aligned","rmse_overlap","final_completion"]
    pivot_table = df.pivot_table(index="region", columns="strategy", values=agg_cols, aggfunc="first")
    pivot_flat = pivot_table.copy()
    pivot_flat.columns = [f"{m}__{s}" for m,s in pivot_flat.columns]
    pivot_flat = pivot_flat.reset_index()
    summary_csv = TAB_DIR / "summary_wide.csv"
    pivot_flat.to_csv(summary_csv, index=False)
    print("Saved:", summary_csv)

    def barplot_metric(metric, ylabel, save_name):
        plt.figure(figsize=(10,6))
        plot_df = df[["region","strategy",metric]].copy()
        means = plot_df.groupby("strategy")[metric].mean().sort_values(ascending=True)
        means.plot.bar()
        plt.ylabel(ylabel)
        plt.title(f"Overall {metric} by Strategy")
        plt.grid(axis="y", alpha=0.3)
        plt.tight_layout()
        plt.savefig(FIG_DIR / save_name, dpi=200)
        plt.close()

    barplot_metric("makespan", "Days", "overall_makespan.png")
    barplot_metric("rmse_aligned", "RMSE", "overall_rmse_aligned.png")
    barplot_metric("rmse_overlap", "RMSE", "overall_rmse_overlap.png")
    print("Saved figures:",
          FIG_DIR / "overall_makespan.png",
          FIG_DIR / "overall_rmse_aligned.png",
          FIG_DIR / "overall_rmse_overlap.png")

    # 单区对比（取第一个 region）
    uniq = df["region"].dropna().unique()
    if len(uniq):
        first_region = uniq[0]
        sub = df[df["region"]==first_region].set_index("strategy")
        metrics = ["makespan","auc@200","utilization","rmse_aligned","rmse_overlap"]
        plt.figure(figsize=(12,6))
        for i, m in enumerate(metrics, 1):
            plt.subplot(1, len(metrics), i)
            sub[m].sort_values(ascending=(m in ["makespan","rmse_aligned","rmse_overlap"])).plot.bar()
            plt.title(m); plt.xticks(rotation=45, ha="right"); plt.grid(axis="y", alpha=0.3)
        plt.tight_layout()
        outp = (FIG_DIR / f"{first_region}_strategy_compare.png")
        plt.savefig(outp, dpi=200); plt.close()
        print("Saved region figure:", outp)

    # 榜单（makespan 最优策略）
    rank_df = df.pivot_table(index="region", columns="strategy", values="makespan", aggfunc="first")
    if not rank_df.empty:
        rank_df["best_strategy"] = rank_df.idxmin(axis=1)
        display(rank_df.sort_index())

# =================== Train + Eval ===================
saved_models = []
for SEED in SEEDS:
    print(f"========== Training seed={SEED} ==========")
    set_random_seed(SEED)

    ts_tag = time.strftime("%Y-%m-%d_%H-%M-%S")
    run_dir = ROOT_RUNS / f"sac_seed{SEED}_{ts_tag}"
    tb_dir  = run_dir / "tb"
    run_dir.mkdir(parents=True, exist_ok=True)

    # ---- Build train vec env ----
    env_fns = [make_train_env(SEED + i) for i in range(N_ENVS)]
    _tmp = env_fns[0]() ; _tmp.close()

    try:
        vec_env = SubprocVecEnv(env_fns, start_method="spawn")
    except Exception as exc:
        print("[WARN] SubprocVecEnv spawn failed, fallback to DummyVecEnv:", repr(exc))
        vec_env = DummyVecEnv(env_fns)
    vec_env = VecMonitor(vec_env)
    vec_env = VecNormalize(vec_env, norm_obs=True, norm_reward=True, clip_obs=10.0)

    # ---- Eval env for callback (multi region, RL) ----
    def make_eval_envs(seed: int):
        envs = []
        for i, region_key in enumerate(TRAIN_REGION_KEYS):
            envs.append(make_eval_env(region_key, seed + 9000 + i, env_cls=RLEnv))
        vec = DummyVecEnv(envs)
        vec = VecMonitor(vec)
        vec = VecNormalize(vec, norm_obs=True, norm_reward=False, training=False, clip_obs=10.0)
        vec.training = False; vec.norm_reward = False
        return vec

    eval_env = make_eval_envs(SEED)
    # 对齐统计
    eval_env.obs_rms = vec_env.obs_rms
    eval_env.clip_obs = vec_env.clip_obs
    eval_env.training = False
    eval_env.norm_reward = False
    eval_env.ret_rms = vec_env.ret_rms

    # ---- manifest ----
    manifest = {
        "global_seed": SEED,
        "train_env_seeds": [SEED + i for i in range(N_ENVS)],
        "eval_seeds": {region: SEED + 9000 + idx for idx, region in enumerate(TRAIN_REGION_KEYS)},
        "train_regions": TRAIN_REGION_KEYS,
        "timestamp": ts_tag,
        "output_dir": str(config.OUTPUT_DIR),
        "tensorboard_dir": str(tb_dir),
    }
    (run_dir / "seed_manifest.json").write_text(json.dumps(manifest, indent=2))

    # ---- Policy/LR ----
    total = int(TOTAL_STEPS)
    b1, b2 = int(0.60 * total), int(0.85 * total)
    lr1, lr2, lr3 = 3e-4, 1e-4, 5e-5
    def lr_schedule(progress_remaining: float) -> float:
        step_done = int((1.0 - progress_remaining) * total)
        if step_done < b1: return lr1
        if step_done < b2: return lr2
        return lr3

    policy_kwargs = dict(net_arch=dict(pi=[512], qf=[512]))
    batch_size = 1024 if DEVICE == "mps" else 512
    model = SAC(
        "MlpPolicy",
        vec_env,
        verbose=1,
        device=DEVICE,
        policy_kwargs=policy_kwargs,
        learning_rate=lr_schedule,
        buffer_size=max(300_000, TOTAL_STEPS),
        batch_size=batch_size,
        gamma=0.95,
        tau=0.01,
        train_freq=(1, "step"),
        gradient_steps=1,
        ent_coef="auto",
        tensorboard_log=str(tb_dir),
        seed=SEED,
    )

    eval_cb = EvalCallback(
        eval_env,
        best_model_save_path=str(run_dir / "best"),
        log_path=str(run_dir / "eval"),
        eval_freq=max(1, EVAL_FREQ // max(1, N_ENVS)),
        deterministic=True,
        render=False,
    )
    ckpt_cb = CheckpointCallback(
        save_freq=max(1, CKPT_FREQ // max(1, N_ENVS)),
        save_path=str(run_dir / "ckpt"),
        name_prefix="sac",
    )
    completion_cb = CompletionTBCallback(tb_every=200)

    model.learn(total_timesteps=TOTAL_STEPS, callback=[eval_cb, ckpt_cb, completion_cb], progress_bar=True)

    # ---- Save artifacts ----
    model_path   = run_dir / "sac_model.zip"
    vecnorm_path = run_dir / "vecnormalize.pkl"
    model.save(str(model_path))
    vec_env.save(str(vecnorm_path))
    saved_models.append(model_path)
    print("Saved:", model_path)
    print("VecNormalize stats:", vecnorm_path)
    print("Best checkpoint dir:", run_dir / "best")

    # ---- 训练后：对全部 region 评估（RL + baselines）并生成报表 ----
    # 强校验动作维度
    model_M = model.policy.action_space.shape[0]
    assert model_M == config.M_CANDIDATES, f"Model action dim {model_M} != M_CANDIDATES {config.M_CANDIDATES}"

    OUT_DIR = config.OUTPUT_DIR
    (OUT_DIR / "tab").mkdir(parents=True, exist_ok=True)
    (OUT_DIR / "fig").mkdir(parents=True, exist_ok=True)

    all_rows = []
    for region_key in TRAIN_REGION_KEYS:
        print(f"[Eval] Region = {region_key}")
        rows = evaluate_region(region_key, model=model, vecnorm_src=vec_env, max_days=config.MAX_STEPS, seed=SEED+2024)
        all_rows.extend(rows)

    df = pd.DataFrame(all_rows)
    write_summary_and_plots(df, OUT_DIR)

    # ---- Cleanup ----
    try:
        vec_env.close(); eval_env.close()
    except Exception:
        pass

print("All seeds finished. Saved models:")
for path in saved_models:
    print(" -", path)
print("Reports saved under:", config.OUTPUT_DIR)
# ====================== END MAIN ======================
