# 初始化

In [1]:
# 环境与依赖

# 基础包
import tempfile

import os, gc, glob, json, yaml, time
from pathlib import Path
import numpy as np, pandas as pd, polars as pl
import lightgbm as lgb
from dataclasses import dataclass
import pyarrow.parquet as pq
from typing import Sequence, Optional, Union, List, Tuple, Iterable, Mapping

# Azure & 文件系统
import fsspec
from getpass import getpass

print("Py versions:",
      "polars", pl.__version__,
      "lightgbm", lgb.__version__)

from dotenv import load_dotenv
load_dotenv()  # 默认会加载当前目录下的 .env 文件


Py versions: polars 1.21.0 lightgbm 4.6.0


True

In [2]:
# 连接云空间

ACC = os.getenv("AZURE_STORAGE_ACCOUNT_NAME")
KEY = os.getenv("AZURE_STORAGE_ACCOUNT_KEY")

if not ACC or not KEY:
    raise RuntimeError("Azure credentials not found. Please set them in .env")

storage_options = {"account_name": ACC, "account_key": KEY}
fs = fsspec.filesystem("az", **storage_options)

# 小函数：构造路径
cfg = yaml.safe_load(open("exp/v1/config/data.yaml"))
CONTAINER = cfg["blob"]["container"]
PREFIX    = cfg["blob"]["prefix"].strip("/")

def BLOB(*parts):      # 给 polars 用（需要 az://）
    return "az://" + "/".join([CONTAINER, PREFIX, *parts]).replace("//","/")
def NOPRO(*parts):     # 给 fs.glob/fs.exists 用（无协议）
    return "/".join([CONTAINER, PREFIX, *parts]).replace("//","/")
def LOC (*parts):      # 本地路径
    return str(Path(cfg["local"]["root"], *parts))

# 自检：看看 raw 是否存在
print("raw exists? ", fs.exists(NOPRO(cfg["paths"]["raw"])))

# 全局路径变量
RAW_SHA_DIR_NP = NOPRO(cfg["paths"]["raw_shards"])       # exp/v1/raw_shards
FE_SHA_DIR_B   = BLOB(cfg["paths"]["fe_shards"])         # az://.../exp/v1/fe_shards
FE_SHA_DIR_NP  = NOPRO(cfg["paths"]["fe_shards"])        # jackson/js_exp/exp/v1/fe_shards


fs.mkdirs(RAW_SHA_DIR_NP, exist_ok=True)
fs.mkdirs(FE_SHA_DIR_NP, exist_ok=True)  


raw exists?  True


In [3]:
#定义键列/范围/本地目录

KEYS   = cfg["keys"]
TARGET = cfg["target"]
WEIGHT = cfg["weight"]
DATE_LO, DATE_HI = cfg["date_lo"], cfg["date_hi"]
FEATURE_COLS = [f"feature_{i:02d}" for i in range(79)]

# 创建本地高速区目录
Path(LOC(cfg["paths"]["mm_local"])).mkdir(parents=True, exist_ok=True)
Path(LOC(cfg["paths"]["tmp_local"])).mkdir(parents=True, exist_ok=True)
Path(LOC(cfg["paths"]["cache"])).mkdir(parents=True, exist_ok=True)   

MM_DIR  = LOC(cfg["paths"]["mm_local"])                # /mnt/data/mm/js_exp

TMP_ROOT = LOC(cfg["paths"]["tmp_local"])               # /mnt/data/tmp/js_exp
BASE_DIR = BLOB(cfg['paths']['clean'])

In [4]:
# 数据读取与快速EDA
pattern = f"{BLOB(cfg['paths']['raw'])}/train.parquet/partition_id=*[4-8]/*.parquet"
lf = pl.scan_parquet(
    pattern,
    storage_options={"account_name": ACC, "account_key": KEY},
)


test_pattern  = f"{BLOB(cfg['paths']['raw'])}/train.parquet/partition_id=*[8-9]/*.parquet" # 这里partition 8 仅用于测试部分的数据处理和历史特征工程

test_data = pl.scan_parquet(
    test_pattern,
    storage_options={"account_name": ACC, "account_key": KEY},
)

lf.limit().collect()

date_id,time_id,symbol_id,weight,feature_00,feature_01,feature_02,feature_03,feature_04,feature_05,feature_06,feature_07,feature_08,feature_09,feature_10,feature_11,feature_12,feature_13,feature_14,feature_15,feature_16,feature_17,feature_18,feature_19,feature_20,feature_21,feature_22,feature_23,feature_24,feature_25,feature_26,feature_27,feature_28,feature_29,feature_30,feature_31,feature_32,…,feature_51,feature_52,feature_53,feature_54,feature_55,feature_56,feature_57,feature_58,feature_59,feature_60,feature_61,feature_62,feature_63,feature_64,feature_65,feature_66,feature_67,feature_68,feature_69,feature_70,feature_71,feature_72,feature_73,feature_74,feature_75,feature_76,feature_77,feature_78,responder_0,responder_1,responder_2,responder_3,responder_4,responder_5,responder_6,responder_7,responder_8
i16,i16,i8,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,i8,i8,i16,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,…,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32
680,0,0,2.29816,0.851814,1.197591,0.219422,0.411698,2.057359,-0.542597,-3.4331,-1.090165,0.151888,11,7,76,-0.97142,0.670215,-0.502896,,-0.070096,,-1.308236,-2.120128,1.747068,-0.219661,0.791602,0.114922,-0.311672,0.156548,1.476346,1.301341,1.235173,-0.259882,-0.30125,-0.264237,,…,0.322135,,,-2.273168,,-1.627507,1.155271,,-3.847984,-2.230864,-0.171579,-0.273791,-0.301876,-0.432933,-1.215778,-1.670469,-0.637963,0.803874,-0.21269,-0.764702,0.435278,-0.619145,,,-0.021034,-0.045094,-0.178144,-0.1951,-0.304665,0.164485,-0.205231,0.191064,-1.413209,0.375675,0.929775,-1.574939,1.101371
680,0,1,3.928745,0.534441,1.07974,0.038748,0.275343,2.135057,-0.541966,-2.774344,-1.048089,0.163768,11,7,76,-0.873847,1.794426,-0.226819,,-0.328627,,-0.870575,-1.204292,0.935869,-0.064365,1.273109,0.752816,-0.062281,0.687036,0.546603,1.229373,1.535235,-0.483758,-0.431013,-0.058963,,…,0.505224,,,-1.074216,,-1.909605,1.566016,,-1.044804,-0.807186,-0.171579,-0.466983,-0.250124,-0.342297,-1.896279,-2.157645,-0.698755,1.743311,-0.069648,-0.933445,1.653637,-0.348816,,,-0.154413,-0.301091,-0.266495,-0.470271,0.089769,0.011395,0.092348,0.473781,0.397024,0.777026,0.826995,0.569681,1.986971
680,0,2,1.340433,-0.227643,0.764146,-0.243349,0.247027,2.347248,-0.478477,-2.660244,-1.261613,0.234425,81,2,59,-0.952889,-0.04806,-0.791763,,-0.140953,,-1.691419,-2.242023,-0.459649,-0.241993,-0.47108,-1.056603,-0.387841,-0.408962,0.359514,0.707161,0.669771,-0.641928,-0.645097,-0.362472,,…,2.115924,,,0.095438,,-0.664823,2.016616,,2.085358,0.690842,-0.171579,0.022096,-0.018111,-0.015769,-2.270972,-1.826189,-0.908704,-0.051331,-0.658539,-1.011692,-0.008993,-0.363811,,,1.677642,1.705228,0.198109,0.152837,0.218281,0.060373,-0.164715,-0.132612,0.543831,-0.123519,-0.296969,0.547286,-0.049303
680,0,3,1.695526,0.267686,1.193612,-0.388798,0.030673,2.175273,-0.408371,-1.859344,-0.771972,0.104885,4,3,11,-1.005184,0.546772,-0.587481,,-0.628245,,-1.243775,-2.907238,-0.098398,0.004987,-0.320296,-1.193256,1.260476,1.621998,-0.255902,-1.644356,-1.292619,-0.541355,-0.703121,0.00681,,…,-0.442253,,,-1.021894,,-3.088617,1.440495,,0.19264,0.270458,-0.171579,-0.387987,-0.30065,-0.169365,-1.324387,-2.20877,-0.975838,0.605086,-0.303872,-1.071495,0.445026,-0.465118,,,3.820025,4.335268,9.818627,11.179185,-0.012298,1.047678,-0.696032,0.960062,2.32889,0.718955,2.047506,3.691308,3.031337
680,0,5,2.700766,0.952372,0.861269,-0.375405,0.259099,2.497325,-0.618828,-2.754378,-0.479992,0.108627,2,10,171,-0.892601,0.207765,-0.420839,,-0.522783,,-1.711776,-2.078692,0.163208,0.023584,0.598875,-0.497226,0.231989,1.630302,-1.233744,-0.434498,0.37159,-0.586595,-0.512672,0.017691,,…,1.003178,,,-0.92965,,-1.621122,1.520507,,-0.532552,-0.482004,-0.171579,-0.164008,-0.194068,-0.257875,-1.623473,-1.959616,-0.748592,0.022282,-0.561959,-0.995984,0.385183,-0.312552,,,-0.552556,-0.460608,-0.700707,-0.770234,-0.229585,-0.240741,-0.887929,-0.061485,0.691569,1.016049,0.103898,0.814866,2.07328


In [5]:
TEST_DATA_START_DATE_ID = 1429 # 真正的测试集开始于1530， 这里多取100天用于历史特征计算肯定足够了
lf_test = test_data.filter(pl.col("date_id") >= TEST_DATA_START_DATE_ID)

# EDA

In [None]:
# EDA：每个交易日的symbol覆盖情况,是否覆盖全程
lf_s_d = lf.select(['date_id', 'symbol_id'])

per_date = (
    lf_s_d.group_by("date_id")
      .agg(pl.col("symbol_id").n_unique().alias("n_symbols"))
      .sort("date_id")
)

max_n = per_date.select(pl.max("n_symbols")).collect().item()
summary = per_date.with_columns([
    pl.lit(max_n).alias("max_n"),
    (pl.col("n_symbols") == max_n).alias("is_full_universe")
])

dates_missing = summary.filter(pl.col("is_full_universe") == False).select("date_id")
# summary.collect(); dates_missing.collect()

dates_missing.collect()

# 数据预处理

In [None]:
def rolling_sigma_clip_polars_plan(
    source: Union[pl.LazyFrame, List[str]],
    features: Sequence[str],
    group_cols: Sequence[str] = ("symbol_id",),
    sort_cols: Sequence[str] = ("symbol_id","date_id","time_id"),
    window: int = 7,
    k: float = 3.0,
    ddof: int = 0,
    min_valid: int = 2,
    batch: int = 10,
    cast_float32: bool = True
) -> pl.LazyFrame:
    """返回执行滚动 kσ 裁剪的 LazyFrame（不写盘）。"""
    # 0) 输入
    lf = source if isinstance(source, pl.LazyFrame) else pl.scan_parquet(source)

    # 1) 仅选必要列
    need_cols = list(dict.fromkeys([*group_cols, *sort_cols, *features]))
    schema = lf.collect_schema()
    missing = [c for c in need_cols if c not in schema]
    if missing:
        raise KeyError(f"Missing columns: {missing}")
    lf = lf.select(need_cols)


    # 3) 分批构建裁剪表达式
    for i in range(0, len(features), batch):
        part = features[i:i+batch]
        exprs = []
        for c in part:
            cnt = (
                pl.col(c).is_not_null().cast(pl.Int32)
                        .rolling_sum(window_size=window, min_samples=1)
                        .over(group_cols)
            )
            mu  = pl.col(c).rolling_mean(window_size=window, min_samples=1).over(group_cols)
            sd  = (pl.col(c).rolling_std(window_size=window, ddof=ddof, min_samples=1)
                        .over(group_cols).fill_null(0.0))
            lo, hi = (mu - k*sd), (mu + k*sd)

            base = pl.col(c).cast(pl.Float32) if cast_float32 else pl.col(c)
            if cast_float32:
                lo, hi = lo.cast(pl.Float32), hi.cast(pl.Float32)

            exprs.append(
                pl.when(cnt >= min_valid).then(base.clip(lo, hi)).otherwise(base).alias(c)
            )
        lf = lf.with_columns(exprs)

    return lf

In [None]:
# 1. Clip - train
OUT_DIR = LOC(cfg["paths"]["cache"])
FEATURES_ALL = [f"feature_{i:02d}" for i in range(79)]
RESP_ALL = [f"responder_{i}" for i in range(9)]


lf_clip = rolling_sigma_clip_polars_plan(
    source=lf,
    features=FEATURES_ALL,
    group_cols=("symbol_id",),
    sort_cols=("symbol_id","date_id","time_id"),
    window=40, k=3.0, ddof=0, min_valid=8,
    batch=5, cast_float32=True
)

lf_clip = lf_clip.sort(KEYS)

lf_clip.collect(streaming=True).write_parquet(f"{OUT_DIR}/clip.parquet", compression="zstd")

In [None]:
# Clip - test
# 1. Clip - train
OUT_DIR = LOC(cfg["paths"]["cache"])
FEATURES_ALL = [f"feature_{i:02d}" for i in range(79)]
RESP_ALL = [f"responder_{i}" for i in range(9)]


lf_test_clip = rolling_sigma_clip_polars_plan(
    source=lf_test,
    features=FEATURES_ALL,
    group_cols=("symbol_id",),
    sort_cols=("symbol_id","date_id","time_id"),
    window=40, k=3.0, ddof=0, min_valid=8,
    batch=5, cast_float32=True
)

lf_test_clip = lf_test_clip.sort(KEYS)

lf_test_clip.collect(streaming=True).write_parquet(f"{OUT_DIR}/test_clip.parquet", compression="zstd")

