In [None]:
from __future__ import annotations
import time
from pathlib import Path
import numpy as np
import pandas as pd
import polars as pl
import torch
import lightning as L
from lightning.pytorch.callbacks import EarlyStopping, ModelCheckpoint, LearningRateMonitor
from lightning.pytorch.loggers import TensorBoardLogger

from pytorch_forecasting import TimeSeriesDataSet, TemporalFusionTransformer
from pytorch_forecasting.data import NaNLabelEncoder
from pytorch_forecasting.metrics import MAE  # 用 MAE 训练
from pytorch_forecasting.data.encoders import TorchNormalizer

# 项目内工具
from pipeline.io import cfg, P, fs, storage_options, ensure_dir_local

def _now(): 
    return time.strftime("%Y-%m-%d %H:%M:%S")

In [None]:
# ---------- 工具函数 ----------
def add_missing_flags_and_fill(df: pd.DataFrame, group_col: str, cont_cols: list[str]) -> tuple[list[str], pd.DataFrame]:
    """连续特征：添加 __isna 标记，组内 ffill，兜底 0 填充。"""
    if not cont_cols:
        return [], df
    df = df.copy()
    df[cont_cols] = df[cont_cols].replace([np.inf, -np.inf], np.nan)
    flags = []
    g = df.groupby(group_col, observed=False)
    for c in cont_cols:
        flag = f"{c}__isna"
        flags.append(flag)
        df[flag] = df[c].isna().astype("int8")
        df[c] = g[c].ffill()
        df[c] = df[c].fillna(0.0)
    return flags, df


In [None]:

# ---------- 主流程 ----------

print(f"[{_now()}][tft] ===== start =====")
target_col = cfg["target"]                 # 如 "responder_6"
g_sym, g_date, g_time = cfg["keys"]        # 如 ("symbol_id","date_id","time_id")
weight_col = cfg["weight"]
TIME_SORT = cfg["sorts"].get("time_major", [g_date, g_time, g_sym])

# 1) 选择特征列（示例：你后续替换为真实列表）
time_features = ["time_pos", "time_sin", "time_cos", "time_bucket"]
base_features   = ["feature_36", "feature_06"]                 # TODO: 放入你的真实特征集合
resp_his_feats  = ["responder_5_prevday_std", "responder_3_prevday_std", "responder_4_prev_tail_d1"]  # 示例
feat_his_feats  = ["feature_08__ewm5", "feature_53__rstd3"]           # 示例
feature_cols = list(dict.fromkeys(base_features + resp_his_feats + feat_his_feats))

need_cols = list(dict.fromkeys(cfg["keys"] + [weight_col] + [target_col] + time_features + feature_cols))

# 2) 读 panel（Lazy） & 构 grid
panel_dir = P("az", cfg["paths"].get("panel_shards", "panel_shards"))
glob_pat  = f"{panel_dir}/*.parquet"
if not fs.glob(glob_pat.replace("az://", "")):
    raise FileNotFoundError(f"No parquet shards under: {glob_pat}")
lf = pl.scan_parquet(glob_pat, storage_options=storage_options)

grid_path = P("local", "tft/panel/grid_timeidx.parquet")
if not Path(grid_path).exists():
    lf_grid = (
        lf.select([g_date, g_time]).unique()
        .sort([g_date, g_time])
        .with_row_index("time_idx")
        .with_columns(pl.col("time_idx").cast(pl.Int64))
    )
    ensure_dir_local(Path(grid_path).parent.as_posix())
    lf_grid.collect(streaming=True).write_parquet(grid_path, compression="zstd")
    print(f"[{_now()}][tft] grid saved -> {grid_path}")
grid_lazy = pl.scan_parquet(grid_path)

# 全局 time_idx 连续性（安全检查）
grid_df = grid_lazy.select([g_date, g_time, "time_idx"]).collect()
ti = grid_df["time_idx"]
assert grid_df.select(pl.col("time_idx").is_duplicated().any()).item() is False
assert ti.max() - ti.min() + 1 == len(ti), "全局 time_idx 不连续"

# 3) 时间窗 + join time_idx + 选列
lo = cfg["dates"]["tft_dates"]["date_lo"]; hi = cfg["dates"]["tft_dates"]["date_hi"]
lw = lf.filter(pl.col(g_date).is_between(lo, hi, closed="both"))
lw_with_idx = (
    lw.join(grid_lazy, on=[g_date, g_time], how="left")
        .select(need_cols + ["time_idx"])
        .sort(TIME_SORT)
)
print(f"[{_now()}][tft] schema -> {lw_with_idx.collect_schema().names()}")

# 4) 取一个窗口 demo → pandas（你可以换成全量）
demo_lo, demo_hi = 1610, 1698
df = (
    lw_with_idx
    .filter(pl.col(g_date).is_between(demo_lo, demo_hi, closed="both"))
    .collect(streaming=True)
    .to_pandas()
).sort_values([g_sym, "time_idx"])

# 打印df实际日期范围
print(f"实际日期范围: {df[g_date].min()} 到 {df[g_date].max()}")

# 类型
df[g_sym] = df[g_sym].astype("string").astype("category")
df["time_idx"] = df["time_idx"].astype("int64")


In [None]:
df[target_col].describe()  # 查看 target 分布

In [None]:
df[weight_col].describe()  

In [None]:
df.info()

In [None]:
df.isnull().sum()  # 查看缺失情况

