In [23]:
# %%
# ============================================================
# Block A: Config for RECENT/POST runs
# ============================================================
import os

# 新结果根目录（里面子目录名是 long_term_forecast_RECENT10_* 或 long_term_forecast_POST_*）
RESULTS_ROOT_RECENT_POST = "./results_MultiPeriod_20251202_234234"
# RESULTS_ROOT_RECENT_POST = "./results_WTI_trading_tuned_20251202_175241"

# 使用哪个 horizon（true_h / pred_h）
HORIZON_RECENT_POST = 0

# 在这个根目录下存图
PLOT_RECENT_POST_DIR = os.path.join(RESULTS_ROOT_RECENT_POST, "plots_recent_post")
os.makedirs(PLOT_RECENT_POST_DIR, exist_ok=True)


In [24]:
# %%
# ============================================================
# Block B: Helpers for RECENT/POST plotting
# ============================================================

import os
import re
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt


def infer_window_type_from_model_id(model_id: str) -> str | None:
    """
    从 model_id 判断是 RECENT 还是 POST。
    例如:
      long_term_forecast_RECENT10_TimeMixer_... -> RECENT
      long_term_forecast_POST_TimeMixer_...    -> POST
    """
    s = model_id.lower()
    if "recent" in s:
        return "RECENT"
    # if "post" in s:
    #     return "POST"
    # return None
    return "POST"


def infer_arch_name_from_model_id(model_id: str) -> str:
    """
    直接从 model_id 里猜架构名（Autoformer / PatchTST / DLinear / Informer / TimeMixer 等）。
    如果你前面已经定义过同名函数，可以保留任意一个版本（实现一样即可）。
    """
    s = model_id.lower()

    if "autoformer" in s:
        return "Autoformer"
    if "patchtst" in s:
        return "PatchTST"
    if "dlinear" in s:
        return "DLinear"
    if "informer" in s:
        return "Informer"
    if "timesnet" in s:
        return "TimesNet"
    if "reformer" in s:
        return "Reformer"
    if "fedformer" in s:
        return "FedFormer"
    if "tft" in s:
        return "TFT"
    if "psformer" in s:
        return "PSFormer"
    if "timemixer" in s:
        return "TimeMixer"

    # fallback：尽量短一点
    short = model_id
    short = re.sub(r'^long_term_forecast_', '', short, flags=re.IGNORECASE)
    return short[:20]

def position_from_prediction(pred: np.ndarray, threshold: float = 0.0) -> np.ndarray:
    return np.where(pred > threshold, 1.0, -1.0)

def model_id_to_pred_path(model_id: str, results_root: str) -> str:
    return os.path.join(results_root, model_id, "data_table.csv")

def load_prediction_table(model_id: str, results_root: str) -> pd.DataFrame | None:
    path = model_id_to_pred_path(model_id, results_root)
    if not os.path.exists(path):
        print(f"[WARN] data_table not found for model_id={model_id}")
        return None
    df = pd.read_csv(path)
    if "date" in df.columns:
        df["date"] = pd.to_datetime(df["date"])
    df = df.sort_values("date").reset_index(drop=True)
    return df


def compute_timeseries_full_sample(
    pred_df: pd.DataFrame,
    horizon: int,
    threshold: float = 0.0,
) -> pd.DataFrame:
    """
    不切 regime，直接对整个样本期计算：
      date, true_ret, pred, position, strat_daily, bh_daily, strat_cum, bh_cum
    """
    true_col = f"true_{horizon}"
    pred_col = f"pred_{horizon}"

    if true_col not in pred_df.columns or pred_col not in pred_df.columns:
        return pd.DataFrame()

    df = pred_df[["date", true_col, pred_col]].dropna().sort_values("date").reset_index(drop=True)
    df = df.rename(columns={true_col: "true_ret", pred_col: "pred"})
    df["position"] = position_from_prediction(df["pred"].values, threshold=threshold)
    df["strat_daily"] = df["position"] * df["true_ret"]
    df["bh_daily"] = df["true_ret"]
    df["strat_cum"] = df["strat_daily"].cumsum()
    df["bh_cum"] = df["bh_daily"].cumsum()
    return df


In [25]:
# %%
# ============================================================
# Block C: Scan RESULTS_ROOT_RECENT_POST and group models
# ============================================================