In [None]:
def causal_impute_polars_no_na(
    lf: pl.LazyFrame,
    features: Optional[Sequence[str]] = None,            # 默认为 schema 中的 feature_*
    keys: Tuple[str,str,str] = ("symbol_id","date_id","time_id"),
    open_ticks: Tuple[int,int] = (0, 5),                 # 开盘窗口 [t0, t1]
    carry_days: int = 5,                                  # 开盘跨日承接 TTL(天)
    intraday_limit: Optional[int] = None,                 # 日内 ffill 步数上限(None=不限)
    crossday_same_tick_limit: Optional[int] = 2,          # 同一 time_id 跨日回退(+TTL)
    train_end: Optional[int] = None,                      # 训练期截止 date_id（None=全局）
    fallback: Optional[str] = "mean",                     # 'median'/'mean'/None（仅开盘仍缺时用）
    ensure_sorted: bool = False,                          # 如上游未保证顺序，可在函数内排序
    global_means: Optional[Mapping[str, float]] = None,   # 可选：外部注入全局均值
    int_features: Optional[Sequence[str]] = None          # 可选：最终回铸为 Int 的特征名
) -> pl.LazyFrame:
    """
    因果填补（不写 *_na 列）：
      1) 开盘窗口内：跨日 ffill(+TTL)，仍缺用训练期统计兜底（均值/中位数）
      2) 日内：按 (symbol,date) ffill（可限步数）
      3) 可选：同一 time_id 跨日回退(+TTL)
      4) 兜底：per-symbol 均值 → 全局均值
    输出列固定为 [keys + features]，类型统一为 Float32（可通过 int_features 回铸整数列）。
    """
    gcol, dcol, tcol = keys
    f32 = pl.Float32

    # 0) 选列 + 类型统一 +（可选）排序
    if features is None:
        schema = lf.collect_schema()
        features = sorted([c for c in schema if c.startswith("feature_")])
    need_cols = list(dict.fromkeys([gcol, dcol, tcol, *features]))
    lf = lf.select(need_cols).with_columns([pl.col(c).cast(f32) for c in features])
    if ensure_sorted:
        lf = lf.sort([gcol, dcol, tcol])

    # 训练期源（用于统计）
    SRC_train = lf if train_end is None else lf.filter(pl.col(dcol) <= train_end)

    # 1) 开盘子集
    t0, t1 = open_ticks
    lf_open = (
        lf.select([gcol, dcol, tcol, *features])
        .filter(pl.col(tcol).is_between(t0, t1))
    )

    # 2) 开盘：跨日承接 + TTL
    open_exprs = []
    for c in features:
        last_non_null_date = (
            pl.when(pl.col(c).is_not_null()).then(pl.col(dcol)).otherwise(None)
            .forward_fill().over(gcol)
        )
        cand = pl.col(c).forward_fill().over(gcol)
        gap  = (pl.col(dcol) - last_non_null_date).cast(pl.Int32)
        open_exprs.append(
            pl.when(pl.col(c).is_null() & cand.is_not_null()
                    & last_non_null_date.is_not_null() & (gap <= carry_days))
            .then(cand).otherwise(pl.col(c)).alias(c)
        )
    lf_open_ff = lf_open.with_columns(open_exprs)

    # 3) 开盘仍缺 → 训练期统计兜底（聚合产物统一 Float32）
    if fallback in ("median", "mean"):
        aggs = [
            (pl.col(c).median().cast(f32).alias(c) if fallback == "median"
            else pl.col(c).mean().cast(f32).alias(c))
            for c in features
        ]
        stats = SRC_train.select(aggs).collect(streaming=True).to_dicts()[0]
        lf_open_filled = lf_open_ff.with_columns([
            pl.when(pl.col(c).is_null()).then(pl.lit(stats[c]).cast(f32))
            .otherwise(pl.col(c)).alias(c)
            for c in features
        ])
    else:
        lf_open_filled = lf_open_ff

    # 4) 将开盘结果写回全量
    suf = "__open"
    lf_after_open = (
        lf.join(lf_open_filled.select([gcol, dcol, tcol, *features]),
                on=[gcol, dcol, tcol], how="left", suffix=suf)
        .with_columns([
            pl.coalesce([pl.col(f"{c}{suf}"), pl.col(c)]).cast(f32).alias(c)
            for c in features
        ])
        .drop([f"{c}{suf}" for c in features])
    )

    # 5) 日内 ffill
    if intraday_limit is None:
        lf_intra = lf_after_open.with_columns([
            pl.col(c).forward_fill().over([gcol, dcol]).alias(c)
            for c in features
        ])
    else:
        intra_exprs = []
        for c in features:
            cand = pl.col(c).forward_fill().over([gcol, dcol])
            last_non_null_time = (
                pl.when(pl.col(c).is_not_null()).then(pl.col(tcol)).otherwise(None)
                .forward_fill().over([gcol, dcol])
            )
            step_gap = (pl.col(tcol) - last_non_null_time).cast(pl.Int32)
            intra_exprs.append(
                pl.when(pl.col(c).is_null() & cand.is_not_null()
                        & last_non_null_time.is_not_null()
                        & (step_gap <= intraday_limit))
                .then(cand).otherwise(pl.col(c)).alias(c)
            )
        lf_intra = lf_after_open.with_columns(intra_exprs)

    # 6) 同一 time_id 跨日回退 + TTL
    lf_after_same = lf_intra
    if crossday_same_tick_limit is not None:
        exprs = []
        for c in features:
            last_non_null_date_same = (
                pl.when(pl.col(c).is_not_null()).then(pl.col(dcol)).otherwise(None)
                .forward_fill().over([gcol, tcol])
            )
            cand_same = pl.col(c).forward_fill().over([gcol, tcol])
            gap2 = (pl.col(dcol) - last_non_null_date_same).cast(pl.Int32)
            exprs.append(
                pl.when(pl.col(c).is_null() & cand_same.is_not_null()
                        & last_non_null_date_same.is_not_null()
                        & (gap2 <= crossday_same_tick_limit))
                .then(cand_same).otherwise(pl.col(c)).alias(c)
            )
        lf_after_same = lf_intra.with_columns(exprs)

    # 7) 兜底①：per-symbol 均值（Float32）
    sym_stats = (
        SRC_train.select([gcol] + list(features))
                .group_by(gcol)
                .agg([pl.col(c).mean().cast(f32).alias(c) for c in features])
    )
    lf_lvl1 = (
        lf_after_same.join(sym_stats, on=[gcol], how="left", suffix="__sym")
                    .with_columns([
                        pl.coalesce([pl.col(c), pl.col(f"{c}__sym")]).cast(f32).alias(c)
                        for c in features
                    ])
                    .drop([f"{c}__sym" for c in features])
    )

    # 8) 兜底②：全局均值（无条件 coalesce；支持外部注入）
    if global_means is None:
        glob_stats = (
            SRC_train.select([pl.col(c).mean().cast(f32).alias(c) for c in features])
                    .collect(streaming=True).to_dicts()[0]
        )
    else:
        glob_stats = {c: float(global_means[c]) for c in features if c in global_means}
    lf_final = lf_lvl1.with_columns([
        pl.coalesce([pl.col(c), pl.lit(glob_stats[c]).cast(f32)]).alias(c)
        for c in features
    ])

    # 9) 固定输出列顺序；可选把部分列回铸为整型
    lf_final = lf_final.select([gcol, dcol, tcol, *features])
    if int_features:
        lf_final = lf_final.with_columns([
            pl.col(c).round().cast(pl.Int32).alias(c) for c in int_features
        ])

    return lf_final

In [None]:
# Train

# I/O
SRC_DIR = LOC(cfg["paths"]["cache"])
SRC = f"{SRC_DIR}/clip.parquet"

OUT_DIR = LOC(cfg["paths"]["cache"])
OUT = f"{OUT_DIR}/impute.parquet"
KEYS = ["symbol_id", "date_id", "time_id"]
BATCH = 5                      # 内存吃紧可降到 2~3
ENSURE_SORTED = False          # 如需保证批内顺序，可在 lf_b 后加 .sort(KEYS)
EXPECTED = [*KEYS, *FEATURES_ALL]


In [None]:


# 2 Impute

# 0) 全局均值（Float32 + 兜底）
GLOBAL_MEAN = (
    pl.scan_parquet(str(SRC))
      .select([pl.col(c).mean().fill_null(0.0).cast(pl.Float32).alias(c)
               for c in FEATURES_ALL])
      .collect(streaming=True)
      .to_dicts()[0]
)

# 1) symbol 列表
symbols = (
    pl.scan_parquet(str(SRC))
      .select(pl.col("symbol_id").unique())
      .collect(streaming=True)["symbol_id"]
      .to_list()
)

writer = None
try:
    for i in range(0, len(symbols), BATCH):
        ids = symbols[i:i+BATCH]

        # 批数据（列裁剪 + 固定顺序）
        lf_b = (
            pl.scan_parquet(str(SRC))
              .filter(pl.col("symbol_id").is_in(ids))
              .select(EXPECTED)
        )
        # 如需保证批内顺序（Stage1 未排序时），打开下一行：
        # if ENSURE_SORTED:
        #     lf_b = lf_b.sort(KEYS)

        # 因果填补（延用你的函数）
        lf_imp_b = causal_impute_polars_no_na(
            lf=lf_b,
            features=FEATURES_ALL,
            keys=tuple(KEYS),
            open_ticks=(0, 10),
            carry_days=5,
            intraday_limit= 50,
            crossday_same_tick_limit=5,
            train_end=None,
            fallback="mean",
            ensure_sorted=ENSURE_SORTED,
            global_means=GLOBAL_MEAN,
            int_features=None
        )

        # 再次固定列顺序与 dtype
        lf_imp_b = lf_imp_b.select(EXPECTED).with_columns(
            [pl.col(c).cast(pl.Float32) for c in FEATURES_ALL]
        )

        # 物化当前批
        df_b = lf_imp_b.collect(streaming=True)
        tbl  = df_b.to_arrow()

        if writer is None:
            writer = pq.ParquetWriter(
                str(OUT), tbl.schema,
                compression="zstd",
                use_dictionary=True,
                write_statistics=True
            )

        writer.write_table(tbl, row_group_size=300_000)  # 20~50万行更省内存
        del df_b, tbl
        gc.collect()
finally:
    if writer is not None:
        writer.close()


# 3. join responders → 单文件增量写 --------------------
# ==== 路径 ====
SRC_IMP_DIR = LOC(cfg["paths"]["cache"])                  # "jackson/js_exp/cache"
SRC_IMP = f"{SRC_IMP_DIR}/impute.parquet"      # 本地路径（Step2 输出）
CLEAN_DIR_NP = NOPRO(cfg["paths"]["clean"])                  # "jackson/js_exp/clean"
OUT_NP = f"{CLEAN_DIR_NP}/final_clean.parquet"               # 无协议路径（给 fs/arrow 用


# 右表：lf.responders

# ==== 参数 ====
EXPECTED = [*KEYS, WEIGHT, *FEATURES_ALL, *RESP_ALL]
BATCH = 2  # 每批若干 symbol

# ==== 准备输出目录 & 如果已存在旧文件，可选择删除 ====
fs.mkdirs(CLEAN_DIR_NP, exist_ok=True)
if fs.exists(OUT_NP):
    fs.rm(OUT_NP)   # 或者注释掉：有则报错，避免误覆盖

# ==== 键类型对齐 ====
imp_schema = pl.scan_parquet(str(SRC_IMP)).collect_schema()
key_casts = {k: imp_schema[k] for k in KEYS}

# 全量 symbol 列表
symbols = (pl.scan_parquet(str(SRC_IMP))
             .select(pl.col("symbol_id").unique())
             .collect(streaming=True)["symbol_id"].to_list())


# ==== 打开 Blob 文件句柄 + 分批写入行组 ====
writer = None
with fs.open(OUT_NP, "wb") as fout:
    try:
        for i in range(0, len(symbols), BATCH):
            ids = symbols[i:i+BATCH]

            # 左表：impute 后的特征（本地）
            lf_left = (pl.scan_parquet(str(SRC_IMP))
                         .filter(pl.col("symbol_id").is_in(ids))
                         .select([*KEYS, *FEATURES_ALL]))

            # 右表：responders（Blob，记得 storage_options）
            lf_right = (
                lf.filter(pl.col("symbol_id").is_in(ids))
                .select([*KEYS, WEIGHT, *RESP_ALL])
                .with_columns(
                    *[pl.col(k).cast(key_casts[k]) for k in KEYS],
                    pl.col(WEIGHT).cast(pl.Float32)
                )
            )

            # join + 固定列顺序 + 强制类型（避免某批全空变成 NullType）
            lf_joined = lf_left.join(lf_right, on=KEYS, how="left").select(EXPECTED)
            df_b = (
                lf_joined
                .with_columns(
                    pl.col(WEIGHT).cast(pl.Float32),
                    *[pl.col(c).cast(pl.Float32) for c in FEATURES_ALL + RESP_ALL]
                )
                .collect(streaming=True)
            )


            tbl = df_b.to_arrow()
            if writer is None:
                writer = pq.ParquetWriter(
                    fout, tbl.schema,
                    compression="zstd",
                    use_dictionary=True,
                    write_statistics=True,
                )
            writer.write_table(tbl, row_group_size=300_000)
            del df_b, tbl; gc.collect()
    finally:
        if writer is not None:
            writer.close()

print("✅ 写到 Blob：", "az://" + OUT_NP)  
    

In [None]:
# Test

# I/O
TEST_SRC_DIR = LOC(cfg["paths"]["cache"])
TEST_SRC = f"{TEST_SRC_DIR}/test_clip.parquet"

TEST_OUT_DIR = LOC(cfg["paths"]["cache"])
TEST_OUT = f"{TEST_OUT_DIR}/test_impute.parquet"
KEYS = ["symbol_id", "date_id", "time_id"]
BATCH = 5                      # 内存吃紧可降到 2~3
ENSURE_SORTED = False          # 如需保证批内顺序，可在 lf_b 后加 .sort(KEYS)
EXPECTED = [*KEYS, *FEATURES_ALL]


# 2 Impute

# 0) 全局均值（Float32 + 兜底）
GLOBAL_MEAN = (
    pl.scan_parquet(str(SRC)) # 这里仍然用训练集的均值
      .select([pl.col(c).mean().fill_null(0.0).cast(pl.Float32).alias(c)
               for c in FEATURES_ALL])
      .collect(streaming=True)
      .to_dicts()[0]
)

# 1) symbol 列表
symbols = (
    pl.scan_parquet(str(TEST_SRC))
      .select(pl.col("symbol_id").unique())
      .collect(streaming=True)["symbol_id"]
      .to_list()
)

writer = None
try:
    for i in range(0, len(symbols), BATCH):
        ids = symbols[i:i+BATCH]

        # 批数据（列裁剪 + 固定顺序）
        lf_b = (
            pl.scan_parquet(str(TEST_SRC))
              .filter(pl.col("symbol_id").is_in(ids))
              .select(EXPECTED)
        )
        # 如需保证批内顺序（Stage1 未排序时），打开下一行：
        # if ENSURE_SORTED:
        #     lf_b = lf_b.sort(KEYS)

        # 因果填补（延用你的函数）
        lf_imp_b = causal_impute_polars_no_na(
            lf=lf_b,
            features=FEATURES_ALL,
            keys=tuple(KEYS),
            open_ticks=(0, 10),
            carry_days=5,
            intraday_limit= 50,
            crossday_same_tick_limit=5,
            train_end=None,
            fallback="mean",
            ensure_sorted=ENSURE_SORTED,
            global_means=GLOBAL_MEAN, # 这里仍然用训练集的均值
            int_features=None
        )

        # 再次固定列顺序与 dtype
        lf_imp_b = lf_imp_b.select(EXPECTED).with_columns(
            [pl.col(c).cast(pl.Float32) for c in FEATURES_ALL]
        )

        # 物化当前批
        df_b = lf_imp_b.collect(streaming=True)
        tbl  = df_b.to_arrow()

        if writer is None:
            writer = pq.ParquetWriter(
                str(TEST_OUT), tbl.schema,
                compression="zstd",
                use_dictionary=True,
                write_statistics=True
            )

        writer.write_table(tbl, row_group_size=300_000)  # 20~50万行更省内存
        del df_b, tbl
        gc.collect()
finally:
    if writer is not None:
        writer.close()


# 3. join responders → 单文件增量写 --------------------
# ==== 路径 ====
TEST_SRC_IMP_DIR = LOC(cfg["paths"]["cache"])                  # "jackson/js_exp/cache"
TEST_SRC_IMP = f"{TEST_SRC_IMP_DIR}/test_impute.parquet"      # 本地路径（Step2 输出）
TEST_CLEAN_DIR_NP = NOPRO(cfg["paths"]["clean"])                  # "jackson/js_exp/clean"
TEST_OUT_NP = f"{TEST_CLEAN_DIR_NP}/test_final_clean.parquet"               # 无协议路径（给 fs/arrow 用


# 右表：lf.responders

# ==== 参数 ====
EXPECTED = [*KEYS, WEIGHT, *FEATURES_ALL, *RESP_ALL]
BATCH = 2  # 每批若干 symbol

# ==== 准备输出目录 & 如果已存在旧文件，可选择删除 ====
fs.mkdirs(TEST_CLEAN_DIR_NP, exist_ok=True)
if fs.exists(TEST_OUT_NP):
    fs.rm(TEST_OUT_NP)   # 或者注释掉：有则报错，避免误覆盖

# ==== 键类型对齐 ====
imp_schema = pl.scan_parquet(str(TEST_SRC_IMP)).collect_schema()
key_casts = {k: imp_schema[k] for k in KEYS}

# 全量 symbol 列表
symbols = (pl.scan_parquet(str(TEST_SRC_IMP))
             .select(pl.col("symbol_id").unique())
             .collect(streaming=True)["symbol_id"].to_list())


# ==== 打开 Blob 文件句柄 + 分批写入行组 ====
writer = None
with fs.open(TEST_OUT_NP, "wb") as fout:
    try:
        for i in range(0, len(symbols), BATCH):
            ids = symbols[i:i+BATCH]

            # 左表：impute 后的特征（本地）
            lf_left = (pl.scan_parquet(str(TEST_SRC_IMP))
                         .filter(pl.col("symbol_id").is_in(ids))
                         .select([*KEYS, *FEATURES_ALL]))

            # 右表：responders（Blob，记得 storage_options）
            lf_right = (
                lf_test.filter(pl.col("symbol_id").is_in(ids))
                .select([*KEYS, WEIGHT, *RESP_ALL])
                .with_columns(
                    *[pl.col(k).cast(key_casts[k]) for k in KEYS],
                    pl.col(WEIGHT).cast(pl.Float32)
                )
            )

            # join + 固定列顺序 + 强制类型（避免某批全空变成 NullType）
            lf_joined = lf_left.join(lf_right, on=KEYS, how="left").select(EXPECTED)
            df_b = (
                lf_joined
                .with_columns(
                    pl.col(WEIGHT).cast(pl.Float32),
                    *[pl.col(c).cast(pl.Float32) for c in FEATURES_ALL + RESP_ALL]
                )
                .collect(streaming=True)
            )


            tbl = df_b.to_arrow()
            if writer is None:
                writer = pq.ParquetWriter(
                    fout, tbl.schema,
                    compression="zstd",
                    use_dictionary=True,
                    write_statistics=True,
                )
            writer.write_table(tbl, row_group_size=300_000)
            del df_b, tbl; gc.collect()
    finally:
        if writer is not None:
            writer.close()

print("✅ 写到 Blob：", "az://" + TEST_OUT_NP)  
    


# 特征工程函数

In [6]:
# 特征工程

# -----------------------------
# A. prev-day tails + daily summaries
# -----------------------------
import polars as pl
from typing import Sequence, Optional, Tuple

