# 导入库

In [1]:
# -*- coding: utf-8 -*-
from __future__ import annotations

# ── 标准库 ──────────────────────────────────────────────────────────────────
import os
import time
from pathlib import Path
from collections import defaultdict
from datetime import datetime
import re

# ── 第三方 ──────────────────────────────────────────────────────────────────
import numpy as np
import pandas as pd
import polars as pl

import gc

import torch
import torch.backends.cudnn as cudnn
import lightning as L
import lightning.pytorch as lp
from torch.utils.data import DataLoader

from lightning.pytorch.callbacks import EarlyStopping, ModelCheckpoint, LearningRateMonitor
from lightning.pytorch.loggers import TensorBoardLogger
from lightning.pytorch.callbacks import DeviceStatsMonitor
from pytorch_forecasting import TimeSeriesDataSet, TemporalFusionTransformer, Baseline
from pytorch_forecasting.metrics import MAE, RMSE
from pytorch_forecasting.data.encoders import NaNLabelEncoder
from pytorch_forecasting.data import TorchNormalizer, GroupNormalizer


# 你的工程工具
from pipeline.io import cfg, P, fs, storage_options, ensure_dir_local, ensure_dir_az
from pipeline.stream_input_local import ShardedBatchStream  
from pipeline.wr2 import WR2

