In [None]:
# ====================== MAIN: Train + Evaluate(7 regions) + Report ======================
import os
import time
import json
import random
import multiprocessing as mp
from pathlib import Path
from importlib import reload

import numpy as np
import pandas as pd
import torch
import matplotlib.pyplot as plt
from IPython.display import display

# ---- Project modules（不要 reload(config)，避免 OUTPUT_DIR 时间戳漂移）
import config
import evaluate
import housegymrl
from evaluate import create_unified_ramp, create_tasks_from_real_config
from housegymrl import RLEnv, BaselineEnv  # 需已按约定实现：两个环境类

# ---- SB3
from stable_baselines3 import SAC
from stable_baselines3.common.monitor import Monitor
from stable_baselines3.common.vec_env import SubprocVecEnv, DummyVecEnv, VecMonitor, VecNormalize
from stable_baselines3.common.callbacks import EvalCallback, CheckpointCallback, BaseCallback
from stable_baselines3.common.utils import set_random_seed

# 热更新非 config 文件
reload(evaluate)
reload(housegymrl)

# ----------------- 全局参数 -----------------
SEEDS = [42]
N_ENVS = 8
TOTAL_STEPS = 500_000
EVAL_FREQ = 10_000
CKPT_FREQ = 50_000
TRAIN_REGION_KEYS = list(config.REGION_CONFIG.keys())  # 覆盖全部区域
DEVICE = "mps" if torch.backends.mps.is_available() else ("cuda" if torch.cuda.is_available() else "cpu")
print("Device:", DEVICE)

ROOT_RUNS = Path("runs"); ROOT_RUNS.mkdir(exist_ok=True, parents=True)
try:
    mp.set_start_method("spawn", force=True)
except RuntimeError:
    pass

# ----------------- TensorBoard 记录器 -----------------
class CompletionTBCallback(BaseCallback):
    def __init__(self, tb_every: int = 200, verbose: int = 0):
        super().__init__(verbose)
        self.tb_every = int(tb_every)

    def _on_step(self) -> bool:
        infos = self.locals.get("infos", None) or []
        # 记录 makespan（回合级，env 内已在 done=True 时写入 info['day']）
        done_days = [i.get("day") for i in infos if isinstance(i, dict) and i.get("done") is True]
        if done_days:
            self.logger.record("env/makespan_episode", float(np.min(done_days)))
            self.logger.record("env/episodes_finished", float(len(done_days)))

        if self.tb_every > 0 and (self.num_timesteps % self.tb_every != 0):
            return True
        if not infos:
            return True

        vals = [i.get("completion") for i in infos if isinstance(i, dict) and ("completion" in i)]
        if not vals:
            return True

        idle_vals = [i.get("idle_workers", 0) for i in infos if isinstance(i, dict)]
        K_vals    = [i.get("K_effective", 0) for i in infos if isinstance(i, dict)]
        alloc_vals= [i.get("allocated_workers", 0) for i in infos if isinstance(i, dict)]
        util_vals = [a / max(1, k) for a, k in zip(alloc_vals, K_vals)] if K_vals else []

        v = np.asarray(vals, dtype=float)
        self.logger.record("env/completion_household_mean", float(np.nanmean(v)))
        self.logger.record("env/completion_household_min",  float(np.nanmin(v)))
        self.logger.record("env/completion_household_max",  float(np.nanmax(v)))
        self.logger.record("env/completion_household_p50",  float(np.nanpercentile(v, 50)))
        self.logger.record("env/completion_household_p90",  float(np.nanpercentile(v, 90)))

        if idle_vals:
            self.logger.record("env/idle_workers_mean", float(np.mean(idle_vals)))
        if util_vals:
            self.logger.record("env/utilization_mean", float(np.mean(util_vals)))
        if alloc_vals:
            alloc_arr = np.asarray(alloc_vals, dtype=float)
            self.logger.record("env/allocated_workers_p50", float(np.percentile(alloc_arr, 50)))
            self.logger.record("env/allocated_workers_p90", float(np.percentile(alloc_arr, 90)))
        if K_vals:
            K_arr = np.asarray(K_vals, dtype=float)
            self.logger.record("env/K_effective_p50", float(np.percentile(K_arr, 50)))
            self.logger.record("env/K_effective_p90", float(np.percentile(K_arr, 90)))

        if hasattr(self.model, "actor") and getattr(self.model.actor, "optimizer", None) is not None:
            lr = self.model.actor.optimizer.param_groups[0]["lr"]
            self.logger.record("train/lr", float(lr))
        if hasattr(self.model, "log_ent_coef"):
            ent_coef = float(self.model.log_ent_coef.exp().item())
            self.logger.record("train/ent_coef", ent_coef)
        return True

