## 导入库

In [1]:
# -*- coding: utf-8 -*-
from __future__ import annotations

# ── 标准库 ──────────────────────────────────────────────────────────────────
import os
import time
from pathlib import Path
from collections import defaultdict
from datetime import datetime

# ── 第三方 ──────────────────────────────────────────────────────────────────
import numpy as np
import pandas as pd
import polars as pl

import torch
import torch.backends.cudnn as cudnn
import lightning as L
import lightning.pytorch as lp
from torch.utils.data import DataLoader

from lightning.pytorch.callbacks import EarlyStopping, ModelCheckpoint, LearningRateMonitor
from lightning.pytorch.loggers import TensorBoardLogger
from lightning.pytorch.callbacks import DeviceStatsMonitor
from pytorch_forecasting import TimeSeriesDataSet, TemporalFusionTransformer, Baseline
from pytorch_forecasting.metrics import MAE, RMSE
from pytorch_forecasting.data.encoders import NaNLabelEncoder
from pytorch_forecasting.data import TorchNormalizer, GroupNormalizer


# 你的工程工具
from pipeline.io import cfg, P, fs, storage_options, ensure_dir_local
from pipeline.stream_input_local import ShardedBatchStream  # 使用下方给你的新版类
from pipeline.wr2 import WR2

