## 导入库

In [1]:
# -*- coding: utf-8 -*-
from __future__ import annotations

# ── 标准库 ──────────────────────────────────────────────────────────────────
import os
import time
from pathlib import Path
from collections import defaultdict
from datetime import datetime

# ── 第三方 ──────────────────────────────────────────────────────────────────
import numpy as np
import pandas as pd
import polars as pl

import torch
import torch.backends.cudnn as cudnn
import lightning as L
import lightning.pytorch as lp
from torch.utils.data import DataLoader

from lightning.pytorch.callbacks import EarlyStopping, ModelCheckpoint, LearningRateMonitor
from lightning.pytorch.loggers import TensorBoardLogger
from lightning.pytorch.callbacks import DeviceStatsMonitor
from pytorch_forecasting import TimeSeriesDataSet, TemporalFusionTransformer, Baseline
from pytorch_forecasting.metrics import MAE, RMSE
from pytorch_forecasting.data.encoders import NaNLabelEncoder
from pytorch_forecasting.data import TorchNormalizer, GroupNormalizer


# 你的工程工具
from pipeline.io import cfg, P, fs, storage_options, ensure_dir_local, ensure_dir_az
from pipeline.stream_input_local import ShardedBatchStream  
from pipeline.wr2 import WR2

# ---- 性能/兼容开关（仅一次）----
os.environ.setdefault("POLARS_MAX_THREADS", str(max(1, os.cpu_count() // 2)))
pl.enable_string_cache()
cudnn.benchmark = True
torch.set_float32_matmul_precision("high")

import time as _t


import warnings
warnings.filterwarnings("ignore")  # avoid printing out absolute paths

def _now() -> str:
    return _t.strftime("%Y-%m-%d %H:%M:%S")
print(f"[{_now()}] imports ok")

[2025-10-14 07:49:55] imports ok


## 定义工具函数

In [2]:
# ───────────────────────────────────────────────────────────────────────────
# 滑动窗划分
def make_sliding_cv_by_days(all_days: np.ndarray, *, n_splits: int, gap_days: int, train_to_val: int):
    all_days = np.asarray(all_days).ravel()
    K, R, G = n_splits, train_to_val, gap_days
    usable = len(all_days) - G
    if usable <= 0 or K <= 0 or R <= 0:
        return []
    V_base, rem = divmod(usable, R + K)
    if V_base <= 0:
        return []
    T = R * V_base
    v_lens = [V_base + 1 if i < rem else V_base for i in range(K)]
    folds, v_lo = [], T + G
    for V_i in v_lens:
        v_hi, tr_hi, tr_lo = v_lo + V_i, v_lo - G, v_lo - G - T
        if tr_lo < 0 or v_hi > len(all_days):
            break
        folds.append((all_days[tr_lo:tr_hi], all_days[v_lo:v_hi]))
        v_lo = v_hi
    return folds


## 初始化参数

In [3]:
# 读入筛选的特征列
df_ranking_features = pd.read_csv("/mnt/data/js/exp/v1/models/tune/feature_importance__fixed__fixed__mm_full_train__features__fs__1300-1500__cv3-g7-r4__seed42__top1000__1760299442__range1000-1600__range1000-1600__cv2-g7-r4__1760347190.csv")

In [58]:
df_ranking_features["mean_gain"].iloc[:200].sum()

np.float64(0.5137157718121288)

In [60]:
df_ranking_features['mean_gain'].iloc[:150].sum()

np.float64(0.443650384606832)

In [4]:

ls_topk_features = df_ranking_features['feature'].tolist()
TIME_FEATURES = ["time_bucket", "time_pos", "time_sin", "time_cos"]
TO_Z_INPUTS = [c for c in ls_topk_features if c not in TIME_FEATURES]

In [None]:
# ========== 1) 统一配置 ==========
G_SYM, G_DATE, G_TIME = cfg["keys"]          # e.g. ("symbol_id","date_id","time_id")
TARGET_COL = cfg["target"]                   # e.g. "responder_6"
WEIGHT_COL = cfg["weight"]                   # 允许为 None

# 训练 & CV 超参
N_SPLITS     = 1
GAP_DAYS     = 0
TRAIN_TO_VAL = 8
ENC_LEN      = 2
DEC_LEN      = 1
PRED_LEN     = DEC_LEN
BATCH_SIZE   = 512
LR           = 1e-5
HIDDEN       = 16
HEADS        = 1
DROPOUT      = 0.1
MAX_EPOCHS   = 30

# 数据路径
PANEL_DIR_AZ   = P("az", cfg["paths"].get("panel_shards", "panel_shards"))

TFT_LOCAL_ROOT = P("local", "tft"); ensure_dir_local(TFT_LOCAL_ROOT)

LOCAL_CLEAN_DIR = f"{TFT_LOCAL_ROOT}/clean"; ensure_dir_local(LOCAL_CLEAN_DIR)
CKPTS_DIR = Path(TFT_LOCAL_ROOT) / "ckpts"; ensure_dir_local(CKPTS_DIR.as_posix())
LOGS_DIR  = Path(TFT_LOCAL_ROOT) / "logs";  ensure_dir_local(LOGS_DIR.as_posix())

print("[config] ready")

[config] ready


## 数据导入

In [None]:
data_start_date = 1500
data_end_date = 1600

data_paths = fs.glob("az://jackson/js_exp/exp/v1/panel_shards/*.parquet")
data_paths =[f"az://{p}" for p in data_paths]

lf_data = (
    pl.scan_parquet(data_paths, storage_options=storage_options)
    .select([*cfg['keys'], WEIGHT_COL, TARGET_COL, *TIME_FEATURES, *TO_Z_INPUTS])
    .filter(pl.col(G_DATE).is_between(data_start_date, data_end_date, closed="both"))
)
lf_data = lf_data.sort([G_SYM, G_DATE, G_TIME])


## 数据处理

### 添加全局时间序列号

In [7]:
grid_df = (
    lf_data.select([G_DATE, G_TIME]).unique()
        .sort([G_DATE, G_TIME])
        .with_row_index("time_idx")
        .with_columns(pl.col("time_idx").cast(pl.Int64))
        .collect(streaming=True)
)

In [8]:
container_prefix = "az://jackson/js_exp/exp/v1/tft/panel_clean_shards"; ensure_dir_az(container_prefix)
chunk_size = 30
for lo in range(data_start_date, data_end_date, chunk_size):
    hi = min(lo + chunk_size - 1, data_end_date)
    print(f"processing date range: {lo} ~ {hi}")
    
    lf_chunk = lf_data.filter(pl.col(G_DATE).is_between(lo, hi, closed="both"))
    
    lf_grid_chunk = (
        grid_df.lazy().filter(pl.col(G_DATE).is_between(lo, hi, closed="both"))
    )
    
    lf_joined = (
        lf_chunk.join(lf_grid_chunk, on=[G_DATE, G_TIME], how="left").sort([G_SYM, "time_idx"])
    )
    
    out_path = f"{container_prefix}/panel_clean_{lo:04d}_{hi:04d}.parquet"
    print(f"writing to: {out_path}")
    
    lf_joined.sink_parquet(
        out_path,
        storage_options=storage_options,
        compression="zstd",
    )
print(f"[{_now()}] all done")

processing date range: 1500 ~ 1529
writing to: az://jackson/js_exp/exp/v1/tft/panel_clean_shards/panel_clean_1500_1529.parquet
processing date range: 1530 ~ 1559
writing to: az://jackson/js_exp/exp/v1/tft/panel_clean_shards/panel_clean_1530_1559.parquet
processing date range: 1560 ~ 1589
writing to: az://jackson/js_exp/exp/v1/tft/panel_clean_shards/panel_clean_1560_1589.parquet
processing date range: 1590 ~ 1600
writing to: az://jackson/js_exp/exp/v1/tft/panel_clean_shards/panel_clean_1590_1600.parquet
[2025-10-14 07:54:03] all done


### 导入新数据

In [9]:
# 重新读入数据
container_prefix = "az://jackson/js_exp/exp/v1/tft/panel_clean_shards"; ensure_dir_az(container_prefix)
data_paths = fs.glob(f"{container_prefix}/*.parquet")
data_paths = [f"az://{p}" for p in data_paths] 
lf_with_idx = pl.scan_parquet(data_paths, storage_options=storage_options).sort([G_SYM, G_DATE, G_TIME])

In [10]:
lf_with_idx.limit(10).collect()

symbol_id,date_id,time_id,weight,responder_6,time_bucket,time_pos,time_sin,time_cos,feature_06,feature_36,feature_04,feature_75__rstd14,feature_60__cs_z,feature_59,responder_0_close_roll30_std,feature_59__rstd30,feature_07,feature_61__lag900,feature_60,feature_61__lag1936,responder_6_prevday_std,responder_8_prev_tail_lag10,feature_61__ret50,feature_61__lag6776,feature_25__diff50,feature_76__rstd7,feature_48,responder_5_prevday_std,feature_60__rstd30,responder_5_prevday_mean,feature_51__rmean14,responder_1_close_roll14_std,feature_37__rstd30,responder_2_close_roll30_std,feature_31__diff50,feature_58,…,feature_12__lag2904,feature_67__lag4840,feature_70__lag6776,feature_21__lag968,feature_33__lag3872,feature_64__rstd30,feature_47__lag500,feature_67__lag7744,feature_75,responder_4_prev_tail_lag1,feature_72__lag50,feature_04__lag5808,feature_69__lag3872,feature_67__lag5808,feature_70__diff50,feature_01__lag3872,feature_27__rstd3,responder_7_same_t_last10_slope,feature_33__lag4840,feature_69__lag968,feature_57__diff50,feature_69__lag2904,feature_29__rz14,feature_27__diff3,feature_65__lag50,feature_12__ret50,feature_00__lag5808,feature_07__lag4840,responder_4_same_t_last10_slope,feature_42__lag100,feature_65__lag968,responder_5_prevday_close,responder_1_same_t_prev3,responder_2_prevday_close,feature_30__rstd14,responder_8_prev_tail_lag1,time_idx
i32,i32,i32,f32,f32,u8,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,…,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,i64
0,1500,0,4.426604,-1.314976,0,0.0,0.0,1.0,-0.159566,-2.922452,2.150813,0.649587,-0.850858,1.746195,0.908512,0.517499,-0.829518,-1.370043,0.641324,-1.257957,0.821905,-0.156409,-0.172714,0.898012,-0.224687,0.665439,1.725063,0.434996,0.832047,-0.01,-0.181106,0.580871,0.186617,1.597364,-0.026319,-1.144345,…,-0.665987,-0.39041,-1.242651,-0.050829,-1.686733,0.084058,-0.01854,-0.803619,2.612922,0.165235,-0.195847,-0.586194,0.25757,-0.84916,-0.659239,2.059735,0.000669,0.033831,-0.661166,-0.461735,2.671528,-0.460689,-0.000516,0.050152,0.564481,-25.604591,-0.473339,-2.192881,-0.053192,-1.107264,-1.738099,0.047745,0.186933,-1.005825,0.0,0.301072,0
0,1500,1,4.426604,-0.671698,0,1.0,0.006491,0.999979,1.069108,-2.794829,2.031462,0.836284,1.438851,1.23963,0.908512,0.65714,-0.089751,-1.370043,0.988123,-1.257957,0.821905,-0.156409,-0.188648,0.9046,-0.224687,0.759381,1.759171,0.434996,0.841773,-0.01,-0.07661,0.580871,0.227577,1.597364,-0.026319,-1.144345,…,-1.234799,-0.530497,-0.762113,-0.050829,-1.686733,0.084335,0.029866,-0.490434,2.648023,0.165235,-0.197776,-0.426806,0.752326,-0.7315,-0.906033,2.097947,0.028961,0.020192,-0.661166,-0.548755,1.858275,-0.358871,-3.2803,0.053028,0.827406,-13.701934,-1.003718,-1.360826,-0.04475,-0.782763,-1.273083,0.047745,-0.26652,-1.005825,0.011843,0.301072,1
0,1500,2,4.426604,-0.290113,0,2.0,0.012981,0.999916,1.138704,-2.443655,2.43745,0.953172,1.20742,2.128175,0.908512,0.721343,0.062938,-1.370043,1.320846,-1.257957,0.821905,-0.156409,-0.203366,0.911017,-0.224687,0.779975,1.77037,0.434996,0.868895,-0.01,0.06524,0.580871,0.258462,1.597364,-0.026319,-1.144345,…,-1.222237,-0.640656,-1.536348,-0.050829,-1.686733,0.084741,0.071782,-0.863921,2.688747,0.165235,-0.120331,-0.701958,0.650036,-0.79736,-0.655329,2.145121,0.029824,0.034855,-0.661166,-0.472839,2.921851,-0.397401,-2.28656,0.05582,0.347862,-16.002048,-0.64903,-1.178232,0.055282,-0.402001,-1.885703,0.047745,-0.275355,-1.005825,0.016092,0.301072,2
0,1500,3,4.426604,-0.338634,0,3.0,0.019471,0.99981,1.024881,-1.752336,2.53054,1.020966,1.261943,2.167717,0.908512,0.848355,0.152608,-1.370043,1.082904,-1.257957,0.821905,-0.156409,-0.217116,0.917273,-0.224687,0.786966,1.202953,0.434996,0.933084,-0.01,0.177734,0.580871,0.275996,1.597364,-0.026319,-1.144345,…,-0.552355,-0.420501,-1.087695,-0.050829,-1.686733,0.08198,0.075735,-0.724895,1.87707,0.165235,0.084185,-0.854242,0.853553,-0.628053,-0.861989,2.174752,0.002858,-0.052003,-0.661166,-0.430539,2.603309,-0.275922,-1.802911,0.008386,0.078301,-5.068503,-0.565976,-0.827309,-0.091356,-0.443677,-1.864025,0.047745,-0.336798,-1.005825,0.01887,0.301072,3
0,1500,4,4.426604,-0.530085,0,4.0,0.025961,0.999663,1.954541,-2.193815,2.395972,1.006895,0.862566,1.732435,0.908512,0.956367,0.468781,-1.370043,1.062898,-1.257957,0.821905,-0.156409,-0.23007,0.92338,-0.224687,0.70098,1.503926,0.434996,0.976965,-0.01,0.301686,0.580871,0.290789,1.597364,-0.026319,-1.144345,…,-0.8053,-0.526361,-0.917775,-0.050829,-1.686733,0.081711,0.085322,-0.837613,1.397569,0.165235,0.024061,-1.390298,0.704476,-0.919837,-0.93792,2.203427,0.002768,-0.05121,-0.661166,-0.445658,2.298639,-0.501758,-1.494783,0.008158,0.179187,-8.762663,-0.335041,-0.300222,-0.049125,-0.331161,-1.575063,0.047745,-0.254777,-1.005825,0.020776,0.301072,4
0,1500,5,4.426604,-0.523631,0,5.0,0.032449,0.999473,0.468671,-1.852185,2.583213,0.963734,0.692108,0.595211,0.908512,1.006524,0.264905,-1.370043,0.967309,-1.257957,0.821905,-0.156409,-0.24235,0.92935,-0.224687,0.777294,0.468923,0.434996,0.998191,-0.01,0.280917,0.580871,0.299506,1.597364,-0.026319,-1.144345,…,-0.689993,-0.419114,-1.344129,-0.050829,-1.686733,0.087196,0.188119,-0.593926,1.286738,0.165235,0.086693,-1.501489,0.681864,-0.95835,-0.933405,2.21269,0.002686,-0.022412,-0.661166,-0.477593,2.944041,-0.30584,-1.271117,0.007949,0.100722,-8.030293,-0.386803,-0.174156,-0.055867,-0.138473,-1.542874,0.047745,-0.482733,-1.005825,0.022037,0.301072,5
0,1500,6,4.426604,-0.56909,0,6.0,0.038936,0.999242,-0.078664,-1.991771,3.055325,0.898562,0.7424,0.231122,0.908512,1.007153,-0.005691,-1.370043,0.992348,-1.257957,0.821905,-0.156409,-0.254052,0.93519,-0.224687,0.763281,0.687261,0.434996,1.040608,-0.01,0.232937,0.580871,0.315011,1.597364,-0.026319,-1.144345,…,-0.809958,-0.295684,-0.737493,-0.050829,-1.686733,0.087121,0.232508,-0.928018,0.969025,0.165235,0.041407,-1.753176,0.946623,-0.902968,-0.586024,2.252421,0.002612,0.011225,-0.661166,-0.446725,1.62205,-0.207938,-1.09522,0.007752,0.390485,-8.977188,-0.426779,0.138501,0.041062,-0.145896,-1.682145,0.047745,0.0823,-1.005825,0.022759,0.301072,6
0,1500,7,4.426604,-0.729039,0,7.0,0.045421,0.998968,0.10524,-1.826996,2.184363,0.820525,0.790378,0.00997,0.908512,1.007391,0.099631,-1.370043,0.55643,-1.257957,0.821905,-0.156409,-0.265253,0.940908,-0.224687,0.821404,0.64151,0.434996,1.07784,-0.01,0.252241,0.580871,0.303542,1.597364,-0.026319,-1.144345,…,-0.923364,-0.409874,-0.965199,-0.050829,-1.686733,0.089139,0.114082,-0.596197,0.425887,0.165235,0.051658,-1.928109,1.060069,-0.798959,-0.629064,2.285988,0.002549,0.057432,-0.661166,-0.529096,2.309217,-0.141837,-0.948937,0.00757,1.109996,-6.574678,-0.868704,0.113869,0.13511,-0.948322,-1.574444,0.047745,-0.035614,-1.005825,0.022994,0.301072,7
0,1500,8,4.426604,-1.032072,0,8.0,0.051904,0.998652,-0.864718,-1.532152,2.492193,0.796218,0.390078,-0.010016,0.908512,1.005947,-0.286703,-1.370043,0.455362,-1.257957,0.821905,-0.156409,-0.27601,0.946512,-0.224687,0.900719,0.026692,0.434996,1.093246,-0.01,0.215117,0.580871,0.307478,1.597364,-0.026319,-1.144345,…,-0.821925,-0.329315,-1.005486,-0.050829,-1.686733,0.088686,0.080459,-0.823782,0.278089,0.165235,0.046512,-1.890853,1.151143,-0.588894,-0.576537,2.328777,0.00249,0.05335,-0.661166,-0.483563,1.832548,-0.109526,-0.821592,0.007398,0.43348,-12.403069,-0.34293,0.102463,0.089992,-0.154216,-1.617033,0.047745,-0.251343,-1.005825,0.022764,0.301072,8
0,1500,9,4.426604,-0.609382,0,9.0,0.058385,0.998294,0.573761,-1.496487,2.304881,0.840634,0.452221,0.011837,0.908512,1.003755,0.207682,-1.370043,0.542987,-1.257957,0.821905,-0.156409,-0.286369,0.952008,-0.224687,0.891581,-0.013964,0.434996,1.103102,-0.01,0.119246,0.580871,0.307522,1.597364,-0.026319,-1.144345,…,-0.758572,-0.375304,-0.763938,-0.050829,-1.686733,0.090107,-0.016491,-0.794234,0.120118,0.165235,0.011739,-2.448828,0.658324,-1.012865,-0.7372,2.357362,0.002435,0.0072,-0.661166,-0.37065,2.349816,-0.044383,-0.706429,0.007237,-0.352614,-5.183767,-0.508246,-0.043192,0.094173,-0.396566,-1.183666,0.047745,-0.18378,-1.005825,0.022042,0.301072,9


In [11]:
# 明确要标准化的列
z_cols = [c for c in TO_Z_INPUTS if c in lf_with_idx.collect_schema().names()]

# ========== 连续特征清洗 ==========
inf2null_exprs  = [pl.when(pl.col(c).is_infinite()).then(None).otherwise(pl.col(c)).alias(c) for c in z_cols]
ffill_exprs = [pl.col(c).forward_fill().over(G_SYM).fill_null(0.0).alias(c) for c in z_cols]

lf_clean = (
    lf_with_idx.with_columns(inf2null_exprs).with_columns(ffill_exprs) #.with_columns(isna_flag_exprs)
)

In [26]:
lf_clean.select(pl.all().is_null().sum()).collect()

symbol_id,date_id,time_id,weight,responder_6,time_bucket,time_pos,time_sin,time_cos,feature_06,feature_36,feature_04,feature_75__rstd14,feature_60__cs_z,feature_59,responder_0_close_roll30_std,feature_59__rstd30,feature_07,feature_61__lag900,feature_60,feature_61__lag1936,responder_6_prevday_std,responder_8_prev_tail_lag10,feature_61__ret50,feature_61__lag6776,feature_25__diff50,feature_76__rstd7,feature_48,responder_5_prevday_std,feature_60__rstd30,responder_5_prevday_mean,feature_51__rmean14,responder_1_close_roll14_std,feature_37__rstd30,responder_2_close_roll30_std,feature_31__diff50,feature_58,…,feature_12__lag2904,feature_67__lag4840,feature_70__lag6776,feature_21__lag968,feature_33__lag3872,feature_64__rstd30,feature_47__lag500,feature_67__lag7744,feature_75,responder_4_prev_tail_lag1,feature_72__lag50,feature_04__lag5808,feature_69__lag3872,feature_67__lag5808,feature_70__diff50,feature_01__lag3872,feature_27__rstd3,responder_7_same_t_last10_slope,feature_33__lag4840,feature_69__lag968,feature_57__diff50,feature_69__lag2904,feature_29__rz14,feature_27__diff3,feature_65__lag50,feature_12__ret50,feature_00__lag5808,feature_07__lag4840,responder_4_same_t_last10_slope,feature_42__lag100,feature_65__lag968,responder_5_prevday_close,responder_1_same_t_prev3,responder_2_prevday_close,feature_30__rstd14,responder_8_prev_tail_lag1,time_idx
u32,u32,u32,u32,u32,u32,u32,u32,u32,u32,u32,u32,u32,u32,u32,u32,u32,u32,u32,u32,u32,u32,u32,u32,u32,u32,u32,u32,u32,u32,u32,u32,u32,u32,u32,u32,u32,…,u32,u32,u32,u32,u32,u32,u32,u32,u32,u32,u32,u32,u32,u32,u32,u32,u32,u32,u32,u32,u32,u32,u32,u32,u32,u32,u32,u32,u32,u32,u32,u32,u32,u32,u32,u32,u32
0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,…,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0


## CV 划分

In [12]:
# ==========  CV 划分 ==========
all_days = (
    lf_clean.select(pl.col(G_DATE)).unique().sort([G_DATE])
    .collect(streaming=True)[G_DATE].to_numpy()
)
folds_by_day = make_sliding_cv_by_days(all_days, n_splits=N_SPLITS, gap_days=GAP_DAYS, train_to_val=TRAIN_TO_VAL)
assert len(folds_by_day) > 0, "no CV folds constructed"

## 数据标准化处理

In [14]:
fold_id = 0
# 取第一个 fold 的训练集最后一天，作为统计 z-score 的上界
stats_hi = int(folds_by_day[fold_id][0][-1])
print(f"标准化使用训练集的上限日期 = {stats_hi}")

# ========== Z-score ==========
# 开始计算 stats
lf_stats_sym = (
    lf_clean.filter(pl.col(G_DATE) <= stats_hi)
            .group_by(G_SYM)
            .agg([pl.col(c).mean().alias(f"mu_{c}") for c in z_cols] +
                [pl.col(c).std(ddof=0).alias(f"std_{c}") for c in z_cols])
).collect(streaming=True)
lf_stats_glb = (
    lf_clean.filter(pl.col(G_DATE) <= stats_hi)
            .select([pl.col(c).mean().alias(f"mu_{c}_glb") for c in z_cols] +
                    [pl.col(c).std(ddof=0).alias(f"std_{c}_glb") for c in z_cols])
).collect(streaming=True)


标准化使用训练集的上限日期 = 1583


In [15]:
# 逐日处理

z_prefix = "az://jackson/js_exp/exp/v1/tft/z_shards"; ensure_dir_az(z_prefix)

eps = 1e-6
for d in all_days:
    print(f"processing date: {d}")
    lf_day = lf_clean.filter(pl.col(G_DATE) == d)
    
    lf_day_z = (
        lf_day.join(lf_stats_glb.lazy(), how="cross").join(lf_stats_sym.lazy(), on=G_SYM, how="left").sort([G_SYM, "time_idx"])
        
    )
    for c in z_cols:
        mu_sym, std_sym = f"mu_{c}", f"std_{c}"
        mu_glb, std_glb = f"mu_{c}_glb", f"std_{c}_glb"
        mu_use, std_use = f"mu_{c}_use", f"std_{c}_use"
        z_c = f"z_{c}"
        
        lf_day_z = lf_day_z.with_columns([
            pl.when(pl.col(mu_sym).is_null()).then(pl.col(mu_glb)).otherwise(pl.col(mu_sym)).alias(mu_use),
            pl.when(pl.col(std_sym).is_null() | (pl.col(std_sym) == 0)).then(pl.col(std_glb)).otherwise(pl.col(std_sym)).alias(std_use),   
        ]).with_columns(
            ((pl.col(c) - pl.col(mu_use)) / (pl.col(std_use) + eps)).alias(z_c)
        ).drop([mu_sym, std_sym, mu_glb, std_glb, mu_use, std_use])
    
    out_cols = [G_SYM, G_DATE, G_TIME, "time_idx", WEIGHT_COL, TARGET_COL, *TIME_FEATURES, *[f"z_{c}" for c in z_cols]] 
    lf_day_z = lf_day_z.select([c for c in out_cols if c in lf_day_z.collect_schema().names()])
    lf_day_z = lf_day_z.sort([G_SYM, "time_idx"])
    
    # 写出
    out_path = f"{z_prefix}/z_shard_{d:04d}.parquet"
    
    lf_day_z.collect(streaming=True).write_parquet(
        out_path,
        storage_options=storage_options,
        compression="zstd")
    


processing date: 1500
processing date: 1501
processing date: 1502
processing date: 1503
processing date: 1504
processing date: 1505
processing date: 1506
processing date: 1507
processing date: 1508
processing date: 1509
processing date: 1510
processing date: 1511
processing date: 1512
processing date: 1513
processing date: 1514
processing date: 1515
processing date: 1516
processing date: 1517
processing date: 1518
processing date: 1519
processing date: 1520
processing date: 1521
processing date: 1522
processing date: 1523
processing date: 1524
processing date: 1525
processing date: 1526
processing date: 1527
processing date: 1528
processing date: 1529
processing date: 1530
processing date: 1531
processing date: 1532
processing date: 1533
processing date: 1534
processing date: 1535
processing date: 1536
processing date: 1537
processing date: 1538
processing date: 1539
processing date: 1540
processing date: 1541
processing date: 1542
processing date: 1543
processing date: 1544
processing

In [16]:
# 明确第一折训练日期区间
train_start_date = int(folds_by_day[fold_id][0][0])
train_end_date   = int(folds_by_day[fold_id][0][-1])
print(f"第一折训练集日期区间 = [{train_start_date}, {train_end_date}]")

第一折训练集日期区间 = [1500, 1583]


In [None]:
# Incremental PCA

from sklearn.decomposition import IncrementalPCA
import joblib

pca_cols = [f"z_{c}" for c in z_cols[100:]]
keep_cols = [f"z_{c}" for c in z_cols[:100]]
max_components = min(100, len(pca_cols))

ipca = IncrementalPCA(n_components=max_components)

for d in range(train_start_date, train_end_date + 1):
    path = f"{z_prefix}/z_shard_{d:04d}.parquet"
    print(f"fitting IPCA on date: {d}, path: {path}")
    
    df_day = pl.read_parquet(path, storage_options=storage_options).select(pca_cols).to_pandas()
    df_day.fillna(0.0, inplace=True)
    X = df_day.to_numpy()
    ipca.partial_fit(X)  # 增量拟合


fitting IPCA on date: 1500, path: az://jackson/js_exp/exp/v1/tft/z_shards/z_shard_1500.parquet
fitting IPCA on date: 1501, path: az://jackson/js_exp/exp/v1/tft/z_shards/z_shard_1501.parquet
fitting IPCA on date: 1502, path: az://jackson/js_exp/exp/v1/tft/z_shards/z_shard_1502.parquet
fitting IPCA on date: 1503, path: az://jackson/js_exp/exp/v1/tft/z_shards/z_shard_1503.parquet
fitting IPCA on date: 1504, path: az://jackson/js_exp/exp/v1/tft/z_shards/z_shard_1504.parquet
fitting IPCA on date: 1505, path: az://jackson/js_exp/exp/v1/tft/z_shards/z_shard_1505.parquet
fitting IPCA on date: 1506, path: az://jackson/js_exp/exp/v1/tft/z_shards/z_shard_1506.parquet
fitting IPCA on date: 1507, path: az://jackson/js_exp/exp/v1/tft/z_shards/z_shard_1507.parquet
fitting IPCA on date: 1508, path: az://jackson/js_exp/exp/v1/tft/z_shards/z_shard_1508.parquet
fitting IPCA on date: 1509, path: az://jackson/js_exp/exp/v1/tft/z_shards/z_shard_1509.parquet
fitting IPCA on date: 1510, path: az://jackson/js_

In [49]:
# 拟合完成后，看看累计方差占比，决定真正保留的维度k (95%)
cum = ipca.explained_variance_ratio_.cumsum()

k = int(np.searchsorted(cum, 0.5)) + 1
k

38

In [38]:
int(np.searchsorted(cum, 0.65)) + 1

92

In [None]:
# 保存模型
ipca.components_ = IncrementalPCA(n_components=k)
ipca.mean_ = ipca.mean_ # 保持不变
ipca.n_components_ = k
ipca.explained_variance_ = ipca.explained_variance_ratio_[:k]


In [None]:
pca_dict = f"{TFT_LOCAL_ROOT}/ipca_{train_start_date}_{train_end_date}_k{k}_{_now()}.joblib"
joblib.dump(
    {"ipca": ipca, "pca_cols": pca_cols, "cum_ratio": cum, "chosen_k": k}, 
    pca_dict
)
print(f"IPCA 模型已保存至: {pca_dict}")
print(f"保留主成分维度 k = {k}")

In [None]:
# 按7天合并
ten_days_prefix = "az://jackson/js_exp/exp/v1/tft/z_ten_chunks"; ensure_dir_az(ten_days_prefix)

chunk_size = 10
data_start_date = 1400
data_end_date = 1600
for lo in range(data_start_date, data_end_date+1, chunk_size):
    hi = min(lo + chunk_size - 1, data_end_date)
    paths = [f"{z_prefix}/z_shard_{d:04d}.parquet" for d in range(lo, hi + 1)]
    print(f"merging date range: {lo} ~ {hi}, num files = {len(paths)}")
    
    lf = pl.scan_parquet(paths, storage_options=storage_options)
    
    df = lf.collect(streaming=True).sort([G_SYM, "time_idx"]).rechunk()
    
    out_path = f"{ten_days_prefix}/z__chunk_{lo:04d}_{hi:04d}.parquet"
    print(f"writing to: {out_path}")
    df.write_parquet(
        out_path,
        storage_options=storage_options,
        compression="zstd",
    )
print(f"[{_now()}] all done")
    


## 训练模型

In [None]:
KNOWN_REALS = [c for c in lf_clean.collect_schema().names() if c not in (G_SYM, G_DATE, G_TIME, "time_idx", "time_bucket", WEIGHT_COL, TARGET_COL)]

KNOWN_CATEGORIES = ["time_bucket"] if 'time_bucket' in lf_clean.collect_schema().names() else []

UNSCALE_COLS = KNOWN_REALS

TRAIN_COLS = [c for c in lf_clean.collect_schema().names() if c not in (G_DATE, G_TIME)]

# 定义 identity scalers
identity_scalers = {name: None for name in UNSCALE_COLS}

## try

In [None]:
# 取第一折先试探一下
best_ckpt_paths, fold_metrics = [], []
fold_id = 0
train_days, val_days = folds_by_day[0]

print(f"[fold {fold_id}] train {train_days[0]}..{train_days[-1]} ({len(train_days)} days), "
    f"val {val_days[0]}..{val_days[-1]} ({len(val_days)} days)")

# 明确日期：
train_start_date = int(train_days[0])
train_end_date   = int(train_days[-1])
val_start_date   = int(val_days[0])
val_end_date     = int(val_days[-1])      


# 提取数据
date_range = (train_start_date, val_end_date)
pdf_data = (
    pl.scan_parquet(clean_path_local)
    .filter(pl.col(G_DATE).is_between(train_start_date, val_end_date, closed="both"))
    .collect(streaming=True)
    .to_pandas()
    .sort_values([G_SYM, "time_idx"])
)
pdf_data[G_SYM] = pdf_data[G_SYM].astype("str")
if "time_bucket" in pdf_data.columns:
    pdf_data["time_bucket"] = pdf_data["time_bucket"].astype("str")

# 明确 indexes:
train_end_idx = pdf_data.loc[pdf_data[G_DATE] == train_end_date, "time_idx"].max()
val_start_idx = pdf_data.loc[pdf_data[G_DATE] == val_start_date, "time_idx"].min()
val_end_idx   = pdf_data.loc[pdf_data[G_DATE] == val_end_date, "time_idx"].max()
assert pd.notna(train_end_idx) and pd.notna(val_start_idx) and pd.notna(val_end_idx), "train/val idx not found"
train_end_idx, val_start_idx, val_end_idx = int(train_end_idx), int(val_start_idx), int(val_end_idx)
print(f"[fold {fold_id}] train idx up to {train_end_idx}, val idx {val_start_idx}..{val_end_idx}")  



In [None]:
t_data = pdf_data[TRAIN_COLS]

identity_scalers = {name: None for name in UNSCALE_COLS}
base_ds = TimeSeriesDataSet(
    t_data,
    time_idx="time_idx",
    target=TARGET_COL,
    group_ids=[G_SYM],
    weight=WEIGHT_COL,
    max_encoder_length=ENC_LEN, 
    min_encoder_length=ENC_LEN,
    max_prediction_length=PRED_LEN, 
    min_prediction_length=PRED_LEN,
    
    static_categoricals=[G_SYM],
    time_varying_known_categoricals=KNOWN_CATEGORIES,
    time_varying_known_reals =KNOWN_REALS,

    categorical_encoders={G_SYM: NaNLabelEncoder(add_nan=True),
                          "time_bucket": NaNLabelEncoder(add_nan=True) if "time_bucket" in KNOWN_CATEGORIES else None,
                          },
    add_relative_time_idx=True,
    add_target_scales=True,
    add_encoder_length=True,
    allow_missing_timesteps=True,
    target_normalizer=GroupNormalizer(
        method="standard", groups=[G_SYM], center=True, scale_by_group=False),
    scalers=identity_scalers,
)

In [None]:
# 划分训练集，验证集
train_ds = base_ds.filter(
    lambda idx: (
        idx.time_idx_last <= train_end_idx
    ),
    copy=True
)

val_ds = base_ds.filter(
    lambda idx: (
        (idx.time_idx_first_prediction == val_start_idx + ENC_LEN) &
        
        (idx.time_idx_last <= val_end_idx)
    ),
    copy=True
)

In [None]:
# 数据集加载

train_loader = train_ds.to_dataloader(
    train=True, 
    batch_size=BATCH_SIZE, 
    num_workers=14,
    pin_memory=True,
    persistent_workers=False,
    prefetch_factor=8,
)

n_train_batches = len(train_loader)
print(f"[debug] train_loader batches = {n_train_batches}")
assert n_train_batches > 0, "Empty train dataloader. Check min_prediction_idx/ENC_LEN/date windows."

val_loader = val_ds.to_dataloader(
    train=False,
    batch_size=BATCH_SIZE,
    num_workers=14,
    pin_memory=True,
    persistent_workers=True,
    prefetch_factor=8,
)

n_val_batches = len(val_loader)
print(f"[debug] val_loader batches = {n_val_batches}")
assert n_val_batches > 0, "Empty val dataloader. Check min_prediction_idx/ENC_LEN/date windows."

In [None]:
lp.seed_everything(42)
trainer = lp.Trainer(
    accelerator="gpu",
    # clipping gradients is a hyperparameter and important to prevent divergance
    # of the gradient for recurrent neural networks
    gradient_clip_val=0.1,
)

tft = TemporalFusionTransformer.from_dataset(
    train_ds,
    # not meaningful for finding the learning rate but otherwise very important
    learning_rate=LR,
    hidden_size=HIDDEN,  # most important hyperparameter apart from learning rate
    # number of attention heads. Set to up to 4 for large datasets
    attention_head_size=HEADS,
    dropout=DROPOUT,  # between 0.1 and 0.3 are good values
    hidden_continuous_size=HIDDEN // 2,  # set to <= hidden_size
    loss=RMSE(),
    optimizer=torch.optim.Adam,
    # reduce learning rate if no improvement in validation loss after x epochs
    # reduce_on_plateau_patience=1000,
)
print(f"Number of parameters in network: {tft.size() / 1e3:.1f}k")
    
# find optimal learning rate
from lightning.pytorch.tuner import Tuner

res = Tuner(trainer).lr_find(
    tft,
    train_dataloaders=train_loader,
    val_dataloaders=val_loader,
    max_lr=10.0,
    min_lr=1e-6,
)

print(f"suggested learning rate: {res.suggestion()}")
fig = res.plot(show=True, suggest=True)
fig.show()

## main

In [None]:
lp.seed_everything(42) 

# ========== 8) 训练（按 CV 折） ========== 先取第一折
best_ckpt_paths, fold_metrics = [], []

#for fold_id, (train_days, val_days) in enumerate(folds_by_day, start=1):
####################################
fold_id = 0
train_days, val_days = folds_by_day[fold_id]
####################################

print(f"[fold {fold_id}] train {train_days[0]}..{train_days[-1]} ({len(train_days)} days), "
    f"val {val_days[0]}..{val_days[-1]} ({len(val_days)} days)")

# 明确日期：
train_start_date = int(train_days[0])
train_end_date   = int(train_days[-1])
val_start_date   = int(val_days[0])
val_end_date     = int(val_days[-1])      

# 提取数据
date_range = (train_start_date, val_end_date)
pdf_data = (
    lf_clean
    .filter(pl.col(G_DATE).is_between(train_start_date, val_end_date, closed="both"))
    .collect(streaming=True)
    .to_pandas()
) 

pdf_data[G_SYM] = pdf_data[G_SYM].astype("str")
if "time_bucket" in pdf_data.columns:
    pdf_data["time_bucket"] = pdf_data["time_bucket"].astype("str")
pdf_data.sort_values([G_SYM, "time_idx"], inplace=True)

In [None]:

# 明确 indexes:
train_end_idx = pdf_data.loc[pdf_data[G_DATE] == train_end_date, "time_idx"].max()
val_start_idx = pdf_data.loc[pdf_data[G_DATE] == val_start_date, "time_idx"].min()
val_end_idx   = pdf_data.loc[pdf_data[G_DATE] == val_end_date, "time_idx"].max()
assert pd.notna(train_end_idx) and pd.notna(val_start_idx) and pd.notna(val_end_idx), "train/val idx not found"
train_end_idx, val_start_idx, val_end_idx = int(train_end_idx), int(val_start_idx), int(val_end_idx)
print(f"[fold {fold_id}] train idx up to {train_end_idx}, val idx {val_start_idx}..{val_end_idx}")  


In [None]:
# 构建训练集 timeseries dataset

t_data = pdf_data[TRAIN_COLS]
train_df = t_data.loc[t_data["time_idx"] <= train_end_idx]

train_ds = TimeSeriesDataSet(
    train_df,
    time_idx="time_idx",
    target=TARGET_COL,
    group_ids=[G_SYM],
    weight=WEIGHT_COL,
    max_encoder_length=ENC_LEN, min_encoder_length=ENC_LEN,
    max_prediction_length=PRED_LEN, min_prediction_length=PRED_LEN,
    time_varying_known_reals =KNOWN_REALS,
    time_varying_known_categoricals=KNOWN_CATEGORIES,
    static_categoricals=[G_SYM],
    categorical_encoders={G_SYM: NaNLabelEncoder(add_nan=True),
                        "time_bucket": NaNLabelEncoder(add_nan=True) if "time_bucket" in pdf_data.columns else None},
    add_relative_time_idx=True,
    add_target_scales=True,
    add_encoder_length=True,
    allow_missing_timesteps=True,
    target_normalizer=GroupNormalizer(
        method="standard", groups=[G_SYM], center=True, scale_by_group=False),
    scalers=identity_scalers,
)

In [None]:

# 验证集复用 train_ds 的所有 encoders/normalizer（不泄漏）

val_df = t_data.loc[t_data["time_idx"].between(val_start_idx, val_end_idx, inclusive="both")]
val_ds = TimeSeriesDataSet.from_dataset(
    train_ds,
    val_df,
    min_prediction_idx=val_start_idx+ENC_LEN,
    stop_randomization=True,
    predict=False
)

In [None]:
# 数据集加载

train_loader = train_ds.to_dataloader(
    train=True, 
    batch_size=BATCH_SIZE, 
    num_workers=14,
    pin_memory=True,
    persistent_workers=True,
    prefetch_factor=4,
    shuffle=True
)

n_train_batches = len(train_loader)
print(f"[debug] train_loader batches = {n_train_batches}")
assert n_train_batches > 0, "Empty train dataloader. Check min_prediction_idx/ENC_LEN/date windows."

val_loader = val_ds.to_dataloader(
    train=False,
    batch_size=BATCH_SIZE*4,
    num_workers=14,
    pin_memory=True,
    persistent_workers=True,
    prefetch_factor=4,
    shuffle=False,
)

n_val_batches = len(val_loader)
print(f"[debug] val_loader batches = {n_val_batches}")
assert n_val_batches > 0, "Empty val dataloader. Check min_prediction_idx/ENC_LEN/date windows."

In [None]:
# 8.6 callbacks/logger/trainer
ckpt_dir_fold = Path(CKPTS_DIR) / f"fold_{fold_id}"
ckpt_dir_fold.mkdir(parents=True, exist_ok=True)

callbacks = [EarlyStopping(monitor="val_loss", mode="min", patience=5, check_on_train_epoch_end=False),
            ModelCheckpoint(monitor="val_loss", mode="min", save_top_k=1, dirpath=ckpt_dir_fold.as_posix(), filename=f"fold{fold_id}-tft-best-{{epoch:02d}}-{{val_loss:.5f}}", save_on_train_epoch_end=False),
            LearningRateMonitor(logging_interval="step"),]
RUN_NAME = (f"f{fold_id}"f"_E{MAX_EPOCHS}"f"_lr{LR}"f"_bs{BATCH_SIZE}"f"_enc{ENC_LEN}_dec{DEC_LEN}"f"_{datetime.now().strftime('%Y%m%d-%H%M%S')}")
logger = TensorBoardLogger(save_dir=LOGS_DIR.as_posix(),name="tft",version=RUN_NAME,default_hp_metric=False)

trainer = lp.Trainer(max_epochs=3,
                    accelerator="gpu",
                    devices=1,
                    precision="bf16-mixed",
                    enable_model_summary=True,
                    gradient_clip_val=1.0,
                    gradient_clip_algorithm="norm",
                    fast_dev_run=False,
                    limit_train_batches=50,
                    limit_val_batches=25,
                    val_check_interval=0.25,
                    num_sanity_val_steps=2,
                    log_every_n_steps=10,
                    callbacks=callbacks,
                    logger=logger,
                    accumulate_grad_batches=1,
                    deterministic=False
                    )

tft = TemporalFusionTransformer.from_dataset(
    train_ds,
    learning_rate=LR,
    hidden_size=HIDDEN,
    attention_head_size=HEADS,
    dropout=DROPOUT,
    hidden_continuous_size=HIDDEN // 2,
    loss=RMSE(),
    logging_metrics=[RMSE()],
    optimizer=torch.optim.AdamW,
    optimizer_params={"weight_decay": 1e-4},
    reduce_on_plateau_patience=4,
)
trainer.fit(
    tft,
    train_dataloaders=train_loader,
    val_dataloaders=val_loader,
    )

In [None]:
ckpt_cb = next(cb for cb in callbacks if isinstance(cb, ModelCheckpoint))
best_path = ckpt_cb.best_model_path
best_tft = TemporalFusionTransformer.load_from_checkpoint(best_path)

predictions = best_tft.predict(
    val_loader,
    return_y=True,
    trainer_kwargs=dict(accelerator="gpu")
)
y_pred = predictions.output
y_true, w = predictions.y

In [None]:
num = (w * (y_true - y_pred).pow(2)).sum()
den = (w * y_true.pow(2)).sum()

wr2 = 1.0 - num / (den + eps)
print(f"wr2 after training: {wr2.item():.6f}")

In [None]:
val_ds.get_parameters

In [None]:
num = (torch.square(y_true - y_pred) * w).sum()
den = (torch.square(y_true) * w).sum()  
wr2 = 1 - num / den
print(f"wr2 after training: {wr2}")