In [1]:
from __future__ import annotations
import time, numpy as np, polars as pl, pandas as pd, torch, lightning as L
from pathlib import Path

from lightning.pytorch.callbacks import EarlyStopping, ModelCheckpoint, LearningRateMonitor
from lightning.pytorch.loggers import TensorBoardLogger
from pytorch_forecasting import TimeSeriesDataSet, TemporalFusionTransformer
from pytorch_forecasting.data import NaNLabelEncoder, GroupNormalizer
from pytorch_forecasting.metrics import MAE
from pytorch_forecasting.data.encoders import TorchNormalizer

# 你的工程工具
from pipeline.io import cfg, P, fs, storage_options, ensure_dir_local

def _now(): return time.strftime("%Y-%m-%d %H:%M:%S")
print(f"[{_now()}] imports ok")

  from tqdm.autonotebook import tqdm


[2025-10-03 07:47:42] imports ok


In [2]:
# ---- 关键配置（按需改）----
target_col = cfg["target"]                 # e.g. "responder_6"
g_sym, g_date, g_time = cfg["keys"]        # e.g. ("symbol_id","date_id","time_id")
weight_col = cfg["weight"]

time_features = ["time_pos", "time_sin", "time_cos", "time_bucket"]
base_features   = ["feature_36"]
resp_his_feats  = ["responder_3_prevday_std"]
feat_his_feats  = ["feature_08__ewm5"]
feature_cols = list(dict.fromkeys(base_features + resp_his_feats + feat_his_feats))

need_cols = list(dict.fromkeys([g_sym, g_date, g_time, weight_col, target_col] + time_features + feature_cols))

# CV & 训练
N_SPLITS   = 5
GAP_DAYS   = 7
TRAIN_TO_VAL = 4               # 训练:验证 = 4:1
ENC_LEN    = 10
PRED_LEN   = 1
BATCH_SIZE = 1024
LR = 1e-3
HIDDEN     = 64
HEADS      = 2
DROPOUT    = 0.1
MAX_EPOCHS_PER_SHARD = 1
CHUNK_DAYS = 20               # 训练分片：每片多少天

print("config ready")

config ready


In [32]:
panel_dir = P("az", cfg["paths"].get("panel_shards", "panel_shards"))
data_path  = fs.glob(f"{panel_dir}/*.parquet")
az_path = [f"az://{p}" for p in data_path]
lf_data = pl.scan_parquet(az_path, storage_options=storage_options)

In [None]:
# 全局 time_idx 仅一次
lf_grid = (
    lf_data.select([g_date, g_time]).unique()
        .sort([g_date, g_time])
        .with_row_index("time_idx")
        .with_columns(pl.col("time_idx").cast(pl.Int64))
)

grid_path = P("local", "tft/panel/grid_timeidx.parquet"); ensure_dir_local(Path(grid_path).parent.as_posix())

lf_grid.collect(streaming=True).write_parquet(grid_path, compression="zstd")

In [36]:
grid_lazy = pl.scan_parquet(grid_path)

# 接入 time_idx + 只保留所需列（仍是 Lazy）
lf0 = (
    lf_data.join(grid_lazy, on=[g_date, g_time], how="left")
          .select(need_cols + ["time_idx"])
          .sort([g_date, g_time, g_sym])
)

print(f"[{_now()}] lazyframe ready")

[2025-10-03 08:11:18] lazyframe ready


In [37]:
# 全量天列表
all_days = (
    lf0.select(pl.col(g_date)).unique().sort(by=g_date)
       .collect(streaming=True).get_column(g_date).to_numpy()
)

In [39]:
# 构造滑动时间窗 CV
def make_sliding_cv_by_days(all_days: np.ndarray, *, n_splits: int, gap_days: int, train_to_val: int):
    all_days = np.asarray(all_days).ravel()
    K, R, G = n_splits, train_to_val, gap_days
    usable = len(all_days) - G
    if usable <= 0 or K <= 0 or R <= 0: return []
    V_base, rem = divmod(usable, R + K)
    if V_base <= 0: return []
    T = R * V_base
    v_lens = [V_base + 1 if i < rem else V_base for i in range(K)]
    folds, v_lo = [], T + G
    for V_i in v_lens:
        v_hi, tr_hi, tr_lo = v_lo + V_i, v_lo - G, v_lo - G - T
        if tr_lo < 0 or v_hi > len(all_days): break
        folds.append((all_days[tr_lo:tr_hi], all_days[v_lo:v_hi]))
        v_lo = v_hi
    return folds