In [None]:
# 缺失处理（只作用于连续特征，不动 target/weight）
miss_flags, df = add_missing_flags_and_fill(df, g_sym, feature_cols)

# 降精度
for c in feature_cols:
    if c in df.columns:
        df[c] = pd.to_numeric(df[c], downcast="float")



In [None]:
df.info()

In [None]:
df.describe()

In [None]:
# 因果切分
cutoff = int(df["time_idx"].quantile(0.9))
train_df = df[df["time_idx"] <= cutoff].copy()
val_df   = df[df["time_idx"] >  cutoff].copy()

In [None]:
val_df.shape

In [None]:
val_df.head()

In [None]:
val_df["symbol_id"].nunique()

In [None]:
from pytorch_forecasting import TimeSeriesDataSet, TemporalFusionTransformer
from pytorch_forecasting.data import NaNLabelEncoder, GroupNormalizer

# unknown_reals：特征 + 标记（不包含 weight）
unknown_reals = list(dict.fromkeys(feature_cols + miss_flags))
# TimeSeriesDataSet —— 把权重列作为 weight 传入；不要放进 unknown_reals； 按照group对自变量做标准化
training = TimeSeriesDataSet(
    train_df.sort_values([g_sym, "time_idx"]),
    time_idx="time_idx",
    target=target_col,
    group_ids=[g_sym],
    weight=weight_col,       # 关键：启用“样本加权”
    
    max_encoder_length=36,
    max_prediction_length=1,
    static_categoricals=[g_sym],
    time_varying_known_categoricals=[],
    time_varying_known_reals=[],
    time_varying_unknown_categoricals=[],
    time_varying_unknown_reals=unknown_reals,   # 不含 weight
    
    target_normalizer=None,
    scalers={c: GroupNormalizer(groups=[g_sym]) for c in feature_cols},
    
    categorical_encoders={g_sym: NaNLabelEncoder(add_nan=True)},
    add_relative_time_idx=True, add_target_scales=True, add_encoder_length=True,
    allow_missing_timesteps=True,
)

validation = TimeSeriesDataSet.from_dataset(training, val_df, stop_randomization=True)

train_loader = training.to_dataloader(
    train=True, batch_size=int(cfg.get("tft",{}).get("batch_size", 1024)), num_workers=4
)
val_loader = validation.to_dataloader(
    train=False, batch_size=int(cfg.get("tft",{}).get("batch_size", 1024)), num_workers=4
)



In [None]:
from pytorch_forecasting.metrics import MAE  # 你这版有 MSE 就换 MSE()
tft = TemporalFusionTransformer.from_dataset(
    training,
    loss=MAE(),
    learning_rate=float(cfg.get("tft", {}).get("lr", 1e-3)),
    hidden_size=int(cfg.get("tft", {}).get("hidden_size", 128)),
    attention_head_size=int(cfg.get("tft", {}).get("heads", 4)),
    dropout=float(cfg.get("tft", {}).get("dropout", 0.2)),
    reduce_on_plateau_patience=4,
)

from lightning.pytorch.callbacks import EarlyStopping, ModelCheckpoint, LearningRateMonitor
from lightning.pytorch.loggers import TensorBoardLogger
import lightning as L

logger = TensorBoardLogger(save_dir=P("local", "tft/logs"), name="tft", default_hp_metric=False)
callbacks = [
    EarlyStopping(monitor="val_loss", mode="min", patience=5),
    ModelCheckpoint(monitor="val_loss", mode="min", save_top_k=1, filename="tft-best-{epoch:02d}-{val_loss:.5f}"),
    LearningRateMonitor(logging_interval="step"),
]
trainer = L.Trainer(
    max_epochs=int(cfg.get("tft", {}).get("max_epochs", 2)),
    accelerator="auto",
    precision=32,          # 稳定优先，跑顺了再试 bf16
    gradient_clip_val=0.5,
    log_every_n_steps=50,
    callbacks=callbacks,
    logger=logger,
    default_root_dir=P("local", "tft/ckpts"),
)

L.seed_everything(int(cfg.get("seed", 42)), workers=True)
trainer.fit(tft, train_loader, val_loader)


In [None]:
# 覆盖整个 validation dataloader 的预测
pred = tft.predict(val_loader, mode="prediction")
# pred 通常是 [N, max_prediction_length]；你现在 max_prediction_length=1
import torch, numpy as np
if isinstance(pred, torch.Tensor):
    pred = pred.detach().cpu().numpy()
pred = np.asarray(pred)
y_pred = pred.squeeze()  # 变成 [N]


In [None]:
res = tft.predict(val_loader, mode="prediction", return_index=True)
y_pred = res.output
if isinstance(y_pred, torch.Tensor):
    y_pred = y_pred.detach().cpu()
# 压成一维（max_prediction_length=1 的情况）

y_pred = y_pred[:, -1]
y_pred = np.asarray(y_pred)
y_pred = pd.DataFrame(y_pred, columns=["y_pred"])

y_pred_index = res.index
# join 真实值 val_df.target

y_actual = y_pred_index.merge(
    val_df[[g_sym, "time_idx", target_col]],
    on=[g_sym, "time_idx"],
    how="left"
)
# 拼接真实值与预测值
y_pred_actual = pd.concat([y_actual, y_pred], axis=1)

# 画图展示
y_pred_actual.plot(x="time_idx", y=[target_col, "y_pred"], kind="line", alpha=0.7)  