# ---- 性能/兼容开关（仅一次）----
os.environ.setdefault("POLARS_MAX_THREADS", str(max(1, os.cpu_count() // 2)))
pl.enable_string_cache()
cudnn.benchmark = True
torch.set_float32_matmul_precision("high")

import time as _t


import warnings
warnings.filterwarnings("ignore")  # avoid printing out absolute paths

def _now() -> str:
    return _t.strftime("%Y-%m-%d %H:%M:%S")
print(f"[{_now()}] imports ok")

[2025-10-21 17:35:55] imports ok


# 定义工具函数

In [2]:
# ───────────────────────────────────────────────────────────────────────────
# 滑动窗划分
def make_sliding_cv_by_days(all_days: np.ndarray, *, n_splits: int, gap_days: int, train_to_val: int):
    all_days = np.asarray(all_days).ravel()
    K, R, G = n_splits, train_to_val, gap_days
    usable = len(all_days) - G
    if usable <= 0 or K <= 0 or R <= 0:
        return []
    V_base, rem = divmod(usable, R + K)
    if V_base <= 0:
        return []
    T = R * V_base
    v_lens = [V_base + 1 if i < rem else V_base for i in range(K)]
    folds, v_lo = [], T + G
    for V_i in v_lens:
        v_hi, tr_hi, tr_lo = v_lo + V_i, v_lo - G, v_lo - G - T
        if tr_lo < 0 or v_hi > len(all_days):
            break
        folds.append((all_days[tr_lo:tr_hi], all_days[v_lo:v_hi]))
        v_lo = v_hi
    return folds


# 初始化参数

In [3]:
start_date, end_date = (1400, 1600)  # 仅用于本次实验

# 读入筛选的所有特征列

path = Path("/mnt/data/js/exp/v1/models/tune/selected_covariant_features.txt")
filted_cols = path.read_text(encoding="utf-8").splitlines()

selected_features = pd.read_csv("/mnt/data/js/exp/v1/models/tune/feature_importance__fixed__fixed__mm_full_train__features__fs__1400-1698__cv2-g7-r4__seed42__top1000__1760906660__range830-1698__range830-1698__cv2-g7-r4__1760912739.csv")

df_cov_cols = selected_features[selected_features['feature'].isin(filted_cols)].copy()

# 我们这里重新归一化一下
df_cov_cols["mean_gain"] = (df_cov_cols['mean_gain'] / df_cov_cols['mean_gain'].sum()).astype(np.float32)
df_e_features = df_cov_cols.reset_index(drop=True)





# ========== 1) 初始化配置 ==========

# 所有列
G_SYM, G_DATE, G_TIME = cfg["keys"]          # e.g. ("symbol_id","date_id","time_id")
TARGET_COL = cfg["target"]                   # e.g. "responder_6"
WEIGHT_COL = cfg["weight"]                   # 允许为 None

TIME_FEATURES = ["time_bucket", "time_pos", "time_sin", "time_cos"]
COV_FEATURES = df_e_features['feature'].iloc[:100].tolist() # 可能含有组内常数列

STATIC_FEATURES = [c for c in COV_FEATURES if c in ["feature_09", "feature_10", "feature_11"]]

CS_RANK_FEATURES = [c for c in COV_FEATURES if c.endswith("__csrank")]

CS_R_Z_FEATURES = [c for c in COV_FEATURES if c.endswith("__cs_z") or c.endswith("__rz")]


do_z_co_cols = [c for c in COV_FEATURES if c not in STATIC_FEATURES + CS_RANK_FEATURES + CS_R_Z_FEATURES]


In [4]:
# 训练 & CV 超参
N_SPLITS     = 1
GAP_DAYS     = 0
TRAIN_TO_VAL = 8
ENC_LEN      = 50
DEC_LEN      = 1
PRED_LEN     = DEC_LEN
BATCH_SIZE   = 512
LR           = 2e-3
HIDDEN       = 64
HEADS        = 2
DROPOUT      = 0.2
MAX_EPOCHS   = 30

# 数据路径
PANEL_DIR_AZ   = P("az", cfg["paths"].get("panel_shards", "panel_shards"))

TFT_LOCAL_ROOT = P("local", "tft"); ensure_dir_local(TFT_LOCAL_ROOT)

LOCAL_CLEAN_DIR = f"{TFT_LOCAL_ROOT}/clean"; ensure_dir_local(LOCAL_CLEAN_DIR)
CKPTS_DIR = Path(TFT_LOCAL_ROOT) / "ckpts"; ensure_dir_local(CKPTS_DIR.as_posix())
LOGS_DIR  = Path(TFT_LOCAL_ROOT) / "logs";  ensure_dir_local(LOGS_DIR.as_posix())


print("[config] ready")

[config] ready


# 导入数据，添加全局time_idx

In [None]:
data_paths = fs.glob("az://jackson/js_exp/exp/v1/panel_shards/*.parquet")
data_paths =[f"az://{p}" for p in data_paths]

lf_data = (
    pl.scan_parquet(data_paths, storage_options=storage_options)
    .select([*cfg['keys'], WEIGHT_COL, TARGET_COL, *TIME_FEATURES, *COV_FEATURES])
    .filter(pl.col(G_DATE).is_between(start_date, end_date, closed="both"))
)
lf_data = lf_data.sort([G_SYM, G_DATE, G_TIME])


# 先归一化time_pos
lf_data = lf_data.with_columns(
    (pl.col("time_pos") / pl.lit(968)).cast(pl.Float32).alias("time_pos")
)

In [None]:
grid_df = (
    lf_data.select([G_DATE, G_TIME]).unique()
        .sort([G_DATE, G_TIME])
        .with_row_index("time_idx")
        .with_columns(pl.col("time_idx").cast(pl.Int64))
        .collect(streaming=True)
)

container_prefix = "az://jackson/js_exp/exp/v1/tft/panel_clean_shards"; ensure_dir_az(container_prefix)
chunk_size = 30
for lo in range(start_date, end_date + 1, chunk_size):
    hi = min(lo + chunk_size - 1, end_date)
    print(f"processing date range: {lo} ~ {hi}")
    
    lf_chunk = lf_data.filter(pl.col(G_DATE).is_between(lo, hi, closed="both"))
    
    lf_grid_chunk = (
        grid_df.lazy().filter(pl.col(G_DATE).is_between(lo, hi, closed="both"))
    )
    
    lf_joined = (
        lf_chunk.join(lf_grid_chunk, on=[G_DATE, G_TIME], how="left").sort([G_SYM, "time_idx"])
    )
    
    out_path = f"{container_prefix}/panel_clean_{lo:04d}_{hi:04d}.parquet"
    print(f"writing to: {out_path}")
    
    lf_joined.sink_parquet(
        out_path,
        storage_options=storage_options,
        compression="zstd",
    )
    
del lf_data, lf_chunk, lf_grid_chunk, lf_joined; gc.collect()
print(f"[{_now()}] all done")

# 重新导入含有time_idx的数据 & CV 划分

In [5]:
container_prefix = "az://jackson/js_exp/exp/v1/tft/panel_clean_shards"
data_paths = fs.glob(f"{container_prefix}/*.parquet")
data_paths = [f"az://{p}" for p in data_paths] 
lf_with_idx = pl.scan_parquet(data_paths, storage_options=storage_options).filter(pl.col(G_DATE).is_between(start_date, end_date, closed="both")).sort([G_SYM, "time_idx"])

In [6]:
# ==========  CV 划分 ==========
all_days = (
    lf_with_idx.select(pl.col(G_DATE)).unique().sort([G_DATE])
    .collect(streaming=True)[G_DATE].to_numpy()
)
folds_by_day = make_sliding_cv_by_days(all_days, n_splits=N_SPLITS, gap_days=GAP_DAYS, train_to_val=TRAIN_TO_VAL)

print(f"[cv] total {len(folds_by_day)} folds")

assert len(folds_by_day) > 0, "no CV folds constructed"
print(folds_by_day)

[cv] total 1 folds
[(array([1400, 1401, 1402, 1403, 1404, 1405, 1406, 1407, 1408, 1409, 1410,
       1411, 1412, 1413, 1414, 1415, 1416, 1417, 1418, 1419, 1420, 1421,
       1422, 1423, 1424, 1425, 1426, 1427, 1428, 1429, 1430, 1431, 1432,
       1433, 1434, 1435, 1436, 1437, 1438, 1439, 1440, 1441, 1442, 1443,
       1444, 1445, 1446, 1447, 1448, 1449, 1450, 1451, 1452, 1453, 1454,
       1455, 1456, 1457, 1458, 1459, 1460, 1461, 1462, 1463, 1464, 1465,
       1466, 1467, 1468, 1469, 1470, 1471, 1472, 1473, 1474, 1475, 1476,
       1477, 1478, 1479, 1480, 1481, 1482, 1483, 1484, 1485, 1486, 1487,
       1488, 1489, 1490, 1491, 1492, 1493, 1494, 1495, 1496, 1497, 1498,
       1499, 1500, 1501, 1502, 1503, 1504, 1505, 1506, 1507, 1508, 1509,
       1510, 1511, 1512, 1513, 1514, 1515, 1516, 1517, 1518, 1519, 1520,
       1521, 1522, 1523, 1524, 1525, 1526, 1527, 1528, 1529, 1530, 1531,
       1532, 1533, 1534, 1535, 1536, 1537, 1538, 1539, 1540, 1541, 1542,
       1543, 1544, 1545, 1546,

# 接下来均已第一折为例

In [7]:
fold_id = 0
# 取第一个 fold 的训练集最后一天，作为本折统计 z-score 的上界
train_lo, train_hi = folds_by_day[fold_id][0][0], folds_by_day[fold_id][0][-1]
val_lo, val_hi = folds_by_day[fold_id][1][0], folds_by_day[fold_id][1][-1]

print(f"统计标准化使用训练集的日期范围 = {train_lo} ~ {train_hi}")
print(f"验证集日期范围 = {val_lo} ~ {val_hi}")

统计标准化使用训练集的日期范围 = 1400 ~ 1575
验证集日期范围 = 1576 ~ 1598


# 数据处理

## 填补因为特征工程产生的缺失列

In [None]:
# 计算训练集的各特征量的中位数
lf_tr = lf_with_idx.filter(pl.col(G_DATE).is_between(train_lo, train_hi, closed="both"))

grp_median = (
    lf_tr
    .group_by([G_SYM, "time_pos"])
    .agg([
        pl.col(col).median().alias(f"{col}_median_st") for col in COV_FEATURES
    ])
    .sort([G_SYM, "time_pos"])
)

glb_median = (
    lf_tr
    .group_by("time_pos")
    .agg([
        pl.col(col).median().alias(f"{col}_median_t") for col in COV_FEATURES
    ])
    .sort("time_pos")
)

# 应用于本折全部数据 trian + val
lf_all = lf_with_idx.join(grp_median, on=[G_SYM, "time_pos"], how="left").join(glb_median, on=["time_pos"], how="left").sort([G_SYM, "time_idx"])

# 逐列用中位数替换缺失值
fill_exprs = []

for col in COV_FEATURES:
    fill_exprs.append(
        pl.coalesce([
            pl.col(col),
            pl.col(f"{col}_median_st"),
            pl.col(f"{col}_median_t")
        ]).alias(col)
    )

lf_all_imputed = lf_all.with_columns(fill_exprs)

# 去掉中位数列
drop_cols = [f"{col}_median_st" for col in COV_FEATURES]
drop_cols += [f"{col}_median_t" for col in COV_FEATURES]
lf_all_imputed = lf_all_imputed.drop(drop_cols)

df_null_lf_all_imputed = lf_all_imputed.select(
    pl.all().null_count()
).collect().to_pandas().T.sort_values(by=0, ascending=False)

assert df_null_lf_all_imputed.iloc[0,0] == 0, "仍有缺失值未填补完成！"

del lf_tr, grp_median, glb_median, lf_all, df_null_lf_all_imputed; gc.collect()

## 静态特征归一化

In [None]:
lf_tr_imputed = lf_all_imputed.filter(pl.col(G_DATE).is_between(train_lo, train_hi, closed="both"))

static_glb_minmax = (
    lf_tr_imputed.select(
        *[pl.col(c).min().alias(f"{c}_min") for c in STATIC_FEATURES],
        *[pl.col(c).max().alias(f"{c}_max") for c in STATIC_FEATURES],
    ).collect().to_dicts()[0]   # ← 用 to_dicts()[0]
)

eps = 1e-8

lf_all = lf_all_imputed.with_columns([
    (
        ((pl.col(c) - pl.lit(static_glb_minmax[f"{c}_min"])) /
         (pl.lit(static_glb_minmax[f"{c}_max"] - static_glb_minmax[f"{c}_min"]) + eps))
        .clip(0.0, 1.0)
    ).cast(pl.Float32).alias(c)   # ← cast/alias 作用于“整个结果”
    for c in STATIC_FEATURES
])

# 不需要 drop（你没创建 *_glb_min/_glb_max 列）
del lf_tr_imputed, static_glb_minmax; gc.collect()


## 标准化

In [None]:
# ========== Z-score ==========
# 计算 训练集 stats

lf_tr = lf_all.filter(pl.col(G_DATE).is_between(train_lo, train_hi, closed="both"))
grp_stats = (
    lf_tr
    .group_by(G_SYM)
    .agg([pl.col(c).mean().alias(f"mu_grp_{c}") for c in do_z_co_cols] +
        [pl.col(c).std().alias(f"std_grp_{c}") for c in do_z_co_cols])
).collect(streaming=True)

glb_stats = (
    lf_tr
    .select([pl.col(c).mean().alias(f"mu_glb_{c}") for c in do_z_co_cols] +
            [pl.col(c).std().alias(f"std_glb_{c}") for c in do_z_co_cols])
).collect(streaming=True)


In [None]:
import math

glb_row = glb_stats.to_dicts()[0]

# 逐日标准化本折所有数据并保存
fold_root = f"az://jackson/js_exp/exp/v1/tft/fold_{fold_id}"; ensure_dir_az(fold_root)
z_prefix = f"{fold_root}/z_shards"; ensure_dir_az(z_prefix)

z_done_cols = [f"z_{c}" for c in (do_z_co_cols)]
min_std = 1e-5
eps = 1e-8

# 预处理全局均值/方差的兜底，避免在行级判断
glb_mu = {c: glb_row[f"mu_glb_{c}"] for c in do_z_co_cols}
glb_std = {}
for c in do_z_co_cols:
    s = glb_row[f"std_glb_{c}"]
    if s is None or s <= 0:
        s = min_std
    glb_std[c] = s

for d in range(train_lo, val_hi + 1):
    lf_day = lf_all.filter(pl.col(G_DATE) == d)

    lf_day_z = lf_day.join(grp_stats.lazy(), on=G_SYM, how="left").sort([G_SYM, "time_idx"])

    # 1) 先把所有“常量”变成具名列，避免 anonymous literal
    const_exprs = []
    for c in do_z_co_cols:
        const_exprs += [
            pl.lit(glb_mu[c]).alias(f"__mu_glb_{c}"),
            pl.lit(glb_std[c] if glb_std[c] > 0 else min_std).alias(f"__std_glb_{c}"),
        ]
    lf_day_z = lf_day_z.with_columns(const_exprs)

    # 2) 计算 z，并用 clip 裁剪
    exprs = []
    for c in do_z_co_cols:
        mu_grp  = pl.col(f"mu_grp_{c}")
        std_grp = pl.col(f"std_grp_{c}")
        mu_use = pl.coalesce([mu_grp, pl.col(f"__mu_glb_{c}")]).cast(pl.Float32)

        std_tmp = pl.when(std_grp.is_null() | (std_grp <= 0))\
                    .then(pl.col(f"__std_glb_{c}"))\
                    .otherwise(std_grp)
        std_use = pl.when(std_tmp < min_std).then(min_std).otherwise(std_tmp).cast(pl.Float32)

        z = ((pl.col(c).cast(pl.Float32) - mu_use) / (std_use + eps)).clip(-3.0, 3.0).alias(f"z_{c}")
        exprs.append(z)

    lf_day_z = lf_day_z.with_columns(exprs)
    
    keep = [
        "time_idx", G_SYM, G_DATE, G_TIME, WEIGHT_COL, TARGET_COL
    ] + TIME_FEATURES + STATIC_FEATURES + CS_RANK_FEATURES + CS_R_Z_FEATURES + z_done_cols
        

    # 去重但保持顺序（避免 DuplicateError）
    keep = list(dict.fromkeys(keep))
    
    lf_out = lf_day_z.select(keep).sort([G_SYM, "time_idx"])

    out_path = f"{z_prefix}/z_{d:04d}.parquet"
    lf_out.collect(streaming=True).write_parquet(
        out_path, storage_options=storage_options, compression="zstd"
    )
    print(f"wrote z-scored data for day {d} to {out_path}")

# 模型训练

In [8]:
lp.seed_everything(42) 

z_done_cols = [f"z_{c}" for c in do_z_co_cols]

# 明确输入的各列
known_varying_categorical_cols = ["time_bucket", "gap_flag"] # gap_flag 需要后续添加

known_varying_reals_cols = ["time_pos", "time_sin", "time_cos"] + STATIC_FEATURES + CS_RANK_FEATURES + CS_R_Z_FEATURES + z_done_cols
unscaler_cols = known_varying_reals_cols

TRAIN_COLS = [G_SYM, G_DATE, G_TIME, "time_idx", WEIGHT_COL, TARGET_COL] + known_varying_categorical_cols + known_varying_reals_cols 

print(f"weight col: {WEIGHT_COL}")
print(f"target cols: {TARGET_COL}")
print(f"encode length: {ENC_LEN}, pred length: {PRED_LEN}")
print(f"input varying_reals_cols: {known_varying_reals_cols}")
print(f"input varying_categorical_cols: {known_varying_categorical_cols}")


# 导入 本折z-score 后的数据进行后续处理
fold_root = f"az://jackson/js_exp/exp/v1/tft/fold_{fold_id}"
z_prefix = f"{fold_root}/z_shards"

data_paths = fs.glob(f"{z_prefix}/*.parquet")
data_paths = [f"az://{p}" for p in data_paths]
lf = pl.scan_parquet(data_paths, storage_options=storage_options).sort([G_SYM, "time_idx"])

# 添加 gap_flag 列
lf = lf.with_columns(
    pl.col("time_idx").diff().over(G_SYM).fill_null(1).alias("time_delta")
).with_columns(
    (pl.col("time_delta") > 1).alias("gap_flag")
).drop("time_delta")

lf = lf.select(TRAIN_COLS)

df = lf.collect(streaming=True).to_pandas()
df[G_SYM] = df[G_SYM].astype(str).astype("category")

# 确保 pandas dtypes
for c in known_varying_categorical_cols:
    df[c] = df[c].astype(str).astype("category")

for c in known_varying_reals_cols:
    df[c] = df[c].astype(np.float32)

    
df[TARGET_COL] = df[TARGET_COL].astype(np.float32)
df[WEIGHT_COL] = df[WEIGHT_COL].astype(np.float32)

df.sort_values(by=[G_SYM, "time_idx"], inplace=True)


train_start_idx = df[df[G_DATE] == train_lo]["time_idx"].min()
train_end_idx   = df[df[G_DATE] == train_hi]["time_idx"].max()
val_start_idx   = df[df[G_DATE] == val_lo]["time_idx"].min()
val_end_idx     = df[df[G_DATE] == val_hi]["time_idx"].max()

df_train = df.loc[df.time_idx.between(train_start_idx, train_end_idx)].copy()

warmup_len = ENC_LEN
df_val_ctx   = df.loc[df.time_idx.between(val_start_idx-warmup_len, val_end_idx)].copy()

df_train["time_idx"] = df_train["time_idx"].astype("int64")
df_val_ctx["time_idx"]   = df_val_ctx["time_idx"].astype("int64")

Seed set to 42


weight col: weight
target cols: responder_6
encode length: 50, pred length: 1
input varying_reals_cols: ['time_pos', 'time_sin', 'time_cos', 'feature_09', 'feature_60__csrank', 'feature_26__csrank', 'feature_60__cs_z', 'feature_59__cs_z', 'z_feature_06', 'z_feature_36', 'z_feature_04', 'z_feature_59__rstd30', 'z_feature_07', 'z_feature_61__ret50', 'z_feature_31__lag4840', 'z_feature_61__lag7744', 'z_feature_61__diff50', 'z_feature_51__ewm10', 'z_feature_61__lag1936', 'z_feature_59__ewm5', 'z_feature_48', 'z_feature_61__lag900', 'z_feature_59', 'z_feature_51__rmean14', 'z_feature_31__diff50', 'z_feature_04__ewm5', 'z_feature_61__lag3872', 'z_feature_04__ewm10', 'z_feature_06__rz30', 'z_feature_60__rstd30', 'z_feature_54__ewm10', 'z_feature_37__rstd30', 'z_feature_22__diff50', 'z_feature_30__lag50', 'z_feature_01__ewm50', 'z_feature_61__lag4840', 'z_feature_61__lag2904', 'z_responder_2_prev_tail_lag10', 'z_feature_26__lag900', 'z_feature_08__ewm50', 'z_feature_20__diff50', 'z_feature_04_

In [9]:
train_ds = TimeSeriesDataSet(
    df_train,
    time_idx="time_idx",
    target=TARGET_COL,
    group_ids=[G_SYM],
    weight=WEIGHT_COL,
    min_encoder_length=ENC_LEN // 2,
    max_encoder_length=ENC_LEN,
    min_prediction_length=PRED_LEN,
    max_prediction_length=PRED_LEN,
    
    static_categoricals=[],
    static_reals=[],
    
    time_varying_unknown_categoricals = [],
    time_varying_unknown_reals = [],


    time_varying_known_categoricals=known_varying_categorical_cols,
    time_varying_known_reals=known_varying_reals_cols,
    
    target_normalizer=TorchNormalizer(method="standard"),
    
    allow_missing_timesteps=True,
    
    add_relative_time_idx=False,
    add_target_scales=False,
    add_encoder_length=False,
    
    scalers={name: None for name in (unscaler_cols)}
)


val_ds = TimeSeriesDataSet.from_dataset(
    train_ds,
    df_val_ctx,
    predict=False,
    stop_randomization=True,
)

train_loader = train_ds.to_dataloader(train=True, batch_size=BATCH_SIZE, num_workers=8)
val_loader   = val_ds.to_dataloader(train=False, batch_size=BATCH_SIZE, num_workers=8)

del df, df_train, df_val_ctx; gc.collect()
print(f"[{_now()}] data loaders ready")

[2025-10-21 17:37:15] data loaders ready


## TFT 模型

## 优化学习率

In [None]:
from lightning.pytorch.tuner import Tuner


# 1) 运行前清理缓存（可选）
torch.cuda.empty_cache()
torch.cuda.reset_peak_memory_stats()

# 2) 给 LR finder 专用的小 batch loader
#    不要动正式的 train_loader/val_loader
BATCH_LR_FIND = max(1, BATCH_SIZE // 4)  # 比正式 batch 小 4 倍
train_loader_lr = train_ds.to_dataloader(train=True,  batch_size=BATCH_LR_FIND, num_workers=4)
val_loader_lr   = val_ds.to_dataloader(  train=False, batch_size=BATCH_LR_FIND*2, num_workers=4)

# 3) 混合精度（优先 bf16，其次 fp16）
trainer_lr = lp.Trainer(
    accelerator="gpu",
    precision=32,
    gradient_clip_val=0.1,
    # 只跑很小的子集，避免显存飙升
    limit_train_batches=50,   # 只取 50 个 batch（可再小）
    limit_val_batches=10,
    enable_progress_bar=False,
)

# 4) 可选：只为 LR finder 构建一个“轻量化”的 TFT（不影响正式模型）
tft_lr = TemporalFusionTransformer.from_dataset(
    train_ds,
    learning_rate=LR,
    hidden_size=min(HIDDEN, 64),             # 降低一点隐藏维度
    attention_head_size=min(HEADS, 2),       # 降低头数
    hidden_continuous_size=min(HIDDEN//2, 32),
    dropout=DROPOUT,
    loss=RMSE(),
    logging_metrics=[],
    optimizer=torch.optim.AdamW,
    optimizer_params={"weight_decay": 1e-4},
    reduce_on_plateau_patience=3,
    # 如果你不让模型看历史 y：
    # use_encoder_target=False,
)

# 5) 运行 LR finder，限制搜索步数
res = Tuner(trainer_lr).lr_find(
    tft_lr,
    train_dataloaders=train_loader_lr,
    val_dataloaders=val_loader_lr,
    min_lr=1e-6,
    max_lr=1e-2,           # 上界别设到 10，TFT 通常 1e-3~1e-2
    num_training=150,      # 最多 150 个 step（再小也可以）
    mode="exponential",
    early_stop_threshold=4.0,
)



In [None]:

print(f"suggested learning rate: {res.suggestion()}")
fig = res.plot(show=True, suggest=True)
fig.show()

## 模型训练

In [None]:
ckpt_dir_fold = Path(CKPTS_DIR) / f"fold_{fold_id}"
ckpt_dir_fold.mkdir(parents=True, exist_ok=True)


early_stop = EarlyStopping(
    monitor="val_loss", mode="min", patience=10, min_delta=1e-5, verbose=False
)
ckpt_cb = ModelCheckpoint(
    monitor="val_loss", mode="min", save_top_k=1,
    dirpath=ckpt_dir_fold.as_posix(),
    filename=f"fold{fold_id}-tft-best-{{epoch:02d}}-{{val_loss:.5f}}",
    save_on_train_epoch_end=False,
)
lr_logger = LearningRateMonitor()
logger = TensorBoardLogger("lighting_logs")

trainer = lp.Trainer(
    fast_dev_run=1,
    max_epochs=MAX_EPOCHS,
    accelerator="gpu",
    devices=1,
    enable_model_summary=True,
    gradient_clip_val=0.1,          # 更稳
    gradient_clip_algorithm="norm",
    limit_train_batches=1.0,
    limit_val_batches=1.0,
    val_check_interval=0.5,         # 或换成 check_val_every_n_epoch=1
    num_sanity_val_steps=0,
    log_every_n_steps=200,
    callbacks=[early_stop, ckpt_cb, lr_logger],   # ← 记得把 lr_logger 放进来
    logger=logger,
    # accumulate_grad_batches=2,     # 显存紧张时再开
)

tft = TemporalFusionTransformer.from_dataset(
    train_ds,
    learning_rate=LR,
    hidden_size=HIDDEN,
    attention_head_size=HEADS,
    dropout=DROPOUT,
    hidden_continuous_size=HIDDEN // 2,
    loss=RMSE(),                            # 点预测 OK，会自动用 weight
    logging_metrics=[],                     # 你自定义的 WR2 可另加
    optimizer=torch.optim.AdamW,
    optimizer_params={"weight_decay": 1e-4},
    reduce_on_plateau_patience=3,
)

trainer.fit(tft, train_dataloaders=train_loader, val_dataloaders=val_loader)


Using 16bit Automatic Mixed Precision (AMP)
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
Running in `fast_dev_run` mode: will run the requested loop using 1 batch(es). Logging and checkpointing is suppressed.


LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

   | Name                               | Type                            | Params | Mode 
------------------------------------------------------------------------------------------------
0  | loss                               | RMSE                            | 0      | train
1  | logging_metrics                    | ModuleList                      | 0      | train
2  | input_embeddings                   | MultiEmbedding                  | 26     | train
3  | prescalers                         | ModuleDict                      | 6.6 K  | train
4  | static_variable_selection          | VariableSelectionNetwork        | 0      | train
5  | encoder_variable_selection         | VariableSelectionNetwork        | 926 K  | train
6  | decoder_variable_selection         | VariableSelectionNetwork        | 926 K  | train
7  | static_context_variable_selection  | GatedResidualNetwork            | 16.8 K | train
8  | static_context_initial_hidden_lstm |

Training: |          | 0/? [00:00<?, ?it/s]

RuntimeError: value cannot be converted to type at::Half without overflow