# 导入库

In [None]:
# -*- coding: utf-8 -*-
from __future__ import annotations

# ── 标准库 ──────────────────────────────────────────────────────────────────
import os
import time
from pathlib import Path
from collections import defaultdict
from datetime import datetime
import re

# ── 第三方 ──────────────────────────────────────────────────────────────────
import numpy as np
import pandas as pd
import polars as pl

import gc

import torch
import torch.backends.cudnn as cudnn
import lightning as L
import lightning.pytorch as lp
from torch.utils.data import DataLoader

from lightning.pytorch.callbacks import EarlyStopping, ModelCheckpoint, LearningRateMonitor
from lightning.pytorch.loggers import TensorBoardLogger
from lightning.pytorch.callbacks import DeviceStatsMonitor
from pytorch_forecasting import TimeSeriesDataSet, TemporalFusionTransformer, Baseline
from pytorch_forecasting.metrics import MAE, RMSE
from pytorch_forecasting.data.encoders import NaNLabelEncoder
from pytorch_forecasting.data import TorchNormalizer, GroupNormalizer


# 你的工程工具
from pipeline.io import cfg, P, fs, storage_options, ensure_dir_local, ensure_dir_az
from pipeline.stream_input_local import ShardedBatchStream  
from pipeline.wr2 import WR2

