## 导入库

In [1]:
# -*- coding: utf-8 -*-
from __future__ import annotations

# ── 标准库 ──────────────────────────────────────────────────────────────────
import os
import time
from pathlib import Path
from collections import defaultdict
from datetime import datetime

# ── 第三方 ──────────────────────────────────────────────────────────────────
import numpy as np
import pandas as pd
import polars as pl

import gc

import torch
import torch.backends.cudnn as cudnn
import lightning as L
import lightning.pytorch as lp
from torch.utils.data import DataLoader

from lightning.pytorch.callbacks import EarlyStopping, ModelCheckpoint, LearningRateMonitor
from lightning.pytorch.loggers import TensorBoardLogger
from lightning.pytorch.callbacks import DeviceStatsMonitor
from pytorch_forecasting import TimeSeriesDataSet, TemporalFusionTransformer, Baseline
from pytorch_forecasting.metrics import MAE, RMSE
from pytorch_forecasting.data.encoders import NaNLabelEncoder
from pytorch_forecasting.data import TorchNormalizer, GroupNormalizer


# 你的工程工具
from pipeline.io import cfg, P, fs, storage_options, ensure_dir_local, ensure_dir_az
from pipeline.stream_input_local import ShardedBatchStream  
from pipeline.wr2 import WR2