# ----------------- Ramp -----------------
ramp_fn = create_unified_ramp()

# ----------------- 训练用 sampler -----------------
USE_SYNTHETIC_TRAIN = getattr(config, "USE_SYNTHETIC_TRAIN", False)

def make_sampler(seed: int):
    rng = np.random.default_rng(seed)
    def sampler():
        region_key = rng.choice(TRAIN_REGION_KEYS)
        cfg = config.REGION_CONFIG[region_key]
        if USE_SYNTHETIC_TRAIN and hasattr(evaluate, "create_tasks_from_synthetic"):
            tasks_df = evaluate.create_tasks_from_synthetic(cfg, rng)
        else:
            tasks_df = create_tasks_from_real_config(cfg, rng)
        resources = {"workers": int(cfg["num_contractors"]), "region_name": region_key}
        return tasks_df, resources, {"region": region_key}
    return sampler

def make_train_env(seed: int):
    sampler = make_sampler(seed)
    def _init():
        env = RLEnv(
            scenario_sampler=sampler,
            M=config.M_CANDIDATES,
            max_steps=config.MAX_STEPS,
            seed=seed,
            k_ramp=ramp_fn,
            batch_arrival=True,
        )
        return Monitor(env)
    return _init

def make_eval_env(region_key: str, seed: int, env_cls):
    def _init():
        env = evaluate.make_region_env(region_key, env_cls, k_ramp=ramp_fn, batch_arrival=True)
        return Monitor(env)
    return _init

# ----------------- 用训练统计包裹单环境（推理） -----------------
def wrap_single_env_with_vecnorm(env, src_vecnorm):
    dummy = DummyVecEnv([lambda: env])
    v = VecNormalize(dummy, norm_obs=True, norm_reward=False, training=False, clip_obs=src_vecnorm.clip_obs)
    v.obs_rms = src_vecnorm.obs_rms
    v.clip_obs = src_vecnorm.clip_obs
    v.training = False
    v.norm_reward = False
    v.ret_rms = src_vecnorm.ret_rms
    return v

# ----------------- 报表输出 -----------------
def write_summary_and_plots(df: pd.DataFrame, OUT_DIR: Path):
    TAB_DIR = OUT_DIR / "tab"
    FIG_DIR = OUT_DIR / "fig"
    TAB_DIR.mkdir(parents=True, exist_ok=True)
    FIG_DIR.mkdir(parents=True, exist_ok=True)

    metrics_path = TAB_DIR / "metrics_eval.csv"
    df.to_csv(metrics_path, index=False)
    print("Saved:", metrics_path)

    agg_cols = ["makespan","auc@200","auc@300","t80","t90","t95","utilization","rmse_aligned","rmse_overlap","final_completion"]
    pivot_table = df.pivot_table(index="region", columns="strategy", values=agg_cols, aggfunc="first")
    pivot_flat = pivot_table.copy()
    pivot_flat.columns = [f"{m}__{s}" for m,s in pivot_flat.columns]
    pivot_flat = pivot_flat.reset_index()
    summary_csv = TAB_DIR / "summary_wide.csv"
    pivot_flat.to_csv(summary_csv, index=False)
    print("Saved:", summary_csv)

    def barplot_metric(metric, ylabel, save_name):
        plt.figure(figsize=(10,6))
        plot_df = df[["region","strategy",metric]].copy()
        means = plot_df.groupby("strategy")[metric].mean().sort_values(ascending=True)
        means.plot.bar()
        plt.ylabel(ylabel)
        plt.title(f"Overall {metric} by Strategy")
        plt.grid(axis="y", alpha=0.3)
        plt.tight_layout()
        plt.savefig(FIG_DIR / save_name, dpi=200)
        plt.close()

    barplot_metric("makespan", "Days", "overall_makespan.png")
    barplot_metric("rmse_aligned", "RMSE", "overall_rmse_aligned.png")
    barplot_metric("rmse_overlap", "RMSE", "overall_rmse_overlap.png")
    print("Saved figures:",
          FIG_DIR / "overall_makespan.png",
          FIG_DIR / "overall_rmse_aligned.png",
          FIG_DIR / "overall_rmse_overlap.png")

    # 单区对比（取第一个 region）
    uniq = df["region"].dropna().unique()
    if len(uniq):
        first_region = uniq[0]
        sub = df[df["region"]==first_region].set_index("strategy")
        metrics = ["makespan","auc@200","utilization","rmse_aligned","rmse_overlap"]
        plt.figure(figsize=(12,6))
        for i, m in enumerate(metrics, 1):
            plt.subplot(1, len(metrics), i)
            sub[m].sort_values(ascending=(m in ["makespan","rmse_aligned","rmse_overlap"])).plot.bar()
            plt.title(m); plt.xticks(rotation=45, ha="right"); plt.grid(axis="y", alpha=0.3)
        plt.tight_layout()
        outp = (FIG_DIR / f"{first_region}_strategy_compare.png")
        plt.savefig(outp, dpi=200); plt.close()
        print("Saved region figure:", outp)

    # 榜单（makespan 最优策略）
    rank_df = df.pivot_table(index="region", columns="strategy", values="makespan", aggfunc="first")
    if not rank_df.empty:
        rank_df["best_strategy"] = rank_df.idxmin(axis=1)
        display(rank_df.sort_index())