def scan_recent_post_models(results_root: str) -> pd.DataFrame:
    """
    扫描 results_root 下的子目录，把 model_id / window_type / base_key / arch_name 收集成一个 DataFrame。
    base_key：去掉最后 _0/_1/_2 的前缀，代表“同一个模型配置”的 3 次 run。
    """
    rows = []
    for name in os.listdir(results_root):
        full_path = os.path.join(results_root, name)
        if not os.path.isdir(full_path):
            continue

        model_id = name
        window_type = infer_window_type_from_model_id(model_id)
        if window_type is None:
            continue  # 既不是 RECENT 也不是 POST，就跳过

        # base_key：去掉结尾的 _0 / _1 / _2
        base_key = re.sub(r"_\d+$", "", model_id)
        arch_name = infer_arch_name_from_model_id(model_id)

        rows.append(
            {
                "model_id": model_id,
                "window_type": window_type,
                "base_key": base_key,
                "arch_name": arch_name,
            }
        )

    df = pd.DataFrame(rows)
    print("Scanned RECENT/POST models:", df.shape)
    print(df.head())
    return df


recent_post_index_df = scan_recent_post_models(RESULTS_ROOT_RECENT_POST)


Scanned RECENT/POST models: (87, 4)
                                            model_id window_type  \
0                                         evaluation        POST   
1  long_term_forecast_RECENT10_Naive_pl5_Naive_cu...      RECENT   
2  long_term_forecast_RECENT10_Naive_pl5_Naive_cu...      RECENT   
3  long_term_forecast_RECENT10_Naive_pl5_Naive_cu...      RECENT   
4  long_term_forecast_RECENT10_DLinear_pl5_DLinea...      RECENT   

                                            base_key             arch_name  
0                                         evaluation            evaluation  
1  long_term_forecast_RECENT10_Naive_pl5_Naive_cu...  RECENT10_Naive_pl5_N  
2  long_term_forecast_RECENT10_Naive_pl5_Naive_cu...  RECENT10_Naive_pl5_N  
3  long_term_forecast_RECENT10_Naive_pl5_Naive_cu...  RECENT10_Naive_pl5_N  
4  long_term_forecast_RECENT10_DLinear_pl5_DLinea...               DLinear  


In [26]:
# %%
# ============================================================
# Block D: Plot RECENT vs POST – each as one big figure
# ============================================================

def plot_window_group(
    index_df: pd.DataFrame,
    results_root: str,
    window_type: str,
    horizon: int,
    max_runs_per_model: int = 3,
):
    """
    对某一个 window_type（RECENT 或 POST）画一张大图：
      - 每行一个 base model（架构）
      - 每列一个 run（最多 3 个：_0/_1/_2）
      - 横轴整个样本期，不分 regime
    """
    sub = index_df[index_df["window_type"] == window_type].copy()
    if sub.empty:
        print(f"No models for window_type={window_type}")
        return

    # base_key → arch_name, [model_ids]
    mapping = {}
    for base_key, g in sub.groupby("base_key"):
        arch_name = g["arch_name"].iloc[0]
        model_ids = sorted(g["model_id"].unique().tolist())
        model_ids = model_ids[:max_runs_per_model]
        mapping[base_key] = {
            "arch_name": arch_name,
            "model_ids": model_ids,
        }

    base_keys = sorted(mapping.keys())
    n_models = len(base_keys)
    n_cols = max_runs_per_model
    n_rows = n_models

    if n_models == 0:
        print(f"No base models for window_type={window_type}")
        return

    # 每个子图横向拉长：和之前 true models 类似
    fig, axes = plt.subplots(
        n_rows,
        n_cols,
        figsize=(7 * n_cols, 1.8 * n_rows),
        sharex=False,
        sharey=False,
    )

    # 统一 axes 形状
    if n_rows == 1 and n_cols == 1:
        axes = np.array([[axes]])
    elif n_rows == 1:
        axes = np.array([axes])
    elif n_cols == 1:
        axes = np.array([[ax] for ax in axes])

    for row_idx, base_key in enumerate(base_keys):
        arch_name = mapping[base_key]["arch_name"]
        model_ids = mapping[base_key]["model_ids"]

        for col_idx in range(n_cols):
            ax = axes[row_idx, col_idx]

            if col_idx < len(model_ids):
                model_id = model_ids[col_idx]
                pred_df = load_prediction_table(model_id, results_root)
                if pred_df is None:
                    ax.set_visible(False)
                    continue

                ts = compute_timeseries_full_sample(
                    pred_df=pred_df,
                    horizon=horizon,
                    threshold=0.0,
                )
                if ts.empty:
                    ax.set_visible(False)
                    continue

                ax.plot(ts["date"], ts["strat_cum"], label="Strategy")
                ax.plot(ts["date"], ts["bh_cum"], linestyle="--", label="B&H")

                # 第一行写列标题：Run 1/2/3
                if row_idx == 0:
                    ax.set_title(f"Run {col_idx+1}", fontsize=10)

                # 每行左边写架构名
                if col_idx == 0:
                    ax.set_ylabel(arch_name, fontsize=10)

                ax.tick_params(axis="x", labelrotation=30, labelsize=8)
            else:
                ax.set_visible(False)

    # 只拿第一个可见子图的 legend
    first_ax = None
    for row in axes:
        for ax in row:
            if ax.get_visible():
                first_ax = ax
                break
        if first_ax is not None:
            break

    if first_ax is not None:
        handles, labels = first_ax.get_legend_handles_labels()
        if handles:
            fig.legend(
                handles,
                labels,
                loc="upper right",
                bbox_to_anchor=(0.98, 0.98),
            )

    fig.suptitle(
        f"{window_type} – Cumulative PnL (horizon={horizon})\nFull sample",
        fontsize=14,
    )
    fig.tight_layout(rect=[0, 0, 0.97, 0.93])

    # 保存图片
    filename = f"{window_type}_h{horizon}_full_sample.png"
    save_path = os.path.join(PLOT_RECENT_POST_DIR, filename)
    plt.savefig(save_path, dpi=200, bbox_inches="tight")
    plt.close(fig)
    print(f"Saved {window_type} figure to: {save_path}")