# ---- 性能/兼容开关（仅一次）----
os.environ.setdefault("POLARS_MAX_THREADS", str(max(1, os.cpu_count() // 2)))
pl.enable_string_cache()
cudnn.benchmark = True
torch.set_float32_matmul_precision("high")

import time as _t


import warnings
warnings.filterwarnings("ignore")  # avoid printing out absolute paths

def _now() -> str:
    return _t.strftime("%Y-%m-%d %H:%M:%S")
print(f"[{_now()}] imports ok")

[2025-10-15 23:24:01] imports ok


## 定义工具函数

In [2]:
# ───────────────────────────────────────────────────────────────────────────
# 滑动窗划分
def make_sliding_cv_by_days(all_days: np.ndarray, *, n_splits: int, gap_days: int, train_to_val: int):
    all_days = np.asarray(all_days).ravel()
    K, R, G = n_splits, train_to_val, gap_days
    usable = len(all_days) - G
    if usable <= 0 or K <= 0 or R <= 0:
        return []
    V_base, rem = divmod(usable, R + K)
    if V_base <= 0:
        return []
    T = R * V_base
    v_lens = [V_base + 1 if i < rem else V_base for i in range(K)]
    folds, v_lo = [], T + G
    for V_i in v_lens:
        v_hi, tr_hi, tr_lo = v_lo + V_i, v_lo - G, v_lo - G - T
        if tr_lo < 0 or v_hi > len(all_days):
            break
        folds.append((all_days[tr_lo:tr_hi], all_days[v_lo:v_hi]))
        v_lo = v_hi
    return folds


## 初始化参数

In [3]:
# 读入筛选的特征列
df_ranking_features = pd.read_csv("/mnt/data/js/exp/v1/models/tune/feature_importance__fixed__fixed__mm_full_train__features__fs__1300-1500__cv3-g7-r4__seed42__top1000__1760299442__range1000-1600__range1000-1600__cv2-g7-r4__1760347190.csv")
de_corr_features = pd.read_csv("/mnt/data/js/exp/v1/tft/selected_features/selected_features__decorr__tau0.95__1400-1600.csv")
df_ranking_decorr_features = df_ranking_features.loc[df_ranking_features['feature'].isin(de_corr_features['feature'])].copy()
df_e_features = df_ranking_decorr_features.reset_index(drop=True)

# 我们这里重新归一化一下
df_e_features['mean_gain'] = (df_e_features['mean_gain'] / df_e_features['mean_gain'].sum()).astype(np.float32)

In [None]:
df_e_features.iloc[:150]['mean_gain'].sum()

In [4]:
# ========== 1) 统一配置 ==========
G_SYM, G_DATE, G_TIME = cfg["keys"]          # e.g. ("symbol_id","date_id","time_id")
TARGET_COL = cfg["target"]                   # e.g. "responder_6"
WEIGHT_COL = cfg["weight"]                   # 允许为 None

TIME_FEATURES = ["time_bucket", "time_pos", "time_sin", "time_cos"]
COV_FEATURES = df_e_features['feature'].tolist()
# 明确要标准化的列
z_cols = ["time_pos"] + COV_FEATURES


# 训练 & CV 超参
N_SPLITS     = 1
GAP_DAYS     = 0
TRAIN_TO_VAL = 8
ENC_LEN      = 10
DEC_LEN      = 1
PRED_LEN     = DEC_LEN
BATCH_SIZE   = 512
LR           = 1e-4
HIDDEN       = 64
HEADS        = 1
DROPOUT      = 0.1
MAX_EPOCHS   = 30

# 数据路径
PANEL_DIR_AZ   = P("az", cfg["paths"].get("panel_shards", "panel_shards"))

TFT_LOCAL_ROOT = P("local", "tft"); ensure_dir_local(TFT_LOCAL_ROOT)

LOCAL_CLEAN_DIR = f"{TFT_LOCAL_ROOT}/clean"; ensure_dir_local(LOCAL_CLEAN_DIR)
CKPTS_DIR = Path(TFT_LOCAL_ROOT) / "ckpts"; ensure_dir_local(CKPTS_DIR.as_posix())
LOGS_DIR  = Path(TFT_LOCAL_ROOT) / "logs";  ensure_dir_local(LOGS_DIR.as_posix())


start_date, end_date = (1500, 1600)


print("[config] ready")

[config] ready


## 数据导入

In [None]:
data_paths = fs.glob("az://jackson/js_exp/exp/v1/panel_shards/*.parquet")
data_paths =[f"az://{p}" for p in data_paths]

lf_data = (
    pl.scan_parquet(data_paths, storage_options=storage_options)
    .select([*cfg['keys'], WEIGHT_COL, TARGET_COL, *TIME_FEATURES, *COV_FEATURES])
    .filter(pl.col(G_DATE).is_between(start_date, end_date, closed="both"))
)
lf_data = lf_data.sort([G_SYM, G_DATE, G_TIME])


In [None]:
# 检查缺失值
df_null_case = lf_data.select(pl.all().is_null().sum()).collect()

In [None]:
df_null = df_null_case.to_pandas().T

In [None]:
# 修改列名
df_null.index.name = 'feature'
df_null = df_null.rename(columns={df_null.columns[0]: "null_count"}).reset_index()

In [None]:
df_null = df_null.sort_values(by="null_count", ascending=False)
df_null

In [None]:
lf_data.select(pl.len()).collect()

## 数据处理

### 添加全局时间序列号

In [None]:
grid_df = (
    lf_data.select([G_DATE, G_TIME]).unique()
        .sort([G_DATE, G_TIME])
        .with_row_index("time_idx")
        .with_columns(pl.col("time_idx").cast(pl.Int64))
        .collect(streaming=True)
)

In [None]:
container_prefix = "az://jackson/js_exp/exp/v1/tft/panel_clean_shards"; ensure_dir_az(container_prefix)
chunk_size = 30
for lo in range(start_date, end_date + 1, chunk_size):
    hi = min(lo + chunk_size - 1, end_date)
    print(f"processing date range: {lo} ~ {hi}")
    
    lf_chunk = lf_data.filter(pl.col(G_DATE).is_between(lo, hi, closed="both"))
    
    lf_grid_chunk = (
        grid_df.lazy().filter(pl.col(G_DATE).is_between(lo, hi, closed="both"))
    )
    
    lf_joined = (
        lf_chunk.join(lf_grid_chunk, on=[G_DATE, G_TIME], how="left").sort([G_SYM, "time_idx"])
    )
    
    out_path = f"{container_prefix}/panel_clean_{lo:04d}_{hi:04d}.parquet"
    print(f"writing to: {out_path}")
    
    lf_joined.sink_parquet(
        out_path,
        storage_options=storage_options,
        compression="zstd",
    )
print(f"[{_now()}] all done")

## 导入新数据

In [5]:
# 重新读入数据
container_prefix = "az://jackson/js_exp/exp/v1/tft/panel_clean_shards"; ensure_dir_az(container_prefix)
data_paths = fs.glob(f"{container_prefix}/*.parquet")
data_paths = [f"az://{p}" for p in data_paths] 
lf_with_idx = pl.scan_parquet(data_paths, storage_options=storage_options).sort([G_SYM, G_DATE, G_TIME])

## CV 划分

In [6]:
# ==========  CV 划分 ==========
all_days = (
    lf_with_idx.select(pl.col(G_DATE)).unique().sort([G_DATE])
    .collect(streaming=True)[G_DATE].to_numpy()
)
folds_by_day = make_sliding_cv_by_days(all_days, n_splits=N_SPLITS, gap_days=GAP_DAYS, train_to_val=TRAIN_TO_VAL)

print(f"[cv] total {len(folds_by_day)} folds")

assert len(folds_by_day) > 0, "no CV folds constructed"
print(folds_by_day)

[cv] total 1 folds
[(array([1500, 1501, 1502, 1503, 1504, 1505, 1506, 1507, 1508, 1509, 1510,
       1511, 1512, 1513, 1514, 1515, 1516, 1517, 1518, 1519, 1520, 1521,
       1522, 1523, 1524, 1525, 1526, 1527, 1528, 1529, 1530, 1531, 1532,
       1533, 1534, 1535, 1536, 1537, 1538, 1539, 1540, 1541, 1542, 1543,
       1544, 1545, 1546, 1547, 1548, 1549, 1550, 1551, 1552, 1553, 1554,
       1555, 1556, 1557, 1558, 1559, 1560, 1561, 1562, 1563, 1564, 1565,
       1566, 1567, 1568, 1569, 1570, 1571, 1572, 1573, 1574, 1575, 1576,
       1577, 1578, 1579, 1580, 1581, 1582, 1583, 1584, 1585, 1586, 1587],
      dtype=int32), array([1588, 1589, 1590, 1591, 1592, 1593, 1594, 1595, 1596, 1597, 1598,
       1599], dtype=int32))]


## 数据标准化处理

下面我们以一折为例，验证成功后，再使用多折

In [7]:
fold_id = 0
# 取第一个 fold 的训练集最后一天，作为本折统计 z-score 的上界
train_lo, train_hi = folds_by_day[fold_id][0][0], folds_by_day[fold_id][0][-1]
val_lo, val_hi = folds_by_day[fold_id][1][0], folds_by_day[fold_id][1][-1]

print(f"统计标准化使用训练集的日期范围 = {train_lo} ~ {train_hi}")

统计标准化使用训练集的日期范围 = 1500 ~ 1587


In [8]:
# 计算训练集的各特征量的中位数
lf_tr = lf_with_idx.filter(pl.col(G_DATE).is_between(train_lo, train_hi, closed="both"))

grp_median = (
    lf_tr
    .group_by([G_SYM, "time_pos"])
    .agg([
        pl.col(col).median().alias(f"{col}_median_st") for col in COV_FEATURES
    ])
).sort([G_SYM, "time_pos"])

glb_median = (
    lf_tr
    .group_by("time_pos")
    .agg([
        pl.col(col).median().alias(f"{col}_median_s") for col in COV_FEATURES
    ])
).sort("time_pos")

In [None]:
# 应用于本折全部数据 trian + val
lf_all = lf_with_idx.filter(pl.col(G_DATE).is_between(train_lo, val_hi, closed="both"))
lf_all = lf_all.join(grp_median, on=[G_SYM, "time_pos"], how="left").join(glb_median, on=["time_pos"], how="left").sort([G_SYM, "time_idx"])

# 逐列用中位数替换缺失值
fill_exprs = []

for col in COV_FEATURES:
    fill_exprs.append(
        pl.coalesce([
            pl.col(col),
            pl.col(f"{col}_median_st"),
            pl.col(f"{col}_median_s")
        ]).alias(col)
    )

lf_all_imputed = lf_all.with_columns(fill_exprs)

# 去掉中位数列
drop_cols = [f"{col}_median_st" for col in COV_FEATURES]
drop_cols += [f"{col}_median_s" for col in COV_FEATURES]
lf_all_imputed = lf_all_imputed.drop(drop_cols)

# 显式检查：若仍有缺失，直接报错
remain = lf_all_imputed.select([pl.col(c).is_null().sum().alias(c) for c in COV_FEATURES]).collect()
bad = {c: remain[c][0] for c in COV_FEATURES if remain[c][0] > 0}
if bad:
    raise ValueError(f"补值失败：以下列在 (symbol_id,time_pos) 和 time_pos 层级都无中位数 -> {bad}")

In [None]:
# ========== Z-score ==========
# 开始计算 stats
lf_tr_imputed = lf_all_imputed.filter(pl.col(G_DATE).is_between(train_lo, train_hi, closed="both"))

grp_mean = (
    lf_tr_imputed
    .group_by(G_SYM)
    .agg([pl.col(c).mean().alias(f"mu_{c}") for c in z_cols] +
        [pl.col(c).std(ddof=0).alias(f"std_{c}") for c in z_cols])
).collect(streaming=True)

glb_mean = (
    lf_tr_imputed
    .select([pl.col(c).mean().alias(f"mu_{c}_glb") for c in z_cols] +
            [pl.col(c).std(ddof=0).alias(f"std_{c}_glb") for c in z_cols])
).collect(streaming=True)


In [None]:
# 逐日处理本折所有数据

z_prefix = "az://jackson/js_exp/exp/v1/tft/z_shards"; ensure_dir_az(z_prefix)

eps = 1e-6
for d in range(train_lo, val_hi + 1):
    print(f"processing date: {d}")
    lf_day = lf_all_imputed.filter(pl.col(G_DATE) == d)
    
    lf_day_z = (
        lf_day.join(glb_mean.lazy(), how="cross").join(grp_mean.lazy(), on=G_SYM, how="left") # 循环内拼接，避免OOM
        
    ).sort([G_SYM, "time_idx"])
    
    for c in z_cols:
        mu_sym, std_sym = f"mu_{c}", f"std_{c}"
        mu_glb, std_glb = f"mu_{c}_glb", f"std_{c}_glb"
        mu_use, std_use = f"mu_{c}_use", f"std_{c}_use"
        z_c = f"z_{c}"
        
        lf_day_z = lf_day_z.with_columns([
            pl.when(pl.col(mu_sym).is_null()).then(pl.col(mu_glb)).otherwise(pl.col(mu_sym)).alias(mu_use),
            pl.when(pl.col(std_sym).is_null() | (pl.col(std_sym) == 0)).then(pl.col(std_glb)).otherwise(pl.col(std_sym)).alias(std_use),   
        ]).with_columns(
            ((pl.col(c) - pl.col(mu_use)) / (pl.col(std_use) + eps)).alias(z_c)
        ).drop([mu_sym, std_sym, mu_glb, std_glb, mu_use, std_use])
    
    out_cols = [G_SYM, G_DATE, G_TIME, "time_idx", WEIGHT_COL, TARGET_COL, "time_bucket", "time_sin", "time_cos", *[f"z_{c}" for c in z_cols]] # 这个地方z_cols 含 z_time_pos
    lf_day_z = lf_day_z.select([c for c in out_cols if c in lf_day_z.collect_schema().names()])
    lf_day_z = lf_day_z.sort([G_SYM, "time_idx"])
    
    # 写出
    out_path = f"{z_prefix}/z_shard_{d:04d}.parquet"
    
    lf_day_z.collect(streaming=True).write_parquet(
        out_path,
        storage_options=storage_options,
        compression="zstd")
    


## PCA 降维

In [None]:
# 训练集上计算 PCA

from sklearn.decomposition import IncrementalPCA
import joblib


pca_cols = [f"z_{c}" for c in COV_FEATURES[150:]]
keep_cols = [f"z_{c}" for c in COV_FEATURES[:150]]

max_components = min(100, len(pca_cols))

ipca = IncrementalPCA(n_components=max_components)

for d in range(train_lo, train_hi + 1):
    path = f"{z_prefix}/z_shard_{d:04d}.parquet"
    print(f"fitting IPCA on date: {d} from the path: {path}")
    
    df_day = pl.read_parquet(path, storage_options=storage_options).select(pca_cols).to_pandas()
    X = df_day.to_numpy()
    ipca.partial_fit(X)  # 增量拟合


In [None]:
# 拟合完成后，看看累计方差占比，决定真正保留的维度k (95%)
cum = ipca.explained_variance_ratio_.cumsum()
tau =0.5
k = int(np.searchsorted(cum, tau)) + 1

print(f"{k} components explain {cum[k-1]:.4f} of variance")

# 保存模型
# —— 裁剪参数到前 k 维（不要重新 new 模型）——
ipca.components_              = ipca.components_[:k]
ipca.explained_variance_      = ipca.explained_variance_[:k]
ipca.explained_variance_ratio_= ipca.explained_variance_ratio_[:k]
ipca.singular_values_         = ipca.singular_values_[:k]
ipca.n_components_            = k


pca_path = f"{TFT_LOCAL_ROOT}/fold_{fold_id}_ipca_{train_lo}_{train_hi}_k{k}_{_now()}.joblib"
joblib.dump(
    {"ipca": ipca, "pca_cols": pca_cols, "cum_ratio": cum, "chosen_k": k}, 
    pca_path
)
print(f"IPCA 模型已保存至: {pca_path}")
print(f"保留主成分{tau}维度 k = {k}")

In [None]:
# 本折全部数据进行PCA变换

pca_prefix = "az://jackson/js_exp/exp/v1/tft/feat_pca_shards"; ensure_dir_az(pca_prefix)
meta_cols = [G_SYM, G_DATE, G_TIME, "time_idx", WEIGHT_COL, TARGET_COL, "time_bucket", "time_sin", "time_cos", "z_time_pos"]

pca_k = ipca.n_components_
pca_out_cols = [f"PC{i+1}" for i in range(pca_k)]

for d in range(train_lo, val_hi + 1):
    in_path = f"{z_prefix}/z_shard_{d:04d}.parquet"
    out_path = f"{pca_prefix}/pca_shard_{d:04d}.parquet"
    
    lf = pl.read_parquet(in_path, storage_options=storage_options)
    
    X = (
        lf.select([
            pl.col(c).cast(pl.Float32) for c in pca_cols
        ]).to_numpy()
    )
    if X.size == 0:
        # 当天没有行，直接跳过
        continue
    if np.isnan(X).any():
        raise ValueError("PCA transform 输入里仍有 NaN，请检查上游补值/标准化。")

    Z = ipca.transform(X).astype(np.float32)
    
    # 组装输出
    df_pca = pl.DataFrame(Z, schema=pca_out_cols)
    lf_out = pl.concat([
        lf.select([c for c in meta_cols if c in lf.columns]),
        lf.select([c for c in keep_cols if c in lf.columns]),
        df_pca
    ],
    how="horizontal"
    ).sort([G_SYM, "time_idx"])
    
    lf_out.write_parquet(
        out_path,
        storage_options=storage_options,
        compression="zstd"
    )
    print(f"written PCA features to: {out_path}")
print(f"[{_now()}] all done")


In [None]:
# 按10天合并
tft_fold_root = f"az://jackson/js_exp/exp/v1/tft/fold_{fold_id}"; ensure_dir_az(tft_fold_root)
ten_days_prefix = f"{tft_fold_root}/feat_pca_ten_days_shards"; ensure_dir_az(ten_days_prefix)

chunk_size = 10

for lo in range(train_lo, val_hi + 1, chunk_size):
    hi = min(lo + chunk_size - 1, folds_by_day[fold_id][1][-1])
    paths = [f"{pca_prefix}/pca_shard_{d:04d}.parquet" for d in range(lo, hi + 1)]
    
    print(f"merging date range: {lo} ~ {hi}, num files = {len(paths)}")
    
    lf = pl.scan_parquet(paths, storage_options=storage_options)
    
    df = lf.collect(streaming=True).sort([G_SYM, "time_idx"]).rechunk()
    
    out_path = f"{ten_days_prefix}/feat_pca_ten_days_chunk_{lo:04d}_{hi:04d}.parquet"
    print(f"writing to: {out_path}")
    df.write_parquet(
        out_path,
        storage_options=storage_options,
        compression="zstd",
    )
print(f"[{_now()}] all done")

## 训练模型

In [9]:
# 读取处理好的数据
tft_fold_root = f"az://jackson/js_exp/exp/v1/tft/fold_{fold_id}"; ensure_dir_az(tft_fold_root)
ten_days_prefix = f"{tft_fold_root}/feat_pca_ten_days_shards"; ensure_dir_az(ten_days_prefix)


lf_clean = pl.scan_parquet(f"{ten_days_prefix}/*.parquet", storage_options=storage_options)
lf_clean = lf_clean.sort([G_SYM, "time_idx"])


In [10]:
KNOWN_REALS = [c for c in lf_clean.collect_schema().names() if c not in (G_SYM, G_DATE, G_TIME, "time_idx", "time_bucket", WEIGHT_COL, TARGET_COL)]

KNOWN_CATEGORIES = ["time_bucket"]

UNSCALE_COLS = KNOWN_REALS

TRAIN_COLS = [c for c in lf_clean.collect_schema().names() if c not in (G_DATE, G_TIME)]

# 定义 identity scalers
identity_scalers = {name: None for name in UNSCALE_COLS}

## try

In [None]:
# 取第一折先试探一下
best_ckpt_paths, fold_metrics = [], []
fold_id = 0
train_days, val_days = folds_by_day[0]

print(f"[fold {fold_id}] train {train_days[0]}..{train_days[-1]} ({len(train_days)} days), "
    f"val {val_days[0]}..{val_days[-1]} ({len(val_days)} days)")

# 明确日期：
train_start_date = int(train_days[0])
train_end_date   = int(train_days[-1])
val_start_date   = int(val_days[0])
val_end_date     = int(val_days[-1])      


# 提取数据
date_range = (train_start_date, val_end_date)
pdf_data = (
    pl.scan_parquet(clean_path_local)
    .filter(pl.col(G_DATE).is_between(train_start_date, val_end_date, closed="both"))
    .collect(streaming=True)
    .to_pandas()
    .sort_values([G_SYM, "time_idx"])
)
pdf_data[G_SYM] = pdf_data[G_SYM].astype("str")
if "time_bucket" in pdf_data.columns:
    pdf_data["time_bucket"] = pdf_data["time_bucket"].astype("str")

# 明确 indexes:
train_end_idx = pdf_data.loc[pdf_data[G_DATE] == train_end_date, "time_idx"].max()
val_start_idx = pdf_data.loc[pdf_data[G_DATE] == val_start_date, "time_idx"].min()
val_end_idx   = pdf_data.loc[pdf_data[G_DATE] == val_end_date, "time_idx"].max()
assert pd.notna(train_end_idx) and pd.notna(val_start_idx) and pd.notna(val_end_idx), "train/val idx not found"
train_end_idx, val_start_idx, val_end_idx = int(train_end_idx), int(val_start_idx), int(val_end_idx)
print(f"[fold {fold_id}] train idx up to {train_end_idx}, val idx {val_start_idx}..{val_end_idx}")  



In [None]:
t_data = pdf_data[TRAIN_COLS]

identity_scalers = {name: None for name in UNSCALE_COLS}
base_ds = TimeSeriesDataSet(
    t_data,
    time_idx="time_idx",
    target=TARGET_COL,
    group_ids=[G_SYM],
    weight=WEIGHT_COL,
    max_encoder_length=ENC_LEN, 
    min_encoder_length=ENC_LEN,
    max_prediction_length=PRED_LEN, 
    min_prediction_length=PRED_LEN,
    
    static_categoricals=[G_SYM],
    time_varying_known_categoricals=KNOWN_CATEGORIES,
    time_varying_known_reals =KNOWN_REALS,

    categorical_encoders={G_SYM: NaNLabelEncoder(add_nan=True),
                          "time_bucket": NaNLabelEncoder(add_nan=True) if "time_bucket" in KNOWN_CATEGORIES else None,
                          },
    add_relative_time_idx=True,
    add_target_scales=True,
    add_encoder_length=True,
    allow_missing_timesteps=True,
    target_normalizer=GroupNormalizer(
        method="standard", groups=[G_SYM], center=True, scale_by_group=False),
    scalers=identity_scalers,
)

In [None]:
# 划分训练集，验证集
train_ds = base_ds.filter(
    lambda idx: (
        idx.time_idx_last <= train_end_idx
    ),
    copy=True
)

val_ds = base_ds.filter(
    lambda idx: (
        (idx.time_idx_first_prediction == val_start_idx + ENC_LEN) &
        
        (idx.time_idx_last <= val_end_idx)
    ),
    copy=True
)

In [None]:
# 数据集加载

train_loader = train_ds.to_dataloader(
    train=True, 
    batch_size=BATCH_SIZE, 
    num_workers=14,
    pin_memory=True,
    persistent_workers=False,
    prefetch_factor=8,
)

n_train_batches = len(train_loader)
print(f"[debug] train_loader batches = {n_train_batches}")
assert n_train_batches > 0, "Empty train dataloader. Check min_prediction_idx/ENC_LEN/date windows."

val_loader = val_ds.to_dataloader(
    train=False,
    batch_size=BATCH_SIZE,
    num_workers=14,
    pin_memory=True,
    persistent_workers=True,
    prefetch_factor=8,
)

n_val_batches = len(val_loader)
print(f"[debug] val_loader batches = {n_val_batches}")
assert n_val_batches > 0, "Empty val dataloader. Check min_prediction_idx/ENC_LEN/date windows."

In [None]:
lp.seed_everything(42)
trainer = lp.Trainer(
    accelerator="gpu",
    # clipping gradients is a hyperparameter and important to prevent divergance
    # of the gradient for recurrent neural networks
    gradient_clip_val=0.1,
)

tft = TemporalFusionTransformer.from_dataset(
    train_ds,
    # not meaningful for finding the learning rate but otherwise very important
    learning_rate=LR,
    hidden_size=HIDDEN,  # most important hyperparameter apart from learning rate
    # number of attention heads. Set to up to 4 for large datasets
    attention_head_size=HEADS,
    dropout=DROPOUT,  # between 0.1 and 0.3 are good values
    hidden_continuous_size=HIDDEN // 2,  # set to <= hidden_size
    loss=RMSE(),
    optimizer=torch.optim.Adam,
    # reduce learning rate if no improvement in validation loss after x epochs
    # reduce_on_plateau_patience=1000,
)
print(f"Number of parameters in network: {tft.size() / 1e3:.1f}k")
    
# find optimal learning rate
from lightning.pytorch.tuner import Tuner

res = Tuner(trainer).lr_find(
    tft,
    train_dataloaders=train_loader,
    val_dataloaders=val_loader,
    max_lr=10.0,
    min_lr=1e-6,
)

print(f"suggested learning rate: {res.suggestion()}")
fig = res.plot(show=True, suggest=True)
fig.show()

## main

In [11]:
lp.seed_everything(42) 

# ========== 训练（按 CV 折） ========== 先取第一折
best_ckpt_paths, fold_metrics = [], []
#for fold_id, (train_days, val_days) in enumerate(folds_by_day, start=1):

####################################
fold_id = 0
train_days, val_days = folds_by_day[fold_id]
####################################

print(f"[fold {fold_id}] train {train_days[0]}..{train_days[-1]} ({len(train_days)} days), "
    f"val {val_days[0]}..{val_days[-1]} ({len(val_days)} days)")

# 明确日期：
train_start_date = int(train_days[0])
train_end_date   = int(train_days[-1])
val_start_date   = int(val_days[0])
val_end_date     = int(val_days[-1])      

# 提取数据
date_range = (train_start_date, val_end_date)
pdf_data = (
    lf_clean
    .filter(pl.col(G_DATE).is_between(train_start_date, val_end_date, closed="both"))
    .collect(streaming=True)
    .to_pandas()
) 

pdf_data[G_SYM] = pdf_data[G_SYM].astype(str).astype("category")
pdf_data["time_bucket"] = pdf_data["time_bucket"].astype(str).astype("category")
pdf_data.sort_values([G_SYM, "time_idx"], inplace=True)

Seed set to 42


[fold 0] train 1500..1587 (88 days), val 1588..1599 (12 days)


In [12]:
# 明确 indexes:
train_end_idx = pdf_data.loc[pdf_data[G_DATE] == train_end_date, "time_idx"].max()
val_start_idx = pdf_data.loc[pdf_data[G_DATE] == val_start_date, "time_idx"].min()
val_end_idx   = pdf_data.loc[pdf_data[G_DATE] == val_end_date, "time_idx"].max()
assert pd.notna(train_end_idx) and pd.notna(val_start_idx) and pd.notna(val_end_idx), "train/val idx not found"
train_end_idx, val_start_idx, val_end_idx = int(train_end_idx), int(val_start_idx), int(val_end_idx)
print(f"[fold {fold_id}] train idx up to {train_end_idx}, val idx {val_start_idx}..{val_end_idx}")  


[fold 0] train idx up to 85183, val idx 85184..96799


In [13]:
pdf_data = pdf_data[TRAIN_COLS]

# 构建训练集 timeseries dataset
train_ds = TimeSeriesDataSet(
    pdf_data.loc[pdf_data["time_idx"] <= train_end_idx],
    time_idx="time_idx",
    target=TARGET_COL,
    group_ids=[G_SYM],
    weight=WEIGHT_COL,
    max_encoder_length=ENC_LEN, min_encoder_length=ENC_LEN,
    max_prediction_length=PRED_LEN, min_prediction_length=PRED_LEN,
    time_varying_known_reals =KNOWN_REALS,
    time_varying_known_categoricals=KNOWN_CATEGORIES,
    static_categoricals=[G_SYM],
    categorical_encoders={G_SYM: NaNLabelEncoder(add_nan=True),
                        "time_bucket": NaNLabelEncoder(add_nan=True)},
    add_relative_time_idx=True,
    add_target_scales=True,
    add_encoder_length=True,
    allow_missing_timesteps=True,
    target_normalizer=GroupNormalizer(
        method="standard", groups=[G_SYM], center=True, scale_by_group=False),
    scalers=identity_scalers,
)

# 验证集复用 train_ds 的所有 encoders/normalizer（不泄漏）

val_ds = TimeSeriesDataSet.from_dataset(
    train_ds,
    pdf_data.loc[pdf_data["time_idx"].between(val_start_idx, val_end_idx, inclusive="both")],
    min_prediction_idx=val_start_idx+ENC_LEN,
    stop_randomization=True,
    predict=False
)

del pdf_data, lf_clean, lf_with_idx; gc.collect()

80

In [15]:
# 数据集加载

train_loader = train_ds.to_dataloader(
    train=True, 
    batch_size=BATCH_SIZE, 
    num_workers=0,
    pin_memory=False,
    persistent_workers=False,
    #prefetch_factor=2,
    shuffle=True
)

n_train_batches = len(train_loader)
print(f"[debug] train_loader batches = {n_train_batches}")
assert n_train_batches > 0, "Empty train dataloader. Check min_prediction_idx/ENC_LEN/date windows."

val_loader = val_ds.to_dataloader(
    train=False,
    batch_size=BATCH_SIZE*2,
    num_workers=0,
    pin_memory=False,
    persistent_workers=False,
    #prefetch_factor=2,
    shuffle=True,
)

n_val_batches = len(val_loader)
print(f"[debug] val_loader batches = {n_val_batches}")
assert n_val_batches > 0, "Empty val dataloader. Check min_prediction_idx/ENC_LEN/date windows."

[debug] train_loader batches = 6394
[debug] val_loader batches = 433


In [None]:
# 8.6 callbacks/logger/trainer
ckpt_dir_fold = Path(CKPTS_DIR) / f"fold_{fold_id}"
ckpt_dir_fold.mkdir(parents=True, exist_ok=True)

callbacks = [EarlyStopping(monitor="val_loss", mode="min", patience=5),
            ModelCheckpoint(monitor="val_loss", mode="min", save_top_k=1, dirpath=ckpt_dir_fold.as_posix(), filename=f"fold{fold_id}-tft-best-{{epoch:02d}}-{{val_loss:.5f}}", save_on_train_epoch_end=False),
            ] # LearningRateMonitor(logging_interval="step"),
RUN_NAME = f"quick_check_f{fold_id}_E{MAX_EPOCHS}_lr{LR:g}_bs{BATCH_SIZE}_enc{ENC_LEN}_dec{DEC_LEN}_{datetime.now():%Y%m%d-%H%M%S}"

TEMP_LOG_DIR = Path("./tft_logs")
TEMP_LOG_DIR.mkdir(parents=True, exist_ok=True)

logger = TensorBoardLogger(save_dir=TEMP_LOG_DIR.as_posix(),name="tft",version=RUN_NAME,default_hp_metric=False)

trainer = lp.Trainer(max_epochs=20,
                    accelerator="gpu",
                    devices=1,
                    precision="bf16-mixed",
                    enable_model_summary=True,
                    gradient_clip_val=1.0,
                    gradient_clip_algorithm="norm",
                    #fast_dev_run=1,
                    limit_train_batches=0.3,
                    limit_val_batches=0.2,
                    val_check_interval= 1.0,
                    num_sanity_val_steps=0,
                    log_every_n_steps=1000,
                    callbacks=callbacks,
                    logger=logger,
                    #accumulate_grad_batches=1,
                    )

tft = TemporalFusionTransformer.from_dataset(
    train_ds,
    learning_rate=LR,
    hidden_size=HIDDEN,
    attention_head_size=HEADS,
    dropout=DROPOUT,
    hidden_continuous_size=HIDDEN // 2,
    loss=RMSE(),
    logging_metrics=[],
    optimizer=torch.optim.AdamW,
    optimizer_params={"weight_decay": 1e-4},
    reduce_on_plateau_patience=3, 
)
trainer.fit(
    tft,
    train_dataloaders=train_loader,
    val_dataloaders=val_loader,
    )

In [None]:
ckpt_cb = next(cb for cb in callbacks if isinstance(cb, ModelCheckpoint))
best_path = ckpt_cb.best_model_path
best_tft = TemporalFusionTransformer.load_from_checkpoint(best_path)

predictions = best_tft.predict(
    val_loader,
    return_y=True,
    trainer_kwargs=dict(accelerator="gpu")
)
y_pred = predictions.output
y_true, w = predictions.y

In [None]:
num = (w * (y_true - y_pred).pow(2)).sum()
den = (w * y_true.pow(2)).sum()

wr2 = 1.0 - num / (den + eps)
print(f"wr2 after training: {wr2.item():.6f}")

In [None]:
val_ds.get_parameters

In [None]:
num = (torch.square(y_true - y_pred) * w).sum()
den = (torch.square(y_true) * w).sum()  
wr2 = 1 - num / den
print(f"wr2 after training: {wr2}")