In [12]:
# -*- coding: utf-8 -*-
from __future__ import annotations

# ── 标准库 ──────────────────────────────────────────────────────────────────
import os
import time
from pathlib import Path
from collections import defaultdict
from datetime import datetime

# ── 第三方 ──────────────────────────────────────────────────────────────────
import numpy as np
import pandas as pd
import polars as pl

import torch
import torch.backends.cudnn as cudnn
import lightning as L
import lightning.pytorch as lp
from torch.utils.data import DataLoader

from lightning.pytorch.callbacks import EarlyStopping, ModelCheckpoint, LearningRateMonitor
from lightning.pytorch.loggers import TensorBoardLogger
from lightning.pytorch.callbacks import DeviceStatsMonitor
from pytorch_forecasting import TimeSeriesDataSet, TemporalFusionTransformer, Baseline
from pytorch_forecasting.metrics import MAE, RMSE
from pytorch_forecasting.data.encoders import NaNLabelEncoder
from pytorch_forecasting.data import TorchNormalizer, GroupNormalizer


# 你的工程工具
from pipeline.io import cfg, P, fs, storage_options, ensure_dir_local
from pipeline.stream_input_local import ShardedBatchStream  # 使用下方给你的新版类
from pipeline.wr2 import WR2

# ---- 性能/兼容开关（仅一次）----
os.environ.setdefault("POLARS_MAX_THREADS", str(max(1, os.cpu_count() // 2)))
pl.enable_string_cache()
cudnn.benchmark = True
torch.set_float32_matmul_precision("high")


def _now() -> str:
    import time as _t
    return _t.strftime("%Y-%m-%d %H:%M:%S")


# ───────────────────────────────────────────────────────────────────────────
# 滑动窗划分
def make_sliding_cv_by_days(all_days: np.ndarray, *, n_splits: int, gap_days: int, train_to_val: int):
    all_days = np.asarray(all_days).ravel()
    K, R, G = n_splits, train_to_val, gap_days
    usable = len(all_days) - G
    if usable <= 0 or K <= 0 or R <= 0:
        return []
    V_base, rem = divmod(usable, R + K)
    if V_base <= 0:
        return []
    T = R * V_base
    v_lens = [V_base + 1 if i < rem else V_base for i in range(K)]
    folds, v_lo = [], T + G
    for V_i in v_lens:
        v_hi, tr_hi, tr_lo = v_lo + V_i, v_lo - G, v_lo - G - T
        if tr_lo < 0 or v_hi > len(all_days):
            break
        folds.append((all_days[tr_lo:tr_hi], all_days[v_lo:v_hi]))
        v_lo = v_hi
    return folds


In [2]:
print(f"[{_now()}] imports ok")

# ========== 1) 统一配置 ==========
G_SYM, G_DATE, G_TIME = cfg["keys"]          # e.g. ("symbol_id","date_id","time_id")
TARGET_COL = cfg["target"]                   # e.g. "responder_6"
WEIGHT_COL = cfg["weight"]                   # 允许为 None

# 已知时间特征（给 encoder+decoder）
TIME_FEATURES = ["time_bucket"]

# 连续特征（示例）
BASIC_FEATURES = [f"feature_{i:02d}" for i in range(79)]
#BASIC_FEATURES = ["feature_36", "feature_06", "feature_04", "feature_16", "feature_69", "feature_22","feature_20", "feature_58", "feature_24", "feature_27","feature_37"]

RAW_FEATURES = BASIC_FEATURES 

# 训练 & CV 超参
N_SPLITS     = 1
GAP_DAYS     = 5
TRAIN_TO_VAL = 5
ENC_LEN      = 30
DEC_LEN      = 1
PRED_LEN     = DEC_LEN
BATCH_SIZE   = 1024
LR           = 1e-3
HIDDEN       = 32
HEADS        = 2
DROPOUT      = 0.2
MAX_EPOCHS   = 30
CHUNK_DAYS   = 30

# 数据路径
PANEL_DIR_AZ   = P("az", cfg["paths"].get("panel_shards", "panel_shards"))
TFT_LOCAL_ROOT = P("local", "tft")

# 目录准备
ensure_dir_local(TFT_LOCAL_ROOT)
LOCAL_CLEAN_DIR = f"{TFT_LOCAL_ROOT}/clean"; ensure_dir_local(LOCAL_CLEAN_DIR)
CKPTS_DIR = Path(TFT_LOCAL_ROOT) / "ckpts"; ensure_dir_local(CKPTS_DIR.as_posix())
LOGS_DIR  = Path(TFT_LOCAL_ROOT) / "logs";  ensure_dir_local(LOGS_DIR.as_posix())

print("[config] ready")


[2025-10-10 13:55:18] imports ok
[config] ready


In [3]:
# ========== 2) 读取原始面板 & 建全局 time_idx ==========
#data_paths = fs.glob(f"{PANEL_DIR_AZ}/*.parquet")
#data_paths = [p if p.startswith("az://") else f"az://{p}" for p in data_paths]

data_paths = fs.glob("az://jackson/js_exp/exp/v1/clean_shards/*.parquet")
data_paths =[f"az://{p}" for p in data_paths]

lf_data = pl.scan_parquet(data_paths, storage_options=storage_options)
lf_data = lf_data.filter(pl.col(G_DATE).is_between(1600, 1690, closed="both"))

lf_grid = (
    lf_data.select([G_DATE, G_TIME]).unique()
        .sort([G_DATE, G_TIME])
        .with_row_index("time_idx")
        .with_columns(pl.col("time_idx").cast(pl.Int64))
)
grid_path_local = P("local", "tft/panel/grid_timeidx.parquet")
Path(grid_path_local).parent.mkdir(parents=True, exist_ok=True)
lf_grid.collect(streaming=True).write_parquet(grid_path_local, compression="zstd")
grid_lazy = pl.scan_parquet(grid_path_local)

NEED_COLS = list(dict.fromkeys([G_SYM, G_DATE, G_TIME, WEIGHT_COL, TARGET_COL] + TIME_FEATURES + RAW_FEATURES))
lf0 = (
    lf_data.join(grid_lazy, on=[G_DATE, G_TIME], how="left")
        .select(NEED_COLS + ["time_idx"])
        .sort([G_DATE, G_TIME, G_SYM])
)
print(f"[{_now()}] lazyframe ready")

# ========== 3) CV 划分 ==========
all_days = (
    lf0.select(pl.col(G_DATE)).unique().sort(by=G_DATE)
    .collect(streaming=True).get_column(G_DATE).to_numpy()
)
folds_by_day = make_sliding_cv_by_days(all_days, n_splits=N_SPLITS, gap_days=GAP_DAYS, train_to_val=TRAIN_TO_VAL)
assert len(folds_by_day) > 0, "no CV folds constructed"

[2025-10-10 13:55:18] lazyframe ready


In [4]:


# 取第一个 fold 的训练集最后一天，作为统计 z-score 的上界
stats_hi = int(folds_by_day[0][0][-1])
print(f"[stats] upper bound day for z-score = {stats_hi}")

# ========== 4) 连续特征清洗 + Z-score ==========
inf2null_exprs  = [pl.when(pl.col(c).is_infinite()).then(None).otherwise(pl.col(c)).alias(c) for c in RAW_FEATURES]
isna_flag_exprs = [pl.col(c).is_null().cast(pl.Int8).alias(f"{c}__isna") for c in RAW_FEATURES]
ffill_exprs     = [pl.col(c).forward_fill().over(G_SYM).fill_null(0.0).alias(c) for c in RAW_FEATURES]

lf_clean = (
    lf0.with_columns(inf2null_exprs)
    .with_columns(isna_flag_exprs)
    .with_columns(ffill_exprs)
)

lf_stats_sym = (
    lf_clean.filter(pl.col(G_DATE) <= stats_hi)
            .group_by(G_SYM)
            .agg([pl.col(c).mean().alias(f"mu_{c}") for c in RAW_FEATURES] +
                [pl.col(c).std(ddof=0).alias(f"std_{c}") for c in RAW_FEATURES])
)
lf_stats_glb = (
    lf_clean.filter(pl.col(G_DATE) <= stats_hi)
            .select([pl.col(c).mean().alias(f"mu_{c}_glb") for c in RAW_FEATURES] +
                    [pl.col(c).std(ddof=0).alias(f"std_{c}_glb") for c in RAW_FEATURES])
)

lf_z = lf_clean.join(lf_stats_glb, how="cross").join(lf_stats_sym, on=G_SYM, how="left")

eps = 1e-6
Z_COLS, NAMARK_COLS = [], [f"{c}__isna" for c in RAW_FEATURES]
for c in RAW_FEATURES:
    mu_sym, std_sym = f"mu_{c}", f"std_{c}"
    mu_glb, std_glb = f"mu_{c}_glb", f"std_{c}_glb"
    mu_use, std_use = f"{c}_mu_use", f"{c}_std_use"
    z_name = f"{c}_z"

    lf_z = lf_z.with_columns(
        pl.when(pl.col(mu_sym).is_null()).then(pl.col(mu_glb)).otherwise(pl.col(mu_sym)).alias(mu_use),
        pl.when(pl.col(std_sym).is_null() | (pl.col(std_sym) == 0)).then(pl.col(std_glb)).otherwise(pl.col(std_sym)).alias(std_use),
    ).with_columns(
        ((pl.col(c) - pl.col(mu_use)) / (pl.col(std_use) + eps)).alias(z_name)
    ).drop([mu_glb, std_glb, mu_sym, std_sym, mu_use, std_use])

    Z_COLS.append(z_name)

OUT_COLS = [G_SYM, G_DATE, G_TIME, "time_idx", WEIGHT_COL, TARGET_COL] + TIME_FEATURES + Z_COLS #+ NAMARK_COLS
lf_out = lf_z.select(OUT_COLS).sort([G_DATE, G_TIME, G_SYM])

# 保存lf_out到本地
clean_path_local = P("local", "tft/panel/cleaned.parquet")
Path(clean_path_local).parent.mkdir(parents=True, exist_ok=True)
lf_out.collect(streaming=True).write_parquet(clean_path_local, compression="zstd")
print(f"[{_now()}] cleaned data saved to {clean_path_local}")


[stats] upper bound day for z-score = 1669
[2025-10-10 13:55:46] cleaned data saved to /mnt/data/js/exp/v1/tft/panel/cleaned.parquet


In [5]:
UNKNOWN_REALS =  None
KNOWN_REALS = TIME_FEATURES + Z_COLS

UNSCALE_COLS = KNOWN_REALS

#  对target col 和 unknown_reals去重
TRAIN_COLS = list(dict.fromkeys([G_SYM, "time_idx", WEIGHT_COL, TARGET_COL] + KNOWN_REALS))

# 定义 identity scalers
identity_scalers = {name: None for name in UNSCALE_COLS}

## try

In [None]:
# 取第一折先试探一下
best_ckpt_paths, fold_metrics = [], []
fold_id = 0
train_days, val_days = folds_by_day[0]

print(f"[fold {fold_id}] train {train_days[0]}..{train_days[-1]} ({len(train_days)} days), "
    f"val {val_days[0]}..{val_days[-1]} ({len(val_days)} days)")

# 明确日期：
train_start_date = int(train_days[0])
train_end_date   = int(train_days[-1])
val_start_date   = int(val_days[0])
val_end_date     = int(val_days[-1])      


# 提取数据
date_range = (train_start_date, val_end_date)
pdf_data = (
    pl.scan_parquet(clean_path_local)
    .filter(pl.col(G_DATE).between(train_start_date, val_end_date, closed="both"))
    .collect(streaming=True)
    .to_pandas()
    .sort_values([G_SYM, "time_idx"])
)
pdf_data[G_SYM] = pdf_data[G_SYM].astype("str")   

# 明确 indexes:
train_end_idx = pdf_data.loc[pdf_data[G_DATE] == train_end_date, "time_idx"].max()
val_start_idx = pdf_data.loc[pdf_data[G_DATE] == val_start_date, "time_idx"].min()
val_end_idx   = pdf_data.loc[pdf_data[G_DATE] == val_end_date, "time_idx"].max()
assert pd.notna(train_end_idx) and pd.notna(val_start_idx) and pd.notna(val_end_idx), "train/val idx not found"
train_end_idx, val_start_idx, val_end_idx = int(train_end_idx), int(val_start_idx), int(val_end_idx)
print(f"[fold {fold_id}] train idx up to {train_end_idx}, val idx {val_start_idx}..{val_end_idx}")  



In [None]:
t_data = pdf_data[TRAIN_COLS].copy()

identity_scalers = {name: None for name in UNSCALE_COLS}
base_ds = TimeSeriesDataSet(
    t_data,
    time_idx="time_idx",
    target=TARGET_COL,
    group_ids=[G_SYM],
    weight=WEIGHT_COL,
    max_encoder_length=ENC_LEN, 
    min_encoder_length=ENC_LEN,
    max_prediction_length=PRED_LEN, 
    min_prediction_length=PRED_LEN,
    
    static_categoricals=[G_SYM],
    time_varying_known_reals =KNOWN_REALS,
    time_varying_unknown_reals=UNKNOWN_REALS,
    
    lags=None,
    categorical_encoders={G_SYM: NaNLabelEncoder(add_nan=True)},
    add_relative_time_idx=False,
    add_target_scales=False,
    add_encoder_length=False,
    allow_missing_timesteps=True,
    target_normalizer=TorchNormalizer(method="identity"),
    scalers=identity_scalers,
)

In [None]:
# 划分训练集，验证集
train_ds = base_ds.filter(
    lambda idx: (
        idx.time_idx_last <= train_end_idx
    ),
    copy=True
)

val_ds = base_ds.filter(
    lambda idx: (
        (idx.time_idx_first_prediction >= val_start_idx) &
        
        (idx.time_idx_last <= val_end_idx)
    ),
    copy=True
)

In [None]:
# 数据集加载

train_loader = train_ds.to_dataloader(
    train=True, 
    batch_size=BATCH_SIZE, 
    num_workers=14,
    pin_memory=True,
    persistent_workers=False,
    prefetch_factor=8,
)

n_train_batches = len(train_loader)
print(f"[debug] train_loader batches = {n_train_batches}")
assert n_train_batches > 0, "Empty train dataloader. Check min_prediction_idx/ENC_LEN/date windows."

val_loader = val_ds.to_dataloader(
    train=False,
    batch_size=BATCH_SIZE,
    num_workers=14,
    pin_memory=True,
    persistent_workers=False,
    prefetch_factor=8,
)

n_val_batches = len(val_loader)
print(f"[debug] val_loader batches = {n_val_batches}")
assert n_val_batches > 0, "Empty val dataloader. Check min_prediction_idx/ENC_LEN/date windows."

In [None]:
lp.seed_everything(42)
trainer = lp.Trainer(
    accelerator="gpu",
    # clipping gradients is a hyperparameter and important to prevent divergance
    # of the gradient for recurrent neural networks
    gradient_clip_val=0.1,
)

tft = TemporalFusionTransformer.from_dataset(
    train_ds,
    # not meaningful for finding the learning rate but otherwise very important
    learning_rate=LR,
    hidden_size=HIDDEN,  # most important hyperparameter apart from learning rate
    # number of attention heads. Set to up to 4 for large datasets
    attention_head_size=HEADS,
    dropout=DROPOUT,  # between 0.1 and 0.3 are good values
    hidden_continuous_size=HIDDEN // 2,  # set to <= hidden_size
    loss=RMSE(),
    optimizer=torch.optim.Adam,
    # reduce learning rate if no improvement in validation loss after x epochs
    # reduce_on_plateau_patience=1000,
)
print(f"Number of parameters in network: {tft.size() / 1e3:.1f}k")
    
# find optimal learning rate
from lightning.pytorch.tuner import Tuner

res = Tuner(trainer).lr_find(
    tft,
    train_dataloaders=train_loader,
    val_dataloaders=val_loader,
    max_lr=10.0,
    min_lr=1e-6,
)

print(f"suggested learning rate: {res.suggestion()}")
fig = res.plot(show=True, suggest=True)
fig.show()

## main

In [None]:
lp.seed_everything(42) 

# ========== 8) 训练（按 CV 折） ========== 先取第一折
best_ckpt_paths, fold_metrics = [], []

#for fold_id, (train_days, val_days) in enumerate(folds_by_day, start=1):
####################################
fold_id = 0
train_days, val_days = folds_by_day[fold_id]
####################################

print(f"[fold {fold_id}] train {train_days[0]}..{train_days[-1]} ({len(train_days)} days), "
    f"val {val_days[0]}..{val_days[-1]} ({len(val_days)} days)")

# 明确日期：
train_start_date = int(train_days[0])
train_end_date   = int(train_days[-1])
val_start_date   = int(val_days[0])
val_end_date     = int(val_days[-1])      

# 提取数据
date_range = (train_start_date, val_end_date)
pdf_data = (
    pl.scan_parquet(clean_path_local)
    .filter(pl.col(G_DATE).is_between(train_start_date, val_end_date, closed="both"))
    .collect(streaming=True)
    .to_pandas()
) 

pdf_data[G_SYM] = pdf_data[G_SYM].astype("str")
pdf_data.sort_values([G_SYM, "time_idx"], inplace=True)        
# 明确 indexes:
train_end_idx = pdf_data.loc[pdf_data[G_DATE] == train_end_date, "time_idx"].max()
val_start_idx = pdf_data.loc[pdf_data[G_DATE] == val_start_date, "time_idx"].min()
val_end_idx   = pdf_data.loc[pdf_data[G_DATE] == val_end_date, "time_idx"].max()
assert pd.notna(train_end_idx) and pd.notna(val_start_idx) and pd.notna(val_end_idx), "train/val idx not found"
train_end_idx, val_start_idx, val_end_idx = int(train_end_idx), int(val_start_idx), int(val_end_idx)
print(f"[fold {fold_id}] train idx up to {train_end_idx}, val idx {val_start_idx}..{val_end_idx}")  


Seed set to 42


[fold 0] train 1600..1669 (70 days), val 1675..1689 (15 days)
[fold 0] train idx up to 67759, val idx 72600..87119


In [None]:

# 构建训练集 timeseries dataset

t_data = pdf_data[TRAIN_COLS]
train_df = t_data.loc[t_data["time_idx"] <= train_end_idx]

train_ds = TimeSeriesDataSet(
    train_df,
    time_idx="time_idx",
    target=TARGET_COL,
    group_ids=[G_SYM],
    weight=WEIGHT_COL,
    max_encoder_length=ENC_LEN, 
    min_encoder_length=ENC_LEN,
    max_prediction_length=PRED_LEN, 
    min_prediction_length=PRED_LEN,
    
    static_categoricals=[G_SYM],
    time_varying_known_reals =KNOWN_REALS,
    time_varying_unknown_reals=None,
    lags=None,
    categorical_encoders={G_SYM: NaNLabelEncoder(add_nan=True)},
    add_relative_time_idx=False,
    add_target_scales=True,
    add_encoder_length=False,
    allow_missing_timesteps=True,
    target_normalizer=GroupNormalizer(
        method="standard", groups=[G_SYM], center=True, scale_by_group=False),
    scalers=identity_scalers,
)

# 验证集复用 train_ds 的所有 encoders/normalizer（不泄漏）

val_df = t_data.loc[t_data["time_idx"].between(val_start_idx, val_end_idx, inclusive="both")]
val_ds = TimeSeriesDataSet.from_dataset(
    train_ds,
    val_df,
    predict=False,
    stop_randomization=True,
)

# 数据集加载

train_loader = train_ds.to_dataloader(
    train=True, 
    batch_size=BATCH_SIZE, 
    num_workers=14,
    pin_memory=True,
    persistent_workers=True,
    prefetch_factor=4,
)

n_train_batches = len(train_loader)
print(f"[debug] train_loader batches = {n_train_batches}")
assert n_train_batches > 0, "Empty train dataloader. Check min_prediction_idx/ENC_LEN/date windows."

val_loader = val_ds.to_dataloader(
    train=False,
    batch_size=BATCH_SIZE*2,
    num_workers=14,
    pin_memory=True,
    persistent_workers=True,
    prefetch_factor=4,
)

n_val_batches = len(val_loader)
print(f"[debug] val_loader batches = {n_val_batches}")
assert n_val_batches > 0, "Empty val dataloader. Check min_prediction_idx/ENC_LEN/date windows."

# 8.6 callbacks/logger/trainer
ckpt_dir_fold = Path(CKPTS_DIR) / f"fold_{fold_id}"
ckpt_dir_fold.mkdir(parents=True, exist_ok=True)

callbacks = [EarlyStopping(monitor="val_loss", mode="min", patience=5, check_on_train_epoch_end=False),
            ModelCheckpoint(monitor="val_loss", mode="min", save_top_k=1, dirpath=ckpt_dir_fold.as_posix(), filename=f"fold{fold_id}-tft-best-{{epoch:02d}}-{{val_loss:.5f}}", save_on_train_epoch_end=False),
            LearningRateMonitor(logging_interval="step"),]
RUN_NAME = (f"f{fold_id}"f"_E{MAX_EPOCHS}"f"_lr{LR}"f"_bs{BATCH_SIZE}"f"_enc{ENC_LEN}_dec{DEC_LEN}"f"_{datetime.now().strftime('%Y%m%d-%H%M%S')}")
logger = TensorBoardLogger(save_dir=LOGS_DIR.as_posix(),name="tft",version=RUN_NAME,default_hp_metric=False)

trainer = lp.Trainer(max_epochs=MAX_EPOCHS,
                    accelerator="gpu",
                    precision="bf16-mixed",
                    enable_model_summary=True,
                    gradient_clip_val=1.0,
                    #fast_dev_run=True,
                    #limit_train_batches=50,  # coment in for training, running valiation every 30 batches
                    log_every_n_steps=100,
                    callbacks=callbacks,
                    logger=logger,
                    )

tft = TemporalFusionTransformer.from_dataset(
    train_ds,
    learning_rate=LR,
    hidden_size=HIDDEN,
    attention_head_size=HEADS,
    dropout=DROPOUT,
    hidden_continuous_size=HIDDEN // 2,
    loss=RMSE(),
    logging_metrics=[RMSE()],
    optimizer=torch.optim.AdamW,
    optimizer_params={"weight_decay": 1e-4},
    reduce_on_plateau_patience=3,
)
trainer.fit(
    tft,
    train_dataloaders=train_loader,
    val_dataloaders=val_loader,
    )

Using bfloat16 Automatic Mixed Precision (AMP)
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs


[debug] train_loader batches = 2534
[debug] val_loader batches = 274


/home/admin_ml/Jackson/projects/js/JS/.venv/lib/python3.10/site-packages/lightning/pytorch/utilities/parsing.py:210: Attribute 'loss' is an instance of `nn.Module` and is already saved during checkpointing. It is recommended to ignore them using `self.save_hyperparameters(ignore=['loss'])`.
/home/admin_ml/Jackson/projects/js/JS/.venv/lib/python3.10/site-packages/lightning/pytorch/callbacks/model_checkpoint.py:751: Checkpoint directory /mnt/data/js/exp/v1/tft/ckpts/fold_0 exists and is not empty.
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
/home/admin_ml/Jackson/projects/js/JS/.venv/lib/python3.10/site-packages/lightning/pytorch/utilities/model_summary/model_summary.py:231: Precision bf16-mixed is not supported by the model summary.  Estimated model size in MB will not be accurate. Using 32 bits instead.

   | Name                               | Type                            | Params | Mode 
-----------------------------------------------------------------------------------------------

Epoch 4:  27%|██▋       | 690/2534 [35:27<1:34:46,  0.32it/s, v_num=0906, train_loss_step=1.200, val_loss=0.973, train_loss_epoch=1.210]   

In [None]:
ckpt_cb = next(cb for cb in callbacks if isinstance(cb, ModelCheckpoint))
best_path = ckpt_cb.best_model_path
best_tft = TemporalFusionTransformer.load_from_checkpoint(best_path)

predictions = best_tft.predict(
    val_loader,
    return_y=True,
    trainer_kwargs=dict(accelerator="gpu")
)
y_pred = predictions.output
y_true, w = predictions.y



In [None]:
num = (w * (y_true - y_pred).pow(2)).sum()
den = (w * y_true.pow(2)).sum()

wr2 = 1.0 - num / (den + eps)
print(f"wr2 after training: {wr2.item():.6f}")

In [None]:
num = (torch.square(y_true - y_pred) * w).sum()
den = (torch.square(y_true) * w).sum()  
wr2 = 1 - num / den
print(f"wr2 after training: {wr2}")