# 实际调用：各画一张 RECENT / POST
plot_window_group(
    index_df=recent_post_index_df,
    results_root=RESULTS_ROOT_RECENT_POST,
    window_type="RECENT",
    horizon=HORIZON_RECENT_POST,
    max_runs_per_model=3,
)

plot_window_group(
    index_df=recent_post_index_df,
    results_root=RESULTS_ROOT_RECENT_POST,
    window_type="POST",
    horizon=HORIZON_RECENT_POST,
    max_runs_per_model=3,
)


[WARN] data_table not found for model_id=plots_recent_post
Saved RECENT figure to: ./results_MultiPeriod_20251202_234234/plots_recent_post/RECENT_h0_full_sample.png
[WARN] data_table not found for model_id=analysis_plots
[WARN] data_table not found for model_id=evaluation
Saved POST figure to: ./results_MultiPeriod_20251202_234234/plots_recent_post/POST_h0_full_sample.png


In [27]:
# %%
# ============================================================
# Block 1: Config & performance metric helper
# ============================================================

# 输出的 CSV 路径
OUTPUT_RECENT_POST_RUNS = os.path.join(
    RESULTS_ROOT_RECENT_POST, "recent_post_backtest_runs.csv"
)
OUTPUT_RECENT_POST_SUMMARY = os.path.join(
    RESULTS_ROOT_RECENT_POST, "recent_post_backtest_summary.csv"
)

import numpy as np
import pandas as pd

def compute_perf_metrics_from_daily(daily_ret: np.ndarray) -> dict:
    """
    daily_ret: 每日策略收益（可以是 PnL 或 return）
    返回年化收益、年化波动、Sharpe。
    这里假设一年 252 个交易日。
    """
    res = {
        "ann_return": np.nan,
        "ann_vol": np.nan,
        "sharpe": np.nan,
    }
    if len(daily_ret) == 0:
        return res

    mean_d = daily_ret.mean()
    std_d = daily_ret.std(ddof=1) if len(daily_ret) > 1 else 0.0
    ann_ret = mean_d * 252
    ann_vol = std_d * np.sqrt(252) if std_d > 0 else np.nan
    sharpe = ann_ret / ann_vol if (ann_vol and ann_vol > 0) else np.nan

    res["ann_return"] = float(ann_ret)
    res["ann_vol"] = float(ann_vol)
    res["sharpe"] = float(sharpe)
    return res


In [28]:
# %%
# ============================================================
# Block 2: Build per-run table for RECENT/POST (full sample)
# ============================================================

def summarize_recent_post_runs(
    index_df: pd.DataFrame,
    results_root: str,
    horizon: int,
) -> pd.DataFrame:
    """
    对 RECENT/POST 这批 model，逐个 run 计算：
      - final_strat_pnl, final_bh_pnl
      - ann_return, ann_vol, sharpe (strategy)
    返回一个 DataFrame。
    """
    rows = []

    for _, row in index_df.iterrows():
        model_id = row["model_id"]
        window_type = row["window_type"]   # RECENT / POST
        base_key = row["base_key"]
        arch_name = row["arch_name"]

        # run_idx: 用最后的 _0/_1/_2 来标记 run
        m = re.search(r"_(\d+)$", model_id)
        run_idx = int(m.group(1)) if m else None

        pred_df = load_prediction_table(model_id, results_root)
        if pred_df is None:
            continue

        ts = compute_timeseries_full_sample(
            pred_df=pred_df,
            horizon=horizon,
            threshold=0.0,
        )
        if ts.empty:
            continue

        strat_daily = ts["strat_daily"].values
        bh_daily = ts["bh_daily"].values

        strat_metrics = compute_perf_metrics_from_daily(strat_daily)
        bh_metrics = compute_perf_metrics_from_daily(bh_daily)

        final_strat_pnl = float(ts["strat_cum"].iloc[-1])
        final_bh_pnl = float(ts["bh_cum"].iloc[-1])

        rows.append(
            {
                "window_type": window_type,     # RECENT / POST
                "arch_name": arch_name,         # TimeMixer / Autoformer / ...
                "base_key": base_key,           # 配置名（不含 _0/_1/_2）
                "model_id": model_id,           # 具体 run
                "run_idx": run_idx,             # 0/1/2
                "n_days": int(len(ts)),
                "final_strat_pnl": final_strat_pnl,
                "final_bh_pnl": final_bh_pnl,
                "strat_ann_return": strat_metrics["ann_return"],
                "strat_ann_vol": strat_metrics["ann_vol"],
                "strat_sharpe": strat_metrics["sharpe"],
                "bh_ann_return": bh_metrics["ann_return"],
                "bh_ann_vol": bh_metrics["ann_vol"],
                "bh_sharpe": bh_metrics["sharpe"],
            }
        )

    df = pd.DataFrame(rows)
    return df