# =================== Train + Eval ===================
saved_models = []
for SEED in SEEDS:
    print(f"========== Training seed={SEED} ==========")
    set_random_seed(SEED)

    ts_tag = time.strftime("%Y-%m-%d_%H-%M-%S")
    run_dir = ROOT_RUNS / f"sac_seed{SEED}_{ts_tag}"
    tb_dir  = run_dir / "tb"
    run_dir.mkdir(parents=True, exist_ok=True)

    # ---- Build train vec env ----
    env_fns = [make_train_env(SEED + i) for i in range(N_ENVS)]
    _tmp = env_fns[0](); _tmp.close()

    try:
        vec_env = SubprocVecEnv(env_fns, start_method="spawn")
    except Exception as exc:
        print("[WARN] SubprocVecEnv spawn failed, fallback to DummyVecEnv:", repr(exc))
        vec_env = DummyVecEnv(env_fns)
    vec_env = VecMonitor(vec_env)
    vec_env = VecNormalize(vec_env, norm_obs=True, norm_reward=True, clip_obs=10.0)

    # ---- Eval env for callback (multi region, RL) ----
    def make_eval_envs(seed: int):
        envs = []
        for i, region_key in enumerate(TRAIN_REGION_KEYS):
            envs.append(make_eval_env(region_key, seed + 9000 + i, env_cls=RLEnv))
        vec = DummyVecEnv(envs)
        vec = VecMonitor(vec)
        vec = VecNormalize(vec, norm_obs=True, norm_reward=False, training=False, clip_obs=10.0)
        vec.training = False; vec.norm_reward = False
        return vec

    eval_env = make_eval_envs(SEED)
    # 对齐统计
    eval_env.obs_rms = vec_env.obs_rms
    eval_env.clip_obs = vec_env.clip_obs
    eval_env.training = False
    eval_env.norm_reward = False
    eval_env.ret_rms = vec_env.ret_rms

    # ---- manifest ----
    manifest = {
        "global_seed": SEED,
        "train_env_seeds": [SEED + i for i in range(N_ENVS)],
        "eval_seeds": {region: SEED + 9000 + idx for idx, region in enumerate(TRAIN_REGION_KEYS)},
        "train_regions": TRAIN_REGION_KEYS,
        "timestamp": ts_tag,
        "output_dir": str(config.OUTPUT_DIR),
        "tensorboard_dir": str(tb_dir),
    }
    (run_dir / "seed_manifest.json").write_text(json.dumps(manifest, indent=2))

    # ---- Policy/LR ----
    total = int(TOTAL_STEPS)
    b1, b2 = int(0.60 * total), int(0.85 * total)
    lr1, lr2, lr3 = 3e-4, 1e-4, 5e-5
    def lr_schedule(progress_remaining: float) -> float:
        step_done = int((1.0 - progress_remaining) * total)
        if step_done < b1: return lr1
        if step_done < b2: return lr2
        return lr3

    policy_kwargs = dict(net_arch=dict(pi=[512], qf=[512]))
    batch_size = 1024 if DEVICE == "mps" else 512
    model = SAC(
        "MlpPolicy",
        vec_env,
        verbose=1,
        device=DEVICE,
        policy_kwargs=policy_kwargs,
        learning_rate=lr_schedule,
        buffer_size=max(300_000, TOTAL_STEPS),
        batch_size=batch_size,
        gamma=0.95,
        tau=0.01,
        train_freq=(1, "step"),
        gradient_steps=1,
        ent_coef="auto",
        tensorboard_log=str(tb_dir),
        seed=SEED,
    )

    eval_cb = EvalCallback(
        eval_env,
        best_model_save_path=str(run_dir / "best"),
        log_path=str(run_dir / "eval"),
        eval_freq=max(1, EVAL_FREQ // max(1, N_ENVS)),
        deterministic=True,
        render=False,
    )
    ckpt_cb = CheckpointCallback(
        save_freq=max(1, CKPT_FREQ // max(1, N_ENVS)),
        save_path=str(run_dir / "ckpt"),
        name_prefix="sac",
    )
    completion_cb = CompletionTBCallback(tb_every=200)

    model.learn(total_timesteps=TOTAL_STEPS, callback=[eval_cb, ckpt_cb, completion_cb], progress_bar=True)

    # ---- Save artifacts ----
    model_path   = run_dir / "sac_model.zip"
    vecnorm_path = run_dir / "vecnormalize.pkl"
    model.save(str(model_path))
    vec_env.save(str(vecnorm_path))
    saved_models.append(model_path)
    print("Saved:", model_path)
    print("VecNormalize stats:", vecnorm_path)
    print("Best checkpoint dir:", run_dir / "best")

    # ================= 统一评估（使用 evaluate.py 的 evaluate_region） =================
    print("\n" + "="*70)
    print("开始评估：使用统一评估方法")
    print("="*70)

    # 强校验动作维度
    model_M = model.policy.action_space.shape[0]
    assert model_M == config.M_CANDIDATES, f"Model action dim {model_M} != M_CANDIDATES {config.M_CANDIDATES}"

    OUT_DIR = config.OUTPUT_DIR
    (OUT_DIR / "tab").mkdir(parents=True, exist_ok=True)
    (OUT_DIR / "fig").mkdir(parents=True, exist_ok=True)

    # 加载观测数据（如果可用）
    observed = {}
    try:
        observed = evaluate.load_observed(config.OBSERVED_DATA_PATH)
        print(f"✅ 加载观测数据：{len(observed)} 个区域")
    except Exception as e:
        print(f"⚠️ 无法加载观测数据: {e}")
        print("   将继续评估但不计算 RMSE")

    # TensorBoard（评估）
    from torch.utils.tensorboard import SummaryWriter
    tb_eval_dir = OUT_DIR / "tb_eval"
    tb_eval_dir.mkdir(parents=True, exist_ok=True)
    tb_writer = SummaryWriter(log_dir=str(tb_eval_dir))

    all_rows = []
    for region_key in TRAIN_REGION_KEYS:
        print(f"\n[Eval] 正在评估区域: {region_key}")
        rows = evaluate.evaluate_region(
            region=region_key,
            obs_series=observed.get(region_key, pd.Series(dtype=float)),  # 若无观测数据传空 Series
            ramp_fn=ramp_fn,
            model=model,
            writer=tb_writer,        # 传入 TensorBoard writer
            vecnorm_src=vec_env      # 传入训练时的 VecNormalize 统计
        )
        all_rows.extend(rows)
        print(f"  完成 {len(rows)} 个策略的评估")

    tb_writer.close()

    print("\n" + "="*70)
    print(f"评估完成：共 {len(all_rows)} 条记录")
    print("="*70)

    # 生成报表
    df = pd.DataFrame(all_rows)
    write_summary_and_plots(df, OUT_DIR)

    # ---- Cleanup ----
    try:
        vec_env.close(); eval_env.close()
    except Exception:
        pass

print("\n" + "="*70)
print("所有种子训练完成！")
print("="*70)
print("保存的模型：")
for path in saved_models:
    print(f"  - {path}")
print(f"\n报表保存位置: {config.OUTPUT_DIR}")
print("="*70)
# ====================== END MAIN ======================