def _resolve_prev_for_daily(
    daily: pl.LazyFrame,
    *,
    keys: Tuple[str,str] = ("symbol_id","date_id"),
    cols: Sequence[str],
    prev_soft_days: Optional[int],
    cast_f32: bool = True,
) -> pl.LazyFrame:
    g_symbol, g_date = keys
    exprs = []
    for c in cols:
        prev_row_val = pl.col(c).shift(1).over(g_symbol)
        prev_row_day = pl.col(g_date).shift(1).over(g_symbol)

        # strict d-1
        prev_strict = pl.when(pl.col(g_date) == (prev_row_day + 1)).then(prev_row_val).otherwise(None)

        if prev_soft_days is None:
            resolved = prev_strict
        else:
            # last non-null day/value before d, computed per symbol
            last_non_null_day = (
                pl.when(pl.col(c).is_not_null()).then(pl.col(g_date))
                  .otherwise(None)
                  .forward_fill().over(g_symbol)
                  .shift(1)
            )
            last_non_null_val = (
                pl.col(c).forward_fill().over(g_symbol).shift(1)
            )
            gap = (pl.col(g_date) - last_non_null_day).cast(pl.Int32)

            resolved = pl.coalesce([
                prev_strict,
                pl.when((gap <= int(prev_soft_days)) & last_non_null_val.is_not_null())
                  .then(last_non_null_val),
            ])

        exprs.append((resolved.cast(pl.Float32) if cast_f32 else resolved).alias(c))

    return daily.with_columns(exprs)


def fe_prevday_tail_and_summaries(
    lf: pl.LazyFrame,
    *,
    rep_cols: Sequence[str],                           # responders，用于“昨日尾部”和“日级摘要”
    tail_lags: Sequence[int] = (1,),                   # 倒数第 L（会自动补 diffs 所需的 L+1）
    tail_diffs: Sequence[int] = (1,),                 # dK = lag1 - lag(K+1)
    rolling_windows: Sequence[int] | None = (3,5,20), # 对 prevday_close 做滚动
    keys: Tuple[str,str,str] = ("symbol_id","date_id","time_id"),
    assume_sorted: bool = True,
    cast_f32: bool = True,
    prev_soft_days: Optional[int] = None,             # 严格 d-1 + ≤K 天回退（None=只严格 d-1）
) -> pl.LazyFrame:
    """一次日频聚合得到昨日尾部与日级摘要 → 统一上一日解析 → 回拼到 tick 级。"""
    g_symbol, g_date, g_time = keys
    if not assume_sorted:
        lf = lf.sort(list(keys))

    # --- 一次性日频聚合 ---
    need_L = sorted(set(tail_lags) | {k+1 for k in tail_diffs} | {1})
    agg_exprs = []
    for r in rep_cols:
        # 尾部（倒数第 L）
        for L in need_L:
            agg_exprs.append(
                pl.col(r).sort_by(pl.col(g_time)).slice(-L, 1).first().alias(f"{r}_prev_tail_lag{L}")
            )
        # 当日统计
        agg_exprs += [
            pl.col(r).sort_by(pl.col(g_time)).last().alias(f"{r}_prevday_close"),
            pl.col(r).mean().alias(f"{r}_prevday_mean"),
            pl.col(r).std(ddof=0).alias(f"{r}_prevday_std"),
        ]

    daily = (lf.group_by([g_symbol, g_date]).agg(agg_exprs).sort([g_symbol, g_date]))

    # 派生 dK
    daily = daily.with_columns([
        (pl.col(f"{r}_prev_tail_lag1") - pl.col(f"{r}_prev_tail_lag{K+1}")).alias(f"{r}_prev_tail_d{K}")
        for r in rep_cols for K in tail_diffs
    ])

    # prev2day/overnight/rolling
    daily = daily.with_columns([
        pl.col(f"{r}_prevday_close").shift(1).over(g_symbol).alias(f"{r}_prev2day_close")
        for r in rep_cols
    ]).with_columns(
        [
            (pl.col(f"{r}_prevday_close") - pl.col(f"{r}_prevday_mean")).alias(f"{r}_prevday_close_minus_mean")
            for r in rep_cols
        ] + [
            (pl.col(f"{r}_prevday_close") - pl.col(f"{r}_prev2day_close")).alias(f"{r}_overnight_gap")
            for r in rep_cols
        ]
    )

    if rolling_windows:
        wins = sorted({int(w) for w in rolling_windows if int(w) > 1})
        roll_exprs = []
        for r in rep_cols:
            for w in wins:
                roll_exprs += [
                    pl.col(f"{r}_prevday_close").rolling_mean(window_size=w, min_samples=1).over(g_symbol).alias(f"{r}_close_roll{w}_mean"),
                    pl.col(f"{r}_prevday_close").rolling_std(window_size=w, ddof=0, min_samples=1).over(g_symbol).alias(f"{r}_close_roll{w}_std"),
                ]
        daily = daily.with_columns(roll_exprs)


    # === 关键：把所有“非键列”转为“对当日 d 生效的上一日值”（严格 d-1 + ≤K 回退） ===
    prev_cols = [c for c in daily.collect_schema().names() if c not in (g_symbol, g_date)]
    daily_prev = _resolve_prev_for_daily(
        daily, keys=(g_symbol, g_date), cols=prev_cols, prev_soft_days=prev_soft_days, cast_f32=cast_f32
    )

    # 回拼到 tick 级
    return lf.join(daily_prev, on=[g_symbol, g_date], how="left")


# -----------------------------
# B: same time_id cross-day
# -----------------------------
import polars as pl
import numpy as np
from typing import Sequence, Iterable, Optional, Tuple

def fe_same_timeid_crossday(
    lf: pl.LazyFrame,
    rep_cols: Sequence[str],
    ndays: int = 3,
    stats_rep_cols: Optional[Sequence[str]] = None,
    add_prev1_multirep: bool = True,
    batch_size: int = 5,
    keys: Tuple[str,str,str] = ("symbol_id","date_id","time_id"),
    assume_sorted: bool = True,
    cast_f32: bool = True,
    # <<< 新增：控制“严格 d-k / 宽松 K 天”
    prev_soft_days: Optional[int] = None,   # None=不限制；比如 5 表示仅接受 gap<=5 天内的 prev{k}
    strict_k: bool = False,                 # True=严格 d-k（gap==k）；忽略 prev_soft_days
) -> pl.LazyFrame:
    """
    同一 time_id 跨日 prev{k} + 统计（均值/方差/斜率）。

    新增：
    - prev_soft_days: 若设定，则对每个 prev{k} 仅保留 (date - date.shift(k)) <= K 的值，过旧置空。
    - strict_k: 若 True，则严格 d-k（gap==k）；此时 prev_soft_days 被忽略。
    默认维持原语义（纯 shift），仅在传参时启用过滤/严格。
    """
    g_symbol, g_date, g_time = keys
    if not assume_sorted:
        # 保证 (symbol, time) 维度上 date 递增，再做 shift(k).over([symbol,time]) 是因果安全的
        lf = lf.sort([g_symbol, g_time, g_date])

    if stats_rep_cols is None:
        stats_rep_cols = list(rep_cols)

    def _chunks(lst, k):
        for i in range(0, len(lst), k):
            yield lst[i:i+k]

    lf_cur = lf

    # 1) prev{k}（可选：严格/TTL 过滤）
    for batch in _chunks(list(rep_cols), batch_size):
        exprs = []
        for r in batch:
            for k in range(1, ndays+1):
                val_k  = pl.col(r).shift(k).over([g_symbol, g_time])
                day_k  = pl.col(g_date).shift(k).over([g_symbol, g_time])
                gap_k  = (pl.col(g_date) - day_k).cast(pl.Int32)

                if strict_k:
                    # 严格 d-k：仅 gap==k 保留
                    val_k = pl.when(gap_k == k).then(val_k).otherwise(None)
                elif prev_soft_days is not None:
                    # 宽松 K 天：仅 gap<=K 保留（最近 k 个“可见历史”仍可能跨缺口，但限制不超过 K 天）
                    val_k = pl.when(gap_k <= int(prev_soft_days)).then(val_k).otherwise(None)
                # else: 保持原逻辑（纯 shift）

                if cast_f32:
                    val_k = val_k.cast(pl.Float32)
                exprs.append(val_k.alias(f"{r}_same_t_prev{k}"))
        lf_cur = lf_cur.with_columns(exprs)

    # 2) mean/std（忽略 null）
    for batch in _chunks([r for r in stats_rep_cols if r in rep_cols], batch_size):
        exprs = []
        for r in batch:
            cols = [f"{r}_same_t_prev{k}" for k in range(1, ndays+1)]
            vals = pl.concat_list([pl.col(c) for c in cols]).list.drop_nulls()
            m = vals.list.mean()
            s = vals.list.std(ddof=0)
            if cast_f32:
                m = m.cast(pl.Float32); s = s.cast(pl.Float32)
            exprs += [
                m.alias(f"{r}_same_t_last{ndays}_mean"),
                s.alias(f"{r}_same_t_last{ndays}_std"),
            ]
        lf_cur = lf_cur.with_columns(exprs)

    # 3) slope（标准化后与标准化时间权重点积 / 有效样本数）
    x = np.arange(1, ndays+1, dtype=np.float64)
    x = (x - x.mean()) / (x.std() + 1e-9)
    x_lits = [pl.lit(float(v)) for v in x]

    for batch in _chunks([r for r in stats_rep_cols if r in rep_cols], batch_size):
        exprs = []
        for r in batch:
            cols = [f"{r}_same_t_prev{k}" for k in range(1, ndays+1)]
            mean_ref = pl.col(f"{r}_same_t_last{ndays}_mean")
            std_ref  = pl.col(f"{r}_same_t_last{ndays}_std")
            terms = [((pl.col(c) - mean_ref) / (std_ref + 1e-9)) * x_lits[i]
                    for i,c in enumerate(cols)]
            n_eff = pl.sum_horizontal([pl.col(c).is_not_null().cast(pl.Int32) for c in cols]).cast(pl.Float32)
            den   = pl.when(n_eff > 0).then(n_eff).otherwise(pl.lit(1.0))
            slope = (pl.sum_horizontal(terms) / den)
            if cast_f32: slope = slope.cast(pl.Float32)
            exprs.append(slope.alias(f"{r}_same_t_last{ndays}_slope"))
        lf_cur = lf_cur.with_columns(exprs)

    # 4) 跨 responder 的 prev1 行内统计
    if add_prev1_multirep and len(rep_cols) > 0:
        n_rep = len(rep_cols)
        prev1_cols = [f"{r}_same_t_prev1" for r in rep_cols]
        prev1_list = pl.concat_list([pl.col(c) for c in prev1_cols]).list.drop_nulls()
        m1 = prev1_list.list.mean()
        s1 = prev1_list.list.std(ddof=0)
        if cast_f32: m1 = m1.cast(pl.Float32); s1 = s1.cast(pl.Float32)
        lf_cur = lf_cur.with_columns([
            m1.alias(f"prev1_same_t_mean_{n_rep}rep"),
            s1.alias(f"prev1_same_t_std_{n_rep}rep"),
        ])

    return lf_cur



# C 系列：历史值、收益率、差分、极值归一化、指数加权均值等

def build_history_features_polars(
    *,
    lf: Optional[pl.LazyFrame] = None,
    paths: Optional[Sequence[str]] = None,
    feature_cols: Sequence[str],
    keys: Tuple[str,str,str] = ("symbol_id","date_id","time_id"),
    group_cols: Sequence[str] = ("symbol_id",),
    assume_sorted: bool = False,
    cast_f32: bool = True,
    batch_size: int = 10,
    lags: Iterable[int] = (1, 3),
    ret_periods: Iterable[int] = (1,),
    diff_periods: Iterable[int] = (1,),
    rz_windows: Iterable[int] = (5,),
    keep_rmean_rstd: bool = True,
    ewm_spans: Iterable[int] = (10,),
    cs_cols: Optional[Sequence[str]] = None,
    cs_by: Sequence[str] = ("date_id","time_id"),
    prev_soft_days: Optional[int] = None,
) -> pl.LazyFrame:
    assert lf is not None or paths is not None, "provide either lf or paths"
    if lf is None:
        lf = pl.scan_parquet(list(paths))

    g_sym, g_date, g_time = keys
    by_grp = list(group_cols)
    by_cs  = list(cs_by)
    feature_cols = list(feature_cols)

    need_cols = [*keys, *feature_cols]
    schema = lf.collect_schema()
    miss = [c for c in need_cols if c not in schema]
    if miss:
        raise KeyError(f"Columns not found: {miss}")

    lf_out = lf.select(need_cols)
    if not assume_sorted:
        lf_out = lf_out.sort(list(keys))

    def _chunks(lst, k):
        for i in range(0, len(lst), k):
            yield lst[i:i+k]

    # ---- 规范化参数：None/[] -> 空元组；并去重/转 int/保正数 ----
    def _clean_iter(x):
        if not x:
            return tuple()
        return tuple(int(v) for v in x)
    def _clean_pos_sorted_unique(x):
        if not x:
            return tuple()
        return tuple(sorted({int(v) for v in x if int(v) >= 1}))

    LAGS   = _clean_pos_sorted_unique(lags)
    K_RET  = _clean_pos_sorted_unique(ret_periods)
    K_DIFF = _clean_pos_sorted_unique(diff_periods)
    RZW    = _clean_pos_sorted_unique(rz_windows)
    SPANS  = _clean_pos_sorted_unique(ewm_spans)

    # ---- TTL helpers ----
    def _ttl_mask(k: int) -> pl.Expr:
        if prev_soft_days is None:
            return pl.lit(True)
        return (pl.col(g_date) - pl.col(g_date).shift(k).over(by_grp)) <= int(prev_soft_days)

    def _gate(expr: pl.Expr, k: int) -> pl.Expr:
        if prev_soft_days is None:
            return expr
        return pl.when(_ttl_mask(k)).then(expr).otherwise(None)

    # C1 lags（可选）
    if LAGS:
        for batch in _chunks(feature_cols, batch_size):
            exprs = []
            for L in LAGS:
                for c in batch:
                    e = pl.col(c).shift(L).over(by_grp)
                    e = _gate(e, L)
                    if cast_f32: e = e.cast(pl.Float32)
                    exprs.append(e.alias(f"{c}__lag{L}"))
            lf_out = lf_out.with_columns(exprs)

    # C2 returns（可选）
    if K_RET:
        for batch in _chunks(feature_cols, batch_size):
            exprs = []
            for c in batch:
                cur = pl.col(c)
                for k in K_RET:
                    prev = (pl.col(f"{c}__lag{k}") if k in LAGS
                            else pl.col(c).shift(k).over(by_grp))
                    mask = _ttl_mask(k)
                    ret = pl.when(mask & prev.is_not_null() & (prev != 0)).then(cur / prev - 1.0).otherwise(None)
                    if cast_f32: ret = ret.cast(pl.Float32)
                    exprs.append(ret.alias(f"{c}__ret{k}"))
            lf_out = lf_out.with_columns(exprs)

    # C3 diffs（可选）
    if K_DIFF:
        for batch in _chunks(feature_cols, batch_size):
            exprs = []
            for c in batch:
                cur = pl.col(c)
                for k in K_DIFF:
                    prevk = pl.col(c).shift(k).over(by_grp)
                    d = pl.when(_ttl_mask(k)).then(cur - prevk).otherwise(None)
                    if cast_f32: d = d.cast(pl.Float32)
                    exprs.append(d.alias(f"{c}__diff{k}"))
            lf_out = lf_out.with_columns(exprs)

    # 仅当 RZ/EWM 需要、且没有 lag1 时，才建临时 lag1
    need_tmp_lag1 = (bool(RZW) or bool(SPANS)) and (1 not in LAGS)
    if need_tmp_lag1:
        for batch in _chunks(feature_cols, batch_size):
            lf_out = lf_out.with_columns([
                pl.col(c).shift(1).over(by_grp).alias(f"{c}__lag1_tmp") for c in batch
            ])

    def _lag1_col(c: str) -> pl.Expr:
        return pl.col(f"{c}__lag1") if 1 in LAGS else pl.col(f"{c}__lag1_tmp")

    # C4 rolling r-z（可选）
    if RZW:
        for batch in _chunks(feature_cols, batch_size):
            base_cols = []
            for c in batch:
                base_alias = f"{c}__tminus1_base"
                base_cols.append(_gate(_lag1_col(c), 1).alias(base_alias))
            lf_out = lf_out.with_columns(base_cols)

            roll_exprs = []
            for c in batch:
                base = pl.col(f"{c}__tminus1_base")
                for w in RZW:
                    m  = base.rolling_mean(window_size=w, min_samples=1).over(by_grp)
                    s  = base.rolling_std(window_size=w, ddof=0, min_samples=2).over(by_grp)
                    rz = (base - m) / (s + 1e-9)
                    if cast_f32:
                        m = m.cast(pl.Float32); s = s.cast(pl.Float32); rz = rz.cast(pl.Float32)
                    if keep_rmean_rstd:
                        roll_exprs += [
                            m.alias(f"{c}__rmean{w}"),
                            s.alias(f"{c}__rstd{w}"),
                            rz.alias(f"{c}__rz{w}"),
                        ]
                    else:
                        roll_exprs.append(rz.alias(f"{c}__rz{w}"))
            lf_out = lf_out.with_columns(roll_exprs)
            lf_out = lf_out.drop([f"{c}__tminus1_base" for c in batch])

    # C5 EWM（可选）
    if SPANS:
        for batch in _chunks(feature_cols, batch_size):
            base_cols = []
            for c in batch:
                base_alias = f"{c}__tminus1_base"
                base_cols.append(_gate(_lag1_col(c), 1).alias(base_alias))
            lf_out = lf_out.with_columns(base_cols)

            ewm_exprs = []
            for c in batch:
                base = pl.col(f"{c}__tminus1_base")
                for s in SPANS:
                    ema = base.ewm_mean(span=int(s), adjust=False, ignore_nulls=True).over(by_grp)
                    if cast_f32: ema = ema.cast(pl.Float32)
                    ewm_exprs.append(ema.alias(f"{c}__ewm{s}"))
            lf_out = lf_out.with_columns(ewm_exprs)
            lf_out = lf_out.drop([f"{c}__tminus1_base" for c in batch])

    if need_tmp_lag1:
        lf_out = lf_out.drop([f"{c}__lag1_tmp" for c in feature_cols])

    # C6 cross-section rank（可选）
    if cs_cols:
        cs_cols = [c for c in cs_cols if c in feature_cols]
        if cs_cols:
            lf_out = lf_out.with_columns([
                _gate(pl.col(c).shift(1).over(by_grp), 1).alias(f"{c}__tminus1_base")
                for c in cs_cols
            ])
            cs_n_tbl = (
                lf_out.select(list(by_cs))
                      .group_by(list(by_cs))
                      .len()
                      .rename({"len": "__cs_n"})
            )
            lf_out = lf_out.join(cs_n_tbl, on=list(by_cs), how="left")
            lf_out = lf_out.with_columns([
                pl.when(pl.col(f"{c}__tminus1_base").is_null())
                  .then(None)
                  .otherwise(pl.col(f"{c}__tminus1_base").rank(method="average"))
                  .over(list(by_cs))
                  .alias(f"{c}__cs_rank_raw")
                for c in cs_cols
            ])
            lf_out = lf_out.with_columns([
                pl.when(pl.col(f"{c}__tminus1_base").is_null())
                  .then(None)
                  .otherwise(pl.col(f"{c}__cs_rank_raw") / pl.col("__cs_n"))
                  .cast(pl.Float32 if cast_f32 else pl.Float64)
                  .alias(f"{c}__csrank")
                for c in cs_cols
            ])
            lf_out = lf_out.drop(
                [f"{c}__tminus1_base" for c in cs_cols] +
                [f"{c}__cs_rank_raw"  for c in cs_cols]
            )

    return lf_out

