In [2]:
from __future__ import annotations

# ── 标准库（stdlib） ─────────────────────────────────────────────────────────────
import os

import time
import random
from pathlib import Path
from collections import deque

# ── 第三方（third-party） ───────────────────────────────────────────────────────
import numpy as np
import pandas as pd
import polars as pl
import pyarrow as pa
import pyarrow.dataset as ds
import torch
import lightning as L
from torch.utils.data import DataLoader, IterableDataset, get_worker_info

from lightning.pytorch.callbacks import (
    EarlyStopping,
    ModelCheckpoint,
    LearningRateMonitor,
)
from lightning.pytorch.loggers import TensorBoardLogger

from pytorch_forecasting import (
    TimeSeriesDataSet,
    TemporalFusionTransformer,
)
from pytorch_forecasting.metrics import MAE, RMSE
from pytorch_forecasting.data.encoders import (
    NaNLabelEncoder,
)

os.environ.setdefault("POLARS_MAX_THREADS", str(max(1, os.cpu_count() // 2)))
pl.enable_string_cache()  # 字符串编码缓存，提速 join/分类

# 你的工程工具
from pipeline.io import cfg, P, fs, storage_options, ensure_dir_local, ensure_dir_az

# 全局一次性设置（不要放到训练/epoch/折叠循环里）
import torch.backends.cudnn as cudnn 
cudnn.benchmark = True
torch.set_float32_matmul_precision("high") 


def _now(): return time.strftime("%Y-%m-%d %H:%M:%S")
print(f"[{_now()}] imports ok")


[2025-10-07 13:37:57] imports ok


In [None]:
# 导入文件 /mnt/data/js/exp/v1/reports/fi/features__fs__1600-1690__cv2-g2-r4__seed42__top1000__1758551023.txt
input_file = "/mnt/data/js/exp/v1/reports/fi/features__fs__1600-1690__cv2-g2-r4__seed42__top1000__1758551023.txt"

with open(input_file, 'r') as file:
    data = file.read()
lines = data.split('\n')
features = [line.split()[0] for line in lines if line.strip() and not line.startswith('#')]

print(f"[{_now()}] loaded {len(features)} features from {input_file}")  

topK_features = features[:50]
print(f"[{_now()}] top 10 features: {topK_features}") 


# 严格匹配基础特征：feature_后跟两位数字、无后缀
base_features = [f for f in features if re.fullmatch(r"feature_\d{2}", f)]
# 响应类历史特征：以 responder_ 开头
resp_his_feats = [f for f in features if f.startswith("responder_")]
# 其他派生的 feature_*（有后缀），但排除基础特征
feat_his_feats = [f for f in features if f.startswith("feature_") and f not in base_features]

# 只保留 topK 特征
base_features = [f for f in base_features if f in topK_features]
resp_his_feats = [f for f in resp_his_feats if f in topK_features]
feat_his_feats = [f for f in feat_his_feats if f in topK_features]


In [3]:

# ---- 关键配置（按需改）----
target_col = cfg["target"]                 # e.g. "responder_6"
g_sym, g_date, g_time = cfg["keys"]        # e.g. ("symbol_id","date_id","time_id")
weight_col = cfg["weight"]

time_features = ["time_pos", "time_sin", "time_cos", "time_bucket"]

feature_cols = list(["feature_36"])
need_cols = list(dict.fromkeys([g_sym, g_date, g_time, weight_col, target_col] + time_features + feature_cols))

# CV & 训练
N_SPLITS   = 2
GAP_DAYS   = 7
TRAIN_TO_VAL = 4               # 训练:验证 = 4:1
ENC_LEN    = 10
PRED_LEN   = 1
BATCH_SIZE = 256
LR = 1e-3
HIDDEN     = 64
HEADS      = 2
DROPOUT    = 0.1
MAX_EPOCHS_PER_SHARD = 1
CHUNK_DAYS = 20               # 训练分片：每片多少天

print("config ready")

config ready


In [4]:
panel_dir = P("az", cfg["paths"].get("panel_shards", "panel_shards"))
data_path  = fs.glob(f"{panel_dir}/*.parquet")
az_path = [f"az://{p}" for p in data_path]
lf_data = pl.scan_parquet(az_path, storage_options=storage_options)

In [8]:
lf_data.limit(5).collect()  # 测试连通性

symbol_id,date_id,time_id,time_bucket,responder_6,weight,feature_00,feature_01,feature_02,feature_03,feature_04,feature_05,feature_06,feature_07,feature_08,feature_09,feature_10,feature_11,feature_12,feature_13,feature_14,feature_15,feature_16,feature_17,feature_18,feature_19,feature_20,feature_21,feature_22,feature_23,feature_24,feature_25,feature_26,feature_27,feature_28,feature_29,feature_30,…,feature_66__rz3,feature_67__rmean3,feature_67__rstd3,feature_67__rz3,feature_68__rmean3,feature_68__rstd3,feature_68__rz3,feature_69__rmean3,feature_69__rstd3,feature_69__rz3,feature_70__rmean3,feature_70__rstd3,feature_70__rz3,feature_71__rmean3,feature_71__rstd3,feature_71__rz3,feature_72__rmean3,feature_72__rstd3,feature_72__rz3,feature_73__rmean3,feature_73__rstd3,feature_73__rz3,feature_74__rmean3,feature_74__rstd3,feature_74__rz3,feature_75__rmean3,feature_75__rstd3,feature_75__rz3,feature_76__rmean3,feature_76__rstd3,feature_76__rz3,feature_77__rmean3,feature_77__rstd3,feature_77__rz3,feature_78__rmean3,feature_78__rstd3,feature_78__rz3
i32,i32,i32,u8,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,…,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32
0,1605,0,0,-1.338004,3.771476,3.249785,1.644559,2.951776,3.072733,3.042695,-1.182882,-0.097529,-0.417959,0.358264,11.0,7.0,76.0,-0.815249,1.583994,-0.189974,-0.677662,-0.494887,-0.662308,-1.601268,-1.917711,0.896151,-0.038347,1.120629,0.518504,1.476256,0.785076,1.017946,0.941139,0.361583,-0.63828,-0.815442,…,-1.15466,-0.071019,0.005233,-1.098445,-0.207571,0.093457,0.952124,-0.091556,0.026293,0.366903,-0.300884,0.050481,-0.177875,-0.176997,0.187346,1.119846,-0.173731,0.040365,1.040592,-0.256137,0.034156,1.097758,-0.320625,0.115785,-0.962262,-0.107168,0.113762,-1.130666,-0.208219,0.124273,-0.990429,-0.189998,0.01129,-0.528344,-0.234332,0.015547,-0.780572
1,1605,0,0,-1.572615,3.858156,2.95633,1.300913,2.780389,3.074109,3.066704,-0.744921,-0.076711,-0.488637,0.391012,11.0,7.0,76.0,-1.138502,1.157193,-0.209437,-0.806402,-0.339077,-0.592733,-1.848096,-1.581626,0.612074,0.191684,0.837864,0.397742,2.657752,1.168224,-1.7849,-0.085103,0.907257,-0.541719,-0.956384,…,0.731997,-0.083682,0.021103,-1.000173,-0.151283,0.119696,-0.735771,0.096701,0.054554,-0.734044,-0.371035,0.101811,0.345227,-0.079026,0.216957,1.114831,-0.168288,0.036253,1.041875,0.433767,0.088799,-0.848983,0.302826,0.125203,-0.598476,0.401475,0.168661,0.730511,0.391299,0.083686,-0.803683,0.450136,0.019665,-0.626704,0.366767,0.044257,1.154521
2,1605,0,0,0.085172,4.029465,3.216702,1.9241,2.972498,3.181072,2.676115,-0.88276,-0.098824,-0.782345,0.471035,81.0,2.0,59.0,-1.232034,1.410958,-0.425981,-1.075019,-0.15005,-0.752805,-1.431982,-1.390779,-0.092707,-0.093736,0.566463,0.213935,1.357846,-0.06643,0.772687,0.473541,-0.005689,-0.39475,-0.57903,…,1.091695,-0.40541,0.220284,-1.140669,-0.129095,0.120208,-0.018042,-0.344662,0.088683,0.421717,-0.470735,0.082466,0.22053,0.158388,0.230158,1.153774,-0.106315,0.084159,1.153407,-0.363056,0.066604,0.935868,-0.409656,0.047955,0.591057,0.094667,0.340072,0.623284,0.144877,0.459288,0.603794,-0.157004,0.090291,0.747794,-0.203354,0.072811,0.847102
3,1605,0,0,-2.768082,2.377184,2.948765,1.438616,2.949908,3.298159,3.010296,-1.109763,-0.115104,-0.513872,0.451119,4.0,3.0,11.0,-1.054552,1.535755,-0.179677,-0.487148,-0.487206,-0.704949,-1.559941,-1.154272,-0.168588,-0.023594,0.106641,-0.863546,2.006068,0.233557,-0.011306,0.072626,-0.502054,-0.889185,-0.516596,…,-1.072401,0.106596,0.03842,0.653794,0.19051,0.344181,0.238918,-0.123472,0.065196,1.008725,-0.001479,0.015971,1.107286,0.20157,0.144305,0.988813,-0.054001,0.050241,1.15088,0.72496,0.18746,0.809649,0.697034,0.089986,-0.271157,2.011532,1.252324,0.60383,2.291722,1.300281,0.435139,0.967457,0.231634,1.126723,0.808233,0.150006,0.518416
4,1605,0,0,2.090984,2.563155,3.026436,1.184749,3.041878,3.604426,2.657762,-0.970559,-0.139817,-0.532435,0.612716,15.0,1.0,9.0,-0.667211,3.580876,0.087301,-0.706757,-0.242437,-0.567562,-1.178818,-1.936775,-0.751342,0.169354,-0.066423,-0.4728,0.777907,-0.851054,-1.80886,-1.056851,-0.143766,-0.846094,-0.693096,…,-1.058791,1.214794,0.327573,0.919609,1.683343,0.643852,0.594364,1.945899,0.376164,0.128812,1.440363,0.586327,0.149032,1.035848,0.712859,1.13505,2.101918,0.296111,-0.786615,-0.504521,0.068624,-0.771453,-0.372165,0.028806,1.084732,0.246521,1.099255,1.127171,0.377151,1.200843,1.124583,-0.104984,0.145686,1.136897,0.08527,0.210976,1.154264


In [13]:
lf_data.filter(pl.col(g_date).is_between(1610, 1630)).group_by(g_sym).agg(
    pl.col(target_col).min().alias("min"),
    pl.col(target_col).max().alias("max"),
    pl.col(target_col).mean().alias("mean"),
    pl.col(target_col).std().alias("std"),
    pl.col(target_col).count().alias("count")
).sort(["mean", "std", "count"]).collect()

symbol_id,min,max,mean,std,count
i32,f32,f32,f32,f32,u32
30,-4.125087,5.0,-0.081503,0.659414,20328
25,-5.0,4.943763,-0.0727,0.808597,20328
27,-5.0,5.0,-0.064174,0.682261,19360
12,-5.0,5.0,-0.048375,0.80594,20328
2,-5.0,5.0,-0.048374,0.771517,20328
…,…,…,…,…,…
26,-5.0,5.0,0.019668,0.786453,20328
4,-5.0,5.0,0.025079,0.952476,18392
17,-5.0,5.0,0.031603,0.657436,20328
23,-5.0,5.0,0.04033,0.877483,20328


In [None]:
lf_data.select([g_sym, target_col])
    .group_by(g_sym)
    .agg([
        pl.col(target_col).mean().alias("mean"),
        pl.col(target_col).std().alias("std"),
        pl.col(target_col).count().alias("count")
    ])
    .collect()

In [None]:
lf_data

In [None]:


# 全局 time_idx 仅一次
lf_grid = (
    lf_data.select([g_date, g_time]).unique()
        .sort([g_date, g_time])
        .with_row_index("time_idx")
        .with_columns(pl.col("time_idx").cast(pl.Int64))
)

grid_path = P("local", "tft/panel/grid_timeidx.parquet"); ensure_dir_local(Path(grid_path).parent.as_posix())

lf_grid.collect(streaming=True).write_parquet(grid_path, compression="zstd")


grid_lazy = pl.scan_parquet(grid_path)

# 接入 time_idx + 只保留所需列（仍是 Lazy）
lf0 = (
    lf_data.join(grid_lazy, on=[g_date, g_time], how="left")
        .select(need_cols + ["time_idx"])
        .sort([g_date, g_time, g_sym])
)

print(f"[{_now()}] lazyframe ready")


# 全量天列表
all_days = (
    lf0.select(pl.col(g_date)).unique().sort(by=g_date)
       .collect(streaming=True).get_column(g_date).to_numpy()
)


# 构造滑动时间窗 CV
def make_sliding_cv_by_days(all_days: np.ndarray, *, n_splits: int, gap_days: int, train_to_val: int):
    all_days = np.asarray(all_days).ravel()
    K, R, G = n_splits, train_to_val, gap_days
    usable = len(all_days) - G
    if usable <= 0 or K <= 0 or R <= 0: return []
    V_base, rem = divmod(usable, R + K)
    if V_base <= 0: return []
    T = R * V_base
    v_lens = [V_base + 1 if i < rem else V_base for i in range(K)]
    folds, v_lo = [], T + G
    for V_i in v_lens:
        v_hi, tr_hi, tr_lo = v_lo + V_i, v_lo - G, v_lo - G - T
        if tr_lo < 0 or v_hi > len(all_days): break
        folds.append((all_days[tr_lo:tr_hi], all_days[v_lo:v_hi]))
        v_lo = v_hi
    return folds

folds_by_day = make_sliding_cv_by_days(all_days, n_splits=N_SPLITS, gap_days=GAP_DAYS, train_to_val=TRAIN_TO_VAL)
assert len(folds_by_day) > 0, "no CV folds constructed"


# 选“标准化统计区间”上界：第1折训练天的最大值（不含验证）
stats_hi = int(folds_by_day[0][0][-1])
print(f"stats_hi (for global z-score) = {stats_hi}; first-fold train days end at this day.")


# ========== 1) 连续特征：一次性处理 inf->null、打缺失标记、组内 ffill、兜底 0 ==========
# 先做 flag（要基于原始缺失），再做填充；合并成两次 with_columns，避免在 for 循环里多次改列
inf2null_exprs = [pl.when(pl.col(c).is_infinite()).then(None).otherwise(pl.col(c)).alias(c)
                for c in feature_cols] # inf -> null 不产生新列
flags_exprs    = [pl.col(c).is_null().cast(pl.Int8).alias(f"{c}__isna")
                for c in feature_cols] # 产生新列
fill_exprs     = [pl.col(c).forward_fill().over(g_sym).fill_null(0.0).alias(c)
                for c in feature_cols] # 填充，覆盖原列

lf_clean = (
    lf0.with_columns(inf2null_exprs)         # inf -> null
    .with_columns(flags_exprs)            # 缺失标记（基于原始缺失）
    .with_columns(fill_exprs)             # 组内 ffill + 兜底 0
)

# 为数据标准化做准备

# ========== 2) 训练区间（<= stats_hi）按组计算 mu/std（一次） ==========
lf_stats_sym = (
    lf_clean.filter(pl.col(g_date) <= stats_hi)
    .group_by(g_sym)
    .agg([pl.col(c).mean().alias(f"mu_{c}") for c in feature_cols] +
        [pl.col(c).std(ddof=0).alias(f"std_{c}") for c in feature_cols])
)

# 训练期全局统计（作为回退）
lf_stats_glb = (
    lf_clean.filter(pl.col(g_date) <= stats_hi)
    .select([pl.col(c).mean().alias(f"mu_{c}_glb") for c in feature_cols] +
            [pl.col(c).std(ddof=0).alias(f"std_{c}_glb") for c in feature_cols])
)

# 3) 把全局统计作为常量列加到每行（cross join 方式）
lf_z = lf_clean.join(lf_stats_glb, how="cross")

# 4) join per-symbol 统计，并对每个特征做回退 & z-score

lf_z = lf_z.join(lf_stats_sym, on=g_sym, how="left")

# 对每个特征做回退 & z-score
eps = 1e-6
z_cols = []
for c in feature_cols:
    mu_c_sym, std_c_sym = f"mu_{c}", f"std_{c}"
    mu_c_glb, std_c_glb = f"mu_{c}_glb", f"std_{c}_glb"
    c_z = f"{c}_z"
    lf_z = lf_z.with_columns(
        pl.when(pl.col(mu_c_sym).is_null()).then(pl.col(mu_c_glb)).otherwise(pl.col(mu_c_sym)).alias(f"{c}_mu_use"),
        pl.when(pl.col(std_c_sym).is_null() | (pl.col(std_c_sym) == 0)).then(pl.col(std_c_glb)).otherwise(pl.col(std_c_sym)).alias(f"{c}_std_use")
    ).with_columns(
        ((pl.col(c) - pl.col(f"{c}_mu_use")) / (pl.col(f"{c}_std_use") + eps)).alias(c_z)
    ).drop([mu_c_glb, std_c_glb, mu_c_sym, std_c_sym, f"{c}_mu_use", f"{c}_std_use"])
    z_cols.append(c_z)
    
    
# 5) 输出列（z_特征 + isna 标记 + 时间/分类/目标/权重）
namark_cols = [f"{c}__isna" for c in feature_cols]
out_cols = [g_sym, g_date, g_time, "time_idx", weight_col, target_col] + time_features + z_cols + namark_cols

lf_out = lf_z.select(out_cols).sort([g_date, g_time, g_sym])


# 
# 关键：不要“逐天 collect”，而是每次收集一批天，然后一次性按 day 分区写入，显著减少 IO 次数


tft_root = P("az", "tft"); ensure_dir_az(tft_root)
clean_dir = f"{tft_root}/clean"; ensure_dir_az(clean_dir)



In [None]:
CHUNK_DAYS = 30
day_list   = list(map(int, all_days))
day_chunks = [day_list[i:i+CHUNK_DAYS] for i in range(0, len(day_list), CHUNK_DAYS)]

for ci, chunk in enumerate(day_chunks, 1):
    df_chunk = lf_out.filter(pl.col(g_date).is_in(chunk)).collect()
    table = df_chunk.to_arrow()

    chunk_dir = f"{clean_dir}/chunk_{chunk[0]:04d}_{chunk[-1]:04d}"  # e.g. chunk_20240101_20240130
    ds.write_dataset(
        data=table,
        base_dir=chunk_dir,
        filesystem=fs,
        format="parquet",
        partitioning=None,                         # ← 不再按天分区
        existing_data_behavior="overwrite_or_ignore",
        basename_template="data-{i}.parquet",      # 少量大文件
        max_rows_per_file=50_000_000,              # 按行数切文件，防止超大
    )
    print(f"[{_now()}] chunk {ci}/{len(day_chunks)} -> days {chunk[0]}..{chunk[-1]} written")

print(fs.ls(clean_dir)[:5])


In [None]:
from collections import defaultdict

# ========= 预分组：按 chunk 归并所有 parquet =========
# 1) 列出 chunk 目录（加 az:// 前缀并排序）
entries = fs.ls(clean_dir)
chunk_dirs = []
for e in entries:
    if isinstance(e, str):
        path = e
    else:
        path = e.get("name") or e.get("path") or e.get("Key") or str(e)
    if path.rstrip("/").split("/")[-1].startswith("chunk_"):
        chunk_dirs.append(path if path.startswith("az://") else f"az://{path}")
chunk_dirs = sorted(chunk_dirs)


# 2) 构建 {chunk_dir: [该目录下所有 parquet 文件]}
chunk2paths = defaultdict(list)
for cdir in chunk_dirs:
    paths = fs.glob(f"{cdir}/*.parquet")
    paths = [p if p.startswith("az://") else f"az://{p}" for p in paths]
    if not paths:
        print(f"[WARN] empty chunk: {cdir}")
    chunk2paths[cdir] = paths


# 3) 展平为 all_paths（模板/验证时 lazy 扫描用）
all_paths = [p for plist in chunk2paths.values() for p in plist]
print(f"[prep] {len(chunk_dirs)} chunks; {len(all_paths)} files total")


unknown_reals = z_cols + namark_cols + time_features  # 为了快速验证，暂用这个
cols = [g_sym, "time_idx", weight_col, target_col, *unknown_reals]

from pipeline.stream_input import ShardedBatchStream

# ========= 训练 =========
best_ckpt_paths, fold_metrics = [], []

tft_root = P("local", "tft"); ensure_dir_local(tft_root)
ckpts_root = Path(tft_root) / "ckpts"; ensure_dir_local(ckpts_root.as_posix())
logs_root  = Path(tft_root) / "logs";  ensure_dir_local(logs_root.as_posix())



for fold_id, (train_days, val_days) in enumerate(folds_by_day, start=1):
    print(f"[fold {fold_id}] train {train_days[0]}..{train_days[-1]} ({len(train_days)} days), "
        f"val {val_days[0]}..{val_days[-1]} ({len(val_days)} days)")

    # ---- Template（用前 N 天构建编码器/缩放器）----
    days_sorted = np.sort(train_days)
    TEMPLATE_DAYS = min(5, len(days_sorted))
    tmpl_days = list(map(int, days_sorted[:TEMPLATE_DAYS]))

    pdf_tmpl = (
        pl.scan_parquet(all_paths, storage_options=storage_options)
        .filter(pl.col(g_date).is_in(tmpl_days))
        .select(cols)
        .collect(streaming=True)
        .to_pandas()
    )
    pdf_tmpl[g_sym] = pdf_tmpl[g_sym].astype("str")
    pdf_tmpl.sort_values([g_sym, "time_idx"], inplace=True)
    print(f"[fold {fold_id}] template days={tmpl_days}, template shape={pdf_tmpl.shape}")

    # ---- 验证集 ----
    pdf_val = (
        pl.scan_parquet(all_paths, storage_options=storage_options)
        .filter(pl.col(g_date).is_in(list(map(int, val_days))))
        .select(cols)
        .collect(streaming=True)
        .to_pandas()
    )
    pdf_val[g_sym] = pdf_val[g_sym].astype("str")
    pdf_val.sort_values([g_sym, "time_idx"], inplace=True)
    print(f"template {pdf_tmpl.shape}, val {pdf_val.shape}")

    # ---- TimeSeries template ----
    identity_scalers = {name: None for name in unknown_reals}  # 连续特征直接透传（你已做 z/标准化）
    template = TimeSeriesDataSet(
        pdf_tmpl,
        time_idx="time_idx",
        target=target_col,
        group_ids=[g_sym],
        weight=weight_col,                        # None 时 TSD 会忽略
        max_encoder_length=ENC_LEN,
        max_prediction_length=PRED_LEN,
        static_categoricals=[g_sym],
        time_varying_unknown_reals=unknown_reals,
        lags=None,
        categorical_encoders={g_sym: NaNLabelEncoder(add_nan=True)},
        add_relative_time_idx=False,
        add_target_scales=False,
        add_encoder_length=False,
        allow_missing_timesteps=True,
        target_normalizer=None,
        scalers=identity_scalers,
    )

    # ---- 验证 Loader ----
    validation = TimeSeriesDataSet.from_dataset(template, data=pdf_val, stop_randomization=True)
    val_loader = validation.to_dataloader(
        train=False,
        batch_size=BATCH_SIZE * 2,
        num_workers=min(8, max(1, os.cpu_count() - 2)),
        pin_memory=True,
        persistent_workers=True,
        prefetch_factor=4,
    )

    # ---- 训练流（按 chunk 一次读取）----
    train_stream = ShardedBatchStream(
        template_tsd=template,
        chunk_dirs=chunk_dirs,
        chunk2paths=chunk2paths,
        g_sym=g_sym,
        batch_size=1024,
        buffer_batches=0,
        seed=42,
        cols=cols,
        print_every_chunks=1,   # 每处理一个 chunk 打印一次
    )
    train_loader = DataLoader(
        train_stream,
        batch_size=None,
        num_workers=min(8, max(1, os.cpu_count() - 2)),
        persistent_workers=True,
        prefetch_factor=2,
        pin_memory=True,
        multiprocessing_context="spawn",   # 保险：远端 FS + 多进程
    )

    # ---- callbacks/logger/trainer ----
    ckpt_dir_fold = Path(ckpts_root) / f"fold_{fold_id}"
    ensure_dir_local(ckpt_dir_fold.as_posix())

    callbacks = [
        EarlyStopping(monitor="val_RMSE", mode="min", patience=3),
        ModelCheckpoint(
            monitor="val_RMSE",
            mode="min",
            save_top_k=1,
            dirpath=ckpt_dir_fold.as_posix(),
            filename=f"fold{fold_id}-tft-best-{{epoch:02d}}-{{val_RMSE:.5f}}",
        ),
        LearningRateMonitor(logging_interval="step"),
    ]
    run_tag = (
        f"seed{cfg['seed']}"
        f"_enc{ENC_LEN}_pred{PRED_LEN}"
        f"_lr{LR}_hid{HIDDEN}_hd{HEADS}_do{DROPOUT}"
        f"_fold{fold_id}"
        f"_{time.strftime('%m%d-%H%M')}"   # 或用 {short_git()}
    )
    logger = TensorBoardLogger(
        save_dir=logs_root, 
        name=f"tft_f{fold_id}", 
        default_hp_metric=False)

    trainer = L.Trainer(
        accelerator="gpu", devices=1, precision="bf16-mixed",
        max_epochs=30,
        num_sanity_val_steps=0,
        gradient_clip_val=0.5,
        log_every_n_steps=50,          # 更频繁
        enable_progress_bar=True,      # 显示进度条
        enable_model_summary=False,
        callbacks=callbacks,
        logger=logger,
        default_root_dir=ckpts_root,
    )

    # ---- 模型并训练 ----
    tft = TemporalFusionTransformer.from_dataset(
        template,
        loss=MAE(),
        logging_metrics=[RMSE()],
        learning_rate=float(cfg.get("tft", {}).get("lr", 1e-3)),
        hidden_size=int(cfg.get("tft", {}).get("hidden_size", 64)),
        attention_head_size=int(cfg.get("tft", {}).get("heads", 2)),
        dropout=float(cfg.get("tft", {}).get("dropout", 0.1)),
        reduce_on_plateau_patience=4,
    )

    trainer.fit(tft, train_dataloaders=train_loader, val_dataloaders=val_loader)

    # ---- 记录结果 ----
    es_cb, ckpt_cb = callbacks[0], callbacks[1]
    print("epoch_end_at   :", trainer.current_epoch)
    print("global_step    :", trainer.global_step)
    print("val_best_score :", float(ckpt_cb.best_model_score))
    print("es_stopped_ep  :", getattr(es_cb, "stopped_epoch", None))
    print("es_wait_count  :", getattr(es_cb, "wait_count", None))

    best_ckpt_paths.append(ckpt_cb.best_model_path)
    fold_metrics.append(float(ckpt_cb.best_model_score))
    cv_rmse = np.mean(fold_metrics)
    print(f"[CV] mean val_RMSE = {cv_rmse:.6f}")

print("[done] best_ckpts =", best_ckpt_paths)