recent_post_runs_df = summarize_recent_post_runs(
    index_df=recent_post_index_df,
    results_root=RESULTS_ROOT_RECENT_POST,
    horizon=HORIZON_RECENT_POST,
)

print("Per-run stats shape:", recent_post_runs_df.shape)
print(recent_post_runs_df.head())

# 保存 per-run 表
recent_post_runs_df.to_csv(OUTPUT_RECENT_POST_RUNS, index=False)
print(f"Saved per-run RECENT/POST stats to: {OUTPUT_RECENT_POST_RUNS}")


[WARN] data_table not found for model_id=evaluation
[WARN] data_table not found for model_id=analysis_plots
[WARN] data_table not found for model_id=plots_recent_post
Per-run stats shape: (84, 14)
  window_type             arch_name  \
0      RECENT  RECENT10_Naive_pl5_N   
1      RECENT  RECENT10_Naive_pl5_N   
2      RECENT  RECENT10_Naive_pl5_N   
3      RECENT               DLinear   
4      RECENT               DLinear   

                                            base_key  \
0  long_term_forecast_RECENT10_Naive_pl5_Naive_cu...   
1  long_term_forecast_RECENT10_Naive_pl5_Naive_cu...   
2  long_term_forecast_RECENT10_Naive_pl5_Naive_cu...   
3  long_term_forecast_RECENT10_DLinear_pl5_DLinea...   
4  long_term_forecast_RECENT10_DLinear_pl5_DLinea...   

                                            model_id  run_idx  n_days  \
0  long_term_forecast_RECENT10_Naive_pl5_Naive_cu...        0     539   
1  long_term_forecast_RECENT10_Naive_pl5_Naive_cu...        1     539   
2  long_term

In [29]:
# %%
# ============================================================
# Block 3: Aggregate per base model (mean across runs)
# ============================================================

if recent_post_runs_df.empty:
    print("No RECENT/POST run stats available.")
else:
    # 按 window_type + arch_name + base_key 聚合（跨 run 求平均）
    grp = (
        recent_post_runs_df
        .groupby(["window_type", "arch_name", "base_key"], as_index=False)
        .agg(
            mean_strat_sharpe=("strat_sharpe", "mean"),
            std_strat_sharpe=("strat_sharpe", "std"),
            mean_final_strat_pnl=("final_strat_pnl", "mean"),
            std_final_strat_pnl=("final_strat_pnl", "std"),
            mean_strat_ann_return=("strat_ann_return", "mean"),
            mean_strat_ann_vol=("strat_ann_vol", "mean"),
        )
        .sort_values(["window_type", "mean_strat_sharpe"], ascending=[True, False])
    )

    print("=== RECENT/POST base-model summary (mean over runs) ===")
    print(grp.head(20).to_string(index=False))

    grp.to_csv(OUTPUT_RECENT_POST_SUMMARY, index=False)
    print(f"Saved base-model summary to: {OUTPUT_RECENT_POST_SUMMARY}")


=== RECENT/POST base-model summary (mean over runs) ===
window_type            arch_name                                                                                                                                                                         base_key  mean_strat_sharpe  std_strat_sharpe  mean_final_strat_pnl  std_final_strat_pnl  mean_strat_ann_return  mean_strat_ann_vol
       POST            TimeMixer                                  long_term_forecast_POST_COVID_TimeMixer_pl5_TimeMixer_custom_ftMS_sl128_ll0_pl5_dm16_nh8_el2_dl1_df32_expand2_dc4_fc1_ebtimeF_dtTrue_POST_COVID           1.039857          0.080706              0.277227             0.021425               0.243419            0.234093
       POST POST_COVID_iTransfor                           long_term_forecast_POST_COVID_iTransformer_pl5_iTransformer_custom_ftMS_sl128_ll48_pl5_dm32_nh4_el2_dl1_df64_expand2_dc4_fc3_ebtimeF_dtTrue_POST_COVID           0.690172          0.762697              0.183885          