def collect_write(lf: pl.LazyFrame, path: str, compression: str = "zstd"):
    Path(path).parent.mkdir(parents=True, exist_ok=True)
    try:
        df = lf.collect(streaming=True)
    except TypeError:
        df = lf.collect()
    df.write_parquet(path, compression=compression)
    del df
    
@dataclass
class StageA:
    tail_lags: Sequence[int]
    tail_diffs: Sequence[int]
    rolling_windows: Optional[Sequence[int]]
    prev_soft_days: Optional[int] = None
    assume_sorted: bool = True
    cast_f32: bool = True

@dataclass
class StageB:
    ndays: int
    stats_rep_cols: Optional[Sequence[str]] = None
    add_prev1_multirep: bool = True
    batch_size: int = 5
    prev_soft_days: Optional[int] = None
    strict_k: bool = False
    assume_sorted: bool = True
    cast_f32: bool = True

# C 的每个操作可选；None / [] 表示跳过该操作
@dataclass
class StageC:
    lags: Optional[Iterable[int]] = None
    ret_periods: Optional[Iterable[int]] = None
    diff_periods: Optional[Iterable[int]] = None
    rz_windows: Optional[Iterable[int]] = None
    ewm_spans: Optional[Iterable[int]] = None
    keep_rmean_rstd: bool = True
    cs_cols: Optional[Sequence[str]] = None
    cs_by: Sequence[str] = ("date_id","time_id")
    prev_soft_days: Optional[int] = None
    batch_size: Optional[int] = 10
    assume_sorted: bool = True
    cast_f32: bool = True


def run_staged_engineering(
    lf_base: pl.LazyFrame,
    *,
    keys: Sequence[str],
    rep_cols: Sequence[str],
    feature_cols: Sequence[str],
    out_dir: str,
    A: StageA | None = None,
    B: StageB | None = None,
    C: StageC | None = None,
):
    Path(out_dir).mkdir(parents=True, exist_ok=True)
    paths = {}

    # ---------- A ----------
    if A is not None:
        lf_resp = lf_base.select([*keys, *rep_cols])
        lf_a_full = fe_prevday_tail_and_summaries(
            lf_resp,
            rep_cols=rep_cols,
            tail_lags=A.tail_lags,
            tail_diffs=A.tail_diffs,
            rolling_windows=A.rolling_windows,
            keys=tuple(keys),
            assume_sorted=A.assume_sorted,
            cast_f32=A.cast_f32,
            prev_soft_days=A.prev_soft_days,
        )
        drop = set(keys) | set(rep_cols)
        a_cols = [c for c in lf_a_full.collect_schema().names() if c not in drop]
        pA = f"{out_dir}/stage_a.parquet"
        collect_write(lf_a_full.select([*keys, *a_cols]), pA)
        paths["A"] = pA

    # ---------- B ----------
    if B is not None:
        lf_resp = lf_base.select([*keys, *rep_cols])
        lf_b_full = fe_same_timeid_crossday(
            lf_resp,
            rep_cols=rep_cols,
            ndays=B.ndays,
            stats_rep_cols=B.stats_rep_cols,
            add_prev1_multirep=B.add_prev1_multirep,
            batch_size=B.batch_size,
            keys=tuple(keys),
            assume_sorted=B.assume_sorted,
            cast_f32=B.cast_f32,
            prev_soft_days=B.prev_soft_days,
            strict_k=B.strict_k,
        )
        drop = set(keys) | set(rep_cols)
        b_cols = [c for c in lf_b_full.collect_schema().names() if c not in drop]
        pB = f"{out_dir}/stage_b.parquet"
        collect_write(lf_b_full.select([*keys, *b_cols]), pB)
        paths["B"] = pB

    # ---------- C：按“操作”分别输出单文件（不分窗口、不分 batch） ----------
    if C is not None:
        paths["C"] = {}

        def _do_op(op_name: str, **op_flags):
            lf_src = lf_base.select([*keys, *feature_cols])
            lf_d = build_history_features_polars(
                lf=lf_src,
                feature_cols=feature_cols,
                keys=tuple(keys),
                group_cols=("symbol_id",),
                assume_sorted=C.assume_sorted,
                cast_f32=C.cast_f32,
                batch_size=C.batch_size,
                lags=op_flags.get("lags"),
                ret_periods=op_flags.get("ret_periods"),
                diff_periods=op_flags.get("diff_periods"),
                rz_windows=op_flags.get("rz_windows"),
                keep_rmean_rstd=C.keep_rmean_rstd,
                ewm_spans=op_flags.get("ewm_spans"),
                cs_cols=op_flags.get("cs_cols"),
                cs_by=C.cs_by,
                prev_soft_days=C.prev_soft_days
            ).drop(feature_cols)
            p = f"{out_dir}/stage_c_{op_name}.parquet"
            collect_write(lf_d, p)
            paths["C"][op_name] = [p]

        if C.lags:         _do_op("lags",   lags=C.lags)
        if C.ret_periods:  _do_op("ret",    ret_periods=C.ret_periods)
        if C.diff_periods: _do_op("diff",   diff_periods=C.diff_periods)
        if C.rz_windows:   _do_op("rz",     rz_windows=C.rz_windows)
        if C.ewm_spans:    _do_op("ewm",    ewm_spans=C.ewm_spans)
        if C.cs_cols:      _do_op("csrank", cs_cols=C.cs_cols)

    return paths

def _join_from(lf_left: pl.LazyFrame, path: str, lo: int, hi: int) -> pl.LazyFrame:
    cols = FILE2COLS.get(path, [])
    if not cols: return lf_left
    lf_right = (pl.scan_parquet(path)
                  .select([*KEYS, *cols])
                  .filter(pl.col("date_id").is_between(lo, hi))
                  .sort(KEYS))
    return lf_left.join(lf_right, on=KEYS, how="left")

def build_slice(lo: int, hi: int) -> pd.DataFrame:
    lf = (pl.scan_parquet(PARQUET_PATHS)
            .select([*KEYS, TARGET])
            .filter(pl.col("date_id").is_between(lo, hi))
            .sort(KEYS))
    for p in FILE2COLS:
        lf = _join_from(lf, p, lo, hi)
    return lf.collect(streaming=True).to_pandas()

def weighted_r2_zero_mean(y_true, y_pred, weight) -> float:
    """
    Sample-weighted zero-mean R^2 used in Jane Street:
        R^2 = 1 - sum_i w_i (y_i - yhat_i)^2 / sum_i w_i y_i^2
    """
    y_true = np.asarray(y_true, dtype=np.float64).ravel()
    y_pred = np.asarray(y_pred, dtype=np.float64).ravel()
    weight = np.asarray(weight, dtype=np.float64).ravel()
    assert y_true.shape == y_pred.shape == weight.shape

    num = np.sum(weight * (y_true - y_pred) ** 2)
    den = np.sum(weight * (y_true ** 2))
    if den <= 0:
        return 0.0  # safe fallback (shouldn't happen on the full JS eval)
    return 1.0 - (num / den)

def lgb_wr2_eval(preds, train_data):
    y = train_data.get_label()
    w = train_data.get_weight()
    if w is None:
        w = np.ones_like(y)
    score = weighted_r2_zero_mean(y, preds, w)
    return ('wr2', score, True)  # higher is better


# 特征选择- 初选 (选特征，省略)

In [None]:
import os, gc, glob
import polars as pl
import numpy as np
import lightgbm as lgb
import pandas as pd
from pathlib import Path

BASE_PATH = ["/mnt/data/js/clean/final_clean.parquet"]
KEYS = ["symbol_id","date_id","time_id"]
TARGET = "responder_6"
WEIGHT = 'weight'
FEATURE_COLS = [f"feature_{i:02d}" for i in range(79)]
REP_COLS = [f"responder_{i}" for i in range(9)]

OUT_DIR = "/mnt/data/js/cache/first_selection"
os.makedirs(OUT_DIR, exist_ok=True)


In [None]:

# --- A: prev-day tails + daily summaries ---
A = StageA(
    tail_lags=(1,),
    tail_diffs=(1,),
    rolling_windows=(5,),
    prev_soft_days=3,          # allow fallback up to 3 calendar days
)

# --- B: same time_id cross-day ---
B = StageB(
    ndays=3,                   # prev{1..3} at same time_id
    stats_rep_cols=None,       # default: use rep_cols
    add_prev1_multirep=True,
    batch_size=5,
    prev_soft_days=3,          # TTL for gaps
    strict_k=False,            # allow ≤K-day gaps instead of strict d-k
)

# --- C: history features (keep it tiny) ---
C = StageC(
    lags=(1, ),
    ret_periods=(1,),
    diff_periods=(1,),
    rz_windows=(5,),
    ewm_spans=(10,),
    keep_rmean_rstd=True,
    cs_cols=None,        # must be subset of feature_cols
    cs_by=("date_id","time_id"),
    prev_soft_days=3,
)

# example call
paths = run_staged_engineering(
    lf_base=lf_base,                # your base LazyFrame
    keys=KEYS,
    rep_cols=REP_COLS,         # updated to use REP_COLS
    feature_cols=FEATURE_COLS, # updated to use FEATURE_COLS
    out_dir=OUT_DIR,
    A=A, B=B, C=C,
)


0. 准备与切分天数

In [None]:

STAGE_PATHS = [
    "/mnt/data/js/cache/first_selection/stage_a.parquet",
    "/mnt/data/js/cache/first_selection/stage_b.parquet",
    "/mnt/data/js/cache/first_selection/stage_c_lags.parquet",
    "/mnt/data/js/cache/first_selection/stage_c_ret.parquet",
    "/mnt/data/js/cache/first_selection/stage_c_diff.parquet",
    "/mnt/data/js/cache/first_selection/stage_c_rz.parquet",
    "/mnt/data/js/cache/first_selection/stage_c_ewm.parquet",
]

DATE_LO, DATE_HI = 1200, 1400
OUT_DIR = "/mnt/data/js/cache/first_selection/run_full"
SHARD_DIR = f"{OUT_DIR}/shards_all"
Path(SHARD_DIR).mkdir(parents=True, exist_ok=True)

lf_base = pl.scan_parquet(BASE_PATH)
# 仅拿目标区间的base
lf_range = lf_base.filter(pl.col("date_id").is_between(DATE_LO, DATE_HI))

# days & split
days = (lf_range.select(pl.col("date_id").unique().sort())
                .collect(streaming=True)["date_id"].to_list())
cut = int(len(days) * 0.8)
train_days, val_days = days[:cut], days[cut:]
print(f"[split] train {len(train_days)} days, val {len(val_days)} days, range={days[0]}..{days[-1]}")


1.收集全量特征列名（并集）

In [None]:
# 来自 base 的特征列
feat_cols = set(FEATURE_COLS)

# 各 stage 全部列（除 KEYS/TARGET/WEIGHT）
for p in STAGE_PATHS:
    if not os.path.exists(p):
        print(f"[skip] missing: {p}")
        continue
    names = pl.scan_parquet(p).collect_schema().names()
    for c in names:
        if c not in KEYS and c not in (TARGET, WEIGHT):
            feat_cols.add(c)

feat_cols = sorted(feat_cols)
print(f"[cols] total feature columns = {len(feat_cols)}")


2. 写“天片”—把所有列拼上并立刻落盘

In [None]:
DAYS_PER_SHARD = 16

# 左表（含 base 的 FEATURE_COLS）
lf_left_base = (
    lf_range
    .select([*KEYS, TARGET, WEIGHT, *[pl.col(c) for c in FEATURE_COLS]])
)

# 为每个 stage 准备元信息（列名 + 文件大小，先拼小文件更省内存）
stage_meta = []
for p in STAGE_PATHS:
    if not os.path.exists(p): 
        continue
    scan = pl.scan_parquet(p).filter(pl.col("date_id").is_between(DATE_LO, DATE_HI))
    cols = [c for c in scan.collect_schema().names() if c not in KEYS]
    if cols:
        stage_meta.append({"path": p, "cols": cols, "size": os.path.getsize(p)})
stage_meta.sort(key=lambda d: d["size"])


In [None]:
def write_shards(tag, days_list):
    ds = sorted(days_list)
    for i in range(0, len(ds), DAYS_PER_SHARD):
        batch = set(ds[i:i+DAYS_PER_SHARD])

        # 当前片的左表
        lf_chunk = lf_left_base.filter(pl.col("date_id").is_in(batch))
        already = set(lf_chunk.collect_schema().names())

        # 逐 stage 拼接（右表只取该片天数 + 只取未存在列）
        for m in stage_meta:
            need = [c for c in m["cols"] if c not in already]
            if not need:
                continue
            lf_add = (pl.scan_parquet(m["path"])
                        .filter(pl.col("date_id").is_in(batch))
                        .select([*KEYS, *need]))
            lf_chunk = lf_chunk.join(lf_add, on=KEYS, how="left")
            already.update(need)

        # 统一 float32 并落盘（列按 feat_cols 顺序对齐；片内缺失的列自然是 null）
        present = [c for c in feat_cols if c in already]
        cast_feats = [pl.col(c).cast(pl.Float32).alias(c) for c in present]
        lf_out = lf_chunk.select([
            *KEYS,
            pl.col(WEIGHT).cast(pl.Float32).alias(WEIGHT),
            pl.col(TARGET).cast(pl.Float32).alias(TARGET),
            *cast_feats,
        ])
        out_path = f"{SHARD_DIR}/{tag}_shard_{i//DAYS_PER_SHARD:04d}.parquet"
        lf_out.sink_parquet(out_path, compression="zstd")
        print(f"[{tag}] wrote {out_path}")
        gc.collect()

write_shards("train", train_days)
write_shards("val",   val_days)

3. 从 shards 构建 memmap 数组 （恒定内存）