# ---- 性能/兼容开关（仅一次）----
os.environ.setdefault("POLARS_MAX_THREADS", str(max(1, os.cpu_count() // 2)))
pl.enable_string_cache()
cudnn.benchmark = True
torch.set_float32_matmul_precision("high")

import time as _t


import warnings
warnings.filterwarnings("ignore")  # avoid printing out absolute paths

def _now() -> str:
    return _t.strftime("%Y-%m-%d %H:%M:%S")
print(f"[{_now()}] imports ok")

[2025-10-12 08:45:32] imports ok


## 定义工具函数

In [2]:
# ───────────────────────────────────────────────────────────────────────────
# 滑动窗划分
def make_sliding_cv_by_days(all_days: np.ndarray, *, n_splits: int, gap_days: int, train_to_val: int):
    all_days = np.asarray(all_days).ravel()
    K, R, G = n_splits, train_to_val, gap_days
    usable = len(all_days) - G
    if usable <= 0 or K <= 0 or R <= 0:
        return []
    V_base, rem = divmod(usable, R + K)
    if V_base <= 0:
        return []
    T = R * V_base
    v_lens = [V_base + 1 if i < rem else V_base for i in range(K)]
    folds, v_lo = [], T + G
    for V_i in v_lens:
        v_hi, tr_hi, tr_lo = v_lo + V_i, v_lo - G, v_lo - G - T
        if tr_lo < 0 or v_hi > len(all_days):
            break
        folds.append((all_days[tr_lo:tr_hi], all_days[v_lo:v_hi]))
        v_lo = v_hi
    return folds


## 初始化参数

In [3]:


# ========== 1) 统一配置 ==========
G_SYM, G_DATE, G_TIME = cfg["keys"]          # e.g. ("symbol_id","date_id","time_id")
TARGET_COL = cfg["target"]                   # e.g. "responder_6"
WEIGHT_COL = cfg["weight"]                   # 允许为 None

# 已知时间特征（给 encoder+decoder）
TIME_FEATURES = ["time_bucket"]

# 连续特征（示例）
#BASIC_FEATURES = [f"feature_{i:02d}" for i in range(79)]
BASIC_FEATURES = [
    "feature_01",
    "feature_04",
    "feature_05",
    "feature_06",
    "feature_07",
    "feature_08",
    "feature_16",
    "feature_20",
    "feature_22",
    "feature_24",
    "feature_27",
    "feature_36",
    "feature_37",
    "feature_38",
    "feature_58",
    "feature_69",
]
BASIC_RESPONDERS = [f"responder_{i}" for i in range(0, 9)]

RAW_FEATURES = BASIC_FEATURES 

# 训练 & CV 超参
N_SPLITS     = 1
GAP_DAYS     = 0
TRAIN_TO_VAL = 8
ENC_LEN      = 256
DEC_LEN      = 1
PRED_LEN     = DEC_LEN
BATCH_SIZE   = 64
LR           = 1e-2
HIDDEN       = 128
HEADS        = 4
DROPOUT      = 0.1
MAX_EPOCHS   = 30
CHUNK_DAYS   = 30

# 数据路径
PANEL_DIR_AZ   = P("az", cfg["paths"].get("panel_shards", "panel_shards"))
TFT_LOCAL_ROOT = P("local", "tft")

# 目录准备
ensure_dir_local(TFT_LOCAL_ROOT)
LOCAL_CLEAN_DIR = f"{TFT_LOCAL_ROOT}/clean"; ensure_dir_local(LOCAL_CLEAN_DIR)
CKPTS_DIR = Path(TFT_LOCAL_ROOT) / "ckpts"; ensure_dir_local(CKPTS_DIR.as_posix())
LOGS_DIR  = Path(TFT_LOCAL_ROOT) / "logs";  ensure_dir_local(LOGS_DIR.as_posix())

print("[config] ready")


[config] ready


## 数据导入

In [4]:
#data_paths = fs.glob(f"{PANEL_DIR_AZ}/*.parquet")
#data_paths = [p if p.startswith("az://") else f"az://{p}" for p in data_paths]

data_start_date = 1500
data_end_date = 1600


data_paths = fs.glob("az://jackson/js_exp/exp/v1/clean_shards/*.parquet")
data_paths =[f"az://{p}" for p in data_paths]

lf_data = pl.scan_parquet(data_paths, storage_options=storage_options).filter(pl.col(G_DATE).is_between(data_start_date, data_end_date, closed="both"))
lf_data = lf_data.sort([G_SYM, G_DATE, G_TIME])

In [5]:
lf_data.limit(5).collect()

symbol_id,date_id,time_id,feature_00,feature_01,feature_02,feature_03,feature_04,feature_05,feature_06,feature_07,feature_08,feature_09,feature_10,feature_11,feature_12,feature_13,feature_14,feature_15,feature_16,feature_17,feature_18,feature_19,feature_20,feature_21,feature_22,feature_23,feature_24,feature_25,feature_26,feature_27,feature_28,feature_29,feature_30,feature_31,feature_32,feature_33,…,feature_53,feature_54,feature_55,feature_56,feature_57,feature_58,feature_59,feature_60,feature_61,feature_62,feature_63,feature_64,feature_65,feature_66,feature_67,feature_68,feature_69,feature_70,feature_71,feature_72,feature_73,feature_74,feature_75,feature_76,feature_77,feature_78,weight,time_bucket,responder_0,responder_1,responder_2,responder_3,responder_4,responder_5,responder_6,responder_7,responder_8
i32,i32,i32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,…,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,u8,f32,f32,f32,f32,f32,f32,f32,f32,f32
0,1500,0,0.299048,-2.586622,0.30078,0.685828,2.150813,-0.692489,-0.159566,-0.829518,0.505951,11.0,7.0,76.0,-0.973271,2.296871,-0.123932,0.0,-0.526217,0.0,-1.656028,-1.476867,0.454376,-0.049191,0.897194,0.927829,3.011863,0.280865,0.985659,0.529341,0.195879,-0.90816,-0.999317,-0.057682,0.0,0.0,…,0.0,0.449019,0.0,-0.548544,2.261573,0.0,1.746195,0.641324,-0.644595,-0.527839,-0.293633,-0.495875,-1.291651,-1.418662,-0.738298,0.7786,-0.316974,-0.652625,2.652362,0.134199,0.0,0.0,2.72201,2.666984,0.572996,0.603494,4.426604,0,-0.433041,0.073844,-0.385438,-1.170643,0.676503,-0.609011,-1.314976,0.819128,-0.747843
0,1500,1,0.172375,-2.373858,0.088191,0.471747,2.031462,-0.838825,1.069108,-0.089751,0.346822,11.0,7.0,76.0,-0.927401,2.378906,-0.006752,0.0,-0.556284,0.0,-1.945728,-0.880582,0.454376,-0.049191,0.897194,0.927829,3.011863,0.280865,0.985659,0.529341,0.195879,-0.90816,-0.999317,-0.057682,0.0,0.0,…,0.0,0.258817,0.0,-0.856843,1.747195,0.0,1.23963,0.988123,-0.644595,-0.384864,-0.248221,-0.418614,-1.07414,-1.668083,-1.008423,0.954679,-0.163946,-0.903738,2.904016,0.338279,0.0,0.0,2.925982,2.56315,0.473738,0.788878,4.426604,0,0.024755,0.119062,0.01866,-0.633278,0.898226,-0.027628,-0.671698,0.784823,-0.031254
0,1500,2,0.515038,-1.997973,0.454253,0.389625,2.43745,-0.761205,1.138704,0.062938,0.35563,11.0,7.0,76.0,-1.125957,1.938568,0.093153,0.0,-0.635094,0.0,-2.103003,-1.594332,0.454376,-0.049191,0.897194,0.927829,3.011863,0.280865,0.985659,0.529341,0.195879,-0.90816,-0.999317,-0.057682,0.0,0.0,…,0.0,0.011891,0.0,-0.870259,2.39707,0.0,2.746681,1.320846,-0.644595,-0.529806,-0.32931,-0.456595,-1.676041,-1.45232,-0.734143,0.657185,-0.157156,-0.63778,2.749268,0.541907,0.0,0.0,2.762261,2.599435,0.740204,0.86289,4.426604,0,-0.023262,0.108246,-0.37168,-0.296779,1.057237,0.052471,-0.290113,1.000917,0.263851
0,1500,3,0.222942,-2.744012,0.244848,0.465595,2.53054,-1.040504,1.024881,0.152608,0.354988,11.0,7.0,76.0,-0.353538,1.761467,0.11914,0.0,-0.615177,0.0,-1.664223,-1.812094,0.454376,-0.049191,0.897194,0.927829,3.011863,0.280865,0.985659,0.529341,0.195879,-0.90816,-0.999317,-0.057682,0.0,0.0,…,0.0,0.195378,0.0,-1.467419,2.111593,0.0,2.352883,1.082904,-0.644595,-0.398067,-0.315122,-0.32617,-1.410547,-1.23672,-0.702807,0.563949,-0.123951,-0.759042,2.106186,0.470093,0.0,0.0,1.87707,2.036223,0.785085,0.570411,4.426604,0,-0.169049,0.016484,-0.58039,-0.285512,0.818561,0.154935,-0.338634,0.75396,0.476659
0,1500,4,0.037762,-2.040448,0.580622,0.619683,2.395972,-0.553257,1.954541,0.468781,0.328159,11.0,7.0,76.0,-0.835098,0.907326,0.108191,0.0,-0.736728,-0.39812,-2.305459,-0.923875,0.454376,-0.049191,0.897194,0.927829,3.011863,0.280865,0.985659,0.529341,0.195879,-0.90816,-0.999317,-0.057682,0.0,0.0,…,0.0,0.003471,0.0,-0.518384,1.701863,0.0,1.732435,1.062898,-0.644595,-0.330009,-0.361604,-0.541628,-1.684552,-0.913574,-0.921793,0.439233,-0.170603,-0.868025,1.830886,0.413593,0.0,0.0,1.397569,0.994369,0.757005,0.610781,4.426604,0,0.177259,-0.129031,-0.582565,-0.452665,0.911334,0.018704,-0.530085,1.117608,0.402173


## 数据处理

### 添加全局时间序列号

In [6]:
lf_grid = (
    lf_data.select([G_DATE, G_TIME]).unique()
        .sort([G_DATE, G_TIME])
        .with_row_index("time_idx")
        .with_columns(pl.col("time_idx").cast(pl.Int64))
)

NEED_COLS = list(dict.fromkeys([G_SYM, G_DATE, G_TIME, WEIGHT_COL, TARGET_COL] + TIME_FEATURES + RAW_FEATURES + BASIC_RESPONDERS))
lf_with_time_idx = (
    lf_data.join(lf_grid, on=[G_DATE, G_TIME], how="left")
        .select(["time_idx"] + NEED_COLS)
        .sort([G_SYM, "time_idx"])
)

### 简单的特征工程

*完成初步验证后就删掉，用回由LGBM选中的特征

In [7]:
# 添加每天的时间正余弦

T = 968
twopi_over_T = np.float32(2.0*np.pi) / T
ti_f = pl.col(G_TIME).cast(pl.Float32)

# 基表筛列 + 时间特征
lf_with_sincos = (
    lf_with_time_idx
    .with_columns([
        ti_f.alias("time_pos"),
        (ti_f * pl.lit(twopi_over_T, dtype=pl.Float32)).alias("_phase_"),
        ])
    .with_columns([
        pl.col("_phase_").sin().cast(pl.Float32).alias("time_sin"),
        pl.col("_phase_").cos().cast(pl.Float32).alias("time_cos"),
        ])
    .drop(["_phase_"])
).sort([G_SYM, "time_idx"])


In [8]:
lf_with_sincos.limit(5).collect()   

time_idx,symbol_id,date_id,time_id,weight,responder_6,time_bucket,feature_01,feature_04,feature_05,feature_06,feature_07,feature_08,feature_16,feature_20,feature_22,feature_24,feature_27,feature_36,feature_37,feature_38,feature_58,feature_69,responder_0,responder_1,responder_2,responder_3,responder_4,responder_5,responder_7,responder_8,time_pos,time_sin,time_cos
i64,i32,i32,i32,f32,f32,u8,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32
0,0,1500,0,4.426604,-1.314976,0,-2.586622,2.150813,-0.692489,-0.159566,-0.829518,0.505951,-0.526217,0.454376,0.897194,3.011863,0.529341,-2.922452,-0.019737,0.283299,0.0,-0.316974,-0.433041,0.073844,-0.385438,-1.170643,0.676503,-0.609011,0.819128,-0.747843,0.0,0.0,1.0
1,0,1500,1,4.426604,-0.671698,0,-2.373858,2.031462,-0.838825,1.069108,-0.089751,0.346822,-0.556284,0.454376,0.897194,3.011863,0.529341,-2.794829,-0.05791,0.255081,0.0,-0.163946,0.024755,0.119062,0.01866,-0.633278,0.898226,-0.027628,0.784823,-0.031254,1.0,0.006491,0.999979
2,0,1500,2,4.426604,-0.290113,0,-1.997973,2.43745,-0.761205,1.138704,0.062938,0.35563,-0.635094,0.454376,0.897194,3.011863,0.529341,-2.443655,-0.0985,0.163584,0.0,-0.157156,-0.023262,0.108246,-0.37168,-0.296779,1.057237,0.052471,1.000917,0.263851,2.0,0.012981,0.999916
3,0,1500,3,4.426604,-0.338634,0,-2.744012,2.53054,-1.040504,1.024881,0.152608,0.354988,-0.615177,0.454376,0.897194,3.011863,0.529341,-1.752336,-0.134653,0.143469,0.0,-0.123951,-0.169049,0.016484,-0.58039,-0.285512,0.818561,0.154935,0.75396,0.476659,3.0,0.019471,0.99981
4,0,1500,4,4.426604,-0.530085,0,-2.040448,2.395972,-0.553257,1.954541,0.468781,0.328159,-0.736728,0.454376,0.897194,3.011863,0.529341,-2.193815,-0.13513,0.176539,0.0,-0.170603,0.177259,-0.129031,-0.582565,-0.452665,0.911334,0.018704,1.117608,0.402173,4.0,0.025961,0.999663


In [9]:

#先做responders日级摘要

# 日频统计聚合
aggs = []
for r in BASIC_RESPONDERS:
    aggs.extend([
        pl.col(r).min().alias(f"{r}_prevday_min"),
        pl.col(r).max().alias(f"{r}_prevday_max"),
        pl.col(r).mean().alias(f"{r}_prevday_mean"),
        pl.col(r).std(ddof=0).alias(f"{r}_prevday_std"),
        pl.col(r).first().alias(f"{r}_prevday_open"),
        pl.col(r).last().alias(f"{r}_prevday_close"),
        (pl.col(r).last() / (pl.col(r).first() + 1e-8) - 1).alias(f"{r}_prevday_chg"),
    ])

lf_daily = (
    lf_with_sincos
    .group_by([G_SYM, G_DATE])
    .agg(aggs)
    .sort([G_SYM, G_DATE])
)
lf_daily.limit(5).collect()


symbol_id,date_id,responder_0_prevday_min,responder_0_prevday_max,responder_0_prevday_mean,responder_0_prevday_std,responder_0_prevday_open,responder_0_prevday_close,responder_0_prevday_chg,responder_1_prevday_min,responder_1_prevday_max,responder_1_prevday_mean,responder_1_prevday_std,responder_1_prevday_open,responder_1_prevday_close,responder_1_prevday_chg,responder_2_prevday_min,responder_2_prevday_max,responder_2_prevday_mean,responder_2_prevday_std,responder_2_prevday_open,responder_2_prevday_close,responder_2_prevday_chg,responder_3_prevday_min,responder_3_prevday_max,responder_3_prevday_mean,responder_3_prevday_std,responder_3_prevday_open,responder_3_prevday_close,responder_3_prevday_chg,responder_4_prevday_min,responder_4_prevday_max,responder_4_prevday_mean,responder_4_prevday_std,responder_4_prevday_open,responder_4_prevday_close,responder_4_prevday_chg,responder_5_prevday_min,responder_5_prevday_max,responder_5_prevday_mean,responder_5_prevday_std,responder_5_prevday_open,responder_5_prevday_close,responder_5_prevday_chg,responder_6_prevday_min,responder_6_prevday_max,responder_6_prevday_mean,responder_6_prevday_std,responder_6_prevday_open,responder_6_prevday_close,responder_6_prevday_chg,responder_7_prevday_min,responder_7_prevday_max,responder_7_prevday_mean,responder_7_prevday_std,responder_7_prevday_open,responder_7_prevday_close,responder_7_prevday_chg,responder_8_prevday_min,responder_8_prevday_max,responder_8_prevday_mean,responder_8_prevday_std,responder_8_prevday_open,responder_8_prevday_close,responder_8_prevday_chg
i32,i32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32
0,1500,-1.171982,1.638323,0.056894,0.298141,-0.433041,1.203615,-3.779448,-0.401366,1.4132,0.167174,0.279106,0.073844,0.780577,9.570622,-1.830436,1.922101,0.018848,0.360276,-0.385438,0.671207,-2.741417,-1.763377,5.0,0.116752,0.673045,-1.170643,0.040914,-1.03495,-1.092831,3.623858,0.345153,0.899939,0.676503,0.024834,-0.96329,-1.862766,2.397358,0.035964,0.454305,-0.609011,0.096066,-1.157741,-2.117248,5.0,0.115151,0.769077,-1.314976,0.020436,-1.015541,-1.263776,4.049618,0.314416,0.921207,0.819128,0.052956,-0.935351,-2.596081,4.449569,0.055386,0.6885,-0.747843,0.032687,-1.043708
0,1501,-1.882106,1.383512,-0.051139,0.302527,-0.529877,-1.033358,0.950185,-1.210065,0.788575,-0.045482,0.201285,-0.552915,-0.050895,-0.907951,-2.251715,1.340849,-0.016244,0.292461,-0.199285,-0.784148,2.934811,-1.597219,2.102368,-0.019878,0.473129,-0.989756,0.211412,-1.2136,-1.923811,1.10415,-0.125298,0.342903,0.255516,0.145771,-0.429505,-1.724785,1.525229,-0.019144,0.355175,-0.456207,0.181569,-1.397996,-2.195617,2.677674,0.001764,0.591589,-0.893377,0.269446,-1.301605,-1.505133,1.422712,-0.091674,0.369853,0.77804,0.122917,-0.842017,-2.382517,2.502072,-0.01253,0.53136,-0.416162,0.445585,-2.070701
0,1502,-2.787402,1.610666,-0.012667,0.292577,-1.050984,-0.226043,-0.784922,-0.887468,1.156111,-0.003481,0.221107,-0.405282,0.019758,-1.04875,-4.331849,3.669092,-0.005124,0.37926,-0.935808,-4.331849,3.628991,-2.506835,5.0,0.187834,0.800285,2.186259,-0.315718,-1.14441,-1.456912,4.194223,0.445855,0.910699,1.067704,-0.148497,-1.139081,-2.921699,5.0,0.060492,0.582812,1.161225,-0.769456,-1.662624,-2.804317,5.0,0.241391,0.954522,2.545646,0.137866,-0.945842,-1.104346,4.581786,0.551449,1.038392,2.04183,0.098275,-0.951869,-5.0,5.0,0.088386,0.825934,2.180924,0.298897,-0.862949
0,1503,-5.0,5.0,0.026813,1.047621,5.0,0.635051,-0.87299,-5.0,1.217449,-0.123636,1.034799,-5.0,0.405424,-1.081085,-5.0,5.0,0.055588,0.897193,-2.914867,0.750888,-1.257606,-2.191018,4.92077,0.085254,0.784194,1.657291,-0.149106,-1.08997,-1.614851,2.996498,0.16453,0.741197,1.52167,-0.059996,-1.039428,-2.210441,3.587258,0.025708,0.549387,-0.966156,0.048235,-1.049925,-2.911035,5.0,0.245749,1.308598,5.0,-0.1034,-1.02068,-1.947205,5.0,0.298861,1.139277,3.707324,-0.018738,-1.005054,-3.26426,5.0,0.225112,1.22695,5.0,-0.237162,-1.047432
0,1504,-2.384349,2.60567,-0.011362,0.279052,-0.018264,2.474159,-136.468826,-1.374218,2.588008,0.142681,0.522564,-0.205888,1.451404,-8.049498,-0.964826,1.132586,-0.000717,0.292545,-0.119845,0.523505,-5.368203,-1.805182,2.241076,-0.044319,0.459643,0.757367,-0.424547,-1.560557,-1.586859,2.381602,-0.13546,0.601128,0.308057,-0.231752,-1.752302,-1.545074,1.907666,-0.005967,0.376329,1.164267,-0.708024,-1.608129,-2.465675,3.208672,-0.044635,0.550831,0.688926,-0.69712,-2.011894,-1.773896,2.619181,-0.123631,0.640969,0.361072,-0.296074,-1.819986,-2.583882,2.728426,-0.011358,0.551581,1.020304,-1.521858,-2.491573


In [10]:
# responders ema
R_EMA_SPANS = [5,]

# 取列名（LazyFrame）
schema_cols = lf_daily.collect_schema().names()  # 或 lf_daily.schema.names()
r_cols = [c for c in schema_cols if c.startswith("responder_")]

In [11]:
r_ewm_exprs = []
for c in r_cols:
    for s in R_EMA_SPANS:
        r_ewm_exprs.append(
            pl.col(c)
            .ewm_mean(span=int(s), adjust=False, ignore_nulls=True, min_periods=1)
            .over(G_SYM)
            .cast(pl.Float32)
            .alias(f"{c}__ewm{s}")
        )

lf_daily = lf_daily.with_columns(r_ewm_exprs)


In [12]:
# 判断日频表的每个组的日期是否连续
lf_daily_with_diff =lf_daily.with_columns(
    [pl.col(G_DATE).diff().over(G_SYM).alias("date_diff"),
    ]
)

# 关键，shift为前一日
shifted_cols = [c for c in lf_daily.collect_schema().names() if c not in (G_SYM, G_DATE, "date_diff")]
lf_prevday_daily = (
    lf_daily_with_diff
    .with_columns([
        pl.when(pl.col("date_diff") <= 5)
        .then(pl.col(c).shift(1).over([G_SYM]))  # ← 按组上一天
        .otherwise(pl.lit(None))                # 超过阈值置空
        .alias(c)                               # 覆盖原列（你当前的语义）
        for c in shifted_cols
    ])
    .drop("date_diff")
)

# 回拼到 tick 级
lf_extended = (
    lf_with_sincos.join(lf_prevday_daily, on=[G_SYM, G_DATE], how="left")
).sort([G_SYM, "time_idx"])
to_drop = [c for c in BASIC_RESPONDERS if c != TARGET_COL]
lf_extended = lf_extended.drop(pl.col(to_drop))


In [13]:
pl.Config.set_tbl_cols(-1)      # 显示所有列
pl.Config.set_tbl_rows(-1)      # 显示所有行
lf_extended.select(pl.all().is_null().sum()).collect()

time_idx,symbol_id,date_id,time_id,weight,responder_6,time_bucket,feature_01,feature_04,feature_05,feature_06,feature_07,feature_08,feature_16,feature_20,feature_22,feature_24,feature_27,feature_36,feature_37,feature_38,feature_58,feature_69,time_pos,time_sin,time_cos,responder_0_prevday_min,responder_0_prevday_max,responder_0_prevday_mean,responder_0_prevday_std,responder_0_prevday_open,responder_0_prevday_close,responder_0_prevday_chg,responder_1_prevday_min,responder_1_prevday_max,responder_1_prevday_mean,responder_1_prevday_std,responder_1_prevday_open,responder_1_prevday_close,responder_1_prevday_chg,responder_2_prevday_min,responder_2_prevday_max,responder_2_prevday_mean,responder_2_prevday_std,responder_2_prevday_open,responder_2_prevday_close,responder_2_prevday_chg,responder_3_prevday_min,responder_3_prevday_max,responder_3_prevday_mean,responder_3_prevday_std,responder_3_prevday_open,responder_3_prevday_close,responder_3_prevday_chg,responder_4_prevday_min,responder_4_prevday_max,responder_4_prevday_mean,responder_4_prevday_std,responder_4_prevday_open,responder_4_prevday_close,responder_4_prevday_chg,responder_5_prevday_min,responder_5_prevday_max,responder_5_prevday_mean,responder_5_prevday_std,responder_5_prevday_open,responder_5_prevday_close,responder_5_prevday_chg,responder_6_prevday_min,responder_6_prevday_max,responder_6_prevday_mean,responder_6_prevday_std,responder_6_prevday_open,responder_6_prevday_close,responder_6_prevday_chg,responder_7_prevday_min,responder_7_prevday_max,responder_7_prevday_mean,responder_7_prevday_std,responder_7_prevday_open,responder_7_prevday_close,responder_7_prevday_chg,responder_8_prevday_min,responder_8_prevday_max,responder_8_prevday_mean,responder_8_prevday_std,responder_8_prevday_open,responder_8_prevday_close,responder_8_prevday_chg,responder_0_prevday_min__ewm5,responder_0_prevday_max__ewm5,responder_0_prevday_mean__ewm5,responder_0_prevday_std__ewm5,responder_0_prevday_open__ewm5,responder_0_prevday_close__ewm5,responder_0_prevday_chg__ewm5,responder_1_prevday_min__ewm5,responder_1_prevday_max__ewm5,responder_1_prevday_mean__ewm5,responder_1_prevday_std__ewm5,responder_1_prevday_open__ewm5,responder_1_prevday_close__ewm5,responder_1_prevday_chg__ewm5,responder_2_prevday_min__ewm5,responder_2_prevday_max__ewm5,responder_2_prevday_mean__ewm5,responder_2_prevday_std__ewm5,responder_2_prevday_open__ewm5,responder_2_prevday_close__ewm5,responder_2_prevday_chg__ewm5,responder_3_prevday_min__ewm5,responder_3_prevday_max__ewm5,responder_3_prevday_mean__ewm5,responder_3_prevday_std__ewm5,responder_3_prevday_open__ewm5,responder_3_prevday_close__ewm5,responder_3_prevday_chg__ewm5,responder_4_prevday_min__ewm5,responder_4_prevday_max__ewm5,responder_4_prevday_mean__ewm5,responder_4_prevday_std__ewm5,responder_4_prevday_open__ewm5,responder_4_prevday_close__ewm5,responder_4_prevday_chg__ewm5,responder_5_prevday_min__ewm5,responder_5_prevday_max__ewm5,responder_5_prevday_mean__ewm5,responder_5_prevday_std__ewm5,responder_5_prevday_open__ewm5,responder_5_prevday_close__ewm5,responder_5_prevday_chg__ewm5,responder_6_prevday_min__ewm5,responder_6_prevday_max__ewm5,responder_6_prevday_mean__ewm5,responder_6_prevday_std__ewm5,responder_6_prevday_open__ewm5,responder_6_prevday_close__ewm5,responder_6_prevday_chg__ewm5,responder_7_prevday_min__ewm5,responder_7_prevday_max__ewm5,responder_7_prevday_mean__ewm5,responder_7_prevday_std__ewm5,responder_7_prevday_open__ewm5,responder_7_prevday_close__ewm5,responder_7_prevday_chg__ewm5,responder_8_prevday_min__ewm5,responder_8_prevday_max__ewm5,responder_8_prevday_mean__ewm5,responder_8_prevday_std__ewm5,responder_8_prevday_open__ewm5,responder_8_prevday_close__ewm5,responder_8_prevday_chg__ewm5
u32,u32,u32,u32,u32,u32,u32,u32,u32,u32,u32,u32,u32,u32,u32,u32,u32,u32,u32,u32,u32,u32,u32,u32,u32,u32,u32,u32,u32,u32,u32,u32,u32,u32,u32,u32,u32,u32,u32,u32,u32,u32,u32,u32,u32,u32,u32,u32,u32,u32,u32,u32,u32,u32,u32,u32,u32,u32,u32,u32,u32,u32,u32,u32,u32,u32,u32,u32,u32,u32,u32,u32,u32,u32,u32,u32,u32,u32,u32,u32,u32,u32,u32,u32,u32,u32,u32,u32,u32,u32,u32,u32,u32,u32,u32,u32,u32,u32,u32,u32,u32,u32,u32,u32,u32,u32,u32,u32,u32,u32,u32,u32,u32,u32,u32,u32,u32,u32,u32,u32,u32,u32,u32,u32,u32,u32,u32,u32,u32,u32,u32,u32,u32,u32,u32,u32,u32,u32,u32,u32,u32,u32,u32,u32,u32,u32,u32,u32,u32,u32,u32,u32
0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,37752,37752,37752,37752,37752,37752,37752,37752,37752,37752,37752,37752,37752,37752,37752,37752,37752,37752,37752,37752,37752,37752,37752,37752,37752,37752,37752,37752,37752,37752,37752,37752,37752,37752,37752,37752,37752,37752,37752,37752,37752,37752,37752,37752,37752,37752,37752,37752,37752,37752,37752,37752,37752,37752,37752,37752,37752,37752,37752,37752,37752,37752,37752,37752,37752,37752,37752,37752,37752,37752,37752,37752,37752,37752,37752,37752,37752,37752,37752,37752,37752,37752,37752,37752,37752,37752,37752,37752,37752,37752,37752,37752,37752,37752,37752,37752,37752,37752,37752,37752,37752,37752,37752,37752,37752,37752,37752,37752,37752,37752,37752,37752,37752,37752,37752,37752,37752,37752,37752,37752,37752,37752,37752,37752,37752,37752


In [14]:
# 添加特征量的ewa

feature_ewm_exprs = []
feature_ewm_cols = []
for c in BASIC_FEATURES:
    for s in R_EMA_SPANS:
        col_name = f"{c}__ewm{s}"
        feature_ewm_cols.append(col_name)
        feature_ewm_exprs.append(
            pl.col(c)
            .ewm_mean(span=int(s), adjust=False, ignore_nulls=True, min_periods=1)
            .over([G_SYM, G_DATE])
            .cast(pl.Float32)
            .alias(col_name)
        )
lf_feature_ewm = lf_with_sincos.with_columns(feature_ewm_exprs)
lf_feature_ewm = lf_feature_ewm.drop(pl.col(to_drop))

In [15]:
lf_combined = (
    lf_extended
    .join(lf_feature_ewm.select([G_SYM, G_DATE, G_TIME] + feature_ewm_cols), on=[G_SYM, G_DATE, G_TIME], how="left")
).sort([G_SYM, "time_idx"])

In [16]:
# 过滤掉第一天的数据
lf_ready = lf_combined.filter(pl.col(G_DATE)>1500) #  去掉全是null 的第一天数据
print(f"[{_now()}] lazyframe ready")

[2025-10-12 08:45:50] lazyframe ready


In [17]:
missing_r_cols = [c for c in lf_ready.collect_schema().names() if c.startswith("responder_")]
fill_exprs = [pl.col(c).fill_null(0).cast(pl.Float32).alias(c) for c in missing_r_cols]
    
lf_ready = lf_ready.with_columns(fill_exprs)


In [18]:
pl.Config.set_tbl_cols(-1)      # 显示所有列
pl.Config.set_tbl_rows(-1)      # 显示所有行
lf_ready.select(pl.all().is_null().sum()).collect() 

time_idx,symbol_id,date_id,time_id,weight,responder_6,time_bucket,feature_01,feature_04,feature_05,feature_06,feature_07,feature_08,feature_16,feature_20,feature_22,feature_24,feature_27,feature_36,feature_37,feature_38,feature_58,feature_69,time_pos,time_sin,time_cos,responder_0_prevday_min,responder_0_prevday_max,responder_0_prevday_mean,responder_0_prevday_std,responder_0_prevday_open,responder_0_prevday_close,responder_0_prevday_chg,responder_1_prevday_min,responder_1_prevday_max,responder_1_prevday_mean,responder_1_prevday_std,responder_1_prevday_open,responder_1_prevday_close,responder_1_prevday_chg,responder_2_prevday_min,responder_2_prevday_max,responder_2_prevday_mean,responder_2_prevday_std,responder_2_prevday_open,responder_2_prevday_close,responder_2_prevday_chg,responder_3_prevday_min,responder_3_prevday_max,responder_3_prevday_mean,responder_3_prevday_std,responder_3_prevday_open,responder_3_prevday_close,responder_3_prevday_chg,responder_4_prevday_min,responder_4_prevday_max,responder_4_prevday_mean,responder_4_prevday_std,responder_4_prevday_open,responder_4_prevday_close,responder_4_prevday_chg,responder_5_prevday_min,responder_5_prevday_max,responder_5_prevday_mean,responder_5_prevday_std,responder_5_prevday_open,responder_5_prevday_close,responder_5_prevday_chg,responder_6_prevday_min,responder_6_prevday_max,responder_6_prevday_mean,responder_6_prevday_std,responder_6_prevday_open,responder_6_prevday_close,responder_6_prevday_chg,responder_7_prevday_min,responder_7_prevday_max,responder_7_prevday_mean,responder_7_prevday_std,responder_7_prevday_open,responder_7_prevday_close,responder_7_prevday_chg,responder_8_prevday_min,responder_8_prevday_max,responder_8_prevday_mean,responder_8_prevday_std,responder_8_prevday_open,responder_8_prevday_close,responder_8_prevday_chg,responder_0_prevday_min__ewm5,responder_0_prevday_max__ewm5,responder_0_prevday_mean__ewm5,responder_0_prevday_std__ewm5,responder_0_prevday_open__ewm5,responder_0_prevday_close__ewm5,responder_0_prevday_chg__ewm5,responder_1_prevday_min__ewm5,responder_1_prevday_max__ewm5,responder_1_prevday_mean__ewm5,responder_1_prevday_std__ewm5,responder_1_prevday_open__ewm5,responder_1_prevday_close__ewm5,responder_1_prevday_chg__ewm5,responder_2_prevday_min__ewm5,responder_2_prevday_max__ewm5,responder_2_prevday_mean__ewm5,responder_2_prevday_std__ewm5,responder_2_prevday_open__ewm5,responder_2_prevday_close__ewm5,responder_2_prevday_chg__ewm5,responder_3_prevday_min__ewm5,responder_3_prevday_max__ewm5,responder_3_prevday_mean__ewm5,responder_3_prevday_std__ewm5,responder_3_prevday_open__ewm5,responder_3_prevday_close__ewm5,responder_3_prevday_chg__ewm5,responder_4_prevday_min__ewm5,responder_4_prevday_max__ewm5,responder_4_prevday_mean__ewm5,responder_4_prevday_std__ewm5,responder_4_prevday_open__ewm5,responder_4_prevday_close__ewm5,responder_4_prevday_chg__ewm5,responder_5_prevday_min__ewm5,responder_5_prevday_max__ewm5,responder_5_prevday_mean__ewm5,responder_5_prevday_std__ewm5,responder_5_prevday_open__ewm5,responder_5_prevday_close__ewm5,responder_5_prevday_chg__ewm5,responder_6_prevday_min__ewm5,responder_6_prevday_max__ewm5,responder_6_prevday_mean__ewm5,responder_6_prevday_std__ewm5,responder_6_prevday_open__ewm5,responder_6_prevday_close__ewm5,responder_6_prevday_chg__ewm5,responder_7_prevday_min__ewm5,responder_7_prevday_max__ewm5,responder_7_prevday_mean__ewm5,responder_7_prevday_std__ewm5,responder_7_prevday_open__ewm5,responder_7_prevday_close__ewm5,responder_7_prevday_chg__ewm5,responder_8_prevday_min__ewm5,responder_8_prevday_max__ewm5,responder_8_prevday_mean__ewm5,responder_8_prevday_std__ewm5,responder_8_prevday_open__ewm5,responder_8_prevday_close__ewm5,responder_8_prevday_chg__ewm5,feature_01__ewm5,feature_04__ewm5,feature_05__ewm5,feature_06__ewm5,feature_07__ewm5,feature_08__ewm5,feature_16__ewm5,feature_20__ewm5,feature_22__ewm5,feature_24__ewm5,feature_27__ewm5,feature_36__ewm5,feature_37__ewm5,feature_38__ewm5,feature_58__ewm5,feature_69__ewm5
u32,u32,u32,u32,u32,u32,u32,u32,u32,u32,u32,u32,u32,u32,u32,u32,u32,u32,u32,u32,u32,u32,u32,u32,u32,u32,u32,u32,u32,u32,u32,u32,u32,u32,u32,u32,u32,u32,u32,u32,u32,u32,u32,u32,u32,u32,u32,u32,u32,u32,u32,u32,u32,u32,u32,u32,u32,u32,u32,u32,u32,u32,u32,u32,u32,u32,u32,u32,u32,u32,u32,u32,u32,u32,u32,u32,u32,u32,u32,u32,u32,u32,u32,u32,u32,u32,u32,u32,u32,u32,u32,u32,u32,u32,u32,u32,u32,u32,u32,u32,u32,u32,u32,u32,u32,u32,u32,u32,u32,u32,u32,u32,u32,u32,u32,u32,u32,u32,u32,u32,u32,u32,u32,u32,u32,u32,u32,u32,u32,u32,u32,u32,u32,u32,u32,u32,u32,u32,u32,u32,u32,u32,u32,u32,u32,u32,u32,u32,u32,u32,u32,u32,u32,u32,u32,u32,u32,u32,u32,u32,u32,u32,u32,u32,u32,u32,u32,u32
0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0


## CV 划分

In [19]:
# ==========  CV 划分 ==========
all_days = (
    lf_ready.select(pl.col(G_DATE)).unique().sort(by=G_DATE)
    .collect(streaming=True).get_column(G_DATE).to_numpy()
)
folds_by_day = make_sliding_cv_by_days(all_days, n_splits=N_SPLITS, gap_days=GAP_DAYS, train_to_val=TRAIN_TO_VAL)
assert len(folds_by_day) > 0, "no CV folds constructed"

## 折内数据标准化处理

In [20]:
lf_ready.collect_schema().names()

['time_idx',
 'symbol_id',
 'date_id',
 'time_id',
 'weight',
 'responder_6',
 'time_bucket',
 'feature_01',
 'feature_04',
 'feature_05',
 'feature_06',
 'feature_07',
 'feature_08',
 'feature_16',
 'feature_20',
 'feature_22',
 'feature_24',
 'feature_27',
 'feature_36',
 'feature_37',
 'feature_38',
 'feature_58',
 'feature_69',
 'time_pos',
 'time_sin',
 'time_cos',
 'responder_0_prevday_min',
 'responder_0_prevday_max',
 'responder_0_prevday_mean',
 'responder_0_prevday_std',
 'responder_0_prevday_open',
 'responder_0_prevday_close',
 'responder_0_prevday_chg',
 'responder_1_prevday_min',
 'responder_1_prevday_max',
 'responder_1_prevday_mean',
 'responder_1_prevday_std',
 'responder_1_prevday_open',
 'responder_1_prevday_close',
 'responder_1_prevday_chg',
 'responder_2_prevday_min',
 'responder_2_prevday_max',
 'responder_2_prevday_mean',
 'responder_2_prevday_std',
 'responder_2_prevday_open',
 'responder_2_prevday_close',
 'responder_2_prevday_chg',
 'responder_3_prevday_min',

In [21]:
# 明确要标准化的列
z_cols = [c for c in lf_ready.collect_schema().names() if c not in (G_SYM, G_DATE, G_TIME, "time_idx", "time_pos", "time_sin", "time_cos", "time_bucket", WEIGHT_COL, TARGET_COL)]


# 取第一个 fold 的训练集最后一天，作为统计 z-score 的上界
stats_hi = int(folds_by_day[0][0][-1])
print(f"标准化使用训练集的上限日期 = {stats_hi}")

# ========== 4) 连续特征清洗 + Z-score ==========
inf2null_exprs  = [pl.when(pl.col(c).is_infinite()).then(None).otherwise(pl.col(c)).alias(c) for c in z_cols]
#isna_flag_exprs = [pl.col(c).is_null().cast(pl.Int8).alias(f"{c}__isna") for c in z_cols]

ffill_exprs = [pl.col(c).forward_fill().over(G_SYM).fill_null(0.0).alias(c) for c in z_cols]
lf_clean = (
    lf_ready.with_columns(inf2null_exprs).with_columns(ffill_exprs) #.with_columns(isna_flag_exprs)
)


# 开始计算 z-score
lf_stats_sym = (
    lf_clean.filter(pl.col(G_DATE) <= stats_hi)
            .group_by(G_SYM)
            .agg([pl.col(c).mean().alias(f"mu_{c}") for c in z_cols] +
                [pl.col(c).std(ddof=0).alias(f"std_{c}") for c in z_cols])
)
lf_stats_glb = (
    lf_clean.filter(pl.col(G_DATE) <= stats_hi)
            .select([pl.col(c).mean().alias(f"mu_{c}_glb") for c in z_cols] +
                    [pl.col(c).std(ddof=0).alias(f"std_{c}_glb") for c in z_cols])
)


标准化使用训练集的上限日期 = 1588


In [22]:
lf_z = lf_clean.join(lf_stats_glb, how="cross").join(lf_stats_sym, on=G_SYM, how="left").sort([G_SYM, "time_idx"])

eps = 1e-6
#Z_COLS, NAMARK_COLS = [], [f"{c}__isna" for c in RAW_FEATURES]

Z_COLS = []
for c in z_cols:
    mu_sym, std_sym = f"mu_{c}", f"std_{c}"
    mu_glb, std_glb = f"mu_{c}_glb", f"std_{c}_glb"
    mu_use, std_use = f"{c}_mu_use", f"{c}_std_use"
    z_name = f"{c}_z"

    lf_z = lf_z.with_columns(
        pl.when(pl.col(mu_sym).is_null()).then(pl.col(mu_glb)).otherwise(pl.col(mu_sym)).alias(mu_use),
        pl.when(pl.col(std_sym).is_null() | (pl.col(std_sym) == 0)).then(pl.col(std_glb)).otherwise(pl.col(std_sym)).alias(std_use),
    ).with_columns(
        ((pl.col(c) - pl.col(mu_use)) / (pl.col(std_use) + eps)).alias(z_name)
    ).drop([mu_glb, std_glb, mu_sym, std_sym, mu_use, std_use])

    Z_COLS.append(z_name)

In [23]:
OUT_COLS = [G_SYM, G_DATE, G_TIME, "time_idx", WEIGHT_COL, TARGET_COL, "time_pos", "time_sin", "time_cos", "time_bucket"] + Z_COLS 
lf_out = lf_z.select(OUT_COLS)

In [24]:
# 保存lf_out到本地
clean_path_local = P("local", "tft/panel/clean.parquet")
Path(clean_path_local).parent.mkdir(parents=True, exist_ok=True)
lf_out.collect(streaming=True).write_parquet(clean_path_local, compression="zstd")
print(f"[{_now()}] cleaned data saved to {clean_path_local}")


[2025-10-12 08:46:39] cleaned data saved to /mnt/data/js/exp/v1/tft/panel/clean.parquet


## 训练模型

In [25]:
# clean 数据导入
clean_path_local = P("local", "tft/panel/clean.parquet")
lf_clean = pl.scan_parquet(clean_path_local)
lf_clean.limit(5).collect()

symbol_id,date_id,time_id,time_idx,weight,responder_6,time_pos,time_sin,time_cos,time_bucket,feature_01_z,feature_04_z,feature_05_z,feature_06_z,feature_07_z,feature_08_z,feature_16_z,feature_20_z,feature_22_z,feature_24_z,feature_27_z,feature_36_z,feature_37_z,feature_38_z,feature_58_z,feature_69_z,responder_0_prevday_min_z,responder_0_prevday_max_z,responder_0_prevday_mean_z,responder_0_prevday_std_z,responder_0_prevday_open_z,responder_0_prevday_close_z,responder_0_prevday_chg_z,responder_1_prevday_min_z,responder_1_prevday_max_z,responder_1_prevday_mean_z,responder_1_prevday_std_z,responder_1_prevday_open_z,responder_1_prevday_close_z,responder_1_prevday_chg_z,responder_2_prevday_min_z,responder_2_prevday_max_z,responder_2_prevday_mean_z,responder_2_prevday_std_z,responder_2_prevday_open_z,responder_2_prevday_close_z,responder_2_prevday_chg_z,responder_3_prevday_min_z,responder_3_prevday_max_z,responder_3_prevday_mean_z,responder_3_prevday_std_z,responder_3_prevday_open_z,responder_3_prevday_close_z,responder_3_prevday_chg_z,responder_4_prevday_min_z,responder_4_prevday_max_z,responder_4_prevday_mean_z,responder_4_prevday_std_z,responder_4_prevday_open_z,responder_4_prevday_close_z,responder_4_prevday_chg_z,responder_5_prevday_min_z,responder_5_prevday_max_z,responder_5_prevday_mean_z,responder_5_prevday_std_z,responder_5_prevday_open_z,responder_5_prevday_close_z,responder_5_prevday_chg_z,responder_6_prevday_min_z,responder_6_prevday_max_z,responder_6_prevday_mean_z,responder_6_prevday_std_z,responder_6_prevday_open_z,responder_6_prevday_close_z,responder_6_prevday_chg_z,responder_7_prevday_min_z,responder_7_prevday_max_z,responder_7_prevday_mean_z,responder_7_prevday_std_z,responder_7_prevday_open_z,responder_7_prevday_close_z,responder_7_prevday_chg_z,responder_8_prevday_min_z,responder_8_prevday_max_z,responder_8_prevday_mean_z,responder_8_prevday_std_z,responder_8_prevday_open_z,responder_8_prevday_close_z,responder_8_prevday_chg_z,responder_0_prevday_min__ewm5_z,responder_0_prevday_max__ewm5_z,responder_0_prevday_mean__ewm5_z,responder_0_prevday_std__ewm5_z,responder_0_prevday_open__ewm5_z,responder_0_prevday_close__ewm5_z,responder_0_prevday_chg__ewm5_z,responder_1_prevday_min__ewm5_z,responder_1_prevday_max__ewm5_z,responder_1_prevday_mean__ewm5_z,responder_1_prevday_std__ewm5_z,responder_1_prevday_open__ewm5_z,responder_1_prevday_close__ewm5_z,responder_1_prevday_chg__ewm5_z,responder_2_prevday_min__ewm5_z,responder_2_prevday_max__ewm5_z,responder_2_prevday_mean__ewm5_z,responder_2_prevday_std__ewm5_z,responder_2_prevday_open__ewm5_z,responder_2_prevday_close__ewm5_z,responder_2_prevday_chg__ewm5_z,responder_3_prevday_min__ewm5_z,responder_3_prevday_max__ewm5_z,responder_3_prevday_mean__ewm5_z,responder_3_prevday_std__ewm5_z,responder_3_prevday_open__ewm5_z,responder_3_prevday_close__ewm5_z,responder_3_prevday_chg__ewm5_z,responder_4_prevday_min__ewm5_z,responder_4_prevday_max__ewm5_z,responder_4_prevday_mean__ewm5_z,responder_4_prevday_std__ewm5_z,responder_4_prevday_open__ewm5_z,responder_4_prevday_close__ewm5_z,responder_4_prevday_chg__ewm5_z,responder_5_prevday_min__ewm5_z,responder_5_prevday_max__ewm5_z,responder_5_prevday_mean__ewm5_z,responder_5_prevday_std__ewm5_z,responder_5_prevday_open__ewm5_z,responder_5_prevday_close__ewm5_z,responder_5_prevday_chg__ewm5_z,responder_6_prevday_min__ewm5_z,responder_6_prevday_max__ewm5_z,responder_6_prevday_mean__ewm5_z,responder_6_prevday_std__ewm5_z,responder_6_prevday_open__ewm5_z,responder_6_prevday_close__ewm5_z,responder_6_prevday_chg__ewm5_z,responder_7_prevday_min__ewm5_z,responder_7_prevday_max__ewm5_z,responder_7_prevday_mean__ewm5_z,responder_7_prevday_std__ewm5_z,responder_7_prevday_open__ewm5_z,responder_7_prevday_close__ewm5_z,responder_7_prevday_chg__ewm5_z,responder_8_prevday_min__ewm5_z,responder_8_prevday_max__ewm5_z,responder_8_prevday_mean__ewm5_z,responder_8_prevday_std__ewm5_z,responder_8_prevday_open__ewm5_z,responder_8_prevday_close__ewm5_z,responder_8_prevday_chg__ewm5_z,feature_01__ewm5_z,feature_04__ewm5_z,feature_05__ewm5_z,feature_06__ewm5_z,feature_07__ewm5_z,feature_08__ewm5_z,feature_16__ewm5_z,feature_20__ewm5_z,feature_22__ewm5_z,feature_24__ewm5_z,feature_27__ewm5_z,feature_36__ewm5_z,feature_37__ewm5_z,feature_38__ewm5_z,feature_58__ewm5_z,feature_69__ewm5_z
i32,i32,i32,i64,f32,f32,f32,f32,f32,u8,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32
0,1501,0,968,3.968951,-0.893377,0.0,0.0,1.0,0,0.36037,2.922099,-0.277876,0.470961,-0.114962,0.176417,0.731475,1.084337,-1.514505,0.687002,-0.512623,2.951518,0.243665,0.473736,-0.20947,0.834224,0.677889,-0.232953,1.545844,-0.400375,-0.097581,1.078562,-0.027449,0.988952,0.101866,1.33037,-0.239898,0.578574,1.035723,1.038067,0.121361,-0.020703,1.268103,-0.015387,-0.27449,0.443679,-0.032835,0.081685,2.162956,0.1744,1.311636,-0.793525,-0.053723,0.116561,0.430962,1.398267,0.531677,1.566557,0.573476,0.005053,-0.223328,-0.041665,0.217608,0.088718,0.785765,-0.643146,0.09318,-0.022753,0.306412,1.706059,0.678425,0.74031,-0.865631,0.241867,-0.162584,0.454686,1.425066,1.044657,1.13246,0.498803,0.421872,0.108536,0.338878,1.015033,0.290111,0.447772,-0.686286,0.260635,-0.094769,1.316718,-0.462847,3.548594,-0.686931,-0.189506,2.466642,-0.049417,2.098807,0.266972,3.230078,-0.469843,1.069641,2.191558,2.277283,0.244191,0.018034,2.799317,-0.005885,-0.543912,1.109625,-0.094759,0.187813,3.20105,0.385604,2.143462,-1.802714,-0.10693,0.257528,0.970331,2.467541,1.174929,2.814839,1.27707,0.010928,-0.492258,-0.12023,0.386016,0.198266,1.38043,-1.389658,0.189961,-0.050142,0.603808,2.986171,1.487867,1.274851,-2.146081,0.444071,-0.337496,0.926739,2.61325,2.263304,2.073606,1.104406,0.818616,0.248765,0.666315,1.849567,0.658265,0.818544,-1.665889,0.50366,-0.176947,0.3669,3.214045,-0.285213,0.736654,-0.125655,0.179569,0.802153,1.0845,-1.515161,0.687269,-0.512673,3.347593,0.248371,0.482636,-0.215882,0.857978
0,1501,1,969,3.968951,-0.751932,1.0,0.006491,0.999979,0,0.531538,2.926264,-0.159044,2.084368,0.420909,0.226627,1.085696,1.05904,-1.367033,0.734336,-0.492465,3.116041,0.224212,0.337318,-0.20947,1.196961,0.677889,-0.232953,1.545844,-0.400375,-0.097581,1.078562,-0.027449,0.988952,0.101866,1.33037,-0.239898,0.578574,1.035723,1.038067,0.121361,-0.020703,1.268103,-0.015387,-0.27449,0.443679,-0.032835,0.081685,2.162956,0.1744,1.311636,-0.793525,-0.053723,0.116561,0.430962,1.398267,0.531677,1.566557,0.573476,0.005053,-0.223328,-0.041665,0.217608,0.088718,0.785765,-0.643146,0.09318,-0.022753,0.306412,1.706059,0.678425,0.74031,-0.865631,0.241867,-0.162584,0.454686,1.425066,1.044657,1.13246,0.498803,0.421872,0.108536,0.338878,1.015033,0.290111,0.447772,-0.686286,0.260635,-0.094769,1.316718,-0.462847,3.548594,-0.686931,-0.189506,2.466642,-0.049417,2.098807,0.266972,3.230078,-0.469843,1.069641,2.191558,2.277283,0.244191,0.018034,2.799317,-0.005885,-0.543912,1.109625,-0.094759,0.187813,3.20105,0.385604,2.143462,-1.802714,-0.10693,0.257528,0.970331,2.467541,1.174929,2.814839,1.27707,0.010928,-0.492258,-0.12023,0.386016,0.198266,1.38043,-1.389658,0.189961,-0.050142,0.603808,2.986171,1.487867,1.274851,-2.146081,0.444071,-0.337496,0.926739,2.61325,2.263304,2.073606,1.104406,0.818616,0.248765,0.666315,1.849567,0.658265,0.818544,-1.665889,0.50366,-0.176947,0.425458,3.215574,-0.244473,1.578465,0.070105,0.196616,0.931865,1.076067,-1.46598,0.703053,-0.505952,3.409801,0.241762,0.436309,-0.215882,0.982013
0,1501,2,970,3.968951,-0.793123,2.0,0.012981,0.999916,0,0.639744,2.516991,-0.147293,1.764247,0.475685,0.286951,-0.075824,1.048066,-1.304686,0.755189,-0.483549,3.356715,0.242508,0.392942,-0.20947,1.229797,0.677889,-0.232953,1.545844,-0.400375,-0.097581,1.078562,-0.027449,0.988952,0.101866,1.33037,-0.239898,0.578574,1.035723,1.038067,0.121361,-0.020703,1.268103,-0.015387,-0.27449,0.443679,-0.032835,0.081685,2.162956,0.1744,1.311636,-0.793525,-0.053723,0.116561,0.430962,1.398267,0.531677,1.566557,0.573476,0.005053,-0.223328,-0.041665,0.217608,0.088718,0.785765,-0.643146,0.09318,-0.022753,0.306412,1.706059,0.678425,0.74031,-0.865631,0.241867,-0.162584,0.454686,1.425066,1.044657,1.13246,0.498803,0.421872,0.108536,0.338878,1.015033,0.290111,0.447772,-0.686286,0.260635,-0.094769,1.316718,-0.462847,3.548594,-0.686931,-0.189506,2.466642,-0.049417,2.098807,0.266972,3.230078,-0.469843,1.069641,2.191558,2.277283,0.244191,0.018034,2.799317,-0.005885,-0.543912,1.109625,-0.094759,0.187813,3.20105,0.385604,2.143462,-1.802714,-0.10693,0.257528,0.970331,2.467541,1.174929,2.814839,1.27707,0.010928,-0.492258,-0.12023,0.386016,0.198266,1.38043,-1.389658,0.189961,-0.050142,0.603808,2.986171,1.487867,1.274851,-2.146081,0.444071,-0.337496,0.926739,2.61325,2.263304,2.073606,1.104406,0.818616,0.248765,0.666315,1.849567,0.658265,0.818544,-1.665889,0.50366,-0.176947,0.501514,3.066347,-0.213284,1.972646,0.220622,0.228462,0.593003,1.066786,-1.4124,0.72053,-0.498499,3.542273,0.243572,0.424315,-0.215882,1.075931
0,1501,3,971,3.968951,-0.293851,3.0,0.019471,0.99981,0,0.71152,2.691642,-0.065101,1.261698,0.540966,0.262779,0.439099,1.039667,-1.256896,0.771072,-0.476793,3.127824,0.200186,0.519877,-0.20947,1.62082,0.677889,-0.232953,1.545844,-0.400375,-0.097581,1.078562,-0.027449,0.988952,0.101866,1.33037,-0.239898,0.578574,1.035723,1.038067,0.121361,-0.020703,1.268103,-0.015387,-0.27449,0.443679,-0.032835,0.081685,2.162956,0.1744,1.311636,-0.793525,-0.053723,0.116561,0.430962,1.398267,0.531677,1.566557,0.573476,0.005053,-0.223328,-0.041665,0.217608,0.088718,0.785765,-0.643146,0.09318,-0.022753,0.306412,1.706059,0.678425,0.74031,-0.865631,0.241867,-0.162584,0.454686,1.425066,1.044657,1.13246,0.498803,0.421872,0.108536,0.338878,1.015033,0.290111,0.447772,-0.686286,0.260635,-0.094769,1.316718,-0.462847,3.548594,-0.686931,-0.189506,2.466642,-0.049417,2.098807,0.266972,3.230078,-0.469843,1.069641,2.191558,2.277283,0.244191,0.018034,2.799317,-0.005885,-0.543912,1.109625,-0.094759,0.187813,3.20105,0.385604,2.143462,-1.802714,-0.10693,0.257528,0.970331,2.467541,1.174929,2.814839,1.27707,0.010928,-0.492258,-0.12023,0.386016,0.198266,1.38043,-1.389658,0.189961,-0.050142,0.603808,2.986171,1.487867,1.274851,-2.146081,0.444071,-0.337496,0.926739,2.61325,2.263304,2.073606,1.104406,0.818616,0.248765,0.666315,1.849567,0.658265,0.818544,-1.665889,0.50366,-0.176947,0.576773,3.030977,-0.164314,1.973223,0.344814,0.241486,0.555655,1.057799,-1.360743,0.737478,-0.491277,3.544042,0.230399,0.459425,-0.215882,1.272251
0,1501,4,972,3.968951,-0.095761,4.0,0.025961,0.999663,0,0.887092,2.937386,-0.182443,0.543732,0.442426,0.110936,0.220615,1.03265,-1.216774,0.784338,-0.471199,2.819678,0.085479,0.320036,-0.20947,2.727845,0.677889,-0.232953,1.545844,-0.400375,-0.097581,1.078562,-0.027449,0.988952,0.101866,1.33037,-0.239898,0.578574,1.035723,1.038067,0.121361,-0.020703,1.268103,-0.015387,-0.27449,0.443679,-0.032835,0.081685,2.162956,0.1744,1.311636,-0.793525,-0.053723,0.116561,0.430962,1.398267,0.531677,1.566557,0.573476,0.005053,-0.223328,-0.041665,0.217608,0.088718,0.785765,-0.643146,0.09318,-0.022753,0.306412,1.706059,0.678425,0.74031,-0.865631,0.241867,-0.162584,0.454686,1.425066,1.044657,1.13246,0.498803,0.421872,0.108536,0.338878,1.015033,0.290111,0.447772,-0.686286,0.260635,-0.094769,1.316718,-0.462847,3.548594,-0.686931,-0.189506,2.466642,-0.049417,2.098807,0.266972,3.230078,-0.469843,1.069641,2.191558,2.277283,0.244191,0.018034,2.799317,-0.005885,-0.543912,1.109625,-0.094759,0.187813,3.20105,0.385604,2.143462,-1.802714,-0.10693,0.257528,0.970331,2.467541,1.174929,2.814839,1.27707,0.010928,-0.492258,-0.12023,0.386016,0.198266,1.38043,-1.389658,0.189961,-0.050142,0.603808,2.986171,1.487867,1.274851,-2.146081,0.444071,-0.337496,0.926739,2.61325,2.263304,2.073606,1.104406,0.818616,0.248765,0.666315,1.849567,0.658265,0.818544,-1.665889,0.50366,-0.176947,0.68701,3.097612,-0.171895,1.599002,0.391612,0.198615,0.450749,1.049469,-1.312924,0.7532,-0.484597,3.428709,0.182646,0.414966,-0.215882,1.781669


In [26]:
KNOWN_REALS = [c for c in lf_clean.collect_schema().names() if c not in (G_SYM, G_DATE, G_TIME, "time_idx", "time_bucket", WEIGHT_COL, TARGET_COL)]

KNOWN_CATEGORIES = ["time_bucket"]

UNSCALE_COLS = KNOWN_REALS

TRAIN_COLS = [c for c in lf_clean.collect_schema().names() if c not in (G_DATE, G_TIME)]

# 定义 identity scalers
identity_scalers = {name: None for name in UNSCALE_COLS}

## try

In [None]:
# 取第一折先试探一下
best_ckpt_paths, fold_metrics = [], []
fold_id = 0
train_days, val_days = folds_by_day[0]

print(f"[fold {fold_id}] train {train_days[0]}..{train_days[-1]} ({len(train_days)} days), "
    f"val {val_days[0]}..{val_days[-1]} ({len(val_days)} days)")

# 明确日期：
train_start_date = int(train_days[0])
train_end_date   = int(train_days[-1])
val_start_date   = int(val_days[0])
val_end_date     = int(val_days[-1])      


# 提取数据
date_range = (train_start_date, val_end_date)
pdf_data = (
    pl.scan_parquet(clean_path_local)
    .filter(pl.col(G_DATE).between(train_start_date, val_end_date, closed="both"))
    .collect(streaming=True)
    .to_pandas()
    .sort_values([G_SYM, "time_idx"])
)
pdf_data[G_SYM] = pdf_data[G_SYM].astype("str")   

# 明确 indexes:
train_end_idx = pdf_data.loc[pdf_data[G_DATE] == train_end_date, "time_idx"].max()
val_start_idx = pdf_data.loc[pdf_data[G_DATE] == val_start_date, "time_idx"].min()
val_end_idx   = pdf_data.loc[pdf_data[G_DATE] == val_end_date, "time_idx"].max()
assert pd.notna(train_end_idx) and pd.notna(val_start_idx) and pd.notna(val_end_idx), "train/val idx not found"
train_end_idx, val_start_idx, val_end_idx = int(train_end_idx), int(val_start_idx), int(val_end_idx)
print(f"[fold {fold_id}] train idx up to {train_end_idx}, val idx {val_start_idx}..{val_end_idx}")  



In [None]:
t_data = pdf_data[TRAIN_COLS].copy()

identity_scalers = {name: None for name in UNSCALE_COLS}
base_ds = TimeSeriesDataSet(
    t_data,
    time_idx="time_idx",
    target=TARGET_COL,
    group_ids=[G_SYM],
    weight=WEIGHT_COL,
    max_encoder_length=ENC_LEN, 
    min_encoder_length=ENC_LEN,
    max_prediction_length=PRED_LEN, 
    min_prediction_length=PRED_LEN,
    
    static_categoricals=[G_SYM],
    time_varying_known_reals =KNOWN_REALS,
    time_varying_unknown_reals=UNKNOWN_REALS,
    
    lags=None,
    categorical_encoders={G_SYM: NaNLabelEncoder(add_nan=True)},
    add_relative_time_idx=False,
    add_target_scales=False,
    add_encoder_length=False,
    allow_missing_timesteps=True,
    target_normalizer=TorchNormalizer(method="identity"),
    scalers=identity_scalers,
)

In [None]:
# 划分训练集，验证集
train_ds = base_ds.filter(
    lambda idx: (
        idx.time_idx_last <= train_end_idx
    ),
    copy=True
)

val_ds = base_ds.filter(
    lambda idx: (
        (idx.time_idx_first_prediction >= val_start_idx) &
        
        (idx.time_idx_last <= val_end_idx)
    ),
    copy=True
)

In [None]:
# 数据集加载

train_loader = train_ds.to_dataloader(
    train=True, 
    batch_size=BATCH_SIZE, 
    num_workers=14,
    pin_memory=True,
    persistent_workers=False,
    prefetch_factor=8,
)

n_train_batches = len(train_loader)
print(f"[debug] train_loader batches = {n_train_batches}")
assert n_train_batches > 0, "Empty train dataloader. Check min_prediction_idx/ENC_LEN/date windows."

val_loader = val_ds.to_dataloader(
    train=False,
    batch_size=BATCH_SIZE,
    num_workers=14,
    pin_memory=True,
    persistent_workers=False,
    prefetch_factor=8,
)

n_val_batches = len(val_loader)
print(f"[debug] val_loader batches = {n_val_batches}")
assert n_val_batches > 0, "Empty val dataloader. Check min_prediction_idx/ENC_LEN/date windows."

In [None]:
lp.seed_everything(42)
trainer = lp.Trainer(
    accelerator="gpu",
    # clipping gradients is a hyperparameter and important to prevent divergance
    # of the gradient for recurrent neural networks
    gradient_clip_val=0.1,
)

tft = TemporalFusionTransformer.from_dataset(
    train_ds,
    # not meaningful for finding the learning rate but otherwise very important
    learning_rate=LR,
    hidden_size=HIDDEN,  # most important hyperparameter apart from learning rate
    # number of attention heads. Set to up to 4 for large datasets
    attention_head_size=HEADS,
    dropout=DROPOUT,  # between 0.1 and 0.3 are good values
    hidden_continuous_size=HIDDEN // 2,  # set to <= hidden_size
    loss=RMSE(),
    optimizer=torch.optim.Adam,
    # reduce learning rate if no improvement in validation loss after x epochs
    # reduce_on_plateau_patience=1000,
)
print(f"Number of parameters in network: {tft.size() / 1e3:.1f}k")
    
# find optimal learning rate
from lightning.pytorch.tuner import Tuner

res = Tuner(trainer).lr_find(
    tft,
    train_dataloaders=train_loader,
    val_dataloaders=val_loader,
    max_lr=10.0,
    min_lr=1e-6,
)

print(f"suggested learning rate: {res.suggestion()}")
fig = res.plot(show=True, suggest=True)
fig.show()

## main

In [27]:
lp.seed_everything(42) 

# ========== 8) 训练（按 CV 折） ========== 先取第一折
best_ckpt_paths, fold_metrics = [], []

#for fold_id, (train_days, val_days) in enumerate(folds_by_day, start=1):
####################################
fold_id = 0
train_days, val_days = folds_by_day[fold_id]
####################################

print(f"[fold {fold_id}] train {train_days[0]}..{train_days[-1]} ({len(train_days)} days), "
    f"val {val_days[0]}..{val_days[-1]} ({len(val_days)} days)")

# 明确日期：
train_start_date = int(train_days[0])
train_end_date   = int(train_days[-1])
val_start_date   = int(val_days[0])
val_end_date     = int(val_days[-1])      

# 提取数据
date_range = (train_start_date, val_end_date)
pdf_data = (
    lf_clean
    .filter(pl.col(G_DATE).is_between(train_start_date, val_end_date, closed="both"))
    .collect(streaming=True)
    .to_pandas()
) 

pdf_data[G_SYM] = pdf_data[G_SYM].astype("str")
pdf_data["time_bucket"] = pdf_data["time_bucket"].astype("str")
pdf_data.sort_values([G_SYM, "time_idx"], inplace=True)

Seed set to 42


[fold 0] train 1501..1588 (88 days), val 1589..1600 (12 days)


In [28]:

# 明确 indexes:
train_end_idx = pdf_data.loc[pdf_data[G_DATE] == train_end_date, "time_idx"].max()
val_start_idx = pdf_data.loc[pdf_data[G_DATE] == val_start_date, "time_idx"].min()
val_end_idx   = pdf_data.loc[pdf_data[G_DATE] == val_end_date, "time_idx"].max()
assert pd.notna(train_end_idx) and pd.notna(val_start_idx) and pd.notna(val_end_idx), "train/val idx not found"
train_end_idx, val_start_idx, val_end_idx = int(train_end_idx), int(val_start_idx), int(val_end_idx)
print(f"[fold {fold_id}] train idx up to {train_end_idx}, val idx {val_start_idx}..{val_end_idx}")  


[fold 0] train idx up to 86151, val idx 86152..97767


In [29]:
# 构建训练集 timeseries dataset

t_data = pdf_data[TRAIN_COLS]
train_df = t_data.loc[t_data["time_idx"] <= train_end_idx]

train_ds = TimeSeriesDataSet(
    train_df,
    time_idx="time_idx",
    target=TARGET_COL,
    group_ids=[G_SYM],
    weight=WEIGHT_COL,
    max_encoder_length=ENC_LEN, min_encoder_length=ENC_LEN,
    max_prediction_length=PRED_LEN, min_prediction_length=PRED_LEN,
    time_varying_known_reals =KNOWN_REALS,
    time_varying_known_categoricals=KNOWN_CATEGORIES,
    static_categoricals=[G_SYM],
    categorical_encoders={G_SYM: NaNLabelEncoder(add_nan=True),
                        "time_bucket": NaNLabelEncoder(add_nan=True),},
    add_relative_time_idx=True,
    add_target_scales=True,
    add_encoder_length=True,
    allow_missing_timesteps=True,
    target_normalizer=GroupNormalizer(
        method="standard", groups=[G_SYM], center=True, scale_by_group=False),
    scalers=identity_scalers,
)

In [30]:

# 验证集复用 train_ds 的所有 encoders/normalizer（不泄漏）

val_df = t_data.loc[t_data["time_idx"].between(val_start_idx, val_end_idx, inclusive="both")]
val_ds = TimeSeriesDataSet.from_dataset(
    train_ds,
    val_df,
    min_prediction_idx=val_start_idx+ENC_LEN+PRED_LEN-1,
    stop_randomization=True,
    predict=True
)

In [31]:
# 数据集加载

train_loader = train_ds.to_dataloader(
    train=True, 
    batch_size=BATCH_SIZE, 
    num_workers=14,
    pin_memory=True,
    persistent_workers=True,
    prefetch_factor=4,
    shuffle=True
)

n_train_batches = len(train_loader)
print(f"[debug] train_loader batches = {n_train_batches}")
assert n_train_batches > 0, "Empty train dataloader. Check min_prediction_idx/ENC_LEN/date windows."

val_loader = val_ds.to_dataloader(
    train=False,
    batch_size=BATCH_SIZE*4,
    num_workers=14,
    pin_memory=True,
    persistent_workers=True,
    prefetch_factor=4,
    shuffle=True,
)

n_val_batches = len(val_loader)
print(f"[debug] val_loader batches = {n_val_batches}")
assert n_val_batches > 0, "Empty val dataloader. Check min_prediction_idx/ENC_LEN/date windows."

[debug] train_loader batches = 50847
[debug] val_loader batches = 1


In [33]:
# 8.6 callbacks/logger/trainer
ckpt_dir_fold = Path(CKPTS_DIR) / f"fold_{fold_id}"
ckpt_dir_fold.mkdir(parents=True, exist_ok=True)

callbacks = [EarlyStopping(monitor="val_loss", mode="min", patience=5, check_on_train_epoch_end=False),
            ModelCheckpoint(monitor="val_loss", mode="min", save_top_k=1, dirpath=ckpt_dir_fold.as_posix(), filename=f"fold{fold_id}-tft-best-{{epoch:02d}}-{{val_loss:.5f}}", save_on_train_epoch_end=False),
            LearningRateMonitor(logging_interval="step"),]
RUN_NAME = (f"f{fold_id}"f"_E{MAX_EPOCHS}"f"_lr{LR}"f"_bs{BATCH_SIZE}"f"_enc{ENC_LEN}_dec{DEC_LEN}"f"_{datetime.now().strftime('%Y%m%d-%H%M%S')}")
logger = TensorBoardLogger(save_dir=LOGS_DIR.as_posix(),name="tft",version=RUN_NAME,default_hp_metric=False)

trainer = lp.Trainer(max_epochs=MAX_EPOCHS,
                    accelerator="gpu",
                    precision="bf16-mixed",
                    enable_model_summary=True,
                    gradient_clip_val=1.0,
                    #fast_dev_run=True,
                    limit_train_batches=200,  # coment in for training, running valiation every 30 batches
                    log_every_n_steps=20,
                    callbacks=callbacks,
                    logger=logger,
                    accumulate_grad_batches=1
                    )

tft = TemporalFusionTransformer.from_dataset(
    train_ds,
    learning_rate=LR,
    hidden_size=HIDDEN,
    attention_head_size=HEADS,
    dropout=DROPOUT,
    hidden_continuous_size=HIDDEN // 2,
    loss=RMSE(),
    logging_metrics=[RMSE()],
    optimizer=torch.optim.AdamW,
    optimizer_params={"weight_decay": 1e-4},
    reduce_on_plateau_patience=4,
)
trainer.fit(
    tft,
    train_dataloaders=train_loader,
    val_dataloaders=val_loader,
    )

Using bfloat16 Automatic Mixed Precision (AMP)
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

   | Name                               | Type                            | Params | Mode 
------------------------------------------------------------------------------------------------
0  | loss                               | RMSE                            | 0      | train
1  | logging_metrics                    | ModuleList                      | 0      | train
2  | input_embeddings                   | MultiEmbedding                  | 555    | train
3  | prescalers                         | ModuleDict                      | 21.1 K | train
4  | static_variable_selection          | VariableSelectionNetwork        | 78.5 K | train
5  | encoder_variable_selection         | VariableSelectionNetwork        | 5.6 M  | train
6  | decoder_variable_selection         | VariableSelectionN

Sanity Checking: |          | 0/? [00:00<?, ?it/s]

Training: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

In [None]:
ckpt_cb = next(cb for cb in callbacks if isinstance(cb, ModelCheckpoint))
best_path = ckpt_cb.best_model_path
best_tft = TemporalFusionTransformer.load_from_checkpoint(best_path)

predictions = best_tft.predict(
    val_loader,
    return_y=True,
    trainer_kwargs=dict(accelerator="gpu")
)
y_pred = predictions.output
y_true, w = predictions.y

In [None]:
num = (w * (y_true - y_pred).pow(2)).sum()
den = (w * y_true.pow(2)).sum()

wr2 = 1.0 - num / (den + eps)
print(f"wr2 after training: {wr2.item():.6f}")

In [None]:
val_ds.get_parameters

In [None]:
num = (torch.square(y_true - y_pred) * w).sum()
den = (torch.square(y_true) * w).sum()  
wr2 = 1 - num / den
print(f"wr2 after training: {wr2}")