In [None]:
from __future__ import annotations


import time, numpy as np, polars as pl, pandas as pd, torch, lightning as L
from pathlib import Path

from lightning.pytorch.callbacks import EarlyStopping, ModelCheckpoint, LearningRateMonitor
from lightning.pytorch.loggers import TensorBoardLogger
from pytorch_forecasting import TimeSeriesDataSet, TemporalFusionTransformer
from pytorch_forecasting.data import NaNLabelEncoder, GroupNormalizer
from pytorch_forecasting.metrics import MAE
from pytorch_forecasting.data.encoders import TorchNormalizer
# 你的工程工具
from pipeline.io import cfg, P, fs, storage_options, ensure_dir_local

def _now(): return time.strftime("%Y-%m-%d %H:%M:%S")
print(f"[{_now()}] imports ok")

# ---- 关键配置（按需改）----
target_col = cfg["target"]                 # e.g. "responder_6"
g_sym, g_date, g_time = cfg["keys"]        # e.g. ("symbol_id","date_id","time_id")
weight_col = cfg["weight"]

time_features = ["time_pos", "time_sin", "time_cos", "time_bucket"]
base_features   = ["feature_36", "feature_06"]
resp_his_feats  = ["responder_5_prevday_std", "responder_3_prevday_std", "responder_4_prev_tail_d1"]
feat_his_feats  = ["feature_08__ewm5", "feature_53__rstd3"]
feature_cols = list(dict.fromkeys(base_features + resp_his_feats + feat_his_feats))

need_cols = list(dict.fromkeys([g_sym, g_date, g_time, weight_col, target_col] + time_features + feature_cols))

# CV & 训练
N_SPLITS   = 5
GAP_DAYS   = 7
TRAIN_TO_VAL = 4               # 训练:验证 = 4:1
ENC_LEN    = 36
PRED_LEN   = 1
BATCH_SIZE = 1024
LR = 1e-3
HIDDEN     = 128
HEADS      = 4
DROPOUT    = 0.2
MAX_EPOCHS_PER_SHARD = 1
CHUNK_DAYS = 15               # 训练分片：每片多少天

print("config ready")

panel_dir = P("az", cfg["paths"].get("panel_shards", "panel_shards"))
glob_pat  = f"{panel_dir}/*.parquet"
if not fs.glob(glob_pat.replace("az://", "")):
    raise FileNotFoundError(f"No parquet shards under: {glob_pat}")

lf_raw = pl.scan_parquet(glob_pat, storage_options=storage_options)

# 全局 time_idx 仅一次
grid_path = P("local", "tft/panel/grid_timeidx.parquet")
if not Path(grid_path).exists():
    lf_grid = (
        lf_raw.select([g_date, g_time]).unique()
              .sort([g_date, g_time])
              .with_row_index("time_idx")
              .with_columns(pl.col("time_idx").cast(pl.Int64))
    )
    ensure_dir_local(Path(grid_path).parent.as_posix())
    lf_grid.collect(streaming=True).write_parquet(grid_path, compression="zstd")

grid_lazy = pl.scan_parquet(grid_path)

# 接入 time_idx + 只保留所需列（仍是 Lazy）
lf0 = (
    lf_raw.join(grid_lazy, on=[g_date, g_time], how="left")
          .select(need_cols + ["time_idx"])
          .sort([g_date, g_time, g_sym])
)

print(f"[{_now()}] lazyframe ready")

# 全量天列表
all_days = (
    lf0.select(pl.col(g_date)).unique().sort(by=g_date)
       .collect(streaming=True).get_column(g_date).to_numpy()
)

def make_sliding_cv_by_days(all_days: np.ndarray, *, n_splits: int, gap_days: int, train_to_val: int):
    all_days = np.asarray(all_days).ravel()
    K, R, G = n_splits, train_to_val, gap_days
    usable = len(all_days) - G
    if usable <= 0 or K <= 0 or R <= 0: return []
    V_base, rem = divmod(usable, R + K)
    if V_base <= 0: return []
    T = R * V_base
    v_lens = [V_base + 1 if i < rem else V_base for i in range(K)]
    folds, v_lo = [], T + G
    for V_i in v_lens:
        v_hi, tr_hi, tr_lo = v_lo + V_i, v_lo - G, v_lo - G - T
        if tr_lo < 0 or v_hi > len(all_days): break
        folds.append((all_days[tr_lo:tr_hi], all_days[v_lo:v_hi]))
        v_lo = v_hi
    return folds