In [None]:
def memmap_from_shards(glob_pat, feat_cols, prefix):
    paths = sorted(glob.glob(glob_pat))
    counts = [pl.scan_parquet(p).select(pl.len()).collect(streaming=True).item() for p in paths]
    n_rows, n_feat = int(sum(counts)), len(feat_cols)
    print(f"[memmap] {glob_pat}: {len(paths)} files, {n_rows} rows, {n_feat} features")

    X = np.memmap(f"{prefix}_X.float32.mmap", dtype="float32", mode="w+", shape=(n_rows, n_feat))
    y = np.memmap(f"{prefix}_y.float32.mmap", dtype="float32", mode="w+", shape=(n_rows,))
    w = np.memmap(f"{prefix}_w.float32.mmap", dtype="float32", mode="w+", shape=(n_rows,))

    i = 0
    for p, k in zip(paths, counts):
        lf = pl.scan_parquet(p)
        names = set(lf.collect_schema().names())
        exprs = [
            (pl.col(c).cast(pl.Float32).alias(c) if c in names
             else pl.lit(None, dtype=pl.Float32).alias(c))
            for c in feat_cols
        ]
        df = lf.select([
            pl.col(TARGET).cast(pl.Float32).alias(TARGET),
            pl.col(WEIGHT).cast(pl.Float32).alias(WEIGHT),
            *exprs
        ]).collect(streaming=True)

        X[i:i+k, :] = df.select(feat_cols).to_numpy()
        y[i:i+k]    = df.select(pl.col(TARGET)).to_numpy().ravel()
        w[i:i+k]    = df.select(pl.col(WEIGHT)).to_numpy().ravel()
        i += k
        del df; gc.collect()

    X.flush(); y.flush(); w.flush()
    return X, y, w

train_X, train_y, train_w = memmap_from_shards(f"{SHARD_DIR}/train_shard_*.parquet", feat_cols, f"{OUT_DIR}/train")
val_X,   val_y,   val_w   = memmap_from_shards(f"{SHARD_DIR}/val_shard_*.parquet",   feat_cols, f"{OUT_DIR}/val")

print("train shapes:", train_X.shape, train_y.shape, train_w.shape)
print("val   shapes:", val_X.shape,   val_y.shape,   val_w.shape)


4. LightGBM 训练 + 重要性 （一次性全列）

In [None]:
dtrain = lgb.Dataset(train_X, label=train_y, weight=train_w,
                     feature_name=feat_cols, free_raw_data=True)
dval   = lgb.Dataset(val_X,   label=val_y,   weight=val_w,
                     feature_name=feat_cols, reference=dtrain, free_raw_data=True)

params = dict(
    objective="regression", metric="None",
    num_threads=16, seed=42, deterministic=True, first_metric_only=True,
    learning_rate=0.05, num_leaves=31, max_depth=-1, min_data_in_leaf=20,
    # 内存友好
    max_bin=63, bin_construct_sample_cnt=100_000, min_data_in_bin=3,
)

model = lgb.train(
    params, dtrain,
    valid_sets=[dval, dtrain], valid_names=["val","train"],
    num_boost_round=1000, callbacks=[lgb.early_stopping(50)],
    feval=lgb_wr2_eval,   # 你的评估函数
)

imp = pd.DataFrame({
    "feature": model.feature_name(),
    "gain": model.feature_importance("gain"),
    "split": model.feature_importance("split"),
}).sort_values("gain", ascending=False).reset_index(drop=True)

print(imp.head(30))



In [None]:
imp.to_csv(f"{OUT_DIR}/imp_1r.csv", index=False)

In [None]:
imp = pd.read_csv(f"{OUT_DIR}/imp_1r.csv")
top_feats = imp.loc[imp.gain > 0]

In [None]:
fam = top_feats['feature'].str.extract(r'^(feature_\d{2}|responder_\d)', expand=False)
top_feats['family'] = fam

In [None]:
top_feats

In [None]:
fam_feats = top_feats.groupby('family').agg(
    n = ('feature', 'count'),
    gain = ('gain', 'sum'),
    split = ('split', 'sum'),
).reset_index().sort_values('gain', ascending=False)

In [None]:
print(fam_feats.shape)

In [None]:
mask_feat = fam_feats['family'].str.startswith('feature_', na=False)
mask_resp = fam_feats["family"].str.startswith("responder_", na=False)
features_only   = fam_feats[mask_feat].sort_values("gain", ascending=False)
responders_only = fam_feats[mask_resp].sort_values("gain", ascending=False)

In [None]:
selected_features = features_only['family'][:79] # select all
selected_resps = responders_only['family'][:9] # select all

# save the Series (no index)
selected_features.to_csv(f"{OUT_DIR}/selected_features_1r.csv", index=False, header=False)
selected_resps.to_csv(f"{OUT_DIR}/selected_responders_1r.csv", index=False, header=False)

# 特征工程

In [None]:
# 训练集参数

FEATURE_COLS = [f"feature_{i:02d}" for i in range(79)] #FEATURE_COLS = pd.read_csv(f"{INPUT_DIR}/selected_features_1r.csv", header=None).squeeze().tolist()
REP_COLS = [f"responder_{i}" for i in range(9)] #REP_COLS = pd.read_csv(f"{INPUT_DIR}/selected_responders_1r.csv", header=None).squeeze().tolist()

BASE_DIR = BLOB(cfg['paths']['clean'])


BASE_PATH = [f"{BASE_DIR}/final_clean.parquet"]
lf_base = pl.scan_parquet(BASE_PATH, storage_options=fs.storage_options)

In [None]:
# test参数

FEATURE_COLS = [f"feature_{i:02d}" for i in range(79)] #FEATURE_COLS = pd.read_csv(f"{INPUT_DIR}/selected_features_1r.csv", header=None).squeeze().tolist()
REP_COLS = [f"responder_{i}" for i in range(9)] #REP_COLS = pd.read_csv(f"{INPUT_DIR}/selected_responders_1r.csv", header=None).squeeze().tolist()

TEST_BASE_DIR = BLOB(cfg['paths']['clean'])

TEST_BASE_PATH = [f"{TEST_BASE_DIR}/test_final_clean.parquet"]
lf_test_base = pl.scan_parquet(TEST_BASE_PATH, storage_options=fs.storage_options)

In [None]:
# ------- step 1: shard raw -------

CORE_DAYS = 30                 # 每片保留的核心天数
# 例：A 的最大 rolling=60 天；C 的最大 tick 窗口=968 → 1 天
PAD_DAYS = 60   # 自行按实际 A/B/C 调整

# 所有交易日
days = (lf_base.select(pl.col("date_id").unique().sort())
            .collect(streaming=True)["date_id"].to_list())

In [None]:
# TEST 

# ------- step 1: shard raw -------

CORE_DAYS = 30                 # 每片保留的核心天数
# 例：A 的最大 rolling=60 天；C 的最大 tick 窗口=968 → 1 天
PAD_DAYS = 60   # 自行按实际 A/B/C 调整

# 所有交易日
test_days = (lf_test_base.select(pl.col("date_id").unique().sort())
            .collect(streaming=True)["date_id"].to_list())

In [None]:
# 分片输出
for start_idx in range(PAD_DAYS, len(days), CORE_DAYS):
    core_lo_idx = start_idx
    core_hi_idx = min(start_idx + CORE_DAYS - 1, len(days) - 1)
    pad_lo_idx  = core_lo_idx - PAD_DAYS

    core_lo, core_hi = days[core_lo_idx], days[core_hi_idx]
    pad_lo           = days[pad_lo_idx]

    out_path = f"{RAW_SHA_DIR}/raw_{core_lo}_{core_hi}.parquet"

    # 如需覆盖旧文件，先删除（可选）
    if fs.exists(out_path.replace("az://","")):
        fs.rm(out_path.replace("az://",""))

    (lf_base
        .filter(pl.col("date_id").is_between(pad_lo, core_hi))
        .select([*KEYS, *FEATURE_COLS, *REP_COLS, WEIGHT])
        .sink_parquet(
            out_path,
            compression="zstd",
            storage_options=storage_options,   # ← 用你构造 FS 的那个 dict
            statistics=True,                   # 可选：写入列统计
            maintain_order=True,             # 可选：严格保持当前行序
        )
    )


In [None]:
# TEST 

RAW_SHA_DIR_B = BLOB(cfg['paths']['raw_shards'])
os.makedirs(RAW_SHA_DIR_B, exist_ok=True)


# 分片输出
for start_idx in range(PAD_DAYS, len(test_days), CORE_DAYS):
    core_lo_idx = start_idx
    core_hi_idx = min(start_idx + CORE_DAYS - 1, len(test_days) - 1)
    pad_lo_idx  = core_lo_idx - PAD_DAYS

    core_lo, core_hi = test_days[core_lo_idx], test_days[core_hi_idx]
    pad_lo           = test_days[pad_lo_idx]  # updated to use test_days

    out_path = f"{RAW_SHA_DIR_B}/test_raw_{core_lo}_{core_hi}.parquet" # added "test_"#

    # 如需覆盖旧文件，先删除（可选）
    if fs.exists(out_path.replace("az://","")):
        fs.rm(out_path.replace("az://",""))

    (lf_test_base
        .filter(pl.col("date_id").is_between(pad_lo, core_hi))
        .select([*KEYS, *FEATURE_COLS, *REP_COLS, WEIGHT])
        .sink_parquet(
            out_path,
            compression="zstd",
            storage_options=storage_options,   # ← 用你构造 FS 的那个 dict
            statistics=True,                   # 可选：写入列统计
            maintain_order=True,             # 可选：严格保持当前行序
        )
    )


In [None]:
# ------- step 2: FE per raw shard (A+B once, C batched internally via C.batch_size) -------

# params (tweak as needed)
A = StageA(
    tail_lags=(1,2,3,5),
    tail_diffs=(1,2),
    rolling_windows=(5,20,60),
    prev_soft_days=7,
)
B = StageB(
    ndays=5, stats_rep_cols=None, add_prev1_multirep=True,
    batch_size=8, prev_soft_days=7, strict_k=False,
)
C = StageC(
    lags=(1,16,64,100,256,968),
    ret_periods=(1,16,64,100,256),
    diff_periods=(1,16,64,100,256,968),
    rz_windows=(16,64,100,256,968),
    ewm_spans=(16,64,100,256,968),
    keep_rmean_rstd=False,                 # 列多时建议关；需要时再开
    cs_cols=None, cs_by=("date_id","time_id"),
    prev_soft_days=7,
    batch_size=4,                         # ← 内部分批规模；OOM 就降到 8/6
    assume_sorted=True, cast_f32=True,
)

# ---- 数值排序，确保 raw_740_769 在 raw_1010_1039 之前 ----
def numeric_key(p):
    stem = Path(p).stem            # 'raw_1010_1039'
    a, b = map(int, stem.split('_')[-2:])
    return (a, b)



# ---- 写出到 az://，去掉 mkdir；兼容 storage_options / cloud_options ----
def trim_core(in_path: str, out_path_blob: str, lo: int, hi: int):
    lf = (pl.scan_parquet(in_path)  # 本地读
            .filter(pl.col("date_id").is_between(lo, hi)))

    opts = dict(compression="zstd", statistics=True, maintain_order=True)

    # 新版本大多用 storage_options；个别版本参数名是 cloud_options
    try:
        lf.sink_parquet(out_path_blob, storage_options=storage_options, **opts)
    except TypeError:
        lf.sink_parquet(out_path_blob, cloud_options=storage_options, **opts)


    
    
def fe_on_raw_shard(raw_path_blob: str, *, core_lo: int, core_hi: int,
                    A=None, B=None, C=None,
                    keys=KEYS, rep_cols=REP_COLS, feature_cols=FEATURE_COLS,
                    fe_dir_blob=FE_SHA_DIR_B, tmp_root=TMP_ROOT):
    # raw_path_blob 是 az://... 的路径
    lf_slice = pl.scan_parquet(raw_path_blob, storage_options=storage_options)

    # 临时工作区仍用本地，跑完删除
    with tempfile.TemporaryDirectory(dir=tmp_root) as tmp_dir:
        res = run_staged_engineering(
            lf_base=lf_slice,
            keys=keys, rep_cols=rep_cols, feature_cols=feature_cols,
            out_dir=tmp_dir, A=A, B=B, C=C
        )

        if "A" in res:
            trim_core(res["A"], f"{fe_dir_blob}/A_{core_lo}_{core_hi}.parquet", core_lo, core_hi)
        if "B" in res:
            trim_core(res["B"], f"{fe_dir_blob}/B_{core_lo}_{core_hi}.parquet", core_lo, core_hi)
        for op, paths in res.get("C", {}).items():
            for p in paths:
                trim_core(p, f"{fe_dir_blob}/C_{op}_{core_lo}_{core_hi}.parquet", core_lo, core_hi)

        gc.collect()

        
# 列出 Blob 上的 raw_* 分片（无协议 → az://）
raw_list_np = sorted(fs.glob(f"{RAW_SHA_DIR_NP}/raw_*_*.parquet"), key=numeric_key)
for raw_np in raw_list_np:
    lo, hi = map(int, Path(raw_np).stem.split("_")[-2:])
    raw_blob = "az://" + raw_np
    fe_on_raw_shard(raw_blob, core_lo=lo, core_hi=hi, A=A, B=B, C=C)

In [None]:
# Test

# ------- step 2: FE per raw shard (A+B once, C batched internally via C.batch_size) -------

# params (tweak as needed)
A = StageA(
    tail_lags=(1,2,3,5),
    tail_diffs=(1,2),
    rolling_windows=(5,20,60),
    prev_soft_days=7,
)
B = StageB(
    ndays=5, stats_rep_cols=None, add_prev1_multirep=True,
    batch_size=8, prev_soft_days=7, strict_k=False,
)
C = StageC(
    lags=(1,16,64,100,256,968),
    ret_periods=(1,16,64,100,256),
    diff_periods=(1,16,64,100,256,968),
    rz_windows=(16,64,100,256,968),
    ewm_spans=(16,64,100,256,968),
    keep_rmean_rstd=False,                 # 列多时建议关；需要时再开
    cs_cols=None, cs_by=("date_id","time_id"),
    prev_soft_days=7,
    batch_size=4,                         # ← 内部分批规模；OOM 就降到 8/6
    assume_sorted=True, cast_f32=True,
)

# ---- 数值排序，确保 raw_740_769 在 raw_1010_1039 之前 ----
def numeric_key(p):
    stem = Path(p).stem            # 'raw_1010_1039'
    a, b = map(int, stem.split('_')[-2:])
    return (a, b)



# ---- 写出到 az://，去掉 mkdir；兼容 storage_options / cloud_options ----
def trim_core(in_path: str, out_path_blob: str, lo: int, hi: int):
    lf = (pl.scan_parquet(in_path)  # 本地读
            .filter(pl.col("date_id").is_between(lo, hi)))

    opts = dict(compression="zstd", statistics=True, maintain_order=True)

    # 新版本大多用 storage_options；个别版本参数名是 cloud_options
    try:
        lf.sink_parquet(out_path_blob, storage_options=storage_options, **opts)
    except TypeError:
        lf.sink_parquet(out_path_blob, cloud_options=storage_options, **opts)


    
    
def fe_on_raw_shard(raw_path_blob: str, *, core_lo: int, core_hi: int,
                    A=None, B=None, C=None,
                    keys=KEYS, rep_cols=REP_COLS, feature_cols=FEATURE_COLS,
                    fe_dir_blob=FE_SHA_DIR_B, tmp_root=TMP_ROOT):
    # raw_path_blob 是 az://... 的路径
    lf_slice = pl.scan_parquet(raw_path_blob, storage_options=storage_options)

    # 临时工作区仍用本地，跑完删除
    with tempfile.TemporaryDirectory(dir=tmp_root) as tmp_dir:
        res = run_staged_engineering(
            lf_base=lf_slice,
            keys=keys, rep_cols=rep_cols, feature_cols=feature_cols,
            out_dir=tmp_dir, A=A, B=B, C=C
        )

        if "A" in res:
            trim_core(res["A"], f"{fe_dir_blob}/test_A_{core_lo}_{core_hi}.parquet", core_lo, core_hi)
        if "B" in res:
            trim_core(res["B"], f"{fe_dir_blob}/test_B_{core_lo}_{core_hi}.parquet", core_lo, core_hi)
        for op, paths in res.get("C", {}).items(): # 这里最好改为path
            for p in paths: # path (s)
                trim_core(p, f"{fe_dir_blob}/test_C_{op}_{core_lo}_{core_hi}.parquet", core_lo, core_hi)

        gc.collect()

        
# 列出 Blob 上的 raw_* 分片（无协议 → az://）
raw_list_np = sorted(fs.glob(f"{RAW_SHA_DIR_NP}/test_raw_*_*.parquet"), key=numeric_key)
for raw_np in raw_list_np:
    lo, hi = map(int, Path(raw_np).stem.split("_")[-2:])
    raw_blob = "az://" + raw_np
    fe_on_raw_shard(raw_blob, core_lo=lo, core_hi=hi, A=A, B=B, C=C)

# 特征选择 二选（选参数）

至此，我们已经获得了全量数据的满参数下的各片特征量，接下来选小量样本 进行特征参数选择

样本集_特征选择

- 选择目标日期的数据集并合并
- 划分训练集，验证集
- 简单参数lgb训练
- 特征重要性排序

1. 基本配置

In [None]:
# 列名与区间
FEATURE_COLS = [f"feature_{i:02d}" for i in range(79)]  #pd.read_csv("/mnt/data/js/config/selected_features_1r.csv",
DATE_LO, DATE_HI = 1200, 1400 # 指定训练/验证的 date_id 范围, 后期转为全部训练集
print("ready")