folds_by_day = make_sliding_cv_by_days(all_days, n_splits=N_SPLITS, gap_days=GAP_DAYS, train_to_val=TRAIN_TO_VAL)
assert len(folds_by_day) > 0, "no CV folds constructed"

In [40]:
folds_by_day

[(array([1605, 1606, 1607, 1608, 1609, 1610, 1611, 1612, 1613, 1614, 1615,
         1616, 1617, 1618, 1619, 1620, 1621, 1622, 1623, 1624, 1625, 1626,
         1627, 1628, 1629, 1630, 1631, 1632, 1633, 1634, 1635, 1636, 1637,
         1638, 1639, 1640], dtype=int32),
  array([1648, 1649, 1650, 1651, 1652, 1653, 1654, 1655, 1656, 1657],
        dtype=int32)),
 (array([1615, 1616, 1617, 1618, 1619, 1620, 1621, 1622, 1623, 1624, 1625,
         1626, 1627, 1628, 1629, 1630, 1631, 1632, 1633, 1634, 1635, 1636,
         1637, 1638, 1639, 1640, 1641, 1642, 1643, 1644, 1645, 1646, 1647,
         1648, 1649, 1650], dtype=int32),
  array([1658, 1659, 1660, 1661, 1662, 1663, 1664, 1665, 1666, 1667],
        dtype=int32)),
 (array([1625, 1626, 1627, 1628, 1629, 1630, 1631, 1632, 1633, 1634, 1635,
         1636, 1637, 1638, 1639, 1640, 1641, 1642, 1643, 1644, 1645, 1646,
         1647, 1648, 1649, 1650, 1651, 1652, 1653, 1654, 1655, 1656, 1657,
         1658, 1659, 1660], dtype=int32),
  array([1668

In [41]:
# 选“标准化统计区间”上界：第1折训练天的最大值（不含验证）
stats_hi = int(folds_by_day[0][0][-1])
print(f"stats_hi (for global z-score) = {stats_hi}; first-fold train days end at this day.")

stats_hi (for global z-score) = 1640; first-fold train days end at this day.


In [None]:

clean_dir = P("local", "tft/clean_by_day")
ensure_dir_local(clean_dir)

# ========== 1) 连续特征：一次性处理 inf->null、打缺失标记、组内 ffill、兜底 0 ==========
# 先做 flag（要基于原始缺失），再做填充；合并成两次 with_columns，避免在 for 循环里多次改列
inf2null_exprs = [pl.when(pl.col(c).is_infinite()).then(None).otherwise(pl.col(c)).alias(c)
                for c in feature_cols] # inf -> null 不产生新列
flags_exprs    = [pl.col(c).is_null().cast(pl.Int8).alias(f"{c}__isna")
                for c in feature_cols] # 产生新列
fill_exprs     = [pl.col(c).forward_fill().over(g_sym).fill_null(0.0).alias(c)
                for c in feature_cols] # 填充，覆盖原列

lf_clean = (
    lf0.with_columns(inf2null_exprs)         # inf -> null
    .with_columns(flags_exprs)            # 缺失标记（基于原始缺失）
    .with_columns(fill_exprs)             # 组内 ffill + 兜底 0
)

# ========== 2) 训练区间（<= stats_hi）按组计算 mu/std（一次） ==========
lf_stats_sym = (
    lf_clean.filter(pl.col(g_date) <= stats_hi)
    .group_by(g_sym)
    .agg([pl.col(c).mean().alias(f"mu_{c}") for c in feature_cols] +
        [pl.col(c).std().alias(f"std_{c}") for c in feature_cols])
)

# 2) 训练期全局统计（作为回退）
lf_stats_glb = (
    lf_clean.filter(pl.col(g_date) <= stats_hi)
    .select([pl.col(c).mean().alias(f"mu_{c}_glb") for c in feature_cols] +
            [pl.col(c).std().alias(f"std_{c}_glb") for c in feature_cols])
)

# 3) 把全局统计作为常量列加到每行（cross join 方式）
lf_z = lf_clean.join(lf_stats_glb, how="cross")

# 4) join per-symbol 统计，并对每个特征做回退 & z-score

lf_z = lf_z.join(lf_stats_sym, on=g_sym, how="left")

# 对每个特征做回退 & z-score
eps = 1e-6
z_cols = []
for c in feature_cols:
    mu_c_sym, std_c_sym = f"mu_{c}", f"std_{c}"
    mu_c_glb, std_c_glb = f"mu_{c}_glb", f"std_{c}_glb"
    c_z = f"{c}_z"
    lf_z = lf_z.with_columns(
        pl.when(pl.col(mu_c_sym).is_null()).then(pl.col(mu_c_glb)).otherwise(pl.col(mu_c_sym)).alias(f"{c}_mu_use"),
        pl.when(pl.col(std_c_sym).is_null() | (pl.col(std_c_sym) == 0)).then(pl.col(std_c_glb)).otherwise(pl.col(std_c_sym)).alias(f"{c}_std_use")
    ).with_columns(
        ((pl.col(c) - pl.col(f"{c}_mu_use")) / (pl.col(f"{c}_std_use") + eps)).alias(c_z)
    ).drop([mu_c_glb, std_c_glb, mu_c_sym, std_c_sym, f"{c}_mu_use", f"{c}_std_use"])
    z_cols.append(c_z)

# 5) 输出列（z_特征 + isna 标记 + 时间/分类/目标/权重）
namark_cols = [f"{c}__isna" for c in feature_cols]
out_cols = [g_sym, g_date, g_time, "time_idx", weight_col, target_col] + time_features + z_cols + namark_cols

lf_out = lf_z.select(out_cols).sort([g_date, g_time, g_sym])

In [None]:

# ========== 4) 分块收集 + 分区写盘 ==========
# 关键：不要“逐天 collect”，而是每次收集一批天，然后一次性按 day 分区写入，显著减少 IO 次数
CHUNK_DAYS = 20  # 可根据机器内存/速度调整；比如 10~30 天一块
day_list = list(map(int, all_days))
day_chunks = [day_list[i:i+CHUNK_DAYS] for i in range(0, len(day_list), CHUNK_DAYS)]

for ci, chunk in enumerate(day_chunks, 1):
    df_chunk = (
        lf_out.filter(pl.col(g_date).is_in(chunk))
              .collect(streaming=True)              # 单次 collect 一大块
    )
    # 将symbol_id 转为 category
    
    # 利用 polars 的分区写出（需要较新的 polars；若你的版本不支持 partition_by，退化为循环写）
    try:
        df_chunk.write_parquet(
            Path(clean_dir).as_posix(),
            compression="zstd",
            # ★ 这行很关键：一次性把该 chunk 里所有 day 写成多个子文件
            partition_by=[g_date],                 # polars>=0.20 支持；不支持就改成下面的 fallback
        )
    except TypeError:
        # Fallback：分天写（比原先好，因为前面的 join/with_columns 已经批量完成，只剩 IO）
        for d in chunk:
            out_path = Path(clean_dir) / f"day={d}.parquet"
            df_chunk.filter(pl.col(g_date) == d).write_parquet(out_path.as_posix(), compression="zstd")
    print(f"[{_now()}] chunk {ci}/{len(day_chunks)} -> days {chunk[0]}..{chunk[-1]} written")

print(f"[{_now()}] cleaned & standardized shards written to {clean_dir}")


In [None]:
# === Cell 1: Template & Validation ===
import glob, numpy as np, polars as pl, pandas as pd
from sklearn.preprocessing import FunctionTransformer
from pytorch_forecasting import TimeSeriesDataSet
from pytorch_forecasting.data import NaNLabelEncoder

# 假设以下变量已由你前置代码给出：
# cfg, P, clean_dir, folds_by_day, g_sym, g_date, g_time, target_col, weight_col
# time_features, z_cols, namark_cols, ENC_LEN, PRED_LEN, BATCH_SIZE

fold_idx = 1
train_days, val_days = folds_by_day[fold_idx-1]
days_sorted = np.sort(train_days)

# ---- 验证集 ----
val_paths = []
for d in val_days:
    val_paths.extend(glob.glob(f"{clean_dir}/date_id={d}/*.parquet"))

pdf_val = pl.scan_parquet(val_paths).collect(streaming=True).to_pandas()
pdf_val[g_sym] = pdf_val[g_sym].astype("str").astype("category")

# ---- Template（用第一天分片建立，固化 encoders/scalers）----
first_train_day = int(days_sorted[0])
tmpl_paths = glob.glob(f"{clean_dir}/date_id={first_train_day}/*.parquet")
pdf_tmpl = pl.scan_parquet(tmpl_paths).collect(streaming=True).to_pandas()
pdf_tmpl[g_sym] = pdf_tmpl[g_sym].astype("str").astype("category")

unknown_reals = time_features + z_cols + namark_cols
# sklearn 的 FunctionTransformer 在 func=None 时就是恒等变换，并且是可 pickling 的
identity_scalers = {c: FunctionTransformer(validate=False) for c in unknown_reals}



template = TimeSeriesDataSet(
    pdf_tmpl.sort_values([g_sym, "time_idx"]),
    time_idx="time_idx",
    target=target_col,
    group_ids=[g_sym],
    weight=weight_col,
    max_encoder_length=ENC_LEN,
    max_prediction_length=PRED_LEN,
    static_categoricals=[g_sym],
    time_varying_unknown_reals=unknown_reals,
    target_normalizer=None,
    categorical_encoders={g_sym: NaNLabelEncoder(add_nan=True)},
    add_relative_time_idx=True, add_target_scales=True, add_encoder_length=True,
    allow_missing_timesteps=True,
    scalers=identity_scalers,
)

validation = TimeSeriesDataSet.from_dataset(
    template, data=pdf_val.sort_values([g_sym, "time_idx"]), stop_randomization=True
)
val_loader = validation.to_dataloader(
    train=False, batch_size=BATCH_SIZE, num_workers=4, pin_memory=True, persistent_workers=True
)

len(val_loader), pdf_tmpl.shape, pdf_val.shape


In [None]:
# === Cell 2: IterableDataset 真·流式（按分片产出 batch） + 加权 WR² ===
import random, torch, polars as pl, glob
from torch.utils.data import IterableDataset
from pytorch_forecasting import TimeSeriesDataSet

class ShardedBatchStream(IterableDataset):
    def __init__(
        self,
        template_tsd,
        shard_days,
        clean_dir: str,
        g_sym: str,
        batch_size: int = 1024,
        num_workers: int = 4,
        shuffle_within_shard: bool = True,
        buffer_batches: int = 0,
        seed: int = 42,
    ):
        super().__init__()
        self.template = template_tsd
        self.days = list(map(int, shard_days))
        self.clean_dir = clean_dir
        self.g_sym = g_sym
        self.batch_size = batch_size
        self.num_workers = num_workers
        self.shuffle_within_shard = shuffle_within_shard
        self.buffer_batches = buffer_batches
        self.seed = seed

    def __iter__(self):
        rng = random.Random(self.seed)
        days = self.days[:]
        rng.shuffle(days)

        from collections import deque
        buf = deque()

        for d in days:
            paths = glob.glob(f"{self.clean_dir}/date_id={d}/*.parquet")
            if not paths:
                continue

            pdf = pl.scan_parquet(paths).collect(streaming=True).to_pandas()
            if pdf.empty:
                continue
            pdf[self.g_sym] = pdf[self.g_sym].astype("str").astype("category")

            tsds = TimeSeriesDataSet.from_dataset(
                self.template,
                data=pdf.sort_values([self.g_sym, "time_idx"]),
                stop_randomization=True,
            )

            dl = tsds.to_dataloader(
                train=True,
                batch_size=self.batch_size,
                num_workers=self.num_workers,
                shuffle=self.shuffle_within_shard,
                pin_memory=True,
                persistent_workers=self.num_workers > 0,
            )

            for batch in dl:
                if self.buffer_batches > 0:
                    buf.append(batch)
                    if len(buf) >= self.buffer_batches:
                        k = rng.randrange(len(buf))
                        if k:
                            buf.rotate(-k)
                        yield buf.popleft()
                else:
                    yield batch

        while buf:
            yield buf.popleft()




In [None]:

def weighted_wr2_tensor(yhat: torch.Tensor, y: torch.Tensor, w: torch.Tensor, eps: float = 1e-9):
    x  = yhat.reshape(-1)
    yt = y.reshape(-1)
    wt = w.reshape(-1)
    m = torch.isfinite(x) & torch.isfinite(yt) & torch.isfinite(wt)
    x, yt, wt = x[m], yt[m], torch.clamp(wt[m], min=0.0)
    ws = torch.clamp(wt.sum(), min=eps)
    mx = (wt * x ).sum()/ws
    my = (wt * yt).sum()/ws
    cov = (wt * (x - mx) * (yt - my)).sum()/ws
    vx  = (wt * (x - mx )**2).sum()/ws
    vy  = (wt * (yt - my)**2).sum()/ws
    return (cov * cov) / (torch.clamp(vx, min=eps) * torch.clamp(vy, min=eps) + eps)

class LogWeightedWR2Callback(Callback):
    """验证阶段逐批前向；y_true 与 w 直接来自 batch 第二项 (target, weight)。"""
    def __init__(self):
        super().__init__()
        self._buf_yhat, self._buf_y, self._buf_w = [], [], []

    def on_validation_epoch_start(self, trainer, pl_module):
        self._buf_yhat.clear(); self._buf_y.clear(); self._buf_w.clear()

    @torch.no_grad()
    def on_validation_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx=0):
        # ✅ 官方接口：batch 是 (x, (target, weight))
        x, (y, w) = batch
        # 当前设备前向（Lightning 已放好设备）
        out   = pl_module(x)
        yhat  = out["prediction"]      # [B, pred_len, 1] 或 [B, pred_len]
        # 对齐维度：若 yhat 是 [...,1] 则挤掉最后一维
        if yhat.dim() == 3 and yhat.size(-1) == 1:
            yhat = yhat.squeeze(-1)
        # 累计
        self._buf_yhat.append(yhat.reshape(-1).detach())
        self._buf_y.append(y.reshape(-1).detach())
        # 有些数据没有提供逐时间步 weight，会给 shape=[B, pred_len]；正好与 y 对齐
        self._buf_w.append(w.reshape(-1).detach())

    @torch.no_grad()
    def on_validation_epoch_end(self, trainer, pl_module):
        if not self._buf_yhat:
            return
        yhat = torch.cat(self._buf_yhat, dim=0)
        y    = torch.cat(self._buf_y,    dim=0)
        w    = torch.cat(self._buf_w,    dim=0)
        wr2_w = weighted_wr2_tensor(yhat, y, w).float().cpu()
        trainer.callback_metrics["val_WR2_weighted"] = wr2_w
        if trainer.logger is not None:
            trainer.logger.log_metrics({"val_WR2_weighted": wr2_w.item()}, step=trainer.global_step)


In [None]:
# === Cell 3: 修正后的带权 WR² 回调 + 模型/Trainer ===
import torch, pandas as pd, lightning as L
from lightning.pytorch.callbacks import EarlyStopping, ModelCheckpoint, LearningRateMonitor, Callback
from lightning.pytorch.loggers import TensorBoardLogger
from pytorch_forecasting import TemporalFusionTransformer
from pytorch_forecasting.metrics import MAE

import torch
from lightning.pytorch.callbacks import Callback


# 设置随机种子
L.seed_everything(int(cfg.get("seed", 42)), workers=True)




callbacks = [
    LogWeightedWR2Callback(),
    EarlyStopping(monitor="val_WR2_weighted", mode="max", patience=5),
    ModelCheckpoint(monitor="val_WR2_weighted", mode="max", save_top_k=1,
                    filename="tft-best-{epoch:02d}-{val_WR2_weighted:.5f}"),
    LearningRateMonitor(logging_interval="step"),
]


logger = TensorBoardLogger(save_dir=P("local", "tft/logs"), name="tft", default_hp_metric=False)

VAL_EVERY_STEPS = 10
trainer = L.Trainer(
    accelerator="gpu", devices=1, precision=32,
    max_epochs=1,
    val_check_interval=VAL_EVERY_STEPS,
    num_sanity_val_steps=0,
    gradient_clip_val=0.5,
    log_every_n_steps=50,
    callbacks=callbacks,
    logger=logger,
    default_root_dir=P("local", "tft/ckpts"),
)


# 创建模型
tft = TemporalFusionTransformer.from_dataset(
    template,
    loss=MAE(),
    logging_metrics=[val_loss],
    learning_rate=float(cfg.get("tft", {}).get("lr", 1e-3)),
    hidden_size=int(cfg.get("tft", {}).get("hidden_size", 128)),
    attention_head_size=int(cfg.get("tft", {}).get("heads", 4)),
    dropout=float(cfg.get("tft", {}).get("dropout", 0.2)),
    reduce_on_plateau_patience=4,
)

print("model device (before fit):", next(tft.parameters()).device)




In [None]:
# === Cell 4: 构建流式训练集 & 单次 fit ===
train_stream = ShardedBatchStream(
    template_tsd=template,
    shard_days=days_sorted,
    clean_dir=clean_dir,
    g_sym=g_sym,
    batch_size=BATCH_SIZE,
    num_workers=4,
    shuffle_within_shard=True,
    buffer_batches=16,   # 0 代表关闭跨分片缓冲打乱；8~64 可微调
    seed=42,
)

trainer.fit(tft, train_dataloaders=train_stream, val_dataloaders=val_loader)


In [None]:
from __future__ import annotations
import time
from pathlib import Path
import numpy as np
import pandas as pd
import polars as pl
import torch
import lightning as L
from lightning.pytorch.callbacks import EarlyStopping, ModelCheckpoint, LearningRateMonitor
from lightning.pytorch.loggers import TensorBoardLogger

from pytorch_forecasting import TimeSeriesDataSet, TemporalFusionTransformer
from pytorch_forecasting.data import NaNLabelEncoder
from pytorch_forecasting.metrics import MAE  # 用 MAE 训练
from pytorch_forecasting.data.encoders import TorchNormalizer

# 项目内工具
from pipeline.io import cfg, P, fs, storage_options, ensure_dir_local

def _now(): 
    return time.strftime("%Y-%m-%d %H:%M:%S")

In [None]:

print(f"[{_now()}][tft] ===== start =====")
target_col = cfg["target"]                 # 如 "responder_6"
g_sym, g_date, g_time = cfg["keys"]        # 如 ("symbol_id","date_id","time_id")
weight_col = cfg["weight"]
TIME_SORT = cfg["sorts"].get("time_major", [g_date, g_time, g_sym])

# 1) 选择特征列（示例：你后续替换为真实列表）
time_features = ["time_pos", "time_sin", "time_cos", "time_bucket"]
base_features   = ["feature_36", "feature_06"]                 # TODO: 放入你的真实特征集合
resp_his_feats  = ["responder_5_prevday_std", "responder_3_prevday_std", "responder_4_prev_tail_d1"]  # 示例
feat_his_feats  = ["feature_08__ewm5", "feature_53__rstd3"]           # 示例
feature_cols = list(dict.fromkeys(base_features + resp_his_feats + feat_his_feats))

need_cols = list(dict.fromkeys(cfg["keys"] + [weight_col] + [target_col] + time_features + feature_cols))

# 2) 读 panel（Lazy） & 构 grid
panel_dir = P("az", cfg["paths"].get("panel_shards", "panel_shards"))
glob_pat  = f"{panel_dir}/*.parquet"
if not fs.glob(glob_pat.replace("az://", "")):
    raise FileNotFoundError(f"No parquet shards under: {glob_pat}")
lf = pl.scan_parquet(glob_pat, storage_options=storage_options)

grid_path = P("local", "tft/panel/grid_timeidx.parquet")
if not Path(grid_path).exists():
    lf_grid = (
        lf.select([g_date, g_time]).unique()
        .sort([g_date, g_time])
        .with_row_index("time_idx")
        .with_columns(pl.col("time_idx").cast(pl.Int64))
    )
    ensure_dir_local(Path(grid_path).parent.as_posix())
    lf_grid.collect(streaming=True).write_parquet(grid_path, compression="zstd")
    print(f"[{_now()}][tft] grid saved -> {grid_path}")
grid_lazy = pl.scan_parquet(grid_path)

# 全局 time_idx 连续性（安全检查）
grid_df = grid_lazy.select([g_date, g_time, "time_idx"]).collect()
ti = grid_df["time_idx"]
assert grid_df.select(pl.col("time_idx").is_duplicated().any()).item() is False
assert ti.max() - ti.min() + 1 == len(ti), "全局 time_idx 不连续"

# 3) 时间窗 + join time_idx + 选列
lo = cfg["dates"]["tft_dates"]["date_lo"]; hi = cfg["dates"]["tft_dates"]["date_hi"]
lw = lf.filter(pl.col(g_date).is_between(lo, hi, closed="both"))
lf0 = (
    lw.join(grid_lazy, on=[g_date, g_time], how="left")
        .select(need_cols + ["time_idx"])
        .sort(TIME_SORT)
)
print(f"[{_now()}][tft] schema -> {lf0.collect_schema().names()}")

In [None]:



# 转换 symbol_id 为 category
lw_with_idx = lw_with_idx.with_columns(pl.col(g_sym).cast(str).cast(pl.Categorical))

# 1) 把 ±inf 变成 null（只对 feature_cols）
lw_with_idx = lw_with_idx.with_columns([
    pl.when(pl.col(c).is_infinite()).then(None).otherwise(pl.col(c)).alias(c)
    for c in feature_cols
])

# 2) 缺失标记（基于上一步的 null）
miss_flags = [f"{c}__isna" for c in feature_cols]
lw_with_idx = lw_with_idx.with_columns([
    pl.col(c).is_null().cast(pl.Int8).alias(f"{c}__isna")
    for c in feature_cols
])

# 3) 组内前向填充 + 兜底 0（仍然只作用于 feature_cols）
#    注意：over(g_sym) 需要 g_sym 在表中；确保你的数据按时间排序后再用
lw_with_idx = lw_with_idx.with_columns([
    pl.col(c).fill_null(strategy="forward").over(g_sym).fill_null(0.0).alias(c)
    for c in feature_cols
])

In [None]:
def make_sliding_cv_by_days(days: np.ndarray, *, n_splits: int, gap_days: int = 5, train_to_val: int = 9):
    days = np.asarray(days).ravel()  # 唯一天且已排序的天列表
    N, R, K, G = len(days), int(train_to_val), int(n_splits), int(gap_days)
    usable = N - G
    if usable <= 0 or K <= 0 or R <= 0:
        return []

    V_base, rem = divmod(usable, R + K)
    if V_base <= 0:
        return []
    T = R * V_base
    v_lens = [V_base + 1 if i < rem else V_base for i in range(K)]

    v_lo = T + G
    folds = []
    for V_i in v_lens:
        v_hi  = v_lo + V_i
        tr_hi = v_lo - G
        tr_lo = tr_hi - T
        if tr_lo < 0 or v_hi > N:
            break
        train_days = days[tr_lo:tr_hi]   # 连续天段
        val_days   = days[v_lo:v_hi]
        folds.append((train_days, val_days))
        v_lo = v_hi
    return folds

# 生成“全量天列表”
all_days = lw_with_idx.select(pl.col(g_date)).unique().sort(by=g_date).collect().get_column(g_date).to_numpy()

folds_by_day = make_sliding_cv_by_days(all_days, n_splits=5, gap_days=7, train_to_val=4)

for k, (tr_days, va_days) in enumerate(folds_by_day, 1):
    lf_tr = lw_with_idx.filter(pl.col(g_date).is_in(tr_days))
    lf_va = lw_with_idx.filter(pl.col(g_date).is_in(va_days))

In [None]:


# 4) 取一个窗口 demo → pandas（你可以换成全量）
demo_lo, demo_hi = 1610, 1698
df = (
    lw_with_idx
    .filter(pl.col(g_date).is_between(demo_lo, demo_hi, closed="both"))
    .collect(streaming=True)
    .to_pandas()
).sort_values([g_sym, "time_idx"])

# 打印df实际日期范围
print(f"实际日期范围: {df[g_date].min()} 到 {df[g_date].max()}")

# 类型
df[g_sym] = df[g_sym].astype("string").astype("category")
df["time_idx"] = df["time_idx"].astype("int64")


In [None]:
from pytorch_forecasting import TimeSeriesDataSet, TemporalFusionTransformer
from pytorch_forecasting.data import NaNLabelEncoder, GroupNormalizer

# unknown_reals：特征 + 标记（不包含 weight）
unknown_reals = list(dict.fromkeys(time_features + feature_cols + miss_flags ))
# TimeSeriesDataSet —— 把权重列作为 weight 传入；不要放进 unknown_reals； 按照group对自变量做标准化
training = TimeSeriesDataSet(
    train_df.sort_values([g_sym, "time_idx"]),
    time_idx="time_idx",
    target=target_col,
    group_ids=[g_sym],
    weight=weight_col,       # 关键：启用“样本加权”
    
    max_encoder_length=36,
    max_prediction_length=1,
    static_categoricals=[g_sym],
    time_varying_known_categoricals=[],
    time_varying_known_reals=[],
    time_varying_unknown_categoricals=[],
    time_varying_unknown_reals=unknown_reals,   # 不含 weight
    
    target_normalizer=None,
    scalers={c: GroupNormalizer(groups=[g_sym]) for c in feature_cols},
    
    categorical_encoders={g_sym: NaNLabelEncoder(add_nan=True)},
    add_relative_time_idx=True, add_target_scales=True, add_encoder_length=True,
    allow_missing_timesteps=True,
)

validation = TimeSeriesDataSet.from_dataset(training, val_df, stop_randomization=True)

train_loader = training.to_dataloader(
    train=True, batch_size=int(cfg.get("tft",{}).get("batch_size", 1024)), num_workers=4
)
val_loader = validation.to_dataloader(
    train=False, batch_size=int(cfg.get("tft",{}).get("batch_size", 1024)), num_workers=4
)

from pytorch_forecasting.metrics import MAE  # 你这版有 MSE 就换 MSE()
tft = TemporalFusionTransformer.from_dataset(
    training,
    loss=MAE(),
    learning_rate=float(cfg.get("tft", {}).get("lr", 1e-3)),
    hidden_size=int(cfg.get("tft", {}).get("hidden_size", 128)),
    attention_head_size=int(cfg.get("tft", {}).get("heads", 4)),
    dropout=float(cfg.get("tft", {}).get("dropout", 0.2)),
    reduce_on_plateau_patience=4,
)

from lightning.pytorch.callbacks import EarlyStopping, ModelCheckpoint, LearningRateMonitor
from lightning.pytorch.loggers import TensorBoardLogger
import lightning as L

logger = TensorBoardLogger(save_dir=P("local", "tft/logs"), name="tft", default_hp_metric=False)
callbacks = [
    EarlyStopping(monitor="val_loss", mode="min", patience=5),
    ModelCheckpoint(monitor="val_loss", mode="min", save_top_k=1, filename="tft-best-{epoch:02d}-{val_loss:.5f}"),
    LearningRateMonitor(logging_interval="step"),
]
trainer = L.Trainer(
    max_epochs=int(cfg.get("tft", {}).get("max_epochs", 2)),
    accelerator="auto",
    precision=32,          # 稳定优先，跑顺了再试 bf16
    gradient_clip_val=0.5,
    log_every_n_steps=50,
    callbacks=callbacks,
    logger=logger,
    default_root_dir=P("local", "tft/ckpts"),
)

L.seed_everything(int(cfg.get("seed", 42)), workers=True)
trainer.fit(tft, train_loader, val_loader)


In [None]:
# 覆盖整个 validation dataloader 的预测
pred = tft.predict(val_loader, mode="prediction")
# pred 通常是 [N, max_prediction_length]；你现在 max_prediction_length=1
import torch, numpy as np
if isinstance(pred, torch.Tensor):
    pred = pred.detach().cpu().numpy()
pred = np.asarray(pred)
y_pred = pred.squeeze()  # 变成 [N]


In [None]:
res = tft.predict(val_loader, mode="prediction", return_index=True)
y_pred = res.output
if isinstance(y_pred, torch.Tensor):
    y_pred = y_pred.detach().cpu()
# 压成一维（max_prediction_length=1 的情况）

y_pred = y_pred[:, -1]
y_pred = np.asarray(y_pred)
y_pred = pd.DataFrame(y_pred, columns=["y_pred"])

y_pred_index = res.index
# join 真实值 val_df.target

y_actual = y_pred_index.merge(
    val_df[[g_sym, "time_idx", target_col]],
    on=[g_sym, "time_idx"],
    how="left"
)
# 拼接真实值与预测值
y_pred_actual = pd.concat([y_actual, y_pred], axis=1)

# 画图展示
y_pred_actual.plot(x="time_idx", y=[target_col, "y_pred"], kind="line", alpha=0.7)  