# ---- 性能/兼容开关（仅一次）----
os.environ.setdefault("POLARS_MAX_THREADS", str(max(1, os.cpu_count() // 2)))
pl.enable_string_cache()
cudnn.benchmark = True
torch.set_float32_matmul_precision("high")

import time as _t


import warnings
warnings.filterwarnings("ignore")  # avoid printing out absolute paths

def _now() -> str:
    return _t.strftime("%Y-%m-%d %H:%M:%S")
print(f"[{_now()}] imports ok")

# 定义工具函数

In [None]:
# ───────────────────────────────────────────────────────────────────────────
# 滑动窗划分
def make_sliding_cv_by_days(all_days: np.ndarray, *, n_splits: int, gap_days: int, train_to_val: int):
    all_days = np.asarray(all_days).ravel()
    K, R, G = n_splits, train_to_val, gap_days
    usable = len(all_days) - G
    if usable <= 0 or K <= 0 or R <= 0:
        return []
    V_base, rem = divmod(usable, R + K)
    if V_base <= 0:
        return []
    T = R * V_base
    v_lens = [V_base + 1 if i < rem else V_base for i in range(K)]
    folds, v_lo = [], T + G
    for V_i in v_lens:
        v_hi, tr_hi, tr_lo = v_lo + V_i, v_lo - G, v_lo - G - T
        if tr_lo < 0 or v_hi > len(all_days):
            break
        folds.append((all_days[tr_lo:tr_hi], all_days[v_lo:v_hi]))
        v_lo = v_hi
    return folds


# 初始化参数

In [None]:
start_date, end_date = (1400, 1600)  # 仅用于本次实验

# 读入筛选的所有特征列

path = Path("/mnt/data/js/exp/v1/models/tune/selected_covariant_features.txt")
filted_cols = path.read_text(encoding="utf-8").splitlines()

selected_features = pd.read_csv("/mnt/data/js/exp/v1/models/tune/feature_importance__fixed__fixed__mm_full_train__features__fs__1400-1698__cv2-g7-r4__seed42__top1000__1760906660__range830-1698__range830-1698__cv2-g7-r4__1760912739.csv")

df_cov_cols = selected_features[selected_features['feature'].isin(filted_cols)].copy()

# 我们这里重新归一化一下
df_cov_cols["mean_gain"] = (df_cov_cols['mean_gain'] / df_cov_cols['mean_gain'].sum()).astype(np.float32)
df_e_features = df_cov_cols.reset_index(drop=True)





# ========== 1) 初始化配置 ==========

# 所有列
G_SYM, G_DATE, G_TIME = cfg["keys"]          # e.g. ("symbol_id","date_id","time_id")
TARGET_COL = cfg["target"]                   # e.g. "responder_6"
WEIGHT_COL = cfg["weight"]                   # 允许为 None

TIME_FEATURES = ["time_bucket", "time_pos", "time_sin", "time_cos"]
COV_FEATURES = df_e_features['feature'].iloc[:100].tolist() # 可能含有组内常数列

STATIC_FEATURES = [c for c in COV_FEATURES if c in ["feature_09", "feature_10", "feature_11"]]

CS_RANK_FEATURES = [c for c in COV_FEATURES if c.endswith("__csrank")]

CS_R_Z_FEATURES = [c for c in COV_FEATURES if c.endswith("__cs_z") or c.endswith("__rz")]


do_z_co_cols = [c for c in COV_FEATURES if c not in STATIC_FEATURES + CS_RANK_FEATURES + CS_R_Z_FEATURES]


In [None]:
# 训练 & CV 超参
N_SPLITS     = 1
GAP_DAYS     = 0
TRAIN_TO_VAL = 8
ENC_LEN      = 50
DEC_LEN      = 1
PRED_LEN     = DEC_LEN
BATCH_SIZE   = 1024
LR           = 1e-2
HIDDEN       = 64
HEADS        = 2
DROPOUT      = 0.2
MAX_EPOCHS   = 5

# 数据路径
PANEL_DIR_AZ   = P("az", cfg["paths"].get("panel_shards", "panel_shards"))

TFT_LOCAL_ROOT = P("local", "tft"); ensure_dir_local(TFT_LOCAL_ROOT)

LOCAL_CLEAN_DIR = f"{TFT_LOCAL_ROOT}/clean"; ensure_dir_local(LOCAL_CLEAN_DIR)
CKPTS_DIR = Path(TFT_LOCAL_ROOT) / "ckpts"; ensure_dir_local(CKPTS_DIR.as_posix())
LOGS_DIR  = Path(TFT_LOCAL_ROOT) / "logs";  ensure_dir_local(LOGS_DIR.as_posix())


print("[config] ready")

# 导入数据，添加全局time_idx

In [None]:
data_paths = fs.glob("az://jackson/js_exp/exp/v1/panel_shards/*.parquet")
data_paths =[f"az://{p}" for p in data_paths]

lf_data = (
    pl.scan_parquet(data_paths, storage_options=storage_options)
    .select([*cfg['keys'], WEIGHT_COL, TARGET_COL, *TIME_FEATURES, *COV_FEATURES])
    .filter(pl.col(G_DATE).is_between(start_date, end_date, closed="both"))
)
lf_data = lf_data.sort([G_SYM, G_DATE, G_TIME])


In [None]:
# 先归一化time_pos
lf_data = lf_data.with_columns(
    (pl.col("time_pos") / pl.lit(968)).cast(pl.Float32).alias("time_pos")
)

In [None]:
grid_df = (
    lf_data.select([G_DATE, G_TIME]).unique()
        .sort([G_DATE, G_TIME])
        .with_row_index("time_idx")
        .with_columns(pl.col("time_idx").cast(pl.Int64))
        .collect(streaming=True)
)

container_prefix = "az://jackson/js_exp/exp/v1/tft/panel_clean_shards"; ensure_dir_az(container_prefix)
chunk_size = 30
for lo in range(start_date, end_date + 1, chunk_size):
    hi = min(lo + chunk_size - 1, end_date)
    print(f"processing date range: {lo} ~ {hi}")
    
    lf_chunk = lf_data.filter(pl.col(G_DATE).is_between(lo, hi, closed="both"))
    
    lf_grid_chunk = (
        grid_df.lazy().filter(pl.col(G_DATE).is_between(lo, hi, closed="both"))
    )
    
    lf_joined = (
        lf_chunk.join(lf_grid_chunk, on=[G_DATE, G_TIME], how="left").sort([G_SYM, "time_idx"])
    )
    
    out_path = f"{container_prefix}/panel_clean_{lo:04d}_{hi:04d}.parquet"
    print(f"writing to: {out_path}")
    
    lf_joined.sink_parquet(
        out_path,
        storage_options=storage_options,
        compression="zstd",
    )
print(f"[{_now()}] all done")

# 重新导入含有time_idx的数据

In [None]:
container_prefix = "az://jackson/js_exp/exp/v1/tft/panel_clean_shards"
data_paths = fs.glob(f"{container_prefix}/*.parquet")
data_paths = [f"az://{p}" for p in data_paths] 
lf_with_idx = pl.scan_parquet(data_paths, storage_options=storage_options).filter(pl.col(G_DATE).is_between(start_date, end_date, closed="both")).sort([G_SYM, "time_idx"])

# CV 划分

In [None]:
# ==========  CV 划分 ==========
all_days = (
    lf_with_idx.select(pl.col(G_DATE)).unique().sort([G_DATE])
    .collect(streaming=True)[G_DATE].to_numpy()
)
folds_by_day = make_sliding_cv_by_days(all_days, n_splits=N_SPLITS, gap_days=GAP_DAYS, train_to_val=TRAIN_TO_VAL)

print(f"[cv] total {len(folds_by_day)} folds")

assert len(folds_by_day) > 0, "no CV folds constructed"
print(folds_by_day)

# 接下来均已第一折为例

# 数据处理

In [None]:


fold_id = 0
# 取第一个 fold 的训练集最后一天，作为本折统计 z-score 的上界
train_lo, train_hi = folds_by_day[fold_id][0][0], folds_by_day[fold_id][0][-1]
val_lo, val_hi = folds_by_day[fold_id][1][0], folds_by_day[fold_id][1][-1]

print(f"统计标准化使用训练集的日期范围 = {train_lo} ~ {train_hi}")
print(f"验证集日期范围 = {val_lo} ~ {val_hi}")


## 填补因为特征工程产生的缺失列

In [None]:
lf_with_idx.select(pl.len()).collect()

In [None]:
lf_with_idx.select(
    pl.all().null_count()
).collect().to_pandas().T.sort_values(by=0, ascending=False)

In [None]:
# 计算训练集的各特征量的中位数
lf_tr = lf_with_idx.filter(pl.col(G_DATE).is_between(train_lo, train_hi, closed="both"))

grp_median = (
    lf_tr
    .group_by([G_SYM, "time_pos"])
    .agg([
        pl.col(col).median().alias(f"{col}_median_st") for col in COV_FEATURES
    ])
    .sort([G_SYM, "time_pos"])
)

glb_median = (
    lf_tr
    .group_by("time_pos")
    .agg([
        pl.col(col).median().alias(f"{col}_median_t") for col in COV_FEATURES
    ])
    .sort("time_pos")
)


# 应用于本折全部数据 trian + val
lf_all = lf_with_idx.join(grp_median, on=[G_SYM, "time_pos"], how="left").join(glb_median, on=["time_pos"], how="left").sort([G_SYM, "time_idx"])

# 逐列用中位数替换缺失值
fill_exprs = []

for col in COV_FEATURES:
    fill_exprs.append(
        pl.coalesce([
            pl.col(col),
            pl.col(f"{col}_median_st"),
            pl.col(f"{col}_median_t")
        ]).alias(col)
    )

lf_all_imputed = lf_all.with_columns(fill_exprs)

# 去掉中位数列
drop_cols = [f"{col}_median_st" for col in COV_FEATURES]
drop_cols += [f"{col}_median_t" for col in COV_FEATURES]
lf_all_imputed = lf_all_imputed.drop(drop_cols)



In [None]:
df_null_lf_all_imputed = lf_all_imputed.select(
    pl.all().null_count()
).collect().to_pandas().T.sort_values(by=0, ascending=False)

assert df_null_lf_all_imputed.iloc[0,0] == 0, "仍有缺失值未填补完成！"

## 标准化

In [None]:
# ========== Z-score ==========
# 计算 训练集 stats
lf_tr_imputed = lf_all_imputed.filter(pl.col(G_DATE).is_between(train_lo, train_hi, closed="both"))

grp_stats = (
    lf_tr_imputed
    .group_by(G_SYM)
    .agg([pl.col(c).mean().alias(f"mu_grp_{c}") for c in do_z_co_cols] +
        [pl.col(c).std().alias(f"std_grp_{c}") for c in do_z_co_cols])
).collect(streaming=True)

glb_stats = (
    lf_tr_imputed
    .select([pl.col(c).mean().alias(f"mu_glb_{c}") for c in do_z_co_cols] +
            [pl.col(c).std().alias(f"std_glb_{c}") for c in do_z_co_cols])
).collect(streaming=True)


In [None]:
import math

glb_row = glb_stats.to_dicts()[0]

# 逐日标准化本折所有数据并保存
fold_root = f"az://jackson/js_exp/exp/v1/tft/fold_{fold_id}"; ensure_dir_az(fold_root)
z_prefix = f"{fold_root}/z_shards"; ensure_dir_az(z_prefix)

z_done_cols = [f"z_{c}" for c in (do_z_co_cols)]
min_std = 1e-5
eps = 1e-8

# 预处理全局均值/方差的兜底，避免在行级判断
glb_mu = {c: glb_row[f"mu_glb_{c}"] for c in do_z_co_cols}
glb_std = {}
for c in do_z_co_cols:
    s = glb_row[f"std_glb_{c}"]
    if s is None or s <= 0:
        s = min_std
    glb_std[c] = s

for d in range(train_lo, val_hi + 1):
    lf_day = lf_all_imputed.filter(pl.col(G_DATE) == d)

    lf_day_z = lf_day.join(grp_stats.lazy(), on=G_SYM, how="left").sort([G_SYM, "time_idx"])

    # 1) 先把所有“常量”变成具名列，避免 anonymous literal
    const_exprs = []
    for c in do_z_co_cols:
        const_exprs += [
            pl.lit(glb_mu[c]).alias(f"__mu_glb_{c}"),
            pl.lit(glb_std[c] if glb_std[c] > 0 else min_std).alias(f"__std_glb_{c}"),
        ]
    lf_day_z = lf_day_z.with_columns(const_exprs)

    # 2) 计算 z，并用 clip 裁剪
    exprs = []
    for c in do_z_co_cols:
        mu_grp  = pl.col(f"mu_grp_{c}")
        std_grp = pl.col(f"std_grp_{c}")
        mu_use = pl.coalesce([mu_grp, pl.col(f"__mu_glb_{c}")]).cast(pl.Float32)

        std_tmp = pl.when(std_grp.is_null() | (std_grp <= 0))\
                    .then(pl.col(f"__std_glb_{c}"))\
                    .otherwise(std_grp)
        std_use = pl.when(std_tmp < min_std).then(min_std).otherwise(std_tmp).cast(pl.Float32)

        z = ((pl.col(c).cast(pl.Float32) - mu_use) / (std_use + eps)).clip(-10.0, 10.0).alias(f"z_{c}")
        exprs.append(z)

    lf_day_z = lf_day_z.with_columns(exprs)
    
    keep = [
        "time_idx", G_SYM, G_DATE, G_TIME, WEIGHT_COL, TARGET_COL
    ] + TIME_FEATURES + STATIC_FEATURES + CS_RANK_FEATURES + CS_R_Z_FEATURES + z_done_cols
        

    # 去重但保持顺序（避免 DuplicateError）
    keep = list(dict.fromkeys(keep))
    
    lf_out = lf_day_z.select(keep).sort([G_SYM, "time_idx"])

    out_path = f"{z_prefix}/z_{d:04d}.parquet"
    lf_out.collect(streaming=True).write_parquet(
        out_path, storage_options=storage_options, compression="zstd"
    )
    print(f"wrote z-scored data for day {d} to {out_path}")

# 模型训练

In [None]:
lp.seed_everything(42) 

KNOW_CATEGORICALS = ["time_bucket"]

STATIC_REALS = STATIC_FEATURES

KNOWN_REALS = TIME_FEATURES + z_done_cols + CS_RANK_FEATURES + CS_R_Z_FEATURES

TRAIN_COLS = [G_SYM, G_DATE, G_TIME, "time_idx", WEIGHT_COL, TARGET_COL] + KNOWN_REALS + STATIC_REALS


# 导入 本折z-score 后的数据进行后续处理
data_paths = fs.glob(f"{z_prefix}/*.parquet")
data_paths = [f"az://{p}" for p in data_paths]
lf = pl.scan_parquet(data_paths, storage_options=storage_options).select(TRAIN_COLS).sort([G_SYM, "time_idx"])

In [None]:
df = lf.collect(streaming=True).to_pandas()

In [None]:
del df; gc.collect()

In [None]:

best_ckpt_paths, fold_metrics = [], []




    