2.枚举窗口 + 拿到训练、验证天

In [None]:
# 从 Blob 列出全部 fe_shards 分片（返回不带协议的路径，要手动加 az://）

fe_all = fs.glob(f"{FE_SHA_DIR_B}/*_*.parquet")
fe_all = [p if p.startswith("az://") else f"az://{p}" for p in fe_all]

# 按日期范围筛选
wins = set()
for p in fe_all:
    base = p.split("/")[-1]  # e.g. C_lags_1220_1249.parquet
    lo = int(base.split("_")[-2]); hi = int(base.split("_")[-1].split(".")[0])
    if hi >= DATE_LO and lo <= DATE_HI:
        wins.add((lo, hi))
wins = sorted(wins)
print(f"windows in range: {wins[:5]} ... (total {len(wins)})")


# 取得该区间实际天，并按 80/20 切分（天为单位，避免泄露）
days = (pl.scan_parquet(BASE_PATH, storage_options=storage_options)
        .filter(pl.col("date_id").is_between(DATE_LO, DATE_HI))
        .select(pl.col("date_id").unique().sort())
        .collect(streaming=True)["date_id"].to_list())
assert days, "no days found in range"
cut = int(len(days)*0.8)
train_days, val_days = set(days[:cut]), set(days[cut:])
print(f"days total={len(days)}  train={len(train_days)}  val={len(val_days)}")

3. 按窗口拼接 (A + B + 所有 C_*) → 直接写 train/val 分片（无大表）

In [None]:
for (lo, hi) in wins:
    # 与全局区间取交集，防止边缘窗口越界
    w_lo, w_hi = max(lo, DATE_LO), min(hi, DATE_HI)

    # 基表 (带 TARGET/WEIGHT + 基础特征)
    lf = (pl.scan_parquet(BASE_PATH, storage_options=storage_options)
            .filter(pl.col("date_id").is_between(w_lo, w_hi))
            .select([*KEYS, TARGET, WEIGHT, *[pl.col(c) for c in FEATURE_COLS]]))

    protocol = FE_SHA_DIR_B.split("://", 1)[0] + "://"

    # 若 fe_all 可能是无协议，先归一化
    fe_set = {p if p.startswith(protocol) else protocol + p for p in fe_all}

    fe_files = []
    for name in (f"A_{lo}_{hi}.parquet", f"B_{lo}_{hi}.parquet"):
        p = f"{FE_SHA_DIR_B}/{name}"  # 带协议
        if p in fe_set:
            fe_files.append(p)

    # 同窗口所有 C_* 分片
    c_files = fs.glob(f"{FE_SHA_DIR_NP}/C_*_{lo}_{hi}.parquet")  # 无协议
    fe_files += [p if p.startswith(protocol) else protocol + p for p in sorted(c_files)]


    # 逐个左连接（scan 到 Blob 时一定加 storage_options）
    already = set(lf.collect_schema().names())
    for fp in fe_files:
        ds = pl.scan_parquet(fp, storage_options=storage_options)
        names = ds.collect_schema().names()
        add_cols = [c for c in names if c not in already]
        if add_cols:
            lf = lf.join(ds.select([*KEYS, *add_cols]), on=KEYS, how="left")
            already.update(add_cols)


    # 统一 float32；注意：此窗口缺失的特征列**不写入**→ 之后 memmap 会自动补 None
    feat_present = [c for c in already if c not in (*KEYS, TARGET, WEIGHT)]
    select_exprs = [
        *KEYS,
        pl.col(TARGET).cast(pl.Float32).alias(TARGET),
        pl.col(WEIGHT).cast(pl.Float32).alias(WEIGHT),
        *[pl.col(c).cast(pl.Float32).alias(c) for c in feat_present],
    ]
    lf_win = lf.select(select_exprs)

    # 直接流式写 train/val 分片（不 materialize 整个窗口）
    out_train = f"{FE_SHA_DIR_B}/train_{lo}_{hi}.parquet"
    out_val   = f"{FE_SHA_DIR_B}/val_{lo}_{hi}.parquet"
    (lf_win.filter(pl.col("date_id").is_in(list(train_days)))
           .sink_parquet(out_train, compression="zstd", storage_options=storage_options))
    (lf_win.filter(pl.col("date_id").is_in(list(val_days)))
           .sink_parquet(out_val, compression="zstd", storage_options=storage_options))

    print(f"[{lo}_{hi}] write: {os.path.basename(out_train)}, {os.path.basename(out_val)}")
    gc.collect()


4.生成最终特征清单

In [None]:
import glob, polars as pl

# 任选一个训练分片当“列模板”
ref = FE_SHA_DIR_B + "/train_1190_1219.parquet"
names = pl.scan_parquet(ref, storage_options=storage_options).collect_schema().names()

# 直接从这个分片拿特征列（已包含 base + engineered）
feat_cols = [c for c in names if c not in (*KEYS, TARGET, WEIGHT)]
print(f"final feature list size = {len(feat_cols)}")

df_feat = pd.DataFrame({'feature': feat_cols})

df_feat.to_csv("exp/v1/config/input_sets/input_all.csv", index=False)

5.构建memmap

In [None]:
import json, time

def shard2memmap(glob_pat, feat_cols, prefix, fs=None, storage_options=None,
                 target_col=TARGET, weight_col=WEIGHT):
    """
    将若干 parquet 分片顺序拼接为三份 memmap: X(float32), y(float32), w(float32)
    - 支持本地路径和 az:// 路径（自动选择 glob 方式）
    - 缺失的特征列用 None (=> NaN) 补齐，列顺序以 feat_cols 为准
    - 会在同目录写出 {prefix}_{X,y,w}.float32.mmap 以及 {prefix}.meta.json
    """

    def _is_az(p: str) -> bool:
        return isinstance(p, str) and p.startswith("az://")

    def _glob_paths(pattern: str):
        # 远程：用 fsspec.filesystem("az").glob(无协议) → 再补上 az://
        if _is_az(pattern):
            assert fs is not None, "fs must be provided for az:// glob"
            nopro = pattern[len("az://"):]               # 去协议
            listed = fs.glob(nopro)                       # 无协议
            return ["az://" + p for p in sorted(listed)]  # 归一化为带协议
        else:
            import glob as _glob
            return sorted(_glob.glob(pattern))            # 本地

    def _scan(p: str):
        if _is_az(p):
            return pl.scan_parquet(p, storage_options=storage_options)
        else:
            return pl.scan_parquet(p)

    # 1) 列出分片
    paths = _glob_paths(glob_pat)
    if not paths:
        raise FileNotFoundError(f"No shards matched: {glob_pat}")

    # 2) 统计每个分片的行数（流式）
    counts = []
    for p in paths:
        k = _scan(p).select(pl.len()).collect(streaming=True).item()
        counts.append(int(k))

    n_rows, n_feat = int(sum(counts)), len(feat_cols)
    print(f"[memmap] {glob_pat}: {len(paths)} files, {n_rows} rows, {n_feat} features")

    # 3) 分配 memmap（本地）
    X = np.memmap(f"{prefix}_X.float32.mmap", dtype="float32", mode="w+", shape=(n_rows, n_feat))
    y = np.memmap(f"{prefix}_y.float32.mmap", dtype="float32", mode="w+", shape=(n_rows,))
    w = np.memmap(f"{prefix}_w.float32.mmap", dtype="float32", mode="w+", shape=(n_rows,))

    # 4) 逐分片流式写入
    i = 0
    for p, k in zip(paths, counts):
        lf = _scan(p)
        names = set(lf.collect_schema().names())
        exprs = [
            (pl.col(c).cast(pl.Float32).alias(c) if c in names
             else pl.lit(None, dtype=pl.Float32).alias(c))
            for c in feat_cols
        ]
        df = (lf.select([
                pl.col(target_col).cast(pl.Float32).alias(target_col),
                pl.col(weight_col).cast(pl.Float32).alias(weight_col),
                *exprs
             ])
             .collect(streaming=True))

        X[i:i+k, :] = df.select(feat_cols).to_numpy()
        y[i:i+k]    = df.select(pl.col(target_col)).to_numpy().ravel()
        w[i:i+k]    = df.select(pl.col(weight_col)).to_numpy().ravel()
        i += k
        del df; gc.collect()

    X.flush(); y.flush(); w.flush()

    # 5) 保存 meta
    meta = {"n_rows": int(n_rows), "n_feat": int(n_feat), "dtype": "float32", "ts": time.time(),
            "features": list(feat_cols), "target": target_col, "weight": weight_col}
    with open(f"{prefix}.meta.json", "w") as f:
        json.dump(meta, f)

    return X, y, w


train_X, train_y, train_w = shard2memmap(
    f"{FE_SHA_DIR_B}/train_*.parquet", feat_cols, f"{MM_DIR}/train",
    fs=fs, storage_options=storage_options)

val_X,   val_y,   val_w   = shard2memmap(
    f"{FE_SHA_DIR_B}/val_*.parquet",   feat_cols, f"{MM_DIR}/val",
    fs=fs, storage_options=storage_options)

print("train shapes:", train_X.shape, train_y.shape, train_w.shape)
print("val   shapes:", val_X.shape,   val_y.shape,   val_w.shape)


6.训练LightGBM

In [None]:
import os, json, numpy as np

def load_memmap(prefix, readonly=True):
    meta_path = f"{prefix}.meta.json"
    if not os.path.exists(meta_path):
        raise FileNotFoundError(f"missing {meta_path}")
    with open(meta_path, "r") as f:
        meta = json.load(f)
    dtype = np.dtype(meta.get("dtype", "float32"))
    n_rows, n_feat = int(meta["n_rows"]), int(meta["n_feat"])
    mode = "r" if readonly else "r+"

    X = np.memmap(f"{prefix}_X.float32.mmap", dtype=dtype, mode=mode, shape=(n_rows, n_feat))
    y = np.memmap(f"{prefix}_y.float32.mmap", dtype=dtype, mode=mode, shape=(n_rows,))
    w = np.memmap(f"{prefix}_w.float32.mmap", dtype=dtype, mode=mode, shape=(n_rows,))
    feat_cols = meta.get("features")  # 用保存的列顺序，避免不一致
    return X, y, w, feat_cols, meta

def ensure_memmap(prefix, glob_pat, feat_cols, *, fs=None, storage_options=None,
                  target_col="responder_6", weight_col="weight"):
    try:
        return load_memmap(prefix)
    except FileNotFoundError:
        X, y, w = shard2memmap(glob_pat, feat_cols, prefix, fs=fs, storage_options=storage_options,
                               target_col=target_col, weight_col=weight_col)
        # 重新加载一次，拿回 meta 和规范的 feat_cols
        return load_memmap(prefix)

# —— 使用：
train_X, train_y, train_w, feat_cols, _ = ensure_memmap(
    f"{MM_DIR}/train", f"{FE_SHA_DIR_B}/train_*.parquet", feat_cols,
    fs=fs, storage_options=storage_options)

val_X, val_y, val_w, _, _ = ensure_memmap(
    f"{MM_DIR}/val", f"{FE_SHA_DIR_B}/val_*.parquet", feat_cols,
    fs=fs, storage_options=storage_options)



def weighted_r2_zero_mean(y_true, y_pred, weight):
    y_true = np.asarray(y_true, dtype=np.float64).ravel()
    y_pred = np.asarray(y_pred, dtype=np.float64).ravel()
    weight = np.asarray(weight, dtype=np.float64).ravel()
    num = np.sum(weight * (y_true - y_pred)**2)
    den = np.sum(weight * (y_true**2))
    return 0.0 if den <= 0 else 1.0 - (num/den)

def lgb_wr2_eval(preds, train_data):
    y = train_data.get_label()
    w = train_data.get_weight()
    if w is None:
        w = np.ones_like(y)
    return ("wr2", weighted_r2_zero_mean(y, preds, w), True)

dtrain = lgb.Dataset(train_X, label=train_y, weight=train_w, feature_name=feat_cols, free_raw_data=True)
dval   = lgb.Dataset(val_X,   label=val_y,   weight=val_w,   feature_name=feat_cols, reference=dtrain, free_raw_data=True)

params = dict(
    objective="regression",
    metric="None",
    learning_rate=0.05,
    num_leaves=63,
    max_depth=10,
    min_data_in_leaf=150,                # ← CHANGED: 20 -> 150，稳住叶子，省显存
    num_threads=16,
    seed=42,
    deterministic=True,
    first_metric_only=True,
    max_bin=63,
    bin_construct_sample_cnt=200_000,
    min_data_in_bin=3,
    device_type="gpu",
    gpu_device_id=0,                     # ← ADDED: 明确使用第0块GPU
    feature_fraction=0.7,                # ← ADDED: 列采样，提速+防过拟合
    bagging_fraction=0.7,                # ← ADDED: 行采样
    bagging_freq=1,                      # ← ADDED
    lambda_l2=3.0,                       # ← ADDED: 稳定分裂
    # extra_trees=True,                  # （可选）再稳一点时打开
)

model = lgb.train(
    params, dtrain,
    valid_sets=[dval],                   # ← CHANGED: 只监控验证集，更快
    valid_names=["val"],
    num_boost_round=1000,
    callbacks=[lgb.early_stopping(50)],
    feval=lgb_wr2_eval,
)

imp = pd.DataFrame({
    "feature": model.feature_name(),
    "gain": model.feature_importance("gain"),
    "split": model.feature_importance("split"),
}).sort_values("gain", ascending=False).reset_index(drop=True)

print(imp.head(30))


In [None]:
top_imp = imp[imp.gain > 0]
top_imp = imp[imp.split > 5]
len(top_imp)

In [None]:
top_imp.to_csv(f"exp/v1/config/input_sets/top_imp.csv", index=False)

# 去共线性

In [None]:
PARQUET_PATHS = ["/mnt/data/js/final_clean.parquet"]
KEYS = ["symbol_id","date_id","time_id"]
TARGET = "responder_6"
FEATURE_COLS = pd.read_csv('/home/admin_ml/Jackson/projects/selected_features.csv')['family'].tolist()
REP_COLS = pd.read_csv('/home/admin_ml/Jackson/projects/selected_resps.csv')['family'].tolist()

lf_base = pl.scan_parquet(PARQUET_PATHS).select(KEYS+FEATURE_COLS+REP_COLS)


lf_slice = lf_base.filter((pl.col("date_id") >= 1200) & (pl.col("date_id") <= 1400))

PARAMS = dict(
        prev_soft_days=7,
        tail_lags=(2, 5, 20, 40, 100),
        tail_diffs=(2, 5,),
        rolling_windows=(5, 20),
        same_time_ndays=5,
        strict_k=False,
        hist_lags=(1,2,5,10,20,100),
        ret_periods=(1,5,20),
        diff_periods=(1,5),
        rz_windows=(5,20),
        ewm_spans=(5,40,100),
        cs_cols=None,       # <- keep this small to avoid blow-up
    )

lf_eng = run_engineering_on_slice(lf_slice, **PARAMS)

feats = pd.read_csv("/home/admin_ml/Jackson/projects/final_fi_mean.csv")["feature"].tolist()

lf_small = lf_eng.select(feats[:500])
lf_small.collect(streaming=True).write_parquet("/mnt/data/js/X_ready.parquet", compression="zstd")


In [None]:

lf = pl.scan_parquet("/mnt/data/js/X_ready.parquet")

df = lf.collect(streaming=True).to_pandas()

# Correlation (pairwise complete obs) + guard on min observations
min_obs = max(50, int(0.3 * len(df)))  # tweak as you like
C = df.corr(method="pearson", min_periods=min_obs).abs().fillna(0.0)

# Ensure to align rows & cols to the same (priority) order, fill any NaNs with 0
order = feats
C = C.reindex(index=order, columns=order).fillna(0.0).copy()


# --- Prepare NumPy array for the greedy loop ---
A = C.values
np.fill_diagonal(A, 0.0)  # ensure the value is smaller than thresh, so the feature won't be dropped by value'1'
m = len(order)

# --- Greedy de-correlation (keep-by-priority, drop neighbors) ---
THRESH = 0.97
keep_mask = np.ones(m, dtype=bool)

for i in range(m):
    if not keep_mask[i]:
        continue  # already removed by a higher-priority pick
    # only check j > i (upper triangle) among still-active features
    active_slice = keep_mask[i+1:]
    drop = (A[i, i+1:] >= THRESH) & active_slice
    active_slice[drop] = False  # marks into keep_mask[i+1:] via view
keep = [order[i] for i in range(m) if keep_mask[i]]


pd.DataFrame({'feature': keep}).to_csv("final_selected_features.csv", index=False)

# 全数据训练

In [7]:
f"{BASE_DIR}/final_clean.parquet"

'az://jackson/js_exp/clean/final_clean.parquet'

In [8]:
lx = pl.scan_parquet(f"{BASE_DIR}/final_clean.parquet", storage_options=fs.storage_options)

lx.limit(5).collect()   