folds_by_day = make_sliding_cv_by_days(all_days, n_splits=N_SPLITS, gap_days=GAP_DAYS, train_to_val=TRAIN_TO_VAL)
assert len(folds_by_day) > 0, "no CV folds constructed"

# 选“标准化统计区间”上界：第1折训练天的最大值（不含验证）
stats_hi = int(folds_by_day[0][0][-1])
print(f"stats_hi (for global z-score) = {stats_hi}; first-fold train days end at this day.")

clean_dir = P("local", "tft/clean_by_day")
ensure_dir_local(clean_dir)

# ========== 1) 连续特征：一次性处理 inf->null、打缺失标记、组内 ffill、兜底 0 ==========
# 先做 flag（要基于原始缺失），再做填充；合并成两次 with_columns，避免在 for 循环里多次改列
inf2null_exprs = [pl.when(pl.col(c).is_infinite()).then(None).otherwise(pl.col(c)).alias(c)
                  for c in feature_cols]
flags_exprs    = [pl.col(c).is_null().cast(pl.Int8).alias(f"{c}__isna")
                  for c in feature_cols]
fill_exprs     = [pl.col(c).forward_fill().over(g_sym).fill_null(0.0).alias(c)
                  for c in feature_cols]

lf_clean = (
    lf0.with_columns(inf2null_exprs)         # inf -> null
       .with_columns(flags_exprs)            # 缺失标记（基于原始缺失）
       .with_columns(fill_exprs)             # 组内 ffill + 兜底 0
)

# ========== 2) 训练区间（<= stats_hi）按组计算 mu/std（一次） ==========
lf_stats_sym = (
    lf_clean.filter(pl.col(g_date) <= stats_hi)
    .group_by(g_sym)
    .agg([pl.col(c).mean().alias(f"mu_{c}") for c in feature_cols] +
         [pl.col(c).std().alias(f"std_{c}") for c in feature_cols])
)

# 2) 训练期全局统计（作为回退）
lf_stats_glb = (
    lf_clean.filter(pl.col(g_date) <= stats_hi)
    .select([pl.col(c).mean().alias(f"mu_{c}_glb") for c in feature_cols] +
            [pl.col(c).std().alias(f"std_{c}_glb") for c in feature_cols])
)

# 3) 把全局统计作为常量列加到每行（cross join 方式）
lf_z = lf_clean.join(lf_stats_glb, how="cross")

# 4) join per-symbol 统计，并对每个特征做回退 & z-score

lf_z = lf_z.join(lf_stats_sym, on=g_sym, how="left")

# 对每个特征做回退 & z-score
eps = 1e-6
z_cols = []
for c in feature_cols:
    mu_c_sym, std_c_sym = f"mu_{c}", f"std_{c}"
    mu_c_glb, std_c_glb = f"mu_{c}_glb", f"std_{c}_glb"
    c_z = f"{c}_z"
    lf_z = lf_z.with_columns(
        pl.when(pl.col(mu_c_sym).is_null()).then(pl.col(mu_c_glb)).otherwise(pl.col(mu_c_sym)).alias(f"{c}_mu_use"),
        pl.when(pl.col(std_c_sym).is_null() | (pl.col(std_c_sym) == 0)).then(pl.col(std_c_glb)).otherwise(pl.col(std_c_sym)).alias(f"{c}_std_use")
    ).with_columns(
        ((pl.col(c) - pl.col(f"{c}_mu_use")) / (pl.col(f"{c}_std_use") + eps)).alias(c_z)
    ).drop([mu_c_glb, std_c_glb, mu_c_sym, std_c_sym, f"{c}_mu_use", f"{c}_std_use"])
    z_cols.append(c_z)

# 5) 输出列（z_特征 + isna 标记 + 时间/分类/目标/权重）
namark_cols = [f"{c}__isna" for c in feature_cols]
out_cols = [g_sym, g_date, g_time, "time_idx", weight_col, target_col] + time_features + z_cols + namark_cols

lf_out = lf_z.select(out_cols).sort([g_date, g_time, g_sym])