symbol_id,date_id,time_id,weight,feature_00,feature_01,feature_02,feature_03,feature_04,feature_05,feature_06,feature_07,feature_08,feature_09,feature_10,feature_11,feature_12,feature_13,feature_14,feature_15,feature_16,feature_17,feature_18,feature_19,feature_20,feature_21,feature_22,feature_23,feature_24,feature_25,feature_26,feature_27,feature_28,feature_29,feature_30,feature_31,feature_32,…,feature_51,feature_52,feature_53,feature_54,feature_55,feature_56,feature_57,feature_58,feature_59,feature_60,feature_61,feature_62,feature_63,feature_64,feature_65,feature_66,feature_67,feature_68,feature_69,feature_70,feature_71,feature_72,feature_73,feature_74,feature_75,feature_76,feature_77,feature_78,responder_0,responder_1,responder_2,responder_3,responder_4,responder_5,responder_6,responder_7,responder_8
i8,i16,i16,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,…,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32
0,680,0,2.29816,0.851814,1.197591,0.219422,0.411698,2.057359,-0.542597,-3.4331,-1.090165,0.151888,11.0,7.0,76.0,-0.97142,0.670215,-0.502896,-0.519751,-0.070096,-0.5044,-1.308236,-2.120128,1.747068,-0.219661,0.791602,0.114922,-0.311672,0.156548,1.476346,1.301341,1.235173,-0.259882,-0.30125,-0.264237,0.233658,…,0.322135,-0.244606,0.177445,-2.273168,0.211338,-1.627507,1.155271,-0.019618,-3.847984,-2.230864,-0.171579,-0.273791,-0.301876,-0.432933,-1.215778,-1.670469,-0.637963,0.803874,-0.21269,-0.764702,0.435278,-0.619145,-0.016643,-0.02328,-0.021034,-0.045094,-0.178144,-0.1951,-0.304665,0.164485,-0.205231,0.191064,-1.413209,0.375675,0.929775,-1.574939,1.101371
0,680,1,2.29816,0.344538,0.735081,0.013578,0.61738,1.956999,-0.371554,-0.526306,-0.684286,0.155305,11.0,7.0,76.0,-0.810169,1.103803,-0.268897,-0.519751,-0.320673,-0.5044,-0.762257,-2.215312,1.747068,-0.219661,0.791602,0.114922,-0.311672,0.156548,1.476346,1.301341,1.235173,-0.259882,-0.30125,-0.264237,0.233658,…,-0.512957,-0.244606,0.177445,-1.259565,0.211338,-2.49033,0.259405,-0.019618,-2.635671,-2.046605,-0.171579,-0.305304,-0.346962,-0.363018,-1.361503,-2.374111,-0.566751,1.22803,-0.274947,-1.316323,1.123081,-0.398956,-0.016643,-0.02328,-0.079538,-0.106091,-0.319045,-0.286955,0.260829,0.497995,0.709419,0.324715,-0.794916,0.437495,0.089887,-1.605929,0.104144
0,680,2,2.29816,-0.180962,1.383893,-0.650321,0.700879,2.427377,-0.535128,-0.81285,-0.7363,0.116907,11.0,7.0,76.0,-0.991004,1.118573,-0.340604,-0.519751,-0.310225,-0.5044,-0.147581,-2.163665,1.747068,-0.219661,0.791602,0.114922,-0.311672,0.156548,1.476346,1.301341,1.235173,-0.259882,-0.30125,-0.264237,0.233658,…,-0.799781,-0.244606,0.177445,-0.404231,0.211338,-2.945363,0.656708,-0.019618,-2.228918,-2.241239,-0.171579,-0.295761,-0.35494,-0.405964,-0.669037,-1.851066,-0.997381,0.942939,-0.204669,-1.064232,0.616298,-0.293038,-0.016643,-0.02328,0.05686,0.015116,-0.22286,-0.204045,0.398035,0.52747,0.488063,0.751176,-0.785524,1.308435,0.250879,-1.45394,1.760168
0,680,3,2.29816,-0.124249,1.744584,0.897987,0.096003,2.132334,-0.851425,-2.684773,-1.338934,0.151086,11.0,7.0,76.0,-1.11294,1.4363,-0.206028,-0.519751,-0.39899,-0.5044,-1.310143,-1.819511,1.747068,-0.219661,0.791602,0.114922,-0.311672,0.156548,1.476346,1.301341,1.235173,-0.259882,-0.30125,-0.264237,0.233658,…,-0.155125,-0.244606,0.177445,-1.505502,0.211338,-1.772148,0.813329,-0.019618,-2.694696,-2.044638,-0.171579,-0.321374,-0.258883,-0.379521,-0.746425,-2.052708,-0.812054,1.020951,-0.088977,-0.992124,0.737115,-0.358254,-0.016643,-0.02328,-0.056656,-0.061023,-0.228405,-0.303751,-0.308972,0.194933,-0.048652,0.191667,-1.004579,0.555013,0.56471,-1.449821,1.333068
0,680,4,2.29816,0.814995,1.38898,-0.253808,0.489444,2.055804,-0.897863,-0.891241,-1.131095,0.17518,11.0,7.0,76.0,-1.03724,0.772707,-0.237228,-0.519751,-0.206644,-0.348021,-1.22889,-1.41934,1.747068,-0.219661,0.791602,0.114922,-0.311672,0.156548,1.476346,1.301341,1.235173,-0.259882,-0.30125,-0.264237,0.233658,…,-0.509942,-0.244606,0.177445,-0.196582,0.211338,-2.116526,0.851539,-0.019618,-3.177497,-2.24501,-0.171579,-0.291239,-0.331689,-0.280069,-1.394332,-1.034399,-0.762234,0.542012,-0.144765,-1.023163,0.307331,-0.37056,-0.016643,-0.02328,-0.074767,-0.091099,-0.191449,-0.263896,-0.543977,0.164482,-0.485082,0.173833,-1.178417,1.133403,0.435787,-1.653541,2.013243


In [None]:
DATE_LO, DATE_HI = 680, 1529 # 指定训练/验证的 date_id 范围, 后期转为全部训练集

BASE_PATH = f"{BASE_DIR}/final_clean.parquet"

# 从 Blob 列出全部 fe_shards 分片（返回不带协议的路径，要手动加 az://）
fe_all = fs.glob(f"{FE_SHA_DIR_B}/*_*.parquet")
fe_all = [p if p.startswith("az://") else f"az://{p}" for p in fe_all]

# 按日期范围筛选
wins = set()
for p in fe_all:
    base = p.split("/")[-1]  # e.g. C_lags_1220_1249.parquet
    lo = int(base.split("_")[-2]); hi = int(base.split("_")[-1].split(".")[0])
    if hi >= DATE_LO and lo <= DATE_HI:
        wins.add((lo, hi))
wins = sorted(wins)
print(f"windows in range: {wins[:5]} ... (total {len(wins)})")


# 取得该区间实际天，并按 80/20 切分（天为单位，避免泄露）
days = (pl.scan_parquet(BASE_PATH, storage_options=storage_options)
        .filter(pl.col("date_id").is_between(DATE_LO, DATE_HI))
        .select(pl.col("date_id").unique().sort())
        .collect(streaming=True)["date_id"].to_list())
assert days, "no days found in range"
cut = int(len(days)*0.8)
train_days, val_days = set(days[:cut]), set(days[cut:])
print(f"days total={len(days)}  train={len(train_days)}  val={len(val_days)}")



for (lo, hi) in wins:
    # 与全局区间取交集，防止边缘窗口越界
    w_lo, w_hi = max(lo, DATE_LO), min(hi, DATE_HI)

    # 基表 (带 TARGET/WEIGHT + 基础特征)
    lf = (pl.scan_parquet(BASE_PATH, storage_options=storage_options)
            .filter(pl.col("date_id").is_between(w_lo, w_hi))
            .select([*KEYS, TARGET, WEIGHT, *[pl.col(c) for c in FEATURE_COLS]]))

    protocol = FE_SHA_DIR_B.split("://", 1)[0] + "://"

    # 若 fe_all 可能是无协议，先归一化
    fe_set = {p if p.startswith(protocol) else protocol + p for p in fe_all}

    fe_files = []
    for name in (f"A_{lo}_{hi}.parquet", f"B_{lo}_{hi}.parquet"):
        p = f"{FE_SHA_DIR_B}/{name}"  # 带协议
        if p in fe_set:
            fe_files.append(p)

    # 同窗口所有 C_* 分片
    c_files = fs.glob(f"{FE_SHA_DIR_NP}/C_*_{lo}_{hi}.parquet")  # 无协议
    fe_files += [p if p.startswith(protocol) else protocol + p for p in sorted(c_files)]


    # 逐个左连接（scan 到 Blob 时一定加 storage_options）
    already = set(lf.collect_schema().names())
    for fp in fe_files:
        ds = pl.scan_parquet(fp, storage_options=storage_options)
        names = ds.collect_schema().names()
        add_cols = [c for c in names if c not in already]
        if add_cols:
            lf = lf.join(ds.select([*KEYS, *add_cols]), on=KEYS, how="left")
            already.update(add_cols)


    # 统一 float32；注意：此窗口缺失的特征列**不写入**→ 之后 memmap 会自动补 None
    feat_present = [c for c in already if c not in (*KEYS, TARGET, WEIGHT)]
    select_exprs = [
        *KEYS,
        pl.col(TARGET).cast(pl.Float32).alias(TARGET),
        pl.col(WEIGHT).cast(pl.Float32).alias(WEIGHT),
        *[pl.col(c).cast(pl.Float32).alias(c) for c in feat_present],
    ]
    lf_win = lf.select(select_exprs)

    # 直接流式写 train/val 分片（不 materialize 整个窗口）
    out_train = f"{FE_SHA_DIR_B}/train_{lo}_{hi}.parquet"
    out_val   = f"{FE_SHA_DIR_B}/val_{lo}_{hi}.parquet"
    (lf_win.filter(pl.col("date_id").is_in(list(train_days)))
           .sink_parquet(out_train, compression="zstd", storage_options=storage_options))
    (lf_win.filter(pl.col("date_id").is_in(list(val_days)))
           .sink_parquet(out_val, compression="zstd", storage_options=storage_options))

    print(f"[{lo}_{hi}] write: {os.path.basename(out_train)}, {os.path.basename(out_val)}")
    gc.collect()


In [28]:

import json, time


feat_cols = pd.read_csv('exp/v1/config/input_sets/top_imp.csv')
feat_cols = feat_cols['feature'].tolist()

In [29]:
len(feat_cols)

512

In [30]:
feat_cols = [*feat_cols, 'time_id']      # 或 feat_cols + ['time_id']


In [31]:
len(feat_cols)

513

In [32]:

def shard2memmap(glob_pat, feat_cols, prefix, fs=None, storage_options=None,
                 target_col=TARGET, weight_col=WEIGHT):
    """
    将若干 parquet 分片顺序拼接为三份 memmap: X(float32), y(float32), w(float32)
    - 支持本地路径和 az:// 路径（自动选择 glob 方式）
    - 缺失的特征列用 None (=> NaN) 补齐，列顺序以 feat_cols 为准
    - 会在同目录写出 {prefix}_{X,y,w}.float32.mmap 以及 {prefix}.meta.json
    """

    def _is_az(p: str) -> bool:
        return isinstance(p, str) and p.startswith("az://")

    def _glob_paths(pattern: str):
        # 远程：用 fsspec.filesystem("az").glob(无协议) → 再补上 az://
        if _is_az(pattern):
            assert fs is not None, "fs must be provided for az:// glob"
            nopro = pattern[len("az://"):]               # 去协议
            listed = fs.glob(nopro)                       # 无协议
            return ["az://" + p for p in sorted(listed)]  # 归一化为带协议
        else:
            import glob as _glob
            return sorted(_glob.glob(pattern))            # 本地

    def _scan(p: str):
        if _is_az(p):
            return pl.scan_parquet(p, storage_options=storage_options)
        else:
            return pl.scan_parquet(p)

    # 1) 列出分片
    paths = _glob_paths(glob_pat)
    if not paths:
        raise FileNotFoundError(f"No shards matched: {glob_pat}")

    # 2) 统计每个分片的行数（流式）
    counts = []
    for p in paths:
        k = _scan(p).select(pl.len()).collect(streaming=True).item()
        counts.append(int(k))

    n_rows, n_feat = int(sum(counts)), len(feat_cols)
    print(f"[memmap] {glob_pat}: {len(paths)} files, {n_rows} rows, {n_feat} features")

    # 3) 分配 memmap（本地）
    X = np.memmap(f"{prefix}_X.float32.mmap", dtype="float32", mode="w+", shape=(n_rows, n_feat))
    y = np.memmap(f"{prefix}_y.float32.mmap", dtype="float32", mode="w+", shape=(n_rows,))
    w = np.memmap(f"{prefix}_w.float32.mmap", dtype="float32", mode="w+", shape=(n_rows,))

    # 4) 逐分片流式写入
    i = 0
    for p, k in zip(paths, counts):
        lf = _scan(p)
        names = set(lf.collect_schema().names())
        exprs = [
            (pl.col(c).cast(pl.Float32).alias(c) if c in names
             else pl.lit(None, dtype=pl.Float32).alias(c))
            for c in feat_cols
        ]
        df = (lf.select([
                pl.col(target_col).cast(pl.Float32).alias(target_col),
                pl.col(weight_col).cast(pl.Float32).alias(weight_col),
                *exprs
             ])
             .collect(streaming=True))

        X[i:i+k, :] = df.select(feat_cols).to_numpy()
        y[i:i+k]    = df.select(pl.col(target_col)).to_numpy().ravel()
        w[i:i+k]    = df.select(pl.col(weight_col)).to_numpy().ravel()
        i += k
        del df; gc.collect()

    X.flush(); y.flush(); w.flush()

    # 5) 保存 meta
    meta = {"n_rows": int(n_rows), "n_feat": int(n_feat), "dtype": "float32", "ts": time.time(),
            "features": list(feat_cols), "target": target_col, "weight": weight_col}
    with open(f"{prefix}.meta.json", "w") as f:
        json.dump(meta, f)

    return X, y, w


train_X, train_y, train_w = shard2memmap(
    f"{FE_SHA_DIR_B}/train_*.parquet", feat_cols, f"{MM_DIR}/train",
    fs=fs, storage_options=storage_options)

val_X,   val_y,   val_w   = shard2memmap(
    f"{FE_SHA_DIR_B}/val_*.parquet",   feat_cols, f"{MM_DIR}/val",
    fs=fs, storage_options=storage_options)

print("train shapes:", train_X.shape, train_y.shape, train_w.shape)
print("val   shapes:", val_X.shape,   val_y.shape,   val_w.shape)


[memmap] az://jackson/js_exp/exp/v1/fe_shards/train_*.parquet: 27 files, 21216624 rows, 513 features
[memmap] az://jackson/js_exp/exp/v1/fe_shards/val_*.parquet: 27 files, 6140024 rows, 513 features
train shapes: (21216624, 513) (21216624,) (21216624,)
val   shapes: (6140024, 513) (6140024,) (6140024,)


In [33]:
import os, json, numpy as np

def load_memmap(prefix, readonly=True):
    meta_path = f"{prefix}.meta.json"
    if not os.path.exists(meta_path):
        raise FileNotFoundError(f"missing {meta_path}")
    with open(meta_path, "r") as f:
        meta = json.load(f)
    dtype = np.dtype(meta.get("dtype", "float32"))
    n_rows, n_feat = int(meta["n_rows"]), int(meta["n_feat"])
    mode = "r" if readonly else "r+"

    X = np.memmap(f"{prefix}_X.float32.mmap", dtype=dtype, mode=mode, shape=(n_rows, n_feat))
    y = np.memmap(f"{prefix}_y.float32.mmap", dtype=dtype, mode=mode, shape=(n_rows,))
    w = np.memmap(f"{prefix}_w.float32.mmap", dtype=dtype, mode=mode, shape=(n_rows,))
    feat_cols = meta.get("features")  # 用保存的列顺序，避免不一致
    return X, y, w, feat_cols, meta

def ensure_memmap(prefix, glob_pat, feat_cols, *, fs=None, storage_options=None,
                  target_col="responder_6", weight_col="weight"):
    try:
        return load_memmap(prefix)
    except FileNotFoundError:
        X, y, w = shard2memmap(glob_pat, feat_cols, prefix, fs=fs, storage_options=storage_options,
                               target_col=target_col, weight_col=weight_col)
        # 重新加载一次，拿回 meta 和规范的 feat_cols
        return load_memmap(prefix)

# —— 使用：
train_X, train_y, train_w, feat_cols, _ = ensure_memmap(
    f"{MM_DIR}/train", f"{FE_SHA_DIR_B}/train_*.parquet", feat_cols,
    fs=fs, storage_options=storage_options)

val_X, val_y, val_w, _, _ = ensure_memmap(
    f"{MM_DIR}/val", f"{FE_SHA_DIR_B}/val_*.parquet", feat_cols,
    fs=fs, storage_options=storage_options)