# ========== 4) 分块收集 + 分区写盘 ==========
# 关键：不要“逐天 collect”，而是每次收集一批天，然后一次性按 day 分区写入，显著减少 IO 次数
CHUNK_DAYS = 20  # 可根据机器内存/速度调整；比如 10~30 天一块
day_list = list(map(int, all_days))
day_chunks = [day_list[i:i+CHUNK_DAYS] for i in range(0, len(day_list), CHUNK_DAYS)]

for ci, chunk in enumerate(day_chunks, 1):
    df_chunk = (
        lf_out.filter(pl.col(g_date).is_in(chunk))
              .collect(streaming=True)              # 单次 collect 一大块
    )
    # 将symbol_id 转为 category
    
    # 利用 polars 的分区写出（需要较新的 polars；若你的版本不支持 partition_by，退化为循环写）
    try:
        df_chunk.write_parquet(
            Path(clean_dir).as_posix(),
            compression="zstd",
            # ★ 这行很关键：一次性把该 chunk 里所有 day 写成多个子文件
            partition_by=[g_date],                 # polars>=0.20 支持；不支持就改成下面的 fallback
        )
    except TypeError:
        # Fallback：分天写（比原先好，因为前面的 join/with_columns 已经批量完成，只剩 IO）
        for d in chunk:
            out_path = Path(clean_dir) / f"day={d}.parquet"
            df_chunk.filter(pl.col(g_date) == d).write_parquet(out_path.as_posix(), compression="zstd")
    print(f"[{_now()}] chunk {ci}/{len(day_chunks)} -> days {chunk[0]}..{chunk[-1]} written")

print(f"[{_now()}] cleaned & standardized shards written to {clean_dir}")

# === Cell 1: Template & Validation ===
import glob, numpy as np, polars as pl, pandas as pd
from sklearn.preprocessing import FunctionTransformer
from pytorch_forecasting import TimeSeriesDataSet
from pytorch_forecasting.data import NaNLabelEncoder

# 假设以下变量已由你前置代码给出：
# cfg, P, clean_dir, folds_by_day, g_sym, g_date, g_time, target_col, weight_col
# time_features, z_cols, namark_cols, ENC_LEN, PRED_LEN, BATCH_SIZE

fold_idx = 1
train_days, val_days = folds_by_day[fold_idx-1]
days_sorted = np.sort(train_days)

# ---- 验证集 ----
val_paths = []
for d in val_days:
    val_paths.extend(glob.glob(f"{clean_dir}/date_id={d}/*.parquet"))

pdf_val = pl.scan_parquet(val_paths).collect(streaming=True).to_pandas()
pdf_val[g_sym] = pdf_val[g_sym].astype("str").astype("category")

# ---- Template（用第一天分片建立，固化 encoders/scalers）----
first_train_day = int(days_sorted[0])
tmpl_paths = glob.glob(f"{clean_dir}/date_id={first_train_day}/*.parquet")
pdf_tmpl = pl.scan_parquet(tmpl_paths).collect(streaming=True).to_pandas()
pdf_tmpl[g_sym] = pdf_tmpl[g_sym].astype("str").astype("category")

unknown_reals = time_features + z_cols + namark_cols
# sklearn 的 FunctionTransformer 在 func=None 时就是恒等变换，并且是可 pickling 的
identity_scalers = {c: FunctionTransformer(validate=False) for c in unknown_reals}



template = TimeSeriesDataSet(
    pdf_tmpl.sort_values([g_sym, "time_idx"]),
    time_idx="time_idx",
    target=target_col,
    group_ids=[g_sym],
    weight=weight_col,
    max_encoder_length=ENC_LEN,
    max_prediction_length=PRED_LEN,
    static_categoricals=[g_sym],
    time_varying_unknown_reals=unknown_reals,
    target_normalizer=None,
    categorical_encoders={g_sym: NaNLabelEncoder(add_nan=True)},
    add_relative_time_idx=True, add_target_scales=True, add_encoder_length=True,
    allow_missing_timesteps=True,
    scalers=identity_scalers,
)

validation = TimeSeriesDataSet.from_dataset(
    template, data=pdf_val.sort_values([g_sym, "time_idx"]), stop_randomization=True
)
val_loader = validation.to_dataloader(
    train=False, batch_size=BATCH_SIZE, num_workers=4, pin_memory=True, persistent_workers=True
)

len(val_loader), pdf_tmpl.shape, pdf_val.shape

# === Cell 2: IterableDataset 真·流式（按分片产出 batch） + 加权 WR² ===
import random, torch, polars as pl, glob
from torch.utils.data import IterableDataset
from pytorch_forecasting import TimeSeriesDataSet

class ShardedBatchStream(IterableDataset):
    def __init__(
        self,
        template_tsd,
        shard_days,
        clean_dir: str,
        g_sym: str,
        batch_size: int = 1024,
        num_workers: int = 4,
        shuffle_within_shard: bool = True,
        buffer_batches: int = 0,
        seed: int = 42,
    ):
        super().__init__()
        self.template = template_tsd
        self.days = list(map(int, shard_days))
        self.clean_dir = clean_dir
        self.g_sym = g_sym
        self.batch_size = batch_size
        self.num_workers = num_workers
        self.shuffle_within_shard = shuffle_within_shard
        self.buffer_batches = buffer_batches
        self.seed = seed

    def __iter__(self):
        rng = random.Random(self.seed)
        days = self.days[:]
        rng.shuffle(days)

        from collections import deque
        buf = deque()

        for d in days:
            paths = glob.glob(f"{self.clean_dir}/date_id={d}/*.parquet")
            if not paths:
                continue

            pdf = pl.scan_parquet(paths).collect(streaming=True).to_pandas()
            if pdf.empty:
                continue
            pdf[self.g_sym] = pdf[self.g_sym].astype("str").astype("category")

            tsds = TimeSeriesDataSet.from_dataset(
                self.template,
                data=pdf.sort_values([self.g_sym, "time_idx"]),
                stop_randomization=True,
            )

            dl = tsds.to_dataloader(
                train=True,
                batch_size=self.batch_size,
                num_workers=self.num_workers,
                shuffle=self.shuffle_within_shard,
                pin_memory=True,
                persistent_workers=self.num_workers > 0,
            )

            for batch in dl:
                if self.buffer_batches > 0:
                    buf.append(batch)
                    if len(buf) >= self.buffer_batches:
                        k = rng.randrange(len(buf))
                        if k:
                            buf.rotate(-k)
                        yield buf.popleft()
                else:
                    yield batch

        while buf:
            yield buf.popleft()




In [None]:
# 设置随机种子
L.seed_everything(int(cfg.get("seed", 42)), workers=True)
early_stop_callback = EarlyStopping(
    monitor="val_loss", min_delta=1e-4, patience=10, verbose=False, mode="min"
)
lr_logger = LearningRateMonitor()  # log the learning rate
logger = TensorBoardLogger("lightning_logs")  # logging results to a tensorboard


# 创建模型
tft = TemporalFusionTransformer.from_dataset(
    template,
    loss=MAE(),
    learning_rate=float(cfg.get("tft", {}).get("lr", 1e-3)),
    hidden_size=int(cfg.get("tft", {}).get("hidden_size", 128)),
    attention_head_size=int(cfg.get("tft", {}).get("heads", 4)),
    dropout=float(cfg.get("tft", {}).get("dropout", 0.2)),
    reduce_on_plateau_patience=4,
)

VAL_EVERY_STEPS = 500
trainer = L.Trainer(
    accelerator="gpu", devices=1, precision=32,
    max_epochs=1,
    val_check_interval=VAL_EVERY_STEPS,
    num_sanity_val_steps=0,
    gradient_clip_val=0.5,
    log_every_n_steps=50,
    callbacks=early_stop_callback,
    logger=logger,
    default_root_dir=P("local", "tft/ckpts"),
)


In [None]:
# === Cell 4: 构建流式训练集 & 单次 fit ===
train_stream = ShardedBatchStream(
    template_tsd=template,
    shard_days=days_sorted,
    clean_dir=clean_dir,
    g_sym=g_sym,
    batch_size=BATCH_SIZE,
    num_workers=4,
    shuffle_within_shard=True,
    buffer_batches=16,   # 0 代表关闭跨分片缓冲打乱；8~64 可微调
    seed=42,
)

trainer.fit(tft, train_dataloaders=train_stream, val_dataloaders=val_loader)