def weighted_r2_zero_mean(y_true, y_pred, weight):
    y_true = np.asarray(y_true, dtype=np.float64).ravel()
    y_pred = np.asarray(y_pred, dtype=np.float64).ravel()
    weight = np.asarray(weight, dtype=np.float64).ravel()
    num = np.sum(weight * (y_true - y_pred)**2)
    den = np.sum(weight * (y_true**2))
    return 0.0 if den <= 0 else 1.0 - (num/den)

def lgb_wr2_eval(preds, train_data):
    y = train_data.get_label()
    w = train_data.get_weight()
    if w is None:
        w = np.ones_like(y)
    return ("wr2", weighted_r2_zero_mean(y, preds, w), True)

dtrain = lgb.Dataset(train_X, label=train_y, weight=train_w, feature_name=feat_cols, free_raw_data=True)
dval   = lgb.Dataset(val_X,   label=val_y,   weight=val_w,   feature_name=feat_cols, reference=dtrain, free_raw_data=True)

params = dict(
    objective="regression",
    metric="None",
    learning_rate=0.03,              # ↓ 更小学习率，配合更多轮数
    num_leaves=127,                  # ↑ 增强模型容量
    max_depth=10,
    min_data_in_leaf=200,            # ↑ 稳定、抑制过拟合
    num_threads=16,
    seed=42,
    deterministic=True,
    first_metric_only=True,

    # 直方图/分箱（GPU 友好）
    max_bin=63,
    bin_construct_sample_cnt=400_000,# ↑ 分箱更准（全量训练）
    min_data_in_bin=3,

    # GPU
    device_type="gpu",
    gpu_device_id=0,

    # 采样 + 正则
    feature_fraction=0.65,           # 列采样
    feature_fraction_bynode=0.8,     # 结点级列采样
    bagging_fraction=0.70,           # 行采样
    bagging_freq=1,
    lambda_l2=5.0,                   # L2 正则更稳
    extra_trees=True,                # 更稳更快（常见于大数据）
)

model = lgb.train(
    params, dtrain,
    valid_sets=[dval], valid_names=["val"],   # 只监控验证集，收敛更快
    num_boost_round=4000,                     # ↑ 配合小 lr
    callbacks=[lgb.early_stopping(200), lgb.log_evaluation(50)],
    feval=lgb_wr2_eval,
)


imp = pd.DataFrame({
    "feature": model.feature_name(),
    "gain": model.feature_importance("gain"),
    "split": model.feature_importance("split"),
}).sort_values("gain", ascending=False).reset_index(drop=True)

print(imp.head(30))


[LightGBM] [Info] This is the GPU trainer!!
[LightGBM] [Info] Total Bins 32319
[LightGBM] [Info] Number of data points in the train set: 21216624, number of used features: 513
[LightGBM] [Info] Using GPU Device: Tesla T4, Vendor: NVIDIA Corporation
[LightGBM] [Info] Compiling OpenCL Kernel with 64 bins...
[LightGBM] [Info] GPU programs have been built
[LightGBM] [Info] Size of histogram bin entry: 8
[LightGBM] [Info] 427 dense feature groups (8660.04 MB) transferred to GPU in 2.003924 secs. 1 sparse feature groups
[LightGBM] [Info] Start training from score -0.003603
Training until validation scores don't improve for 200 rounds
[50]	val's wr2: 0.00757362
[100]	val's wr2: 0.0106474
[150]	val's wr2: 0.0124919
[200]	val's wr2: 0.0138949
[250]	val's wr2: 0.0146271
[300]	val's wr2: 0.0152303
[350]	val's wr2: 0.0154838
[400]	val's wr2: 0.0157228
[450]	val's wr2: 0.0159548
[500]	val's wr2: 0.0160824
[550]	val's wr2: 0.0162302
[600]	val's wr2: 0.0162849
[650]	val's wr2: 0.0163209
[700]	val's w

In [None]:

# 如果不存在则创建目录
os.makedirs("exp/v1/models", exist_ok=True)

# save model
model.save_model("exp/v1/models/lgbm_final.txt")

In [None]:
imp.tail()

# 模型评估

首先对测试集做特征工程

In [None]:
# ====== 参数（与训练期基本一致，仅改为 test 路径/前缀） ======
import os, gc
import polars as pl

DATE_LO, DATE_HI = 1530, 1698  # 可按需调整或覆盖

# 从 Blob 列出全部 test fe_shards 分片（无协议 → 统一加协议）
fe_all_np = fs.glob(f"{FE_SHA_DIR_NP}/test_*_*.parquet")
protocol = FE_SHA_DIR_B.split("://", 1)[0] + "://"
fe_all = [p if p.startswith(protocol) else protocol + p for p in fe_all_np]

# ---- 解析所有窗口，并按日期范围筛选（与训练期相同逻辑）----
wins = set()
for p in fe_all_np:
    base = os.path.basename(p)                         # e.g. test_C_lags_1220_1249.parquet
    stem = base[:-8]                                   # 去掉 .parquet
    parts = stem.split("_")
    lo, hi = int(parts[-2]), int(parts[-1])            # 最后两段是 lo/hi
    if hi >= DATE_LO and lo <= DATE_HI:
        wins.add((lo, hi))
wins = sorted(wins)
print(f"windows in range: {wins[:5]} ... (total {len(wins)})")

# ---- 获取该区间内实际存在的 date_id（仅用于 sanity/log；不做 train/val 切分）----
days = (pl.scan_parquet(TEST_BASE_PATH, storage_options=storage_options)
          .filter(pl.col("date_id").is_between(DATE_LO, DATE_HI))
          .select(pl.col("date_id").unique().sort())
          .collect(streaming=True)["date_id"].to_list())
assert days, "no days found in range"
print(f"days total={len(days)}")

# ---- 主循环：对每个 (lo,hi) 窗口横向拼接后，写出 test_{lo}_{hi}.parquet ----
for (lo, hi) in wins:
    # 与全局区间取交集，防止边缘窗口越界
    w_lo, w_hi = max(lo, DATE_LO), min(hi, DATE_HI)

    # 基表（仅 KEYS + 可选 WEIGHT + 基础特征），并固定排序保证因果/可复现
    # 注意：WEIGHT 在 test 里可能不存在，做个存在性判断
    base_scan = (pl.scan_parquet(TEST_BASE_PATH, storage_options=storage_options)
                   .filter(pl.col("date_id").is_between(w_lo, w_hi))
                   .sort(KEYS))

    lf = base_scan.select([*KEYS, WEIGHT, *[pl.col(c) for c in FEATURE_COLS]])

    # 归一化分片路径集合，方便存在性判断
    fe_set = {p if p.startswith(protocol) else protocol + p for p in fe_all_np}

    # A/B 文件（test_ 前缀）
    fe_files = []
    for name in (f"test_A_{lo}_{hi}.parquet", f"test_B_{lo}_{hi}.parquet"):
        p = f"{FE_SHA_DIR_B}/{name}"  # 带协议
        if p in fe_set or fs.exists(p):
            fe_files.append(p)

    # 同窗口所有 C_* 分片（无协议 → 加协议）
    c_files_np = fs.glob(f"{FE_SHA_DIR_NP}/test_C_*_{lo}_{hi}.parquet")
    fe_files += [p if p.startswith(protocol) else protocol + p for p in sorted(c_files_np)]

    # 逐个左连接（严格只补充新列，避免重复；scan 到 Blob 时加 storage_options）
    already = set(lf.collect_schema().names())
    for fp in fe_files:
        ds = pl.scan_parquet(fp, storage_options=storage_options)
        names = ds.collect_schema().names()
        add_cols = [c for c in names if c not in already]
        if add_cols:
            lf = lf.join(ds.select([*KEYS, *add_cols]), on=KEYS, how="left")
            already.update(add_cols)

    # 统一为 Float32；test 无 TARGET，所以只保留 KEYS + 可选 WEIGHT + 所有特征
    keep_non_key = [c for c in already if c not in set(KEYS)]
    # 确保 WEIGHT 不在特征转换列表里被重复处理
    feat_present = [c for c in keep_non_key if c not in (*KEYS, WEIGHT)]

    select_exprs = [
        *KEYS, 
        pl.col(WEIGHT).cast(pl.Float32).alias(WEIGHT), 
        *[pl.col(c).cast(pl.Float32).alias(c) for c in feat_present], ]
    lf_win = lf.select(select_exprs)

    # 写出单文件 test_{lo}_{hi}.parquet（流式，不 materialize 全窗口）
    out_test = f"{FE_SHA_DIR_B}/test_{lo}_{hi}.parquet"
    # 覆盖旧文件（若存在）
    if fs.exists(out_test):
        fs.rm(out_test)

    lf_win.sink_parquet(
        out_test,
        compression="zstd",
        statistics=True,
        storage_options=storage_options,
        maintain_order=True,  # 上面已 sort(KEYS)，这里维持写出顺序
    )

    print(f"[{lo}_{hi}] write: {os.path.basename(out_test)}  (cols={len(select_exprs)})")
    gc.collect()


In [None]:
import os, gc, json, time, re
import numpy as np
import polars as pl
import pandas as pd

feat_cols = pd.read_csv('exp/v1/config/input_sets/top_imp.csv')['feature'].tolist()

def shard2memmap(fe_paths, feat_cols, prefix, storage_options=None):
    """
    将若干 parquet 分片顺序拼接为一份 memmap: X(float32)
    - 仅提取 feat_cols（缺失列用 NaN 补齐），列顺序以 feat_cols 为准
    - 输出：{prefix}_X.float32.mmap 与 {prefix}.meta.json
    - fe_paths: 已经挑选/过滤好的最终 test_{lo}_{hi}.parquet 列表（可含 az:// 或本地路径）
    """

    # 0) 规范化 + 数值排序
    def _numeric_key(p: str):
        b = os.path.basename(p)
        stem = b[:-8]  # 去掉 .parquet
        parts = stem.split("_")
        try:
            lo, hi = int(parts[-2]), int(parts[-1])
            return (lo, hi)
        except Exception:
            return (b, b)
    fe_paths = sorted(fe_paths, key=_numeric_key)

    # 1) 兜底：确保有文件、确保输出目录存在
    if not fe_paths:
        raise FileNotFoundError("fe_paths is empty.")
    os.makedirs(os.path.dirname(prefix), exist_ok=True)

    # 2) 统计每个分片的行数（流式）
    counts = []
    for p in fe_paths:
        lf = (pl.scan_parquet(p, storage_options=storage_options)
              if p.startswith("az://") else pl.scan_parquet(p))
        k = lf.select(pl.len()).collect(streaming=True).item()
        counts.append(int(k))
    n_rows, n_feat = int(sum(counts)), len(feat_cols)
    print(f"[memmap]: {len(fe_paths)} files, {n_rows} rows, {n_feat} features")

    # 3) 分配 memmap（本地）
    X = np.memmap(f"{prefix}_X.float32.mmap", dtype="float32", mode="w+", shape=(n_rows, n_feat))

    # 4) 逐分片流式写入
    i = 0
    for p, k in zip(fe_paths, counts):
        lf = (pl.scan_parquet(p, storage_options=storage_options)
              if p.startswith("az://") else pl.scan_parquet(p))
        names = set(lf.collect_schema().names())
        exprs = [
            (pl.col(c).cast(pl.Float32).alias(c) if c in names
             else pl.lit(None, dtype=pl.Float32).alias(c))
            for c in feat_cols
        ]
        df = lf.select(exprs).collect(streaming=True)

        X[i:i+k, :] = df.select(feat_cols).to_numpy()
        i += k
        del df; gc.collect()

    X.flush()

    # 5) 保存 meta（含来源文件）
    meta = {
        "n_rows": int(n_rows),
        "n_feat": int(n_feat),
        "dtype": "float32",
        "ts": time.time(),
        "features": list(feat_cols),
        "paths": list(fe_paths),
    }
    with open(f"{prefix}.meta.json", "w") as f:
        json.dump(meta, f)

    return X

# ---- 选取最终 test_{lo}_{hi}.parquet 文件 ----
regex = re.compile(r"^test_\d+_\d+\.parquet$")
fe_all_np = fs.glob(f"{FE_SHA_DIR_NP}/test_*_*.parquet")
fe_all_np = [p for p in fe_all_np if regex.match(os.path.basename(p))]
fe_paths = [("az://" + p) if not p.startswith("az://") else p for p in fe_all_np]

# ---- 生成测试 memmap ----
test_X = shard2memmap(
    fe_paths, feat_cols, prefix=f"{MM_DIR}/test", storage_options=storage_options
)
print("test shape:", test_X.shape)


In [None]:
import joblib

# 载入模型， 用memmap 批量处理

MODEL_PATH = "exp/v1/models/lgbm_final.txt"


PRED_DIR = "exp/v1/predictions"
os.makedirs(PRED_DIR, exist_ok=True)

# 读取 memmap
with open(f"{MM_DIR}/test.meta.json", "r") as f:
    meta = json.load(f)
n_rows, n_feat = meta["n_rows"], meta["n_feat"]
test_X = np.memmap(f"{MM_DIR}/test_X.float32.mmap", dtype="float32", mode="r",
                   shape=(n_rows, n_feat))


In [None]:
# 载入模型
model = lgb.Booster(model_file=MODEL_PATH)

In [None]:
BATCH = 50_000  # 按你机器内存调整
preds = np.memmap(f"{PRED_DIR}/test_pred.float32.mmap", dtype="float32", mode="w+",
                  shape=(n_rows,))
for s in range(0, n_rows, BATCH):
    e = min(s + BATCH, n_rows)
    Xb = test_X[s:e]
    # scikit/LightGBM/XGBoost 都支持 .predict(np.ndarray)
    pb = model.predict(Xb)
    preds[s:e] = pb.astype("float32", copy=False)
    del Xb, pb
    gc.collect()
preds.flush()
print("inference done:", preds.shape)

In [None]:
import re, os

KEYS = ["symbol_id","date_id","time_id"]
regex = re.compile(r"^test_\d+_\d+\.parquet$")
fe_all_np = fs.glob(f"{FE_SHA_DIR_NP}/test_*_*.parquet")
fe_all_np = [p for p in fe_all_np if regex.match(os.path.basename(p))]

# 数值排序，确保 (lo,hi) 顺序一致
def _key(p):
    b = os.path.basename(p)[:-8]
    a,b2 = b.split("_")[-2:]
    return (int(a), int(b2))
fe_paths = [("az://"+p) if not p.startswith("az://") else p for p in sorted(fe_all_np, key=_key)]

# 逐分片写 preds（带 KEYS）
i = 0
for p in fe_paths:
    # 行数
    k = pl.scan_parquet(p, storage_options=storage_options).select(pl.len()).collect(streaming=True).item()

    # 取 KEYS
    df_keys = (pl.scan_parquet(p, storage_options=storage_options)
                 .select(KEYS)
                 .collect(streaming=True))

    # 取对应预测切片
    yhat = preds[i:i+k].copy()  # 小片复制到内存
    df_pred = df_keys.with_columns(pl.Series("y_pred", yhat.astype("float32")))

    # 写出同名窗口的预测文件
    base = os.path.basename(p)[:-8]   # test_lo_hi
    outp = f"{PRED_DIR}/{base}_preds.parquet"
    if fs.exists(outp): fs.rm(outp)
    df_pred.write_parquet(outp, compression="zstd")

    i += k
    del df_keys, df_pred, yhat; gc.collect()

print("pred shards written to:", PRED_DIR)


In [None]:
# 汇总全部预测分片
preds_all = pl.scan_parquet(f"{PRED_DIR}/test_*_preds.parquet")

# 选择评估日期范围（可根据你的 test 窗口，比如 1530..1698）
DATE_LO, DATE_HI = 1530, 1698
labels = (pl.scan_parquet(TEST_BASE_PATH, storage_options=storage_options)
            .filter(pl.col("date_id").is_between(DATE_LO, DATE_HI))
            .select([*KEYS,
                     pl.col(TARGET).cast(pl.Float32).alias("y_true"),
                     pl.col(WEIGHT).cast(pl.Float32).alias("w")]))

# join + 算指标
df = preds_all.join(labels, on=KEYS, how="inner").collect(streaming=True)

# 若某些天没有权重，兜底为 1.0
if "w" not in df.columns:
    df = df.with_columns(pl.lit(1.0).alias("w"))

y = df["y_true"].to_numpy()
p = df["y_pred"].to_numpy()
w = df["w"].to_numpy()



In [None]:
def weighted_r2_zero_mean(y_true, y_pred, weight):
    y_true = np.asarray(y_true, dtype=np.float64).ravel()
    y_pred = np.asarray(y_pred, dtype=np.float64).ravel()
    weight = np.asarray(weight, dtype=np.float64).ravel()
    num = np.sum(weight * (y_true - y_pred)**2)
    den = np.sum(weight * (y_true**2))
    return 0.0 if den <= 0 else 1.0 - (num/den)

#输出结果
wr2 = weighted_r2_zero_mean(y, p, w)
print(f"Final weighted R^2 (zero-mean): {wr2:.6f}")