# 初始化

In [1]:
from pathlib import Path
import gc, re
import polars as pl
import numpy as np
from tqdm.auto import tqdm

from pipeline.io import cfg, fs, storage_options, P, ensure_dir_az
from pipeline.features import run_staged_engineering, StageA, StageB, StageC

def azify(p: str) -> str:
    return p if p.startswith("az://") else f"az://{p}"

In [2]:
# ---- 常量/列名 ----
FEATURE_ALL = [f"feature_{i:02d}" for i in range(79)]
RESP_COLS   = [f"responder_{i}" for i in range(9)]
KEYS        = tuple(cfg["keys"])
g_sym, g_date, g_time = KEYS
TB = cfg['time_bucket']
# ---- I/O ----
clean_root = azify(P("az", cfg["paths"]["clean_shards"]))
fe_root    = azify(P("az", cfg["paths"]["fe_shards"]))
ensure_dir_az(fe_root)

clean_paths = [azify(p) for p in sorted(fs.glob(f"{clean_root}/*.parquet"))]
if not clean_paths:
    raise FileNotFoundError(f"No clean shards under {clean_root}")


lc = pl.scan_parquet(clean_paths, storage_options=storage_options)
lc = (
    lc.filter(pl.col(g_date).is_between(830, 860, closed="both"))
)

In [14]:
df = lc.sort(["symbol_id", "date_id", "time_id"]).limit(1000).collect().to_pandas()

In [17]:
# 显示全部
pd.set_option('display.max_columns', None)
pd.set_option('display.max_rows', None)
df

Unnamed: 0,symbol_id,date_id,time_id,feature_00,feature_01,feature_02,feature_03,feature_04,feature_05,feature_06,feature_07,feature_08,feature_09,feature_10,feature_11,feature_12,feature_13,feature_14,feature_15,feature_16,feature_17,feature_18,feature_19,feature_20,feature_21,feature_22,feature_23,feature_24,feature_25,feature_26,feature_27,feature_28,feature_29,feature_30,feature_31,feature_32,feature_33,feature_34,feature_35,feature_36,feature_37,feature_38,feature_39,feature_40,feature_41,feature_42,feature_43,feature_44,feature_45,feature_46,feature_47,feature_48,feature_49,feature_50,feature_51,feature_52,feature_53,feature_54,feature_55,feature_56,feature_57,feature_58,feature_59,feature_60,feature_61,feature_62,feature_63,feature_64,feature_65,feature_66,feature_67,feature_68,feature_69,feature_70,feature_71,feature_72,feature_73,feature_74,feature_75,feature_76,feature_77,feature_78,weight,time_bucket,responder_0,responder_1,responder_2,responder_3,responder_4,responder_5,responder_6,responder_7,responder_8
0,0,830,0,-1.732551,2.838249,-1.910233,-1.909973,-2.921678,2.130022,-0.288255,1.025684,-0.61722,11.0,7.0,76.0,-0.568132,2.894893,0.541611,-0.179338,-0.071614,-0.228823,-1.216636,-1.019639,0.900644,-0.163989,1.02089,0.901643,-0.21065,-0.904227,1.03609,1.474487,1.036584,0.012248,0.006823,-0.282585,0.627613,-0.268317,-0.572415,0.412958,1.680631,-0.007865,-0.439949,0.964614,1.017879,-0.202907,-0.068881,-1.104081,-0.398352,-1.2861,0.951201,1.04909,1.382032,1.430444,-0.340676,1.536353,0.606949,0.847232,-0.832851,-0.099722,-1.25236,1.961124,0.340672,1.298574,0.502728,0.638373,-0.224157,-0.168616,-0.280784,-1.933092,-1.50325,-0.685505,2.619909,0.378071,-0.621479,3.117842,0.57858,-0.376818,-0.387218,-0.139912,-0.081436,-0.293182,-0.226534,1.613274,0,0.551002,0.112173,0.257809,-0.117991,-0.50658,0.107359,-0.234335,-0.961939,-0.088541
1,0,830,1,-1.619766,2.088053,-1.671484,-1.764021,-2.196926,1.041264,-0.922567,0.430553,-0.627672,11.0,7.0,76.0,-0.608984,2.943188,0.420189,-0.179338,-0.053845,-0.228823,-1.747756,-1.482374,0.900644,-0.163989,1.02089,0.901643,-0.213363,-0.904227,1.03609,1.474487,1.036584,0.012248,0.007117,-0.282585,0.627613,-0.268317,0.075327,0.147362,2.13707,-0.008614,-0.450105,0.964614,1.038494,-0.202907,-0.068881,0.056437,-0.398352,-1.381087,0.96558,1.643109,1.254152,0.898919,-0.340676,-0.066505,0.606949,0.847232,-0.377195,-0.099722,-1.195951,1.314484,0.340672,0.919209,0.65718,0.638373,-0.356323,-0.164762,-0.213675,-1.344059,-1.772867,-0.836668,2.685382,0.463456,-0.547803,3.160747,1.029321,-0.376818,-0.387218,-0.08606,-0.03938,-0.120607,-0.161126,1.613274,0,0.479581,0.162743,0.230476,-0.148894,-0.622764,0.005573,-0.489949,-0.781161,-0.17971
2,0,830,2,-1.759395,2.369348,-1.696708,-1.932428,-2.460622,1.866504,-2.45982,-0.156537,-0.646958,11.0,7.0,76.0,-0.630282,3.015451,0.956868,-0.179338,-0.011268,-0.228823,-1.642332,-1.335051,0.900644,-0.163989,1.02089,0.901643,-0.216021,-0.904227,1.03609,1.474487,1.036584,0.012248,0.007407,-0.282585,0.627613,-0.268317,0.013497,0.13745,1.76129,0.037328,-0.402884,0.964614,0.772581,-0.202907,-0.068881,-0.473211,-0.398352,-1.467465,1.256983,1.685913,-0.15687,0.668008,-0.340676,-1.046606,0.606949,0.847232,-1.386103,-0.099722,-1.136616,1.467468,0.340672,-0.376697,0.254106,0.638373,-0.304117,-0.234047,-0.190594,-1.30752,-1.767718,-0.404848,2.248114,0.594394,-0.483131,3.223826,0.874155,-0.376818,-0.387218,-0.036471,-0.031966,-0.212657,-0.121167,1.613274,0,0.602188,0.000249,0.138208,-0.0981,-0.919861,0.097814,-0.367869,-0.843593,0.031717
3,0,830,3,-1.118403,2.481385,-1.941222,-1.943709,-2.032991,1.674263,-2.987439,-0.442469,-0.66729,11.0,7.0,76.0,-0.633998,3.025571,1.065891,-0.179338,0.064825,-0.228823,-1.138194,-1.231092,0.900644,-0.163989,1.02089,0.901643,-0.218629,-0.904227,1.03609,1.474487,1.036584,0.012248,0.007693,-0.282585,0.627613,-0.268317,-0.523783,0.052489,2.72079,0.004992,-0.324539,0.964614,0.831047,-0.202907,-0.068881,-0.087708,-0.398352,-1.647498,0.693677,1.14599,0.256506,0.83913,-0.340676,1.137633,0.606949,0.847232,-1.375746,-0.099722,-1.328762,0.669375,0.340672,0.373901,0.412801,0.638373,-0.361982,-0.090786,-0.252602,-0.685047,-2.119442,-0.47056,2.583262,0.707315,-0.421351,3.233614,1.567712,-0.376818,-0.387218,-0.033219,-0.028263,-0.098273,-0.091303,1.613274,0,0.272396,-0.035213,-0.031412,0.03202,-0.61678,0.703515,-0.234365,-0.710919,1.267214
4,0,830,4,-1.60178,2.347465,-1.442185,-1.702303,-1.472809,1.725644,-1.938621,-0.309586,-0.684043,11.0,7.0,76.0,-0.630944,3.051734,1.467152,-0.179338,0.014355,-0.059567,-0.58009,-1.242958,0.900644,-0.163989,1.02089,0.901643,-0.221187,-0.904227,1.03609,1.474487,1.036584,0.012248,0.007975,-0.282585,0.627613,-0.268317,-0.458802,0.147759,1.670382,0.003621,-0.282672,0.964614,0.893771,-0.202907,-0.068881,-0.00964,-0.398352,-0.751949,0.926668,1.772647,0.584011,1.117206,-0.340676,1.045445,0.606949,0.847232,-0.898895,-0.099722,-0.895798,1.384581,0.340672,0.452311,0.334856,0.638373,-0.345169,-0.075059,-0.230152,-1.632946,-1.828838,-0.395274,1.896746,0.941926,-0.275199,1.724052,1.310066,-0.376818,-0.387218,0.007926,0.011083,-0.09511,-0.087271,1.613274,0,0.361236,0.03081,0.270003,-0.034661,-0.712328,-0.165719,-0.486264,-0.955219,-0.703522
5,0,830,5,-1.47052,2.406637,-1.954334,-1.63473,-2.110901,1.358015,-0.346647,-0.032392,-0.698457,11.0,7.0,76.0,-0.432312,2.566731,0.76903,-0.179338,-0.086833,-0.122555,-0.898643,-0.61853,0.900644,-0.163989,1.02089,0.901643,-0.223699,-0.904227,1.03609,1.474487,1.036584,0.012248,0.008253,-0.282585,0.627613,-0.268317,-0.225323,0.059583,1.505142,-0.043429,-0.387305,0.964614,1.305616,-0.202907,-0.068881,0.652824,-0.398352,-1.302996,0.770246,1.802846,0.561557,1.164482,-0.340676,0.730354,0.606949,0.847232,-0.717893,-0.099722,-1.555068,2.114195,0.340672,1.197533,0.695309,0.638373,-0.2872,-0.07903,-0.187281,-1.369204,-1.782373,-0.507244,1.507728,0.583661,-0.508174,1.268147,1.252941,-0.376818,-0.387218,0.071324,0.058036,-0.08632,-0.079874,1.613274,0,0.258766,-0.10179,-0.03973,-0.089858,-0.748607,-0.031083,-0.284519,-0.573038,0.007858
6,0,830,6,-1.434085,1.897221,-1.503336,-1.590392,-1.704005,1.426517,0.302406,0.161526,-0.705565,11.0,7.0,76.0,-0.452361,1.341479,0.839572,-0.179338,-0.135466,-0.156047,-1.183237,-0.988724,0.900644,-0.163989,1.02089,0.901643,-0.226166,-0.904227,1.03609,1.474487,1.036584,0.012248,0.008528,-0.282585,0.627613,-0.268317,0.11463,-0.310665,1.837835,-0.021998,-0.484232,0.964614,1.130436,-0.202907,-0.068881,0.129987,-0.398352,-1.278686,1.300538,1.579951,0.396412,0.823259,-0.340676,-0.072782,0.606949,0.847232,-0.703878,-0.099722,-1.058769,1.615847,0.340672,0.585375,0.56496,0.638373,-0.279949,-0.181133,-0.205744,-1.367953,-2.136612,-0.444784,1.077204,0.909239,-0.406598,1.547626,1.014639,-0.376818,-0.387218,0.110623,0.085989,-0.078567,-0.070957,1.613274,0,0.288134,-0.084691,0.395729,-0.219946,-0.719485,0.069922,-0.730919,-0.926907,-0.319176
7,0,830,7,-1.311603,1.865442,-1.483469,-1.846018,-2.082688,2.007598,1.684288,0.580221,-0.604876,11.0,7.0,76.0,-0.384125,1.149583,0.938246,-0.179338,-0.06756,-0.167103,-1.060532,-0.8468,0.900644,-0.163989,1.02089,0.901643,-0.22859,-0.904227,1.03609,1.474487,1.036584,0.012248,0.008799,-0.282585,0.627613,-0.268317,-0.340835,-0.310209,1.643641,-0.028078,-0.530392,0.964614,-0.523133,-0.202907,-0.068881,0.324654,-0.398352,-1.535284,1.334241,1.841466,0.286387,0.921393,-0.340676,0.801837,0.606949,0.847232,-0.431912,-0.099722,-1.03595,1.979144,0.340672,0.426605,0.550294,0.638373,-0.270915,-0.149577,-0.156279,-1.00749,-1.9099,-0.270744,0.455199,0.704623,-0.505943,0.842745,0.772139,-0.376818,-0.387218,0.126309,0.103202,-0.072739,-0.064349,1.613274,0,0.388713,-0.082443,0.30934,-0.165213,-0.884463,0.021834,-0.616084,-0.850632,-0.233127
8,0,830,8,-1.795362,2.474761,-1.684258,-1.926736,-1.949157,2.213566,1.92382,0.465569,-0.719534,11.0,7.0,76.0,-0.44773,0.504785,0.992205,-0.179338,0.004587,-0.120328,-0.431605,-1.163373,0.900644,-0.163989,1.02089,0.901643,-0.230973,-0.904227,1.03609,1.474487,1.036584,0.012248,0.009066,-0.282585,0.627613,-0.268317,-0.645611,-0.248252,1.452747,-0.030108,-0.435907,0.964614,1.399756,-0.202907,-0.068881,-0.377968,-0.398352,-1.44663,1.456728,1.866329,0.084774,0.709023,-0.340676,0.339099,0.606949,0.847232,-0.324791,-0.099722,-0.727576,1.828848,0.340672,0.253482,0.649382,0.638373,-0.287587,-0.14844,-0.256063,-0.947148,-2.280855,-0.517231,0.379745,0.780508,-0.405704,0.462603,1.105871,-0.376818,-0.387218,0.128331,0.112967,-0.067768,-0.058021,1.613274,0,0.361689,-0.062871,0.289949,-0.395257,-0.756007,-0.189198,-1.111335,-0.651781,-1.117711
9,0,830,9,-1.726992,2.08707,-1.469712,-1.991441,-1.563615,1.126233,0.104897,0.304983,-0.734718,11.0,7.0,76.0,-0.545285,0.467782,0.934947,-0.179338,-0.095037,-0.067607,-1.18179,-0.170624,0.900644,-0.163989,1.02089,0.901643,-0.233316,-0.904227,1.03609,1.474487,1.036584,0.012248,0.009238,-0.282585,1.307992,2.213154,-0.03778,-0.24309,1.44626,-0.031905,-0.492729,0.964614,1.308211,-0.202907,-0.068881,-0.01997,-0.398352,-1.071203,0.222938,1.886688,0.184167,0.871699,-0.340676,0.540622,0.606949,0.847232,-0.811018,-0.099722,-0.895964,1.938164,0.210725,0.225821,0.645921,0.638373,-0.366904,-0.102401,-0.165907,-1.234602,-1.987717,-0.48536,0.319159,0.52123,-0.48322,0.478628,1.338114,-0.154858,-0.125658,-0.004975,0.026089,-0.063097,-0.052895,1.613274,0,0.49105,-0.087375,0.3247,-0.502405,-0.606565,-0.113105,-0.969341,-0.574062,-0.722212


In [3]:
cols = lc.collect_schema().names()
cols = [c for c in cols if c not in KEYS]

# 按 symbol 判断是否存在段内波动（可按 __streak_id 更细）
df_val = lc.group_by(g_sym).agg([
    pl.col(c).var().alias(f"{c}_var") for c in cols
]).collect().to_pandas()


In [10]:
df_val.sort_values(by="feature_31_var", ascending=False).head(50)

Unnamed: 0,symbol_id,feature_00_var,feature_01_var,feature_02_var,feature_03_var,feature_04_var,feature_05_var,feature_06_var,feature_07_var,feature_08_var,feature_09_var,feature_10_var,feature_11_var,feature_12_var,feature_13_var,feature_14_var,feature_15_var,feature_16_var,feature_17_var,feature_18_var,feature_19_var,feature_20_var,feature_21_var,feature_22_var,feature_23_var,feature_24_var,feature_25_var,feature_26_var,feature_27_var,feature_28_var,feature_29_var,feature_30_var,feature_31_var,feature_32_var,feature_33_var,feature_34_var,feature_35_var,feature_36_var,feature_37_var,feature_38_var,feature_39_var,feature_40_var,feature_41_var,feature_42_var,feature_43_var,feature_44_var,feature_45_var,feature_46_var,feature_47_var,feature_48_var,feature_49_var,feature_50_var,feature_51_var,feature_52_var,feature_53_var,feature_54_var,feature_55_var,feature_56_var,feature_57_var,feature_58_var,feature_59_var,feature_60_var,feature_61_var,feature_62_var,feature_63_var,feature_64_var,feature_65_var,feature_66_var,feature_67_var,feature_68_var,feature_69_var,feature_70_var,feature_71_var,feature_72_var,feature_73_var,feature_74_var,feature_75_var,feature_76_var,feature_77_var,feature_78_var,weight_var,time_bucket_var,responder_0_var,responder_1_var,responder_2_var,responder_3_var,responder_4_var,responder_5_var,responder_6_var,responder_7_var,responder_8_var
16,37,0.366373,1.311794,0.375217,0.375432,1.206606,0.159056,0.146881,0.143001,0.190047,0.0,0.0,0.0,0.476261,0.279585,0.433256,0.029955,0.031553,0.032638,0.910983,1.336058,0.678759,0.403284,0.738717,0.080016,1.453224,0.127835,0.072822,0.03423,0.036682,0.015845,0.033802,0.245529,0.369344,0.811213,0.373492,0.369429,0.774475,0.303432,0.197682,0.372091,0.196644,0.259468,0.467479,0.205395,0.289014,0.442225,0.653706,0.166973,0.145817,0.157466,0.753307,0.279577,0.647257,0.941583,0.267637,0.694297,0.587165,0.860707,0.257595,0.266848,0.26727,1.159079,0.007899,0.007174,0.007363,1.041013,1.264573,0.501525,0.238545,0.426911,0.414348,0.191723,0.348356,1.203931,1.149972,1.513174,1.490201,1.136118,1.09958,0.219922,2.917448,0.179555,0.226124,0.182158,0.198224,0.245579,0.123772,0.411977,0.403644,0.351097
7,27,0.372001,1.302316,0.383898,0.376739,1.193192,0.139092,0.131223,0.126908,0.167216,0.0,0.0,0.0,1.445603,0.563616,1.068818,0.024236,0.025123,0.024055,0.972613,1.291435,0.10285,0.200121,0.358737,0.040511,1.090163,0.164115,0.07219,0.030093,0.027472,0.008993,0.01161,0.237248,0.390839,0.787706,0.390717,0.393404,0.819395,0.252615,0.219913,0.48714,0.280939,0.336407,0.527131,0.291992,0.370218,0.489238,0.723357,0.195065,0.186212,0.192794,0.856194,0.320292,0.682572,0.920298,0.309162,0.667605,0.588312,0.906126,0.308748,0.290165,0.288621,1.195675,0.012933,0.010051,0.007763,0.873091,1.370891,1.348792,0.436872,0.933343,1.381883,0.428493,0.964166,0.045782,0.04529,0.053129,0.053246,0.040138,0.039034,0.176211,2.917452,0.226868,0.150703,0.270938,0.253184,0.336355,0.164331,0.46683,0.523429,0.417256
2,12,0.317801,1.304645,0.316196,0.313605,1.210775,0.168195,0.163081,0.158445,0.272264,0.0,0.0,0.0,2.056197,1.860358,2.253214,0.862385,0.999979,0.885177,0.663306,0.99165,0.866565,0.182393,0.999581,0.11763,1.083105,0.947809,0.00829,0.014835,0.016528,0.493228,1.018741,0.190025,0.442605,1.097587,0.460715,0.442178,1.276466,0.617951,0.744905,0.617481,1.070083,0.709209,0.680386,1.144845,0.78743,0.582839,0.782861,0.365949,0.194688,0.331163,0.760257,0.620846,0.594664,0.744917,1.150526,0.670949,0.649046,0.810765,0.601215,0.371897,0.559282,1.135867,0.037337,0.193533,0.07919,0.856721,0.940097,2.826781,1.916498,2.846582,1.352128,1.263909,1.643951,0.005323,0.015389,0.002996,0.010323,0.003742,0.01152,0.971346,2.917452,0.539688,0.337139,0.52624,0.466599,0.62416,0.253136,0.902001,0.958529,0.751036
9,13,0.369549,1.310012,0.376786,0.369579,1.21293,0.169034,0.151511,0.148485,0.214974,0.0,0.0,0.0,5.070052,2.044687,3.776048,0.047089,0.162745,0.067783,0.780773,0.946027,0.073334,0.128331,0.317704,0.072502,0.441963,0.094576,0.164628,0.032758,0.072158,0.067168,0.016786,0.135089,0.389971,0.819248,0.412049,0.405063,0.923585,0.497759,0.44623,0.523744,0.687258,0.512033,0.568163,0.690086,0.52422,0.573439,0.708171,0.289765,0.169474,0.247325,0.713508,0.709519,0.482807,0.795271,0.847267,0.526223,0.584795,0.802727,0.525528,0.301261,0.395766,1.158948,0.003938,0.044327,0.010137,0.823761,0.881206,5.677204,1.964686,3.967215,3.869367,1.370178,2.813146,0.005229,0.011844,0.003027,0.00763,0.003662,0.008953,0.283582,2.917448,0.270276,0.222127,0.272176,0.33545,0.523903,0.194149,0.644438,0.752261,0.530947
24,20,0.366126,1.312173,0.372426,0.374001,1.205896,0.391386,0.36251,0.355134,0.499703,0.0,0.0,0.0,1.248567,0.592851,1.0807,0.018344,0.018909,0.018809,0.866226,1.104244,0.159919,0.037502,0.443188,0.016075,0.0657,0.054931,0.023353,0.041617,0.023731,0.021614,0.008269,0.034517,0.496182,0.704224,0.504133,0.497502,0.640384,0.785017,0.864237,0.58209,0.338762,0.42185,0.671607,0.341721,0.461509,0.540412,1.144018,0.461592,0.334574,0.381397,0.892998,0.357335,0.761806,1.078396,0.345476,0.823993,0.661095,1.121381,0.602239,0.488515,0.527925,1.158941,0.007649,0.010471,0.007009,0.922273,1.013293,1.33017,0.481716,1.050467,1.078867,0.468442,0.924572,0.049934,0.052228,0.035626,0.037769,0.044307,0.044979,0.277825,2.917448,0.105201,0.06711,0.154272,0.466694,0.625393,0.284835,0.778201,0.838997,0.664502
1,3,0.379479,1.294095,0.394059,0.382732,1.192196,0.098225,0.094583,0.091628,0.111249,0.0,0.0,0.0,1.285118,0.560127,1.060944,0.028075,0.025028,0.027834,0.947017,1.014697,0.179145,0.031088,0.287668,0.011499,1.832381,0.1891,0.007992,0.039793,0.030197,0.015668,0.046815,0.034255,0.283378,1.255116,0.232805,0.234923,1.347639,0.246436,0.239752,0.320899,0.158238,0.219453,0.324097,0.158494,0.215474,0.566148,0.522321,0.113151,0.088486,0.098584,0.682904,0.247793,0.527434,0.710977,0.239279,0.512872,0.515164,0.96582,0.138445,0.194911,0.189102,1.02387,0.008172,0.006457,0.007584,0.898288,0.901833,1.113648,0.413655,0.880155,1.347925,0.488923,1.041937,0.595362,0.586754,0.60057,0.585238,0.515034,0.505045,0.120699,2.917452,0.675491,0.564255,0.541516,0.135108,0.168593,0.086479,0.311749,0.237999,0.278656
30,35,0.368429,1.315445,0.380018,0.372456,1.207547,0.300931,0.27927,0.268392,0.424123,0.0,0.0,0.0,1.74635,1.64293,1.924537,1.240215,2.797181,1.462692,0.645046,0.850038,0.187041,0.036792,0.215995,0.099992,0.187947,0.412294,0.00071,0.028581,0.038558,0.781908,1.150077,0.033572,0.399872,0.989754,0.443669,0.43608,1.037727,0.565218,0.503745,0.697545,2.16203,1.592653,0.89684,2.183945,1.719066,0.517022,0.825563,0.386682,0.143297,0.315556,0.575937,2.216933,0.492367,0.557471,3.293902,0.479677,0.466812,0.692728,0.616548,0.20076,0.449892,1.159076,0.091336,0.788286,0.293033,0.836345,0.990043,2.651567,1.774573,2.610107,1.21172,1.215865,1.611164,0.005317,0.009497,0.002956,0.00601,0.003761,0.006966,0.132562,2.917448,0.490988,0.418159,0.369293,0.460433,0.682304,0.261876,0.79281,0.875489,0.587438
20,8,0.366803,1.305047,0.373161,0.373377,1.207339,0.345173,0.32056,0.312583,0.429705,0.0,0.0,0.0,1.085688,0.983412,1.255934,0.019065,0.114987,0.034521,0.822989,0.93612,0.030043,0.023192,0.235707,0.063957,0.110094,0.025433,0.006066,0.040927,0.026424,0.026011,0.009143,0.030404,0.408084,0.645453,0.411963,0.407079,0.637025,0.406777,0.414079,0.589281,0.745262,0.593614,0.599974,0.751102,0.611651,0.605746,0.614519,0.288244,0.185046,0.274364,0.687152,0.543391,0.468639,0.763418,0.677616,0.526771,0.653644,0.704885,0.436386,0.260502,0.377188,1.159075,0.003498,0.053889,0.011555,0.800526,0.778659,1.281114,0.92226,1.392528,0.934608,0.762155,1.058861,0.005206,0.008397,0.003,0.00533,0.003645,0.006216,0.249102,2.917448,0.111734,0.102016,0.12941,0.340812,0.471439,0.211723,0.603979,0.670178,0.502358
0,9,0.368724,1.307987,0.380947,0.373457,1.198591,0.822228,0.675253,0.683119,1.073354,0.0,0.0,0.0,1.584345,0.97151,1.427729,0.188956,0.298865,0.205204,0.886238,1.03869,0.049459,0.019503,0.044905,0.016233,0.979733,0.691007,1.56722,0.219252,0.02958,0.003915,0.221085,0.026681,0.46367,0.689164,0.449658,0.449622,0.541795,0.596379,1.199952,0.696527,0.741529,0.653578,0.712196,0.759777,0.709646,0.650885,0.749807,0.566644,0.340907,0.394951,0.824924,0.903413,0.700793,0.835011,1.008457,0.696923,0.686809,0.832152,0.560789,0.497157,0.569479,1.191529,0.014033,0.075265,0.02844,0.983972,0.99374,1.423725,0.734771,1.211819,1.793511,0.939036,1.534664,0.005679,0.010285,0.006021,0.009367,0.004751,0.008917,0.134305,2.917452,0.08007,0.067849,0.10389,0.430048,0.473548,0.304959,0.733733,0.649962,0.712353
28,11,0.366538,1.315638,0.375679,0.37031,1.210124,0.595098,0.540608,0.529026,0.738489,0.0,0.0,0.0,1.508709,1.14456,1.519399,0.033649,0.182364,0.052724,0.738579,0.808844,0.067223,0.020215,0.221212,0.031294,0.211458,0.050934,0.035265,0.048208,0.02832,0.00307,0.039819,0.018038,1.063904,0.724799,1.074184,1.071542,0.690229,0.604389,0.613437,0.644701,0.790039,0.692409,0.735578,0.779065,0.691658,0.752884,0.799345,0.489058,0.318779,0.413488,0.82187,0.875039,0.693428,0.824611,1.17794,0.645733,0.821341,0.87532,0.593395,0.447188,0.546947,1.158941,0.005298,0.064455,0.017453,0.957344,0.973058,1.613313,1.023375,1.578,1.383459,0.96855,1.359761,0.006127,0.009192,0.008853,0.010965,0.006195,0.008614,0.1367,2.917448,0.12816,0.170098,0.137654,0.448097,0.623296,0.290347,0.743786,0.835004,0.662293


In [8]:
cols = lc.collect_schema().names()
cols = [c for c in cols if c not in KEYS]

df_u = lc.group_by(g_sym).agg([
    pl.col(c).n_unique().alias(f"{c}__nunq_in_streak") for c in cols
]).collect().to_pandas()


In [9]:
import pandas as pd
# 显示pandas 所有列
pd.set_option('display.max_columns', None)
df_u

Unnamed: 0,symbol_id,feature_00__nunq_in_streak,feature_01__nunq_in_streak,feature_02__nunq_in_streak,feature_03__nunq_in_streak,feature_04__nunq_in_streak,feature_05__nunq_in_streak,feature_06__nunq_in_streak,feature_07__nunq_in_streak,feature_08__nunq_in_streak,feature_09__nunq_in_streak,feature_10__nunq_in_streak,feature_11__nunq_in_streak,feature_12__nunq_in_streak,feature_13__nunq_in_streak,feature_14__nunq_in_streak,feature_15__nunq_in_streak,feature_16__nunq_in_streak,feature_17__nunq_in_streak,feature_18__nunq_in_streak,feature_19__nunq_in_streak,feature_20__nunq_in_streak,feature_21__nunq_in_streak,feature_22__nunq_in_streak,feature_23__nunq_in_streak,feature_24__nunq_in_streak,feature_25__nunq_in_streak,feature_26__nunq_in_streak,feature_27__nunq_in_streak,feature_28__nunq_in_streak,feature_29__nunq_in_streak,feature_30__nunq_in_streak,feature_31__nunq_in_streak,feature_32__nunq_in_streak,feature_33__nunq_in_streak,feature_34__nunq_in_streak,feature_35__nunq_in_streak,feature_36__nunq_in_streak,feature_37__nunq_in_streak,feature_38__nunq_in_streak,feature_39__nunq_in_streak,feature_40__nunq_in_streak,feature_41__nunq_in_streak,feature_42__nunq_in_streak,feature_43__nunq_in_streak,feature_44__nunq_in_streak,feature_45__nunq_in_streak,feature_46__nunq_in_streak,feature_47__nunq_in_streak,feature_48__nunq_in_streak,feature_49__nunq_in_streak,feature_50__nunq_in_streak,feature_51__nunq_in_streak,feature_52__nunq_in_streak,feature_53__nunq_in_streak,feature_54__nunq_in_streak,feature_55__nunq_in_streak,feature_56__nunq_in_streak,feature_57__nunq_in_streak,feature_58__nunq_in_streak,feature_59__nunq_in_streak,feature_60__nunq_in_streak,feature_61__nunq_in_streak,feature_62__nunq_in_streak,feature_63__nunq_in_streak,feature_64__nunq_in_streak,feature_65__nunq_in_streak,feature_66__nunq_in_streak,feature_67__nunq_in_streak,feature_68__nunq_in_streak,feature_69__nunq_in_streak,feature_70__nunq_in_streak,feature_71__nunq_in_streak,feature_72__nunq_in_streak,feature_73__nunq_in_streak,feature_74__nunq_in_streak,feature_75__nunq_in_streak,feature_76__nunq_in_streak,feature_77__nunq_in_streak,feature_78__nunq_in_streak,weight__nunq_in_streak,time_bucket__nunq_in_streak,responder_0__nunq_in_streak,responder_1__nunq_in_streak,responder_2__nunq_in_streak,responder_3__nunq_in_streak,responder_4__nunq_in_streak,responder_5__nunq_in_streak,responder_6__nunq_in_streak,responder_7__nunq_in_streak,responder_8__nunq_in_streak
0,0,29987,30003,29980,29988,29996,30001,30006,30007,30002,1,1,1,30002,29997,29999,29259,30002,29878,29999,30004,381,357,447,351,448,367,302,438,280,421,595,189,29707,29710,30003,30002,30001,30006,30003,27899,30005,29443,27896,30005,29444,29984,29990,30005,30006,30002,27897,30002,29449,27895,29990,29448,29998,29998,29706,30004,30006,654,29980,30001,29999,29985,29989,29987,29994,29995,29999,29993,30001,29668,29672,29953,29966,29954,29951,31,6,30004,30003,30004,29987,30001,29998,29960,29978,29966
1,9,29012,29031,29020,29015,29036,29037,29034,29035,29034,1,1,1,29030,29024,29033,28310,29034,28912,29037,29036,288,537,348,320,452,509,334,503,278,389,419,735,28748,28752,29034,29032,29035,29036,29034,26994,29033,28493,26994,29038,28493,29028,29033,29035,29037,29035,26996,29037,28494,26997,29027,28498,29032,29031,28751,29037,29038,604,29025,29025,29032,29026,29027,29031,29027,29031,29037,29015,29030,28716,28718,28992,28997,28995,29004,30,6,29034,29039,29032,29008,29025,29009,28941,28989,28929
2,15,29979,30003,29986,29986,30001,30003,30004,30002,30005,1,1,1,30000,30000,29997,29237,29981,29845,30005,29998,524,281,331,135,368,497,229,433,542,145,484,245,29700,29706,30002,30001,30003,30005,29997,27891,30004,29444,27893,29999,29447,30005,30001,30003,30007,30006,27896,30000,29446,27896,29998,29439,30006,30005,29712,30004,30004,654,29974,29977,29969,30001,29999,29997,29993,29996,29997,29995,29991,29704,29710,29996,30002,29999,30000,31,6,29953,29990,29969,30001,30002,30001,29969,29992,29958
3,3,29010,29034,29011,29022,29034,29034,29034,29034,29033,1,1,1,29028,29024,29025,28284,28992,28882,29028,29034,309,394,413,423,247,284,338,286,458,364,453,420,28718,28748,29008,29009,29033,29034,29038,26991,29032,28497,26993,29036,28490,29032,29026,29037,29040,29035,26990,29023,28494,26990,29034,28487,29037,29030,28749,29037,29036,640,28987,28987,28991,29028,29038,29026,29023,29020,29032,29030,29025,28752,28752,29032,29036,29033,29034,30,6,28891,28977,28938,29032,29035,29035,28976,29030,28984
4,12,29015,29035,29013,29017,29036,29036,29036,29037,29031,1,1,1,29036,29022,29032,28299,29029,28908,29034,29035,366,535,351,252,406,504,385,241,503,284,256,404,28749,28749,29031,29035,29033,29036,29036,26994,29026,28492,26996,29030,28498,29027,29020,29035,29037,29038,26997,29022,28496,26997,29026,28493,29036,29034,28753,29035,29036,655,29021,29035,29040,29025,29025,29030,29023,29031,29036,29021,29026,28701,28730,28980,28995,28999,29009,30,6,29003,29038,28986,29005,29009,29024,28888,28965,28877
5,33,29978,29999,29979,29983,30003,30003,30003,30000,30000,1,1,1,30000,29995,29996,29259,30003,29872,30004,30005,439,704,246,301,551,551,414,469,247,287,427,776,29709,29706,29999,30001,30001,29998,30003,27899,29997,29443,27898,30001,29441,29997,29998,30004,30006,30006,27896,29989,29446,27895,29999,29444,30003,30004,29709,30006,30006,586,29989,29997,29998,29995,29996,30000,29981,29999,30003,29996,29994,29655,29678,29956,29957,29963,29952,31,6,30001,30000,30007,29914,29767,29958,29858,29754,29866
6,36,29980,30003,29993,29983,29998,30002,30006,30003,29998,1,1,1,29994,29981,29997,29251,29992,29860,30000,29996,498,303,490,228,404,455,138,454,370,492,211,536,29702,29711,30000,30000,30004,29998,30003,27895,30003,29444,27893,30004,29442,29952,29947,30003,29999,30005,27896,29984,29440,27896,29992,29441,30000,30000,29709,30006,30005,654,30004,29999,30003,29952,29950,30001,29976,29985,29991,29978,29985,29659,29669,29936,29967,29964,29961,31,6,29998,30005,29994,29985,29992,29994,29922,29955,29944
7,30,29983,30006,29984,29983,30000,30003,30002,30004,29999,1,1,1,30003,29991,29993,29241,29999,29877,29998,30005,391,338,416,521,370,409,77,329,445,392,419,343,29704,29710,30000,29990,30006,30008,30006,27898,30003,29444,27898,30002,29444,29996,30003,30003,30002,30003,27894,30004,29443,27896,30003,29446,30001,29999,29710,30006,30004,654,29975,29988,29978,30003,30003,29997,29992,29999,30001,29989,29999,29669,29685,29979,29988,29978,29986,31,6,30006,30004,30005,29956,29920,29989,29886,29905,29913
8,27,29012,29035,29018,29015,29033,29030,29035,29036,29034,1,1,1,29025,29029,29022,28270,28998,28878,29035,29033,378,248,466,463,316,512,312,421,483,173,168,309,28747,28748,29030,29028,29034,29030,29032,26994,29029,28497,26990,29032,28494,29032,29030,29034,29039,29039,26998,29034,28492,26996,29033,28494,29035,29029,28749,29037,29038,639,28990,29009,29007,29035,29032,29021,29017,29027,29024,29020,29026,28731,28732,29018,29019,29021,29027,30,6,29017,29035,29016,29033,29024,29034,28990,29005,28998
9,7,29985,30004,29977,29986,30003,30005,30003,30000,29998,1,1,1,29993,29986,29993,29235,29990,29847,30004,29994,461,628,366,283,587,486,221,258,337,258,277,421,29711,29707,29997,29996,30004,30001,30003,27891,29999,29440,27897,30000,29444,29962,29964,30005,30008,30003,27895,29981,29441,27894,29992,29439,30002,30004,29709,30007,30006,654,30005,29999,30003,29961,29954,30002,29977,29990,30001,29988,29991,29661,29686,29944,29962,29949,29959,31,6,29998,30005,30004,29997,29998,30002,29955,29995,29931


In [7]:
lc.select(pl.col(["symbol_id", "date_id", "time_id", "feature_09", "feature_10", "feature_11"])).sort(["symbol_id", "date_id", "time_id"]).slice(0, 200).collect()

symbol_id,date_id,time_id,feature_09,feature_10,feature_11
i32,i32,i32,f32,f32,f32
0,830,0,11.0,7.0,76.0
0,830,1,11.0,7.0,76.0
0,830,2,11.0,7.0,76.0
0,830,3,11.0,7.0,76.0
0,830,4,11.0,7.0,76.0
…,…,…,…,…,…
0,830,195,11.0,7.0,76.0
0,830,196,11.0,7.0,76.0
0,830,197,11.0,7.0,76.0
0,830,198,11.0,7.0,76.0


In [9]:
lc.select(pl.col(["symbol_id", "date_id", "time_id", "feature_09", "feature_10", "feature_11"])).sort(["symbol_id", "time_id", "date_id"]).slice(0, 200).collect()

symbol_id,date_id,time_id,feature_09,feature_10,feature_11
i32,i32,i32,f32,f32,f32
0,830,0,11.0,7.0,76.0
0,831,0,11.0,7.0,76.0
0,832,0,11.0,7.0,76.0
0,833,0,11.0,7.0,76.0
0,834,0,11.0,7.0,76.0
…,…,…,…,…,…
0,839,6,11.0,7.0,76.0
0,840,6,11.0,7.0,76.0
0,841,6,11.0,7.0,76.0
0,842,6,11.0,7.0,76.0


In [14]:
df_univalue = lc.group_by("symbol_id").agg([
    pl.col("feature_09").n_unique().alias("feature_09__nunq_in_streak"),
    pl.col("feature_10").n_unique().alias("feature_10__nunq_in_streak"),
    pl.col("feature_11").n_unique().alias("feature_11__nunq_in_streak"),
]).sort("symbol_id").collect().to_pandas()

In [15]:
df_univalue

Unnamed: 0,symbol_id,feature_09__nunq_in_streak,feature_10__nunq_in_streak,feature_11__nunq_in_streak
0,0,1,1,1
1,1,1,1,1
2,2,1,1,1
3,3,1,1,1
4,5,1,1,1
5,7,1,1,1
6,8,1,1,1
7,9,1,1,1
8,10,1,1,1
9,11,1,1,1


In [None]:
test_cols = ["feature_09", "feature_10", "feature_11"]
by = ["symbol_id"]

audit_const = lf_data.select([
    *[pl.col(c).n_unique().over(by).alias(f"{c}__nunq_in_streak") for c in test_cols]
]).collect()


NameError: name 'lf_data' is not defined

In [4]:
# 环境与依赖

# 基础包
import tempfile

import os, gc, glob, json, yaml, time
from pathlib import Path
import numpy as np, pandas as pd, polars as pl
import lightgbm as lgb
from dataclasses import dataclass
import pyarrow.parquet as pq
from typing import Sequence, Optional, Union, List, Tuple, Iterable, Mapping

import matplotlib.pyplot as plt
# Azure & 文件系统
import fsspec
from getpass import getpass
from dotenv import load_dotenv
load_dotenv()  # 默认会加载当前目录下的 .env 文件


# 连接云空间

ACC = os.getenv("AZURE_STORAGE_ACCOUNT_NAME")
KEY = os.getenv("AZURE_STORAGE_ACCOUNT_KEY")
if not ACC or not KEY:
    raise RuntimeError("Azure credentials not found. Please set them in .env")
storage_options = {"account_name": ACC, "account_key": KEY}
fs = fsspec.filesystem("az", **storage_options)



# 定义路径辅助函数

# 读取配置（唯一来源）
cfg = yaml.safe_load(open("config/data.yaml"))

# 路径辅助函数
def P(kind: str, subpath: str = "") -> str:
    container  = str(cfg["blob"]["container"]).strip("/")
    prefix     = str(cfg["blob"]["prefix"]).strip("/")
    version    = str(cfg["exp_root"]).strip("/")
    local_root = Path(cfg["local"]["root"])

    sub = str(subpath).strip("/")  # 只做最小化处理；你也可以直接用 subpath

    if kind == "az":
        base = f"az://{container}" + (f"/{prefix}" if prefix else "") + f"/{version}"
        return f"{base}/{sub}" if sub else base
    if kind == "np":
        base = f"{container}" + (f"/{prefix}" if prefix else "") + f"/{version}"
        return f"{base}/{sub}" if sub else base
    if kind == "local":
        base = (local_root / version).as_posix()
        return f"{base}/{sub}" if sub else base
    raise ValueError("kind must be 'az', 'np', or 'local'")


# 全局变量
KEYS = cfg['keys']
WEIGHT = cfg['weight']
TIME_SORT = cfg['sorts']['time_major']

FEATURE_ALL = [f"feature_{i:02d}" for i in range(79)]
RESP_COLS   = [f"responder_{i}" for i in range(0, 9)]


# 读取数据


np_paths = fs.glob(f"az://jackson/js_exp/raw/train.parquet/partition_id=[0-9]/*.parquet")

paths=[]
for p in np_paths:
    paths.append("az://"+p)
lb = pl.scan_parquet(paths, storage_options=storage_options)
DATE_LO = int(1000)
DATE_HI = int(1698)
lb = lb.filter(pl.col("date_id").is_between(DATE_LO, DATE_HI, closed="both"))
lb = lb.sort(KEYS)

In [5]:
lb.select(pl.col("weight").min()).collect()

weight
f32
0.269919


In [2]:
import pandas as pd
import numpy as np
df_features = pd.read_csv("/mnt/data/js/exp/v1/models/tune/feature_importance__fixed__fixed__mm_full_train__features__fs__1300-1500__cv3-g7-r4__seed42__top1000__1760299442__range1000-1600__range1000-1600__cv2-g7-r4__1760347190.csv")

In [21]:
(df_features.iloc[500:])['mean_gain'].sum()

np.float64(0.20675190477962163)

In [1]:
with open("/mnt/data/js/exp/v1/models/tune/feature_importance__fixed__fixed__mm_full_train__features__fs__1300-1500__cv3-g7-r4__seed42__top1000__1760299442__range1000-1600__range1000-1600__cv2-g7-r4__1760347190.csv") as f:
    features = f.read().splitlines()
    
features

['feature,mean_gain',
 'feature_06,0.02820574331691257',
 'feature_36,0.017297077900059965',
 'time_pos,0.012457771899014179',
 'feature_04,0.00906096991481456',
 'feature_75__rstd14,0.008986778901779562',
 'feature_60__cs_z,0.00792897680901437',
 'feature_59,0.007916575192094214',
 'time_cos,0.007066777767288255',
 'responder_0_close_roll30_std,0.006706852191886811',
 'feature_59__rstd30,0.006268706288162722',
 'feature_07,0.005985360017569628',
 'feature_61__lag900,0.004928096683256502',
 'feature_60,0.004492180408317801',
 'feature_61__lag1936,0.004278408984028489',
 'responder_6_prevday_std,0.0042709471169388065',
 'responder_8_prev_tail_lag10,0.004034050832642013',
 'feature_61__ret50,0.0038953944471289033',
 'feature_61__lag6776,0.003744272957590952',
 'feature_25__diff50,0.0037141317539464667',
 'feature_76__rstd7,0.0035787409336275845',
 'feature_48,0.003577339686310038',
 'responder_5_prevday_std,0.003519433258303491',
 'feature_60__rstd30,0.003512531538114583',
 'responder_5_

In [2]:
t_path = "az://jackson/js_exp/exp/v1/panel_shards/panel_0815_0844.parquet"

lx = pl.scan_parquet(t_path, storage_options=storage_options)
names = lx.collect_schema().names()
print(f"feat number: {len(names)}")

feat number: 814


In [None]:
# 统计行数
lb.select(pl.count()).collect()

# 数据预处理

## times + clipping

In [None]:
# Add time features
# 添加时间特征 bucket，将一天T ticks分为B部分
B = cfg['trading']['bucket_size']
T = cfg['trading']['ticks']

def clip_upper(expr: pl.Expr, ub: int) -> pl.Expr:
    return pl.when(expr > pl.lit(ub)).then(pl.lit(ub)).otherwise(expr)
lb = lb.with_columns(
    bucket_raw = pl.col('time_id') * pl.lit(B) // pl.lit(T) # 这里T 我们假设为全局常数，不分组计算
).with_columns(
    time_bucket = clip_upper(pl.col('bucket_raw'), B - 1).cast(pl.UInt8)
).drop(pl.col('bucket_raw'))



# Clipping

def rolling_sigma_clip(
    lf: pl.LazyFrame,
    clip_features: Sequence[str],
    over_cols: Sequence[str],
    *,
    is_sorted: bool = False,
    window: int = 50,
    k: float = 3.0,
    ddof: int = 1,
    min_valid: int = 10,
    cast_float32: bool = True,
    sanitize: bool = True,
) -> pl.LazyFrame:
    if not is_sorted:
        raise ValueError("Input LazyFrame must be pre-sorted by ['symbol_id','date_id','time_id']")

    required = {"symbol_id","date_id","time_id","time_bucket"} | set(clip_features)
    names = set(lf.collect_schema().names())
    missing = list(required - names)
    if missing:
        raise KeyError(f"Missing columns: {missing}")


    base = lf.select(pl.col(["symbol_id","date_id","time_id","time_bucket"] + list(clip_features)))
    min_need = max(min_valid, ddof + 1)
    min_samp = ddof + 1

    exprs = []
    for c in clip_features:
        x = pl.col(c)
        if cast_float32:
            x = x.cast(pl.Float32)
        if sanitize:
            x = pl.when(x.is_finite()).then(x).otherwise(None)

        # 注意：这里不要 over
        xlag = x.shift(1)

        # 只在 rolling 结果上 over（组内历史）
        cnt = (
            xlag.is_not_null()
                .cast(pl.Int32)
                .rolling_sum(window_size=window, min_samples=ddof + 1)
        ).over(over_cols)

        mu = (
            xlag.rolling_mean(window_size=window, min_samples=ddof + 1)
        ).over(over_cols)

        sd = (
            xlag.rolling_std(window_size=window, ddof=ddof, min_samples=ddof + 1)
        ).over(over_cols)

        lo, hi = mu - k * sd, mu + k * sd
        exprs.append(
            pl.when(cnt >= max(min_valid, ddof + 1))
            .then(x.clip(lo, hi))
            .otherwise(x)
            .alias(c)
        )

    return base.with_columns(exprs)


lb = lb.sort(KEYS)

lf_clip = rolling_sigma_clip(
    lf=lb,
    clip_features=FEATURE_ALL,
    over_cols=cfg['winsorization']['groupby'],
    is_sorted=True,
    window=cfg['winsorization']['window'],
    k=cfg['winsorization']['z_k'],
    ddof=cfg['winsorization']['ddof'],
    min_valid=cfg['winsorization']['min_valid'],
    cast_float32=cfg['winsorization']['cast_float32'],
    sanitize=cfg['winsorization'].get('sanitize', True)
)

lf_clip.collect_schema().names()

In [None]:

from pathlib import Path

clip_out = Path(P("local", cfg["paths"]["cache"])) / "sample_clipped.parquet"
clip_out.parent.mkdir(parents=True, exist_ok=True)

df = lf_clip.collect()  # 非流式；会把计划完整执行后落到内存
df.write_parquet(str(clip_out), compression="zstd")  # 可加 use_pyarrow=True

In [None]:
# 统计lf_clip行数
lf_clip.select(pl.count()).collect()




In [None]:
# Imputing
def causal_impute(
    lf: pl.LazyFrame,
    impute_cols: Sequence[str],
    *,
    open_tick_window: Tuple[int, int] = (0, 10),
    ttl_days_open: int = 5,
    intra_ffill_max_gap_ticks: Optional[int] = 100,
    ttl_days_same_tick: Optional[int] = 5,
    is_sorted: bool = False,
) -> pl.LazyFrame:
    if not is_sorted:
        raise ValueError("Input LazyFrame must be pre-sorted by ['symbol_id','date_id','time_id']")

    # 参数合法性
    assert intra_ffill_max_gap_ticks is None or intra_ffill_max_gap_ticks >= 0
    assert ttl_days_same_tick is None or ttl_days_same_tick >= 0

    # 统一 dtype（可选，但更稳）
    lf = lf.with_columns([pl.col(c).cast(pl.Float32) for c in impute_cols])
    
    
    t0, t1 = open_tick_window
    is_open = pl.col("time_id").is_between(t0, t1, closed="left")  # [t0, t1)

    # ---- 1) 开盘：跨日承接（TTL）----
    open_exprs = []
    for c in impute_cols:
        last_date = (
            pl.when(pl.col(c).is_not_null()).then(pl.col("date_id"))
            .forward_fill().over("symbol_id")
        )
        cand = pl.col(c).forward_fill().over("symbol_id")
        gap  = (pl.col("date_id") - last_date).cast(pl.Int32)
        open_exprs.append(
            pl.when(is_open 
                    & pl.col(c).is_null() 
                    & (gap.fill_null(ttl_days_open + 1) <= ttl_days_open))
            .then(cand)
            .otherwise(pl.col(c))
            .alias(c)
        )
    lf1 = lf.with_columns(open_exprs)

    # ---- 2) 日内 ffill（(symbol,date)），可限步数 ----
    if intra_ffill_max_gap_ticks is None:
        lf2 = lf1.with_columns([pl.col(c).forward_fill().over(["symbol_id","date_id"]).alias(c) for c in impute_cols])
    else:
        k = intra_ffill_max_gap_ticks
        exprs = []
        for c in impute_cols:
            last_t = (
                pl.when(pl.col(c).is_not_null()).then(pl.col("time_id"))
                .forward_fill().over(["symbol_id","date_id"])
            )
            cand = pl.col(c).forward_fill().over(["symbol_id","date_id"])
            gap  = (pl.col("time_id") - last_t).cast(pl.Int32)
            exprs.append(
                pl.when(pl.col(c).is_null() & (gap.fill_null(k + 1) <= k))
                .then(cand)
                .otherwise(pl.col(c))
                .alias(c)
            )
        lf2 = lf1.with_columns(exprs)

    # ---- 3) 同一 time_id 跨日承接（TTL，可选）----
    lf3 = lf2
    if ttl_days_same_tick is not None:
        d = ttl_days_same_tick
        exprs = []
        for c in impute_cols:
            last_date_same = (
                pl.when(pl.col(c).is_not_null()).then(pl.col("date_id"))
                .forward_fill().over(["symbol_id","time_id"])
            )
            cand_same = pl.col(c).forward_fill().over(["symbol_id","time_id"])
            gap2 = (pl.col("date_id") - last_date_same).cast(pl.Int32)
            exprs.append(
                pl.when(pl.col(c).is_null() & (gap2.fill_null(d + 1) <= d))
                .then(cand_same)
                .otherwise(pl.col(c))
                .alias(c)
            )
        lf3 = lf2.with_columns(exprs)

    # ---- 4) 再日内 ffill 传播（与步骤2同逻辑）----
    if intra_ffill_max_gap_ticks is None:
        lf4 = lf3.with_columns([pl.col(c).forward_fill().over(["symbol_id","date_id"]).alias(c) for c in impute_cols])
    else:
        k = intra_ffill_max_gap_ticks
        exprs = []
        for c in impute_cols:
            last_t = (
                pl.when(pl.col(c).is_not_null()).then(pl.col("time_id"))
                .forward_fill().over(["symbol_id","date_id"])
            )
            cand = pl.col(c).forward_fill().over(["symbol_id","date_id"])
            gap  = (pl.col("time_id") - last_t).cast(pl.Int32)
            exprs.append(
                pl.when(pl.col(c).is_null() & (gap.fill_null(k + 1) <= k))
                .then(cand)
                .otherwise(pl.col(c))
                .alias(c)
            )
        lf4 = lf3.with_columns(exprs)

    KEYS = ["symbol_id","date_id","time_id"]
    return lf4.select([*KEYS, *impute_cols])


clip_path = Path(P("local", cfg["paths"]["cache"])) / "sample_clipped.parquet"
lf_clip = pl.scan_parquet(str(clip_path))

lf_clip = lf_clip.sort(KEYS)
lf_imp = causal_impute(
    lf=lf_clip,
    impute_cols=FEATURE_ALL,
    open_tick_window=cfg['fill']['open_tick_window'],
    ttl_days_open=cfg['fill']['ttl_days_open'],
    intra_ffill_max_gap_ticks=cfg['fill']['intra_ffill_max_gap_ticks'],
    ttl_days_same_tick=cfg['fill']['ttl_days_same_tick'],
    is_sorted=True
)

assert lf_imp.select(pl.len()).collect().item() == lb.select(pl.len()).collect().item()
assert lf_imp.group_by(["symbol_id","date_id","time_id"]).len().filter(pl.col("len")>1).collect().height == 0

lf_imp = lf_imp.with_columns([pl.col(c).fill_null(0.0).alias(c) for c in FEATURE_ALL])


imp_path = Path(P("local", cfg["paths"]["cache"])) / "sample_imputed.parquet"

lf_imp.collect(streaming=True).write_parquet(str(imp_path), compression="zstd", use_pyarrow=True)  # 可加 use_pyarrow=True

In [None]:
# 统计lf_imp行数
lf_imp.select(pl.count()).collect()

In [None]:
# Merging

batch_size = cfg['fill']['batch_size']

# 右表：
rhs = (
    lb.select([*KEYS, WEIGHT, 'time_bucket', *RESP_COLS])
    .with_columns([pl.col(k).cast(pl.Int32) for k in KEYS]).sort(TIME_SORT)
)
print("Right table schema:", rhs.collect_schema())
print("row count:", rhs.select(pl.count()).collect())

# 左表
imp_path = Path(P("local", cfg["paths"]["cache"])) / "sample_imputed.parquet"
lf_imp = pl.scan_parquet(str(imp_path)).with_columns([pl.col(k).cast(pl.Int32) for k in KEYS])
lf_imp = lf_imp.sort(TIME_SORT)

print("Left table schema:", lf_imp.collect_schema())
print("row count:", lf_imp.select(pl.count()).collect())


dmin, dmax = (
    lf_imp.select(
        pl.col('date_id').min().alias('dmin'),
        pl.col('date_id').max().alias('dmax')
        )
    .collect()
    .row(0)
)
print(f"Date range: {dmin} to {dmax}, total {dmax - dmin + 1} days")

path = P('az', cfg['paths']['clean_shards'])
fs.makedirs(path, exist_ok=True)
print(f"Processing date range: {dmin} to {dmax}")

for lo in range(dmin, dmax + 1, batch_size):
    hi = min(lo + batch_size, dmax + 1)

    left = (
        lf_imp
        .filter(pl.col('date_id').is_between(lo, hi, closed='left'))
    )
    
    right = rhs.filter(pl.col('date_id').is_between(lo, hi, closed='left'))

    part = (left.join(right, on=TIME_SORT, how='left')).sort(TIME_SORT)

    # 命名时注意 hi 是排他的，所以文件名用 hi-1
    out_lo, out_hi = lo, hi - 1
    (
        part.sink_parquet(
            f"{path}/clean_{out_lo:04d}_{out_hi:04d}.parquet",
            compression="zstd",
            statistics=True,                 # 写入页/列统计划出更快
            storage_options=storage_options,
        )
    )


In [None]:
test_root = P("az", cfg["paths"]["clean_shards"])
test_path = f"{test_root}/clean_1000_1029.parquet"

lx = pl.scan_parquet(str(test_path), storage_options=storage_options)
result = lx.select(pl.col("weight").is_null().sum()).collect()
print(result)

In [None]:
lb.limit().collect()  # 试运行一下，看看数据长啥样

In [None]:
days = lb.select(pl.col("date_id").unique().sort()).collect(streaming=True)["date_id"].to_list()

In [None]:
np.diff(days)

# EDA

In [None]:
# EDA：每个交易日的symbol覆盖情况,是否覆盖全程
lf_s_d = lb.select(['date_id', 'symbol_id'])

per_date = (
    lf_s_d.group_by("date_id")
      .agg(pl.col("symbol_id").n_unique().alias("n_symbols"))
      .sort("date_id")
)

max_n = per_date.select(pl.max("n_symbols")).collect().item()
summary = per_date.with_columns([
    pl.lit(max_n).alias("max_n"),
    (pl.col("n_symbols") == max_n).alias("is_full_universe")
])

dates_missing = summary.filter(pl.col("is_full_universe") == False).select("date_id")
# summary.collect(); dates_missing.collect()

dates_missing.collect()

In [None]:
# 先选一个横向，纵向都比较小的样本,按照data_id来选一小块快速试验

ls = lb.filter(pl.col('symbol_id').is_in([1,2,3,4,5]) & pl.col('date_id').is_in([1400,1420]))


# 数据预处理

添加time bucket, 将日分片

In [None]:
# 添加时间特征 bucket，将一天T ticks分为B部分
B = cfg['trading']['bucket_size']
T = cfg['trading']['ticks']

def clip_upper(expr: pl.Expr, ub: int) -> pl.Expr:
    return pl.when(expr > pl.lit(ub)).then(pl.lit(ub)).otherwise(expr)
lb = lb.with_columns(
    bucket_raw = pl.col('time_id') * pl.lit(B) // pl.lit(T) # 这里T 我们假设为全局常数，不分组计算
).with_columns(
    time_bucket = clip_upper(pl.col('bucket_raw'), B - 1).cast(pl.UInt8)
).drop(pl.col('bucket_raw'))


Clip

In [None]:
def rolling_sigma_clip(
    lf: pl.LazyFrame,
    clip_features: Sequence[str],
    over_cols: Sequence[str],
    *,
    is_sorted: bool = False,
    window: int = 50,
    k: float = 3.0,
    ddof: int = 1,
    min_valid: int = 10,
    cast_float32: bool = True,
    sanitize: bool = True,
) -> pl.LazyFrame:
    if not is_sorted:
        raise ValueError("Input LazyFrame must be pre-sorted by ['symbol_id','date_id','time_id']")

    required = {"symbol_id","date_id","time_id","time_bucket"} | set(clip_features)
    names = set(lf.collect_schema().names())
    missing = list(required - names)
    if missing:
        raise KeyError(f"Missing columns: {missing}")


    base = lf.select(pl.col(["symbol_id","date_id","time_id","time_bucket"] + list(clip_features)))
    min_need = max(min_valid, ddof + 1)
    min_samp = ddof + 1

    exprs = []
    for c in clip_features:
        x = pl.col(c)
        if cast_float32:
            x = x.cast(pl.Float32)
        if sanitize:
            x = pl.when(x.is_finite()).then(x).otherwise(None)

        # 注意：这里不要 over
        xlag = x.shift(1)

        # 只在 rolling 结果上 over（组内历史）
        cnt = (
            xlag.is_not_null()
                .cast(pl.Int32)
                .rolling_sum(window_size=window, min_samples=ddof + 1)
        ).over(over_cols)

        mu = (
            xlag.rolling_mean(window_size=window, min_samples=ddof + 1)
        ).over(over_cols)

        sd = (
            xlag.rolling_std(window_size=window, ddof=ddof, min_samples=ddof + 1)
        ).over(over_cols)

        lo, hi = mu - k * sd, mu + k * sd
        exprs.append(
            pl.when(cnt >= max(min_valid, ddof + 1))
            .then(x.clip(lo, hi))
            .otherwise(x)
            .alias(c)
        )

    return base.with_columns(exprs)


lb = lb.sort(KEYS)

lf_clip = rolling_sigma_clip(
    lf=lb,
    clip_features=FEATURE_ALL,
    over_cols=cfg['winsorization']['groupby'],
    is_sorted=True,
    window=cfg['winsorization']['window'],
    k=cfg['winsorization']['z_k'],
    ddof=cfg['winsorization']['ddof'],
    min_valid=cfg['winsorization']['min_valid'],
    cast_float32=cfg['winsorization']['cast_float32'],
    sanitize=cfg['winsorization'].get('sanitize', True)
)

In [None]:
from pathlib import Path

clip_out = Path(P("local", cfg["paths"]["cache"])) / "sample_clipped.parquet"
clip_out.parent.mkdir(parents=True, exist_ok=True)

df = lf_clip.collect()  # 非流式；会把计划完整执行后落到内存
df.write_parquet(str(clip_out), compression="zstd")  # 可加 use_pyarrow=True

Impute

In [None]:
def causal_impute(
    lf: pl.LazyFrame,
    impute_cols: Sequence[str],
    *,
    open_tick_window: Tuple[int, int] = (0, 10),
    ttl_days_open: int = 5,
    intra_ffill_max_gap_ticks: Optional[int] = 100,
    ttl_days_same_tick: Optional[int] = 5,
    is_sorted: bool = False,
) -> pl.LazyFrame:
    if not is_sorted:
        raise ValueError("Input LazyFrame must be pre-sorted by ['symbol_id','date_id','time_id']")

    # 参数合法性
    assert intra_ffill_max_gap_ticks is None or intra_ffill_max_gap_ticks >= 0
    assert ttl_days_same_tick is None or ttl_days_same_tick >= 0

    # 统一 dtype（可选，但更稳）
    lf = lf.with_columns([pl.col(c).cast(pl.Float32) for c in impute_cols])
    
    
    t0, t1 = open_tick_window
    is_open = pl.col("time_id").is_between(t0, t1, closed="left")  # [t0, t1)

    # ---- 1) 开盘：跨日承接（TTL）----
    open_exprs = []
    for c in impute_cols:
        last_date = (
            pl.when(pl.col(c).is_not_null()).then(pl.col("date_id"))
            .forward_fill().over("symbol_id")
        )
        cand = pl.col(c).forward_fill().over("symbol_id")
        gap  = (pl.col("date_id") - last_date).cast(pl.Int32)
        open_exprs.append(
            pl.when(is_open 
                    & pl.col(c).is_null() 
                    & (gap.fill_null(ttl_days_open + 1) <= ttl_days_open))
            .then(cand)
            .otherwise(pl.col(c))
            .alias(c)
        )
    lf1 = lf.with_columns(open_exprs)

    # ---- 2) 日内 ffill（(symbol,date)），可限步数 ----
    if intra_ffill_max_gap_ticks is None:
        lf2 = lf1.with_columns([pl.col(c).forward_fill().over(["symbol_id","date_id"]).alias(c) for c in impute_cols])
    else:
        k = intra_ffill_max_gap_ticks
        exprs = []
        for c in impute_cols:
            last_t = (
                pl.when(pl.col(c).is_not_null()).then(pl.col("time_id"))
                .forward_fill().over(["symbol_id","date_id"])
            )
            cand = pl.col(c).forward_fill().over(["symbol_id","date_id"])
            gap  = (pl.col("time_id") - last_t).cast(pl.Int32)
            exprs.append(
                pl.when(pl.col(c).is_null() & (gap.fill_null(k + 1) <= k))
                .then(cand)
                .otherwise(pl.col(c))
                .alias(c)
            )
        lf2 = lf1.with_columns(exprs)

    # ---- 3) 同一 time_id 跨日承接（TTL，可选）----
    lf3 = lf2
    if ttl_days_same_tick is not None:
        d = ttl_days_same_tick
        exprs = []
        for c in impute_cols:
            last_date_same = (
                pl.when(pl.col(c).is_not_null()).then(pl.col("date_id"))
                .forward_fill().over(["symbol_id","time_id"])
            )
            cand_same = pl.col(c).forward_fill().over(["symbol_id","time_id"])
            gap2 = (pl.col("date_id") - last_date_same).cast(pl.Int32)
            exprs.append(
                pl.when(pl.col(c).is_null() & (gap2.fill_null(d + 1) <= d))
                .then(cand_same)
                .otherwise(pl.col(c))
                .alias(c)
            )
        lf3 = lf2.with_columns(exprs)

    # ---- 4) 再日内 ffill 传播（与步骤2同逻辑）----
    if intra_ffill_max_gap_ticks is None:
        lf4 = lf3.with_columns([pl.col(c).forward_fill().over(["symbol_id","date_id"]).alias(c) for c in impute_cols])
    else:
        k = intra_ffill_max_gap_ticks
        exprs = []
        for c in impute_cols:
            last_t = (
                pl.when(pl.col(c).is_not_null()).then(pl.col("time_id"))
                .forward_fill().over(["symbol_id","date_id"])
            )
            cand = pl.col(c).forward_fill().over(["symbol_id","date_id"])
            gap  = (pl.col("time_id") - last_t).cast(pl.Int32)
            exprs.append(
                pl.when(pl.col(c).is_null() & (gap.fill_null(k + 1) <= k))
                .then(cand)
                .otherwise(pl.col(c))
                .alias(c)
            )
        lf4 = lf3.with_columns(exprs)

    KEYS = ["symbol_id","date_id","time_id"]
    return lf4.select([*KEYS, *impute_cols])

In [None]:
clip_path = Path(P("local", cfg["paths"]["cache"])) / "sample_clipped.parquet"
lf_clip = pl.scan_parquet(str(clip_path))

lf_clip = lf_clip.sort(KEYS)
lf_imp = causal_impute(
    lf=lf_clip,
    impute_cols=FEATURE_ALL,
    open_tick_window=cfg['fill']['open_tick_window'],
    ttl_days_open=cfg['fill']['ttl_days_open'],
    intra_ffill_max_gap_ticks=cfg['fill']['intra_ffill_max_gap_ticks'],
    ttl_days_same_tick=cfg['fill']['ttl_days_same_tick'],
    is_sorted=True
)

In [None]:
# 查看缺失情况

pre_null = lf_clip.select([pl.col(c).is_null().mean().alias(c) for c in FEATURE_ALL]).collect()
post_null = lf_imp.select([pl.col(c).is_null().mean().alias(c) for c in FEATURE_ALL]).collect()
# melt 后拼起来看变化
pre_m = pre_null.melt(variable_name="feature", value_name="pre_null")
post_m = post_null.melt(variable_name="feature", value_name="post_null")
summary = pre_m.join(post_m, on="feature").with_columns(
    (pl.col("pre_null") - pl.col("post_null")).alias("filled_delta")
).sort("post_null", descending=True)
summary  # post_null 高的基本就是那 ~20 列

In [None]:
# 按日期看缺失是否“早期更高”（排除冷启动特征）
cols = ["feature_21","feature_26"]  # 换成你缺失最高的几列
by_date = (
    lf_imp.group_by("date_id")
          .agg([pl.col(c).is_null().mean().alias(c) for c in cols])
          .sort("date_id")
          .collect()
)
by_date  # 冷启动型会在较早的 date_id 更高，然后趋于稳定


In [None]:
# 这天到底有多少 symbol 是“整天全空”？
c = "feature_21"

daily_all_null_share = (
    lf_imp
    .group_by(["date_id","symbol_id"])
    .agg(pl.col(c).is_null().mean().alias("null_rate"))
    .with_columns((pl.col("null_rate") == 1).alias("all_null_day"))
    .group_by("date_id")
    .agg(pl.mean("all_null_day").alias("share_symbols_all_null"))
    .sort("date_id")
    .collect()
)
daily_all_null_share.filter(pl.col("share_symbols_all_null") != 0)

#找到原因， 是在中间的一些日期， 有1-2个symbol_id的一些特征列缺失 1/32, 1/33. 1/34 ~0.029-0.032

In [None]:
assert lf_imp.select(pl.len()).collect().item() == lb.select(pl.len()).collect().item()
assert lf_imp.group_by(["symbol_id","date_id","time_id"]).len().filter(pl.col("len")>1).collect().height == 0

lf_imp = lf_imp.with_columns([pl.col(c).fill_null(0.0).alias(c) for c in FEATURE_ALL])

# 再次查看缺失情况
post_null2 = lf_imp.select([pl.col(c).is_null().mean().alias(c) for c in FEATURE_ALL]).collect()
post_null2  # 应该全0了

In [None]:
imp_path = Path(P("local", cfg["paths"]["cache"])) / "sample_imputed.parquet"

lf_imp.collect(streaming=True).write_parquet(str(imp_path), compression="zstd", use_pyarrow=True)  # 可加 use_pyarrow=True


# 合并响应变量


batch_size = cfg['fill']['batch_size']

# 右表：去重 + 对齐类型
rhs = (
    lb.select([*KEYS, WEIGHT, 'time_bucket', *RESP_COLS])
    .with_columns([pl.col(k).cast(pl.Int32) for k in KEYS])
)


# 左表
imp_path = Path(P("local", cfg["paths"]["cache"])) / "sample_imputed.parquet"
lf_imp = pl.scan_parquet(str(imp_path)).with_columns([pl.col(k).cast(pl.Int32) for k in KEYS])

dmin, dmax = (
    lf_imp.select(
        pl.col('date_id').min().alias('dmin'),
        pl.col('date_id').max().alias('dmax')
        )
    .collect()
    .row(0)
)

path = P('az', cfg['paths']['clean_shards'])
fs.makedirs(path, exist_ok=True)
print(f"Processing date range: {dmin} to {dmax}")


for lo in range(dmin, dmax + 1, batch_size):
    hi = min(lo + batch_size, dmax + 1)

    left = (
        lf_imp
        .filter(pl.col('date_id').is_between(lo, hi, closed='left'))
    )
    right = rhs.filter(pl.col('date_id').is_between(lo, hi, closed='left')).sort(TIME_SORT).unique(subset=KEYS, keep='last')

    part = (left.join(right, on=KEYS, how='left')).sort(TIME_SORT)
    
    feature_cols = [c for c in part.collect_schema().names() if c not in set([*KEYS, WEIGHT, 'time_bucket', *RESP_COLS])]
    part = part.select([*KEYS, WEIGHT, 'time_bucket', *feature_cols,  *RESP_COLS])


    # 命名时注意 hi 是排他的，所以文件名用 hi-1
    out_lo, out_hi = lo, hi - 1
    (
        part.sink_parquet(
            f"{path}/clean_{out_lo:04d}_{out_hi:04d}.parquet",
            compression="zstd",
            statistics=True,                 # 写入页/列统计划出更快
            storage_options=storage_options,
        )
    )


In [None]:
clean_root = P('az', cfg['paths']['clean_shards'])
clean_paths = fs.glob(f"{clean_root}/clean_*.parquet")
print(f"Found {len(clean_paths)} clean shards")


# 特征工程函数

In [None]:
# 特征工程

# A：响应列的“上一日尾部/日度摘要”
def fe_resp_daily(
    lf: pl.LazyFrame,
    *,
    keys: Tuple[str, str, str] = ("symbol_id","date_id","time_id"),
    rep_cols: Sequence[str],
    is_sorted: bool = False,
    prev_soft_days: Optional[int] = None,
    cast_f32: bool = True,
    tail_lags: Sequence[int] = (1,),
    tail_diffs: Sequence[int] = (1,),
    rolling_windows: Sequence[int] | None = (3,),
) -> pl.LazyFrame:
    """一次日频聚合得到昨日尾部与日级摘要 → 统一 TTL 到“对 d 生效的历史值” → 回拼到 tick 级。"""
    g_symbol, g_date, g_time = keys

    # 若未保证排序，这里补一次（只影响 lf；日频表仍会再按 (symbol,date) 排）
    if not is_sorted:
        lf = lf.sort([g_symbol, g_date, g_time])

    # --- 一次性日频聚合 ---
    need_L = sorted(set(tail_lags) | {k+1 for k in tail_diffs} | {1})
    agg_exprs: list[pl.Expr] = []
    for r in rep_cols:
        # 尾部倒数第 L（长度不足 L → null）
        for L in need_L:
            agg_exprs.append(
                pl.when(pl.len() >= L)
                .then(pl.col(r).sort_by(pl.col(g_time)).tail(L).first())
                .otherwise(None)
                .alias(f"{r}_prev_tail_lag{L}")
            )
        # 当日统计（显式补上 prevday_close）
        agg_exprs += [
            pl.col(r).sort_by(pl.col(g_time)).last().alias(f"{r}_prevday_close"),
            pl.col(r).mean().alias(f"{r}_prevday_mean"),
            pl.col(r).std(ddof=1).alias(f"{r}_prevday_std"),
        ]

    daily = (
        lf.group_by([g_symbol, g_date])
        .agg(agg_exprs)
        .sort([g_symbol, g_date])                # 供下面 shift/ffill 正确运行
    )

    # 派生（当日）dK：last - (K+1 from end)
    daily = daily.with_columns([
        (pl.col(f"{r}_prev_tail_lag1") - pl.col(f"{r}_prev_tail_lag{K+1}")).alias(f"{r}_prev_tail_d{K}")
        for r in rep_cols for K in tail_diffs
        if f"{r}_prev_tail_lag{K+1}" in daily.collect_schema().names()
    ])

    # prev2day/overnight/rolling（仍是“当日相对”的量）
    daily = daily.with_columns([
        pl.col(f"{r}_prevday_close").shift(1).over(g_symbol).alias(f"{r}_prev2day_close")
        for r in rep_cols
    ]).with_columns(
        [
            (pl.col(f"{r}_prevday_close") - pl.col(f"{r}_prevday_mean")).alias(f"{r}_prevday_close_minus_mean")
            for r in rep_cols
        ] + [
            (pl.col(f"{r}_prevday_close") - pl.col(f"{r}_prev2day_close")).alias(f"{r}_overnight_gap")
            for r in rep_cols
        ]
    )

    if rolling_windows:
        wins = sorted({int(w) for w in rolling_windows if int(w) > 1})
        roll_exprs: list[pl.Expr] = []
        for r in rep_cols:
            base = pl.col(f"{r}_prevday_close")
            for w in wins:
                roll_exprs += [
                    base.rolling_mean(window_size=w, min_samples=1).over(g_symbol)
                        .alias(f"{r}_close_roll{w}_mean"),
                    base.rolling_std(window_size=w, ddof=1, min_samples=2).over(g_symbol)
                        .alias(f"{r}_close_roll{w}_std"),
                ]
        daily = daily.with_columns(roll_exprs)

    # === 核心：将上面的“当日统计/尾部衍生列”转换为“对 d 生效的历史 TTL 值” ===
    prev_cols = [c for c in daily.collect_schema().names() if c not in (g_symbol, g_date)]
    exprs: list[pl.Expr] = []
    for c in prev_cols:
        # 最近一次（发生在当前日之前）的非空日期与值
        last_non_null_day = (
            pl.when(pl.col(c).is_not_null()).then(pl.col(g_date)).otherwise(None)
            .forward_fill().over(g_symbol)
            .shift(1)
        )
        last_non_null_val = pl.col(c).forward_fill().over(g_symbol).shift(1)

        if prev_soft_days is None:
            resolved = last_non_null_val  # 无限 TTL：总取最近一次历史非空
        else:
            gap_days = (pl.col(g_date) - last_non_null_day).cast(pl.Int32)
            resolved = pl.when(gap_days.is_not_null() & (gap_days <= int(prev_soft_days))) \
                        .then(last_non_null_val) \
                        .otherwise(None)

        if cast_f32:
            resolved = resolved.cast(pl.Float32)
        exprs.append(resolved.alias(c))    # 列名不变，语义已是“对 d 生效的历史值”

    daily_prev = daily.with_columns(exprs)

    # 回拼到 tick 级（左连），并固定顺序（可选）
    out = lf.join(daily_prev, on=[g_symbol, g_date], how="left")
    out = out.sort([g_symbol, g_date, g_time])
    return out



# B：同 time_id 跨日的 prev{k} + 统计
def fe_resp_same_tick_xday(
    lf: pl.LazyFrame,
    *,
    keys: Tuple[str,str,str] = ("symbol_id","date_id","time_id"),
    rep_cols: Sequence[str],
    is_sorted: bool = False,
    prev_soft_days: Optional[int] = None,   # None=严格d-k；整数=TTL
    cast_f32: bool = True,
    ndays: int = 5,
    stats_rep_cols: Optional[Sequence[str]] = None,
    add_prev1_multirep: bool = True,
    batch_size: int = 5,
) -> pl.LazyFrame:
    
    g_symbol, g_date, g_time = keys

    # 保证 (symbol,time) 组内按 date 递增（shift(k).over([symbol,time]) 的因果顺序）
    if not is_sorted:
        lf = lf.sort([g_symbol, g_time, g_date]) # 注意不是date, time

    if stats_rep_cols is None:
        stats_rep_cols = list(rep_cols)

    def _chunks(lst, k):
        for i in range(0, len(lst), k):
            yield lst[i:i+k]

    lf_cur = lf

    # 1) prev{k} with strict / TTL
    for batch in _chunks(list(rep_cols), batch_size):
        exprs = []
        for r in batch:
            for k in range(1, ndays + 1):
                val_k  = pl.col(r).shift(k).over([g_symbol, g_time])
                day_k  = pl.col(g_date).shift(k).over([g_symbol, g_time])
                gap_k  = (pl.col(g_date) - day_k).cast(pl.Int32)

                if prev_soft_days is None:
                    # 严格 d-k：gap==k
                    keep = gap_k.is_not_null() & (gap_k == k)
                else:
                    # TTL：只要在当前日之前，且 gap<=K
                    keep = gap_k.is_not_null() & (gap_k > 0) & (gap_k <= int(prev_soft_days))

                val_k = pl.when(keep).then(val_k).otherwise(None)
                if cast_f32:
                    val_k = val_k.cast(pl.Float32)
                exprs.append(val_k.alias(f"{r}_same_t_prev{k}"))
        lf_cur = lf_cur.with_columns(exprs)

    # 2) mean/std（忽略 null）
    for batch in _chunks([r for r in stats_rep_cols if r in rep_cols], batch_size):
        exprs = []
        for r in batch:
            cols = [f"{r}_same_t_prev{k}" for k in range(1, ndays + 1)]
            vals = pl.concat_list([pl.col(c) for c in cols]).list.drop_nulls()
            m = vals.list.mean()
            s = vals.list.std(ddof=1)   # 和全局统计一致
            if cast_f32:
                m = m.cast(pl.Float32); s = s.cast(pl.Float32)
            exprs += [
                m.alias(f"{r}_same_t_last{ndays}_mean"),
                s.alias(f"{r}_same_t_last{ndays}_std"),
            ]
        lf_cur = lf_cur.with_columns(exprs)

    # 3) slope：时间方向设为“最近为正、久远为负”（正=近期上升）
    x = np.arange(ndays, 0, -1, dtype=np.float64)
    x = (x - x.mean()) / (x.std() + 1e-9)
    x_lits = [pl.lit(float(v)) for v in x]

    for batch in _chunks([r for r in stats_rep_cols if r in rep_cols], batch_size):
        exprs = []
        for r in batch:
            cols = [f"{r}_same_t_prev{k}" for k in range(1, ndays + 1)]
            mean_ref = pl.col(f"{r}_same_t_last{ndays}_mean")
            std_ref  = pl.col(f"{r}_same_t_last{ndays}_std")
            terms = [((pl.col(c) - mean_ref) / (std_ref + 1e-9)) * x_lits[i]
                    for i, c in enumerate(cols)]
            # ——更稳：对 null 显式置 0，避免某些版本 sum_horizontal 因 null 变 null
            terms = [pl.when(pl.col(c).is_not_null() & mean_ref.is_not_null() & std_ref.is_not_null())
                    .then(t).otherwise(pl.lit(0.0)) for t, c in zip(terms, cols)]

            n_eff = pl.sum_horizontal([pl.col(c).is_not_null().cast(pl.Int32) for c in cols]).cast(pl.Float32)
            den   = pl.when(n_eff > 0).then(n_eff).otherwise(pl.lit(1.0))
            slope = pl.sum_horizontal(terms) / den
            if cast_f32:
                slope = slope.cast(pl.Float32)
            exprs.append(slope.alias(f"{r}_same_t_last{ndays}_slope"))
        lf_cur = lf_cur.with_columns(exprs)

    # 4) 跨 responder 的 prev1 行内统计（可选）
    if add_prev1_multirep and len(rep_cols) > 0:
        n_rep = len(rep_cols)  
        prev1_cols = [f"{r}_same_t_prev1" for r in rep_cols]
        prev1_list = pl.concat_list([pl.col(c) for c in prev1_cols]).list.drop_nulls()
        m1 = prev1_list.list.mean()
        s1 = prev1_list.list.std(ddof=1)
        if cast_f32:
            m1 = m1.cast(pl.Float32); s1 = s1.cast(pl.Float32)
        lf_cur = lf_cur.with_columns([
            m1.alias(f"prev1_same_t_mean_{n_rep}rep"),
            s1.alias(f"prev1_same_t_std_{n_rep}rep"),
        ])

    # 出口保持有序，便于后续 C 阶段 shift/rolling
    lf_cur = lf_cur.sort([g_symbol, g_date, g_time])
    return lf_cur




# C 系列：

def fe_feat_history(
    *,
    lf: pl.LazyFrame,
    keys: Tuple[str,str,str] = ("symbol_id","date_id","time_id"),
    feature_cols: Sequence[str],
    is_sorted: bool = False,
    prev_soft_days: Optional[int] = None,
    cast_f32: bool = True,
    batch_size: int = 10,
    lags: Iterable[int] = (1, 3),
    ret_periods: Iterable[int] = (1,),
    diff_periods: Iterable[int] = (1,),
    rz_windows: Iterable[int] = (5,),
    ewm_spans: Iterable[int] = (10,),
    keep_rmean_rstd: bool = True,
    cs_cols: Optional[Sequence[str]] = None,
) -> pl.LazyFrame:
    
    g_sym, g_date, g_time = keys
    
    by_grp = [g_sym]
    by_cs  = [g_date, g_time]

    need_cols = [*keys, *feature_cols]
    schema = lf.collect_schema().names()
    miss = [c for c in need_cols if c not in schema]
    if miss:
        raise KeyError(f"Columns not found: {miss}")

    lf_out = lf.select(need_cols)
    if not is_sorted:
        lf_out = lf_out.sort(list(keys))

    def _chunks(lst, k):
        for i in range(0, len(lst), k):
            yield lst[i:i+k]

    # ---- 规范化参数：None/[] -> 空元组；并去重/转 int/保正数 ----
    def _clean_pos_sorted_unique(x):
        if not x:
            return tuple()
        return tuple(sorted({int(v) for v in x if int(v) >= 1}))

    LAGS   = _clean_pos_sorted_unique(lags)
    K_RET  = _clean_pos_sorted_unique(ret_periods)
    K_DIFF = _clean_pos_sorted_unique(diff_periods)
    RZW    = _clean_pos_sorted_unique(rz_windows)
    SPANS  = _clean_pos_sorted_unique(ewm_spans)

    # C1 lags
    if LAGS:
        for batch in _chunks(feature_cols, batch_size):
            exprs = []
            for L in LAGS:
                last_date_L = pl.col(g_date).shift(L).over(by_grp)
                gap_L = (pl.col(g_date) - last_date_L).cast(pl.Int32)
                if prev_soft_days is not None:
                    keep_L = gap_L.is_not_null() & (gap_L > 0) & (gap_L <= pl.lit(int(prev_soft_days)))
                for c in batch:
                    e = pl.col(c).shift(L).over(by_grp)
                    if prev_soft_days is not None:
                        e = pl.when(keep_L).then(e).otherwise(None)
                    if cast_f32:
                        e = e.cast(pl.Float32)
                    exprs.append(e.alias(f"{c}__lag{L}"))
            lf_out = lf_out.with_columns(exprs)

    # C2 returns（可选）
    if K_RET:
        for batch in _chunks(feature_cols, batch_size):
            exprs = []
            for c in batch:
                cur = pl.col(c)
                for k in K_RET:
                    if k in LAGS:
                        prev = pl.col(f"{c}__lag{k}")  # 已含 TTL
                    else:
                        prev = pl.col(c).shift(k).over(by_grp)
                        if prev_soft_days is not None:
                            last_date_k = pl.col(g_date).shift(k).over(by_grp)
                            gap_k = (pl.col(g_date) - last_date_k).cast(pl.Int32)
                            keep_k = gap_k.is_not_null() & (gap_k > 0) & (gap_k <= pl.lit(int(prev_soft_days)))
                            prev = pl.when(keep_k).then(prev).otherwise(None)
                    ret = pl.when(prev.is_not_null() & (prev.abs() > 1e-12)).then(cur / prev - 1.0).otherwise(None)
                    if cast_f32:
                        ret = ret.cast(pl.Float32)
                    exprs.append(ret.alias(f"{c}__ret{k}"))
            lf_out = lf_out.with_columns(exprs)


    # C3 diffs（可选）
    if K_DIFF:
        for batch in _chunks(feature_cols, batch_size):
            exprs = []
            for c in batch:
                cur = pl.col(c)
                for k in K_DIFF:
                    if k in LAGS:
                        prevk = pl.col(f"{c}__lag{k}")  # 已含 TTL
                    else:
                        prevk = pl.col(c).shift(k).over(by_grp)
                        if prev_soft_days is not None:
                            last_date_k = pl.col(g_date).shift(k).over(by_grp)
                            gap_k = (pl.col(g_date) - last_date_k).cast(pl.Int32)
                            keep_k = gap_k.is_not_null() & (gap_k > 0) & (gap_k <= pl.lit(int(prev_soft_days)))
                            prevk = pl.when(keep_k).then(prevk).otherwise(None)
                    d = pl.when(prevk.is_not_null()).then(cur - prevk).otherwise(None)
                    if cast_f32:
                        d = d.cast(pl.Float32)
                    exprs.append(d.alias(f"{c}__diff{k}"))
            lf_out = lf_out.with_columns(exprs)



    # C4 rolling r-z
    if RZW:
        for batch in _chunks(feature_cols, batch_size):
            exprs_base = []
            # 统一构造 t-1 的基准值（含 TTL 掩码）
            if prev_soft_days is not None:
                last_date_1 = pl.col(g_date).shift(1).over(by_grp)
                gap_1 = (pl.col(g_date) - last_date_1).cast(pl.Int32)
                keep_1 = gap_1.is_not_null() & (gap_1 > 0) & (gap_1 <= pl.lit(int(prev_soft_days)))

            for c in batch:
                # 若之前已在 C1 产出 __lag1，可直接用： base = pl.col(f"{c}__lag1")
                base = pl.col(c).shift(1).over(by_grp)
                if prev_soft_days is not None:
                    base = pl.when(keep_1).then(base).otherwise(None)
                exprs_base.append(base.alias(f"{c}__tminus1_base"))
            lf_out = lf_out.with_columns(exprs_base)

            # 真正的 rolling r-z
            roll_exprs = []
            for c in batch:
                base = pl.col(f"{c}__tminus1_base")
                for w in RZW:
                    m  = base.rolling_mean(window_size=w, min_samples=1).over(by_grp)
                    s  = base.rolling_std(window_size=w, ddof=1, min_samples=2).over(by_grp)  # 统一 ddof=1
                    den = (s.fill_null(0.0) + 1e-9)
                    rz = (base - m) / den
                    if cast_f32:
                        m = m.cast(pl.Float32); s = s.cast(pl.Float32); rz = rz.cast(pl.Float32)
                    if keep_rmean_rstd:
                        roll_exprs += [
                            m.alias(f"{c}__rmean{w}"),
                            s.alias(f"{c}__rstd{w}"),
                            rz.alias(f"{c}__rz{w}"),
                        ]
                    else:
                        roll_exprs.append(rz.alias(f"{c}__rz{w}"))
            lf_out = lf_out.with_columns(roll_exprs)
            lf_out = lf_out.drop([f"{c}__tminus1_base" for c in batch])


    # C5 EWM（可选）
    if SPANS:
        for batch in _chunks(feature_cols, batch_size):
            exprs_base = []

            # TTL 掩码（t-1）
            if prev_soft_days is not None:
                last_date_1 = pl.col(g_date).shift(1).over(by_grp)
                gap_1 = (pl.col(g_date) - last_date_1).cast(pl.Int32)
                keep_1 = gap_1.is_not_null() & (gap_1 > 0) & (gap_1 <= pl.lit(int(prev_soft_days)))

            # 构造 t-1 基准（若你已在 C1 产出 __lag1，可以直接用它替代下面两行）
            for c in batch:
                base = pl.col(c).shift(1).over(by_grp)
                if prev_soft_days is not None:
                    base = pl.when(keep_1).then(base).otherwise(None)
                exprs_base.append(base.alias(f"{c}__tminus1_base"))
            lf_out = lf_out.with_columns(exprs_base)

            # 计算 EWM
            ewm_exprs = []
            for c in batch:
                base = pl.col(f"{c}__tminus1_base")
                for s in SPANS:
                    ema = base.ewm_mean(span=int(s), adjust=False, ignore_nulls=True).over(by_grp)
                    if cast_f32:
                        ema = ema.cast(pl.Float32)
                    ewm_exprs.append(ema.alias(f"{c}__ewm{s}"))
            lf_out = lf_out.with_columns(ewm_exprs)

            # 清理临时列
            lf_out = lf_out.drop([f"{c}__tminus1_base" for c in batch])


    # C6 cross-section rank（可选）
    if cs_cols:
        cs_cols = [c for c in cs_cols if c in feature_cols]
        if cs_cols:

            # TTL 掩码（t-1）
            if prev_soft_days is not None:
                last_date_1 = pl.col(g_date).shift(1).over(by_grp)
                gap_1 = (pl.col(g_date) - last_date_1).cast(pl.Int32)
                keep_1 = gap_1.is_not_null() & (gap_1 > 0) & (gap_1 <= pl.lit(int(prev_soft_days)))

            # 先构造每列的 t-1 基准（含 TTL）
            exprs_base = []
            for c in cs_cols:
                base = pl.col(c).shift(1).over(by_grp)
                if prev_soft_days is not None:
                    base = pl.when(keep_1).then(base).otherwise(None)
                exprs_base.append(base.alias(f"{c}__tminus1_base"))
            lf_out = lf_out.with_columns(exprs_base)

            # 基于 t-1：截面 z 与 rank(0..1)
            cs_exprs = []
            for c in cs_cols:
                base = pl.col(f"{c}__tminus1_base")

                # 截面统计（只用该列的 t-1）
                n_valid = base.is_not_null().cast(pl.Int32).sum().over(by_cs)
                mu      = base.mean().over(by_cs)
                sig     = base.std(ddof=1).over(by_cs)

                # z-score（数值更稳：sig.fill_null(0)+eps）
                z = ((base - mu) / (sig.fill_null(0.0) + 1e-9)) \
                        .cast(pl.Float32 if cast_f32 else pl.Float64)

                # 百分位排名：空→None；n=1→0.5
                rank_raw = base.rank(method="average").over(by_cs)
                csrank = pl.when(base.is_null()).then(None).otherwise(
                    pl.when(n_valid > 1)
                    .then((rank_raw - 0.5) / n_valid.cast(pl.Float32))
                    .otherwise(pl.lit(0.5))
                ).cast(pl.Float32 if cast_f32 else pl.Float64)

                cs_exprs += [z.alias(f"{c}__cs_z"), csrank.alias(f"{c}__csrank")]

            lf_out = lf_out.with_columns(cs_exprs)

            # 清理临时列
            lf_out = lf_out.drop([f"{c}__tminus1_base" for c in cs_cols])
    return lf_out

   
@dataclass
class StageA:
    tail_lags: Sequence[int]
    tail_diffs: Sequence[int]
    rolling_windows: Optional[Sequence[int]]
    prev_soft_days: Optional[int] = None
    is_sorted: bool = False
    cast_f32: bool = True

@dataclass
class StageB:
    ndays: int
    stats_rep_cols: Optional[Sequence[str]] = None
    add_prev1_multirep: bool = True
    batch_size: int = 5
    prev_soft_days: Optional[int] = None
    is_sorted: bool = False
    cast_f32: bool = True

# C 的每个操作可选；None / [] 表示跳过该操作
@dataclass
class StageC:
    lags: Optional[Iterable[int]] = None
    ret_periods: Optional[Iterable[int]] = None
    diff_periods: Optional[Iterable[int]] = None
    rz_windows: Optional[Iterable[int]] = None
    ewm_spans: Optional[Iterable[int]] = None
    cs_cols: Optional[Sequence[str]] = None
    keep_rmean_rstd: bool = True
    prev_soft_days: Optional[int] = None
    batch_size: Optional[int] = 10
    is_sorted: bool = False
    cast_f32: bool = True
    

def assert_time_monotone(path, *, date_col="date_id", time_col="time_id"):
    s = (pl.scan_parquet(path, storage_options=storage_options)
           .select([
               (pl.col(date_col).diff().fill_null(0) < 0).any().alias('date_drop'),
               ((pl.col(date_col).diff().fill_null(0) == 0) &
                (pl.col(time_col).diff().fill_null(0) < 0)).any().alias('time_drop')
           ])
           .collect(streaming=True))
    assert not s['date_drop'][0] and not s['time_drop'][0]


def run_staged_engineering(
    lf_base: pl.LazyFrame,
    *,
    keys: Sequence[str],
    rep_cols: Sequence[str],
    feature_cols: Sequence[str],
    out_dir: str,
    A: StageA | None = None,
    B: StageB | None = None,
    C: StageC | None = None,
    write_date_between: tuple[int, int] | None = None,   # 新增：只写核心区间
):
    g_symbol, g_date, g_time = keys

    def _save(lf_out: pl.LazyFrame, path: str):
        if write_date_between is None:
            raise ValueError("write_date_between must be specified to avoid date overlap")
        lo, hi = write_date_between
        
        sk = [g_date, g_time, g_symbol]
        
        df = lf_out.filter(pl.col(g_date).is_between(lo, hi)).sort(sk).collect()
        with fs.open(path, "wb") as f:   # 复用你上面构好的 fs (fsspec)
            df.write_parquet(f, compression="zstd")
        if cfg.get("debug", {}).get("check_time_monotone", True):
            assert_time_monotone(path, date_col=g_date, time_col=g_time)


        
    # ---------- A ----------
    if A is not None:
        lf_resp = lf_base.select([*keys, *rep_cols])
        lf_a_full = fe_resp_daily(
            lf_resp,
            keys=tuple(keys),
            rep_cols=rep_cols,
            is_sorted=A.is_sorted,
            prev_soft_days=A.prev_soft_days,
            cast_f32=A.cast_f32,
            tail_lags=A.tail_lags,
            tail_diffs=A.tail_diffs,
            rolling_windows=A.rolling_windows,
        )
        drop = set(keys) | set(rep_cols)
        a_cols = [c for c in lf_a_full.collect_schema().names() if c not in drop]
        _save(lf_a_full.select([*keys, *a_cols]), f"{out_dir}/stage_a.parquet")
        

    # ---------- B ----------
    if B is not None:
        lf_resp = lf_base.select([*keys, *rep_cols])
        lf_b_full = fe_resp_same_tick_xday(
            lf_resp,
            keys=tuple(keys),
            rep_cols=rep_cols,
            is_sorted=B.is_sorted,
            prev_soft_days=B.prev_soft_days,
            cast_f32=B.cast_f32,
            ndays=B.ndays,
            stats_rep_cols=B.stats_rep_cols,
            add_prev1_multirep=B.add_prev1_multirep,
            batch_size=B.batch_size,
        )
        drop = set(keys) | set(rep_cols)
        b_cols = [c for c in lf_b_full.collect_schema().names() if c not in drop]
        _save(lf_b_full.select([*keys, *b_cols]), f"{out_dir}/stage_b.parquet")

    # ---------- C（按操作分别输出） ----------
    if C is not None:
        def _do_op(op_name: str, **op_flags):
            lf_src = lf_base.select([*keys, *feature_cols])
            lf_c = fe_feat_history(
                lf=lf_src,
                keys=tuple(keys),
                feature_cols=feature_cols,
                is_sorted=C.is_sorted,
                prev_soft_days=C.prev_soft_days,
                cast_f32=C.cast_f32,
                batch_size=C.batch_size,
                lags=op_flags.get("lags"),
                ret_periods=op_flags.get("ret_periods"),
                diff_periods=op_flags.get("diff_periods"),
                rz_windows=op_flags.get("rz_windows"),
                ewm_spans=op_flags.get("ewm_spans"),
                keep_rmean_rstd=C.keep_rmean_rstd,
                cs_cols=op_flags.get("cs_cols"),
            ).drop(feature_cols)
            cols = [c for c in lf_c.collect_schema().names() if c not in keys]
            _save(lf_c.select([*keys, *cols]), f"{out_dir}/stage_c_{op_name}.parquet")

        if C.lags:         _do_op("lags",   lags=C.lags)
        if C.ret_periods:  _do_op("ret",    ret_periods=C.ret_periods)
        if C.diff_periods: _do_op("diff",   diff_periods=C.diff_periods)
        if C.rz_windows:   _do_op("rz",     rz_windows=C.rz_windows)
        if C.ewm_spans:    _do_op("ewm",    ewm_spans=C.ewm_spans)
        if C.cs_cols:      _do_op("csrank", cs_cols=C.cs_cols)
        

def weighted_r2_zero_mean(y_true, y_pred, weight) -> float:
    """
    Sample-weighted zero-mean R^2 used in Jane Street:
        R^2 = 1 - sum_i w_i (y_i - yhat_i)^2 / sum_i w_i y_i^2
    """
    y_true = np.asarray(y_true, dtype=np.float64).ravel()
    y_pred = np.asarray(y_pred, dtype=np.float64).ravel()
    weight = np.asarray(weight, dtype=np.float64).ravel()
    assert y_true.shape == y_pred.shape == weight.shape

    num = np.sum(weight * (y_true - y_pred) ** 2)
    den = np.sum(weight * (y_true ** 2))
    if den <= 0:
        return 0.0  # safe fallback (shouldn't happen on the full JS eval)
    return 1.0 - (num / den)

def lgb_wr2_eval(preds, train_data):
    y = train_data.get_label()
    w = train_data.get_weight()
    if w is None:
        w = np.ones_like(y)
    score = weighted_r2_zero_mean(y, preds, w)
    return ('wr2', score, True)  # higher is better

# 特征选择- 初选 (选特征，省略)

In [None]:
import os, gc, glob
import polars as pl
import numpy as np
import lightgbm as lgb
import pandas as pd
from pathlib import Path

BASE_PATH = ["/mnt/data/js/clean/final_clean.parquet"]
KEYS = ["symbol_id","date_id","time_id"]
TARGET = "responder_6"
WEIGHT = 'weight'
FEATURE_COLS = [f"feature_{i:02d}" for i in range(79)]
REP_COLS = [f"responder_{i}" for i in range(9)]

OUT_DIR = "/mnt/data/js/cache/first_selection"
os.makedirs(OUT_DIR, exist_ok=True)


In [None]:

# --- A: prev-day tails + daily summaries ---
A = StageA(
    tail_lags=(1,),
    tail_diffs=(1,),
    rolling_windows=(5,),
    prev_soft_days=3,          # allow fallback up to 3 calendar days
)

# --- B: same time_id cross-day ---
B = StageB(
    ndays=3,                   # prev{1..3} at same time_id
    stats_rep_cols=None,       # default: use rep_cols
    add_prev1_multirep=True,
    batch_size=5,
    prev_soft_days=3,          # TTL for gaps
    strict_k=False,            # allow ≤K-day gaps instead of strict d-k
)

# --- C: history features (keep it tiny) ---
C = StageC(
    lags=(1, ),
    ret_periods=(1,),
    diff_periods=(1,),
    rz_windows=(5,),
    ewm_spans=(10,),
    keep_rmean_rstd=True,
    cs_cols=None,        # must be subset of feature_cols
    cs_by=("date_id","time_id"),
    prev_soft_days=3,
)

# example call
paths = run_staged_engineering(
    lf_base=lf_base,                # your base LazyFrame
    keys=KEYS,
    rep_cols=REP_COLS,         # updated to use REP_COLS
    feature_cols=FEATURE_COLS, # updated to use FEATURE_COLS
    out_dir=OUT_DIR,
    A=A, B=B, C=C,
)


0. 准备与切分天数

In [None]:

STAGE_PATHS = [
    "/mnt/data/js/cache/first_selection/stage_a.parquet",
    "/mnt/data/js/cache/first_selection/stage_b.parquet",
    "/mnt/data/js/cache/first_selection/stage_c_lags.parquet",
    "/mnt/data/js/cache/first_selection/stage_c_ret.parquet",
    "/mnt/data/js/cache/first_selection/stage_c_diff.parquet",
    "/mnt/data/js/cache/first_selection/stage_c_rz.parquet",
    "/mnt/data/js/cache/first_selection/stage_c_ewm.parquet",
]

DATE_LO, DATE_HI = 1200, 1400
OUT_DIR = "/mnt/data/js/cache/first_selection/run_full"
SHARD_DIR = f"{OUT_DIR}/shards_all"
Path(SHARD_DIR).mkdir(parents=True, exist_ok=True)

lf_base = pl.scan_parquet(BASE_PATH)
# 仅拿目标区间的base
lf_range = lf_base.filter(pl.col("date_id").is_between(DATE_LO, DATE_HI))

# days & split
days = (lf_range.select(pl.col("date_id").unique().sort())
                .collect(streaming=True)["date_id"].to_list())
cut = int(len(days) * 0.8)
train_days, val_days = days[:cut], days[cut:]
print(f"[split] train {len(train_days)} days, val {len(val_days)} days, range={days[0]}..{days[-1]}")


1.收集全量特征列名（并集）

In [None]:
# 来自 base 的特征列
feat_cols = set(FEATURE_COLS)

# 各 stage 全部列（除 KEYS/TARGET/WEIGHT）
for p in STAGE_PATHS:
    if not os.path.exists(p):
        print(f"[skip] missing: {p}")
        continue
    names = pl.scan_parquet(p).collect_schema().names()
    for c in names:
        if c not in KEYS and c not in (TARGET, WEIGHT):
            feat_cols.add(c)

feat_cols = sorted(feat_cols)
print(f"[cols] total feature columns = {len(feat_cols)}")


2. 写“天片”—把所有列拼上并立刻落盘

In [None]:
DAYS_PER_SHARD = 16

# 左表（含 base 的 FEATURE_COLS）
lf_left_base = (
    lf_range
    .select([*KEYS, TARGET, WEIGHT, *[pl.col(c) for c in FEATURE_COLS]])
)

# 为每个 stage 准备元信息（列名 + 文件大小，先拼小文件更省内存）
stage_meta = []
for p in STAGE_PATHS:
    if not os.path.exists(p): 
        continue
    scan = pl.scan_parquet(p).filter(pl.col("date_id").is_between(DATE_LO, DATE_HI))
    cols = [c for c in scan.collect_schema().names() if c not in KEYS]
    if cols:
        stage_meta.append({"path": p, "cols": cols, "size": os.path.getsize(p)})
stage_meta.sort(key=lambda d: d["size"])


In [None]:
def write_shards(tag, days_list):
    ds = sorted(days_list)
    for i in range(0, len(ds), DAYS_PER_SHARD):
        batch = set(ds[i:i+DAYS_PER_SHARD])

        # 当前片的左表
        lf_chunk = lf_left_base.filter(pl.col("date_id").is_in(batch))
        already = set(lf_chunk.collect_schema().names())

        # 逐 stage 拼接（右表只取该片天数 + 只取未存在列）
        for m in stage_meta:
            need = [c for c in m["cols"] if c not in already]
            if not need:
                continue
            lf_add = (pl.scan_parquet(m["path"])
                        .filter(pl.col("date_id").is_in(batch))
                        .select([*KEYS, *need]))
            lf_chunk = lf_chunk.join(lf_add, on=KEYS, how="left")
            already.update(need)

        # 统一 float32 并落盘（列按 feat_cols 顺序对齐；片内缺失的列自然是 null）
        present = [c for c in feat_cols if c in already]
        cast_feats = [pl.col(c).cast(pl.Float32).alias(c) for c in present]
        lf_out = lf_chunk.select([
            *KEYS,
            pl.col(WEIGHT).cast(pl.Float32).alias(WEIGHT),
            pl.col(TARGET).cast(pl.Float32).alias(TARGET),
            *cast_feats,
        ])
        out_path = f"{SHARD_DIR}/{tag}_shard_{i//DAYS_PER_SHARD:04d}.parquet"
        lf_out.sink_parquet(out_path, compression="zstd")
        print(f"[{tag}] wrote {out_path}")
        gc.collect()

write_shards("train", train_days)
write_shards("val",   val_days)

3. 从 shards 构建 memmap 数组 （恒定内存）

In [None]:
def memmap_from_shards(glob_pat, feat_cols, prefix):
    paths = sorted(glob.glob(glob_pat))
    counts = [pl.scan_parquet(p).select(pl.len()).collect(streaming=True).item() for p in paths]
    n_rows, n_feat = int(sum(counts)), len(feat_cols)
    print(f"[memmap] {glob_pat}: {len(paths)} files, {n_rows} rows, {n_feat} features")

    X = np.memmap(f"{prefix}_X.float32.mmap", dtype="float32", mode="w+", shape=(n_rows, n_feat))
    y = np.memmap(f"{prefix}_y.float32.mmap", dtype="float32", mode="w+", shape=(n_rows,))
    w = np.memmap(f"{prefix}_w.float32.mmap", dtype="float32", mode="w+", shape=(n_rows,))

    i = 0
    for p, k in zip(paths, counts):
        lf = pl.scan_parquet(p)
        names = set(lf.collect_schema().names())
        exprs = [
            (pl.col(c).cast(pl.Float32).alias(c) if c in names
             else pl.lit(None, dtype=pl.Float32).alias(c))
            for c in feat_cols
        ]
        df = lf.select([
            pl.col(TARGET).cast(pl.Float32).alias(TARGET),
            pl.col(WEIGHT).cast(pl.Float32).alias(WEIGHT),
            *exprs
        ]).collect(streaming=True)

        X[i:i+k, :] = df.select(feat_cols).to_numpy()
        y[i:i+k]    = df.select(pl.col(TARGET)).to_numpy().ravel()
        w[i:i+k]    = df.select(pl.col(WEIGHT)).to_numpy().ravel()
        i += k
        del df; gc.collect()

    X.flush(); y.flush(); w.flush()
    return X, y, w

train_X, train_y, train_w = memmap_from_shards(f"{SHARD_DIR}/train_shard_*.parquet", feat_cols, f"{OUT_DIR}/train")
val_X,   val_y,   val_w   = memmap_from_shards(f"{SHARD_DIR}/val_shard_*.parquet",   feat_cols, f"{OUT_DIR}/val")

print("train shapes:", train_X.shape, train_y.shape, train_w.shape)
print("val   shapes:", val_X.shape,   val_y.shape,   val_w.shape)


4. LightGBM 训练 + 重要性 （一次性全列）

In [None]:
dtrain = lgb.Dataset(train_X, label=train_y, weight=train_w,
                     feature_name=feat_cols, free_raw_data=True)
dval   = lgb.Dataset(val_X,   label=val_y,   weight=val_w,
                     feature_name=feat_cols, reference=dtrain, free_raw_data=True)

params = dict(
    objective="regression", metric="None",
    num_threads=16, seed=42, deterministic=True, first_metric_only=True,
    learning_rate=0.05, num_leaves=31, max_depth=-1, min_data_in_leaf=20,
    # 内存友好
    max_bin=63, bin_construct_sample_cnt=100_000, min_data_in_bin=3,
)

model = lgb.train(
    params, dtrain,
    valid_sets=[dval, dtrain], valid_names=["val","train"],
    num_boost_round=1000, callbacks=[lgb.early_stopping(50)],
    feval=lgb_wr2_eval,   # 你的评估函数
)

imp = pd.DataFrame({
    "feature": model.feature_name(),
    "gain": model.feature_importance("gain"),
    "split": model.feature_importance("split"),
}).sort_values("gain", ascending=False).reset_index(drop=True)

print(imp.head(30))



In [None]:
imp.to_csv(f"{OUT_DIR}/imp_1r.csv", index=False)

In [None]:
imp = pd.read_csv(f"{OUT_DIR}/imp_1r.csv")
top_feats = imp.loc[imp.gain > 0]

In [None]:
fam = top_feats['feature'].str.extract(r'^(feature_\d{2}|responder_\d)', expand=False)
top_feats['family'] = fam

In [None]:
top_feats

In [None]:
fam_feats = top_feats.groupby('family').agg(
    n = ('feature', 'count'),
    gain = ('gain', 'sum'),
    split = ('split', 'sum'),
).reset_index().sort_values('gain', ascending=False)

In [None]:
print(fam_feats.shape)

In [None]:
mask_feat = fam_feats['family'].str.startswith('feature_', na=False)
mask_resp = fam_feats["family"].str.startswith("responder_", na=False)
features_only   = fam_feats[mask_feat].sort_values("gain", ascending=False)
responders_only = fam_feats[mask_resp].sort_values("gain", ascending=False)

In [None]:
selected_features = features_only['family'][:79] # select all
selected_resps = responders_only['family'][:9] # select all

# save the Series (no index)
selected_features.to_csv(f"{OUT_DIR}/selected_features_1r.csv", index=False, header=False)
selected_resps.to_csv(f"{OUT_DIR}/selected_responders_1r.csv", index=False, header=False)

# 特征工程

In [None]:
# 训练集参数
paths = fs.glob(f"{P('az', cfg['paths']['clean_shards'])}/*.parquet")
az_paths = []
for p in paths:
    az_paths.append(f"az://{p}")
az_paths = sorted(az_paths)  # Use sorted() instead of sort() to create a new sorted list   
lc = pl.scan_parquet(az_paths, storage_options=storage_options)

days = lc.select(pl.col("date_id").unique().sort()).collect(streaming=True)["date_id"].to_list()

In [None]:
# ------- step 2: FE per clean shard (A+B once, C batched internally via C.batch_size) -------
fea = cfg.get("feature_eng", {})
A_cfg = fea.get("A", {})
B_cfg = fea.get("B", {})
C_cfg = fea.get("C", {})
A_enabled = A_cfg.get("enabled", True)
B_enabled = B_cfg.get("enabled", True)
C_enabled = C_cfg.get("enabled", True)

A = (StageA(
        tail_lags=A_cfg.get("tail_lags", [1]),
        tail_diffs=A_cfg.get("tail_diffs", [1]),
        rolling_windows=A_cfg.get("rolling_windows", [3]),
        prev_soft_days=A_cfg.get("prev_soft_days", 7),
        is_sorted=A_cfg.get("is_sorted", False),
        cast_f32=A_cfg.get("cast_f32", True),
    ) if A_enabled else None)

B = (StageB(
        ndays=B_cfg.get("ndays", 5),
        stats_rep_cols=B_cfg.get("stats_rep_cols", None),
        add_prev1_multirep=B_cfg.get("add_prev1_multirep", True),
        batch_size=B_cfg.get("batch_size", 5),
        prev_soft_days=B_cfg.get("prev_soft_days", 7),
        is_sorted=B_cfg.get("is_sorted", False),
        cast_f32=B_cfg.get("cast_f32", True),
    ) if B_enabled else None)

C = (StageC(
        lags=C_cfg.get("lags", [1,3]),
        ret_periods=C_cfg.get("ret_periods", [1]),
        diff_periods=C_cfg.get("diff_periods", [1]),
        rz_windows=C_cfg.get("rz_windows", [5]),
        ewm_spans=C_cfg.get("ewm_spans", [10]),
        keep_rmean_rstd=C_cfg.get("keep_rmean_rstd", True),
        cs_cols=C_cfg.get("cs_cols", None),
        prev_soft_days=C_cfg.get("prev_soft_days", 7),
        batch_size=C_cfg.get("batch_size", 10),
        is_sorted=C_cfg.get("is_sorted", False),
        cast_f32=C_cfg.get("cast_f32", True),
    ) if C_enabled else None)



# 创建 FE 输出目录
fe_dir = P("az", cfg["paths"]["fe_shards"])
fs.mkdir(fe_dir, exist_ok=True)

In [None]:
# -------- 分片循环：每片包含 [pad_lo .. core_hi] 的输入，但只写 [core_lo .. core_hi] --------
PAD_DAYS = 30 # 后期可定义函数取最小有效日期
CORE_DAYS = 30


g_date= cfg['keys'][1]
for start in range(PAD_DAYS, len(days), CORE_DAYS):
    core_lo_idx = start
    core_hi_idx = min(start + CORE_DAYS - 1, len(days) - 1) # 闭区间
    pad_lo_idx = core_lo_idx - PAD_DAYS
    
    core_lo, core_hi = days[core_lo_idx], days[core_hi_idx]
    pad_lo = days[pad_lo_idx]
    
    # 仅读本片+pad的输入 （懒加载 + 谓词下推）
    lf_shard = (lc.filter(pl.col(g_date).is_between(pad_lo, core_hi))
                .select([*cfg['keys'], cfg['weight'], 'time_bucket', *RESP_COLS, *FEATURE_ALL]))
    out_path = f"{fe_dir}/fe_{core_lo:04d}_{core_hi:04d}"
    fs.mkdir(out_path, exist_ok=True)
    print(f"[FE] days {core_lo}..{core_hi} (pad from {pad_lo}) -> {out_path}")
    run_staged_engineering(
        lf_base = lf_shard,
        keys = cfg['keys'],
        rep_cols = RESP_COLS,
        feature_cols = FEATURE_ALL,
        out_dir = out_path,
        A = A,
        B = B,
        C = C,
        write_date_between=(core_lo, core_hi)
        )

## 把同一分片内 A/B/C 拼成 Panel 分片

1. 基本配置

In [None]:
KEYS = tuple(cfg["keys"])
g_sym, g_date, g_time = KEYS
TARGET, WEIGHT = cfg["target"], cfg["weight"]
TIME_SORT = cfg['sorts']['time_major']

clean_root = P("az", cfg["paths"]["clean_shards"])
fe_root    = P("az", cfg["paths"]["fe_shards"])
panel_root = P("az", cfg["paths"]["panel_shards"])
fs.mkdir(panel_root, exist_ok=True)

DATE_LO, DATE_HI = 900, 1000

paths = fs.glob(f"{fe_root }/*")
sorted_paths = [f"az://{p}" for p in sorted(paths)]

print("ready")

2.枚举窗口

In [None]:
wins = set()
for p in sorted_paths:
    base = p.split("/")[-1]  # e.g. C_lags_1220_1249.parquet
    lo = int(base.split("_")[-2]); hi = int(base.split("_")[-1])
    if hi >= DATE_LO and lo <= DATE_HI:
        wins.add((lo, hi))
wins = sorted(wins)
print(f"windows in range: {wins[:5]} ... (total {len(wins)})")




3. 按窗口拼接 (A + B + 所有 C_*) → 直接写数据分片（无大表）

In [None]:
import numpy as np
import polars as pl
import fsspec, gc

T = np.float32(cfg["trading"]["ticks"])
TWOPI_over_T = np.float32(2.0*np.pi) / T     # 全是 float32
twopi_over_T_lit = pl.lit(TWOPI_over_T, dtype=pl.Float32)
cast_keys = [pl.col(k).cast(pl.Int32).alias(k) for k in KEYS]


lc = pl.scan_parquet(f"{clean_root}/*.parquet", storage_options=storage_options).with_columns(cast_keys)


def assert_panel_shard(path, lo, hi, *, date_col="date_id", time_col="time_id"):
    s = (pl.scan_parquet(path, storage_options=storage_options)
           .select([
               pl.col(date_col).min().alias("dmin"),
               pl.col(date_col).max().alias("dmax"),
               (pl.col(date_col).diff().fill_null(0) < 0).any().alias("date_drop"),
               ((pl.col(date_col).diff().fill_null(0) == 0) &
                (pl.col(time_col).diff().fill_null(0) < 0)).any().alias("time_drop"),
           ])
           .collect(streaming=True))
    dmin, dmax = int(s["dmin"][0]), int(s["dmax"][0])
    assert dmin == lo and dmax == hi, f"date range mismatch: got [{dmin},{dmax}] expect [{lo},{hi}]"
    assert not s["date_drop"][0] and not s["time_drop"][0], "time not monotone"



for (lo, hi) in wins:
    # 与全局区间取交集，防止边缘窗口越界
    w_lo, w_hi = max(lo, DATE_LO), min(hi, DATE_HI)
    
    shard_name = f"fe_{lo:04d}_{hi:04d}"
    
    ti_f = pl.col("time_id").cast(pl.Float32)
    # 基表（先筛行，再一次性加时间特征）
    lf = (
        lc.filter(pl.col("date_id").is_between(w_lo, w_hi))
          .select([*KEYS, "time_bucket", TARGET, WEIGHT, *FEATURE_ALL])
        .with_columns([
            ti_f.alias("time_pos"),
            (ti_f * twopi_over_T_lit).alias("_phase_").cast(pl.Float32),
        ])
        .with_columns([
            # 兼容旧版：对表达式调用 .sin() / .cos()
            pl.col("_phase_").sin().cast(pl.Float32).alias("time_sin"),
            pl.col("_phase_").cos().cast(pl.Float32).alias("time_cos"),
        ])
        .drop(["_phase_"])
    )
    
    fe_dir = f"{fe_root}/{shard_name}"
    # A/B
    A = pl.scan_parquet(f"{fe_dir}/stage_a.parquet", storage_options=storage_options).with_columns(cast_keys)
    B = pl.scan_parquet(f"{fe_dir}/stage_b.parquet", storage_options=storage_options).with_columns(cast_keys)
    
    # C
    C_paths = sorted(fs.glob(f"{fe_dir}/stage_c_*.parquet"))
    C_paths = [f"az://{p}" for p in C_paths]
    C_scans = [pl.scan_parquet(p, storage_options=storage_options).with_columns(cast_keys) for p in C_paths]
    
    # 逐个 join 
    panel = lf.join(A, on=list(KEYS), how="left", suffix="_A")
    panel = panel.join(B, on=list(KEYS), how="left", suffix="_B")
    for C in C_scans:
        panel = panel.join(C, on=list(KEYS), how="left", suffix="_C")
        
    panel = panel.sort(TIME_SORT)
    
    df_out = panel.collect(streaming=True)
    out_path = f"{panel_root}/panel_{w_lo:04d}_{w_hi:04d}.parquet"
    
    with fs.open(out_path, "wb") as f:
        df_out.write_parquet(f, compression="zstd")
    print(f"[panel] wrote {out_path} with {df_out.shape[0]} rows")
    
    if cfg.get("debug", {}).get("check_time_monotone", True):
        assert_panel_shard(out_path, w_lo, w_hi, date_col=g_date, time_col=g_time)
    del df_out

    gc.collect()

4-5.构建memmap

In [None]:
# 路径
mm_root = P("local",  cfg["paths"]["sample_mm"])
os.makedirs(mm_root, exist_ok=True)

# 选定区间的panel分片
panel_paths = sorted([
    f"az://{p}" for p in fs.glob(f"{panel_root}/panel_*.parquet")
    if int(p.split("_")[-2]) <= DATE_HI and int(p.split("_")[-1].split(".")[0]) >= DATE_LO
])

assert panel_paths, "no panel shards matched DATE_LO/DATE_HI"

# 任选一个训练分片当“列模板”
sample_path = panel_paths[0]
names = pl.scan_parquet(sample_path, storage_options=storage_options).collect_schema().names()

# 直接从这个分片拿特征列（已包含 base + engineered）

feat_cols = [c for c in names if c not in (*cfg['keys'], cfg['target'], cfg['weight'])]

In [None]:
import os, json, time, gc
import numpy as np
import polars as pl

def shard2memmap(sorted_paths: list[str], feat_cols: list[str], prefix: str):
    date_col   = cfg["keys"][1]
    target_col = cfg["target"]
    weight_col = cfg["weight"]

    # 统计每片行数
    counts = []
    for p in sorted_paths:
        k = (pl.scan_parquet(p, storage_options=storage_options)
            .select(pl.len())
            .collect(streaming=True)
            .item())
        counts.append(int(k))

    n_rows, n_feat = int(sum(counts)), len(feat_cols)
    os.makedirs(os.path.dirname(prefix), exist_ok=True)

    # 创建 memmap
    X = np.memmap(f"{prefix}_X.float32.mmap", dtype=np.float32, mode="w+", shape=(n_rows, n_feat))
    y = np.memmap(f"{prefix}_y.float32.mmap", dtype=np.float32, mode="w+", shape=(n_rows,))
    w = np.memmap(f"{prefix}_w.float32.mmap", dtype=np.float32, mode="w+", shape=(n_rows,))
    d = np.memmap(f"{prefix}_date.int32.mmap", dtype=np.int32,   mode="w+", shape=(n_rows,))

    need_cols = [date_col, target_col, weight_col, *feat_cols]
    ofs = 0
    for p, k in zip(sorted_paths, counts):
        df = (pl.scan_parquet(p, storage_options=storage_options)
                .select(need_cols)
                .collect(streaming=True))

        # 先 to_numpy，再 astype 保证为 float32（兼容老版本 Polars）
        X_block = df.select(feat_cols).to_numpy().astype(np.float32, copy=False)
        y_block = df.get_column(target_col).to_numpy().astype(np.float32, copy=False).ravel()
        w_block = df.get_column(weight_col).to_numpy().astype(np.float32, copy=False).ravel()
        d_block = df.get_column(date_col).to_numpy().astype(np.int32,   copy=False).ravel()

        X[ofs:ofs+k, :] = X_block
        y[ofs:ofs+k]    = y_block
        w[ofs:ofs+k]    = w_block
        d[ofs:ofs+k]    = d_block

        ofs += k
        del df, X_block, y_block, w_block, d_block
        gc.collect()

    X.flush(); y.flush(); w.flush(); d.flush()

    meta = {
        "n_rows": int(n_rows),
        "n_feat": int(n_feat),
        "dtype": {"X":"float32","y":"float32","w":"float32","date_id":"int32"},
        "features": list(feat_cols),
        "target": target_col,
        "weight": weight_col,
        "date_col": date_col,
        "files": sorted_paths,
        "ts": time.time(),
    }
    with open(f"{prefix}.meta.json", "w") as f:
        json.dump(meta, f, indent=2)

    return {
        "X": f"{prefix}_X.float32.mmap",
        "y": f"{prefix}_y.float32.mmap",
        "w": f"{prefix}_w.float32.mmap",
        "date": f"{prefix}_date.int32.mmap",
        "meta": f"{prefix}.meta.json",
    }

# —— 调用 —— #
prefix = os.path.join(mm_root, "full_sample_v1")
mm_paths = shard2memmap(sorted_paths=panel_paths, feat_cols=feat_cols, prefix=prefix)
print(mm_paths)

开始训练

In [None]:
# ---------- 加载 memmap ----------
with open(mm_paths["meta"]) as f:
    meta = json.load(f)
N, F = meta["n_rows"], meta["n_feat"]
X = np.memmap(mm_paths["X"], dtype=np.float32, mode="r", shape=(N, F))
y = np.memmap(mm_paths["y"], dtype=np.float32, mode="r", shape=(N,))
w = np.memmap(mm_paths["w"], dtype=np.float32, mode="r", shape=(N,))
d = np.memmap(mm_paths["date"], dtype=np.int32,   mode="r", shape=(N,))

In [None]:
# 构建并加载 d 后
assert np.all(np.diff(d) >= 0), "memmap d 不是非降序；请检查 panel 分片或排序"


In [None]:
def _day_ptrs_from_sorted_dates(d: np.ndarray):
    # 假设 d 非降序
    d = d.ravel()
    starts = np.r_[0, np.flatnonzero(d[1:] != d[:-1]) + 1]
    days   = d[starts]
    ends   = np.r_[starts[1:], d.size]     # 每个 day 的 [start,end)
    return days, starts, ends

def make_sliding_cv_fast(date_ids: np.ndarray, *, n_splits: int, gap_days: int = 5, train_to_val: int = 9):
    days, starts, ends = _day_ptrs_from_sorted_dates(date_ids)
    N, R, K, G = len(days), int(train_to_val), int(n_splits), int(gap_days)
    usable = N - G
    if usable <= 0 or K <= 0 or R <= 0:
        return []
    V_base, rem = divmod(usable, R + K)
    if V_base <= 0:
        return []
    T = R * V_base
    v_lens = [V_base + 1 if i < rem else V_base for i in range(K)]
    v_lo = T + G
    folds = []
    for V_i in v_lens:
        v_hi  = v_lo + V_i
        tr_hi = v_lo - G
        tr_lo = tr_hi - T
        if tr_lo < 0 or v_hi > N:
            break
        # 由于 d 全局有序，每个区间对应“连续行切片”
        tr_idx = np.arange(starts[tr_lo], ends[tr_hi-1])
        va_idx = np.arange(starts[v_lo],   ends[v_hi-1])
        folds.append((tr_idx, va_idx))
        v_lo = v_hi
    return folds

# 用 fast 版
folds = make_sliding_cv_fast(d, n_splits=2, gap_days=5, train_to_val=9)


In [None]:
# 1) 统计 N 天窗口的行数（按你真实筛选逻辑来）
n_rows = (
    lc.filter(pl.col("date_id").is_between(DATE_LO, DATE_HI)) 
      .select(pl.len())
      .collect()
      .item()  # -> int
)

# 2) 估算 GPU “transfer to GPU” 的大头（经验值）
n_feat = len(feat_cols)
dense_groups = int(n_feat)  # 按之前比例估
bytes_est = n_rows * 0.8* dense_groups         
gb_est = bytes_est / (1024**3)

print(f"rows≈{n_rows:,}, dense_groups≈{dense_groups}, est GPU load≈{gb_est:.2f} GiB")


In [None]:
ds_params = dict(
    max_bin=63,                    # 降低直方图桶数，省显存/内存
    bin_construct_sample_cnt=200000,# 构桶采样行数（默认是20万）
    min_data_in_bin=3,
    data_random_seed=42,
)

# 1) 全集 Dataset
d_all = lgb.Dataset(
    X, label=y, weight=w,
    feature_name=feat_cols,
    free_raw_data=True,
    params=ds_params,               # 让子集也继承这些设置
)

params = dict(
    objective="regression",
    metric="None",
    device_type="gpu",
    learning_rate=0.05,
    num_leaves=63,
    max_depth=8,
    feature_fraction=0.80,
    bagging_fraction=0.80,
    bagging_freq=1,
    min_data_in_leaf=200,
    seed=42,
)

# 2) 多折训练 + 每折 wr2 + 汇总 gain_share（仅一张表）

fi = pd.DataFrame({"feature": feat_cols})
scores = []

for k, (tr, va) in enumerate(folds, 1):
    dtrain = d_all.subset(tr, params=ds_params)    # 只构建本折的子集
    dvalid = d_all.subset(va, params=ds_params)

    bst = lgb.train(
        params, dtrain,
        valid_sets=[dvalid, dtrain],
        valid_names=["val", "train"],
        feval=lgb_wr2_eval,
        num_boost_round=4000,
        callbacks=[
            lgb.early_stopping(stopping_rounds=100, verbose=True),
            lgb.log_evaluation(period=100),
        ],
    )

    # 每折分数
    scores.append(bst.best_score["val"]["wr2"])   # or bst.best_score["val"]["wr2"]

    # 每折 gain_share → 作为一列加入
    g = bst.feature_importance(importance_type="gain", iteration=bst.best_iteration).astype(float)
    denom = g.sum()
    fi[f"fold{k}_gain_share"] = (g / denom) if denom > 0 else np.zeros_like(g, dtype=float)
    bst.free_dataset()                 # 释放 booster 里持有的 Dataset
    del dtrain, dvalid, bst; gc.collect()

In [None]:
# 汇总均值 + 排序 + 保存
fold_cols = [c for c in fi.columns if c.startswith("fold")]
fi["mean_gain_share"] = fi[fold_cols].mean(axis=1)
fi = fi.sort_values("mean_gain_share", ascending=False, ignore_index=True)

In [None]:
# 保存 汇总均值
fi.to_csv(f"/mnt/data/js/exp/v1/sample_mm//fe_v1_gain_share.csv", index=False)

In [None]:
fi = pd.read_csv(f"/mnt/data/js/exp/v1/sample_mm//fe_v1_gain_share.csv")

In [None]:
fi.head(15)

In [None]:
rl = [c for c in fi['feature'][:100] if c.startswith('responder_')]

In [None]:
rl

In [None]:
whitelist = cfg.get("white_list", [])
fi_normal = fi[~fi["feature"].isin(whitelist)].reset_index(drop=True)

展示

In [None]:
dfi = fi_normal[["feature", "mean_gain_share"]].copy()
dfi.reset_index(drop=True, inplace=True)
dfi['rank'] = dfi.index + 1

cum_share = dfi["mean_gain_share"].cumsum()

fig, ax1 = plt.subplots(figsize=(8,4))
ax1.plot(dfi["rank"], dfi["mean_gain_share"], color="tab:blue")
ax1.set_xlabel("Feature rank (desc)")
ax1.set_ylabel("Mean gain share", color="tab:blue")

ax2 = ax1.twinx()
ax2.plot(dfi["rank"], cum_share, color="tab:orange")
ax2.set_ylabel("Cumulative share", color="tab:orange")

plt.title("Feature importance (gain share)")
plt.tight_layout()
plt.show()


In [None]:
cum_share = dfi["mean_gain_share"].cumsum()
tot = cum_share.iloc[-1]
for th in [0.8, 0.9, 0.95]:
    k = (cum_share <= th*tot).sum()
    print(f"{th*100:.0f}% → Top-{k}")


In [None]:
whitelist

In [None]:
final_feats = list(dict.fromkeys(whitelist + dfi['feature'][:632].tolist()))  # 保持顺序且不重复
final_feats = pd.Series(final_feats)

final_feats.to_csv("/mnt/data/js/exp/v1/sample_mm/top_fi_0911.csv", index=False, header=False)

In [None]:
final_feats

In [None]:
pd.Series(final_feats, name="feature").to_csv(f"{P('local', 'exp/v1', cfg['paths']['sample_mm'])}/top_fi_0911.csv", index=False)

In [None]:
# check it
df_check = pd.read_csv(f"/mnt/data/js/exp/v1/sample_mm/top_fi_0911.csv")
df_check

# 去共线性

In [None]:
PARQUET_PATHS = ["/mnt/data/js/final_clean.parquet"]
KEYS = ["symbol_id","date_id","time_id"]
TARGET = "responder_6"
FEATURE_COLS = pd.read_csv('/home/admin_ml/Jackson/projects/selected_features.csv')['family'].tolist()
REP_COLS = pd.read_csv('/home/admin_ml/Jackson/projects/selected_resps.csv')['family'].tolist()

lf_base = pl.scan_parquet(PARQUET_PATHS).select(KEYS+FEATURE_COLS+REP_COLS)


lf_slice = lf_base.filter((pl.col("date_id") >= 1200) & (pl.col("date_id") <= 1400))

PARAMS = dict(
        prev_soft_days=7,
        tail_lags=(2, 5, 20, 40, 100),
        tail_diffs=(2, 5,),
        rolling_windows=(5, 20),
        same_time_ndays=5,
        strict_k=False,
        hist_lags=(1,2,5,10,20,100),
        ret_periods=(1,5,20),
        diff_periods=(1,5),
        rz_windows=(5,20),
        ewm_spans=(5,40,100),
        cs_cols=None,       # <- keep this small to avoid blow-up
    )

lf_eng = run_engineering_on_slice(lf_slice, **PARAMS)

feats = pd.read_csv("/home/admin_ml/Jackson/projects/final_fi_mean.csv")["feature"].tolist()

lf_small = lf_eng.select(feats[:500])
lf_small.collect(streaming=True).write_parquet("/mnt/data/js/X_ready.parquet", compression="zstd")


In [None]:

lf = pl.scan_parquet("/mnt/data/js/X_ready.parquet")

df = lf.collect(streaming=True).to_pandas()

# Correlation (pairwise complete obs) + guard on min observations
min_obs = max(50, int(0.3 * len(df)))  # tweak as you like
C = df.corr(method="pearson", min_periods=min_obs).abs().fillna(0.0)

# Ensure to align rows & cols to the same (priority) order, fill any NaNs with 0
order = feats
C = C.reindex(index=order, columns=order).fillna(0.0).copy()


# --- Prepare NumPy array for the greedy loop ---
A = C.values
np.fill_diagonal(A, 0.0)  # ensure the value is smaller than thresh, so the feature won't be dropped by value'1'
m = len(order)

# --- Greedy de-correlation (keep-by-priority, drop neighbors) ---
THRESH = 0.97
keep_mask = np.ones(m, dtype=bool)

for i in range(m):
    if not keep_mask[i]:
        continue  # already removed by a higher-priority pick
    # only check j > i (upper triangle) among still-active features
    active_slice = keep_mask[i+1:]
    drop = (A[i, i+1:] >= THRESH) & active_slice
    active_slice[drop] = False  # marks into keep_mask[i+1:] via view
keep = [order[i] for i in range(m) if keep_mask[i]]


pd.DataFrame({'feature': keep}).to_csv("final_selected_features.csv", index=False)

# 全数据训练

基本配置

In [None]:
DATE_LO, DATE_HI = 680, 1530 # 指定训练/验证的 date_id 范围, 后期转为全部训练集
# 基本量
FEATURE_COLS = [f"feature_{i:02d}" for i in range(79)] #FEATURE_COLS = pd.read_csv(f"{INPUT_DIR}/selected_features_1r.csv", header=None).squeeze().tolist()
REP_COLS = [f"responder_{i}" for i in range(9)] #REP_COLS = pd.read_csv(f"{INPUT_DIR}/selected_responders_1r.csv", header=None).squeeze().tolist()

paths = fs.glob(f"{P('az', 'exp/v1', cfg['paths']['clean'])}/*.parquet")
clean_files = []
for p in paths:
    bn = os.path.basename(p)
    parts = bn.split('_')
    start = int(parts[1])
    clean_files.append((start, p))
    
clean_sorted_file_list = [f"az://{f}" for _, f in sorted(clean_files)]

lc = pl.scan_parquet(clean_sorted_file_list, storage_options=storage_options)

print("ready")

枚举窗口

In [None]:
# 从 Blob 列出全部 fe_shards 分片（返回不带协议的路径，要手动加 az://）

fe_all = fs.glob(f"{P('np', 'exp/v1',cfg['paths']['fe_shards'])}/*.parquet")
fe_all = [f"az://{p}" for p in fe_all]

# 按日期范围筛选
wins = set()
for p in fe_all:
    base = p.split("/")[-1]  # e.g. C_lags_1220_1249.parquet
    lo = int(base.split("_")[-2]); hi = int(base.split("_")[-1].split(".")[0])
    if hi >= DATE_LO and lo <= DATE_HI:
        wins.add((lo, hi))
wins = sorted(wins)
print(f"windows in range: {wins[:5]} ... (total {len(wins)})")

# 取得该区间实际天
days = [d for d in range(DATE_LO, DATE_HI+1)]
#cut = int(len(days)*0.8)


3. 按窗口拼接 (A + B + 所有 C_*) → 直接写数据分片（无大表）

In [None]:
import numpy as np
import polars as pl
import gc

T = int(cfg['ticks'])                 # 例如 968 
K = int(cfg['bucket_size'])           # 例如 6
open_n  = int(cfg.get('open_auction_ticks', 5))
close_n = int(cfg.get('close_auction_ticks', 5))

# 安全的“上界截断”工具（兼容旧版 Polars 无 clip_max）
def clip_upper(expr: pl.Expr, ub: int) -> pl.Expr:
    return pl.when(expr > pl.lit(ub)).then(pl.lit(ub)).otherwise(expr)

for (lo, hi) in wins:
    # 与全局区间取交集，防止边缘窗口越界
    w_lo, w_hi = max(lo, DATE_LO), min(hi, DATE_HI)

    # 基表（先筛行，再一次性加时间特征）
    base = (
        lc.filter(pl.col("date_id").is_between(w_lo, w_hi))
          .select([*cfg['keys'], cfg['target'], cfg['weight'], *FEATURE_COLS])
        .with_columns([
            # 复制一份 time_id 作为“位置特征”，避免与 key 列冲突
            pl.col("time_id").cast(pl.Float32).alias("time_pos"),

              # 周期相位：phase = 2π * time_id / T
              (2*np.pi * pl.col("time_id") / pl.lit(T, dtype=pl.Float32)).alias("_phase_"),
        ])
        .with_columns([
            # 兼容旧版：对表达式调用 .sin() / .cos()
            pl.col("_phase_").sin().cast(pl.Float32).alias("time_sin"),
            pl.col("_phase_").cos().cast(pl.Float32).alias("time_cos"),
        ])
        .drop(["_phase_"])
        .with_columns([
            # 开盘/收盘指示器（恰好 open_n / close_n 个 tick）
            (pl.col("time_id") <  pl.lit(open_n)).cast(pl.Int8).alias("is_open_auction"),
            (pl.col("time_id") >= pl.lit(T - close_n)).cast(pl.Int8).alias("is_close_auction"),
        ])
    )

    # 分桶：bucket = floor(time_id * K / T)，并防御性截到 [0..K-1]
    bucket_raw = ( (pl.col('time_id') * pl.lit(K)) // pl.lit(T) )
    bucket_capped = clip_upper(bucket_raw, K - 1)
    base = base.with_columns([
        bucket_capped.cast(pl.Int8).alias(f"time_bucket_{K}")
    ])

    lf = base  # 后面继续你的 join 逻辑

    fe_files = []
    for name in (f"A_{lo}_{hi}.parquet", f"B_{lo}_{hi}.parquet"):
        p = f"{P('az', 'exp/v1', cfg['paths']['fe_shards'])}/{name}"
        fe_files.append(p)

    # 同窗口所有 C_* 分片
    c_files = fs.glob(f"{P('np', 'exp/v1', cfg['paths']['fe_shards'])}/C_*_{lo}_{hi}.parquet")
    c_files = [f"az://{p}" for p in c_files]
    fe_files.extend(c_files)

    # 逐个左连接
    already = set(lf.collect_schema().names())
    for fp in fe_files:
        ds = pl.scan_parquet(fp, storage_options=storage_options)
        names = ds.collect_schema().names()
        add_cols = [c for c in names if c not in already]
        if add_cols:
            lf = lf.join(ds.select([*cfg['keys'], *add_cols]), on=cfg['keys'], how="left")
            already.update(add_cols)

    # 选出特征并统一类型
    feat_present = [c for c in already if c not in (*cfg['keys'], cfg['target'], cfg['weight'])]
    select_exprs = [
        *cfg['keys'],
        pl.col(cfg['target']).cast(pl.Float32).alias(cfg['target']),
        pl.col(cfg['weight']).cast(pl.Float32).alias(cfg['weight']),
        *[pl.col(c).cast(pl.Float32).alias(c) for c in feat_present],
    ]
    lf_win = lf.select(select_exprs)

    # 直接流式写分片
    panel_path = P("az", "exp/v1", cfg["paths"]["panel_shards"])
    fs.mkdir(panel_path, exist_ok=True)
    out_path = f"{panel_path}/panel_{w_lo}_{w_hi}.parquet"
    (
        lf_win.sink_parquet(
            out_path,
            compression="zstd",
            storage_options=storage_options,
            statistics=True,
            maintain_order=True,
        )
    )

    gc.collect()


In [None]:
# 检查关键分片/总表是否按 (symbol_id, date_id, time_id) 排序

paths = sorted(fs.glob(f"{P('az', 'exp/v1', cfg['paths']['clean_shards'])}/*.parquet"))
for p in paths:
    df = pl.read_parquet(f"az://{p}", storage_options=storage_options).select(["symbol_id","date_id","time_id"])
    n  = df.height
    # 计算“按 key 的正确顺序”
    sid = df["symbol_id"].to_numpy()
    did = df["date_id"].to_numpy()
    tid = df["time_id"].to_numpy()
    ord_keys = np.lexsort((tid, did, sid))    # 以 symbol->date->time 升序
    is_sorted = np.all(ord_keys == np.arange(n))
    print(os.path.basename(p), "sorted_by_keys:", is_sorted, "rows:", n)

4.导入最终特征清单

In [None]:

feat_cols = pd.read_csv("/mnt/data/js/exp/v1/sample_mm/top_fi_0911.csv")
feat_cols = feat_cols['feature'].tolist()
    

In [None]:
len(feat_cols)

5.构建memmap

In [None]:
import re
mm_dir = P("local", "exp/v1", cfg["paths"]["panel_mm"])
os.makedirs(mm_dir, exist_ok=True)

def full_shard_key(p: str):
    bn = os.path.basename(p)          # e.g. panel_1200_1219.parquet
    m = re.match(r"panel_(\d+)_(\d+)\.parquet$", bn)
    if not m:
        return (10**12, 10**12, bn)
    lo, hi = map(int, m.groups())
    return (lo, hi)


def shard2memmap(glob_paths: list[str], feat_cols: list[str], prefix: str):
    date_col   = cfg["keys"][1]
    target_col = cfg["target"]
    weight_col = cfg["weight"]

    paths = sorted(glob_paths, key=full_shard_key)

    counts = []
    for p in paths:
        k = (pl.scan_parquet(p, storage_options=storage_options)
               .select(pl.len()).collect(streaming=True).item())
        counts.append(int(k))
    n_rows, n_feat = int(sum(counts)), len(feat_cols)

    os.makedirs(os.path.dirname(prefix), exist_ok=True)

    X = np.memmap(f"{prefix}_X.float32.mmap", dtype="float32", mode="w+", shape=(n_rows, n_feat))
    y = np.memmap(f"{prefix}_y.float32.mmap", dtype="float32", mode="w+", shape=(n_rows,))
    w = np.memmap(f"{prefix}_w.float32.mmap", dtype="float32", mode="w+", shape=(n_rows,))
    d = np.memmap(f"{prefix}.date.i32.mmap",  dtype="int32",   mode="w+", shape=(n_rows,))

    i = 0
    need_cols = [date_col, target_col, weight_col, *feat_cols]
    for p, k in zip(paths, counts):
        df = (pl.scan_parquet(p, storage_options=storage_options)
                .select(need_cols).collect(streaming=True))

        X[i:i+k, :] = df.select(feat_cols).to_numpy()
        y[i:i+k]    = df.select(pl.col(target_col)).to_numpy().ravel()
        w[i:i+k]    = df.select(pl.col(weight_col)).to_numpy().ravel()
        d[i:i+k]    = df.select(pl.col(date_col)).to_numpy().ravel().astype("int32")
        i += k
        del df; gc.collect()

    X.flush(); y.flush(); w.flush(); d.flush()

    meta = {
        "n_rows": int(n_rows), "n_feat": int(n_feat), "dtype": "float32", "ts": time.time(),
        "features": list(feat_cols), "target": target_col, "weight": weight_col,
        "date_col": date_col, "files": paths
    }
    with open(f"{prefix}.meta.json", "w") as f:
        json.dump(meta, f)
    return X, y, w

In [None]:
np_paths = fs.glob(f"{P('np', 'exp/v1', cfg['paths']['panel_shards'])}/panel_*_*.parquet")
glob_paths = []
for p in np_paths:
    glob_paths.append(f"az://{p}")
    
X, y, w = shard2memmap(glob_paths= glob_paths, feat_cols=feat_cols, prefix=f"{mm_dir}/full_panel_v1")


In [None]:
# 滑动式来分割训练集/验证集

d = np.memmap(f"/mnt/data/js/exp/v1/panel_mm/full_panel_v1.date.i32.mmap", dtype="int32", mode="r")

In [None]:
import numpy as np

d = np.memmap("/mnt/data/js/exp/v1/panel_mm/full_panel_v1.date.i32.mmap", dtype="int32", mode="r")

mono = np.all(d[1:] >= d[:-1])
viol = np.flatnonzero(d[1:] < d[:-1])
print("monotonic_non_decreasing:", bool(mono), "| violations:", viol.size)

# 看前几个“逆序”位置（如果有）
for j in viol[:10]:
    print(f"idx {j}->{j+1}: {d[j]} -> {d[j+1]}")


In [None]:
import numpy as np

d = np.memmap("/mnt/data/js/exp/v1/panel_mm/full_panel_v1.date.i32.mmap", dtype="int32", mode="r")

days_all = np.unique(d)
print("rows:", d.size, "unique_days:", days_all.size, "min/max:", days_all.min(), days_all.max())

# 相邻唯一天的差
gaps = np.diff(days_all)
gap_pos = np.flatnonzero(gaps > 1)
print("global missing-day blocks:", gap_pos.size)

# 看前几段缺口：[前一天, 后一天, 差值]
if gap_pos.size:
    preview = np.c_[days_all[gap_pos], days_all[gap_pos+1], gaps[gap_pos]]
    print(preview[:10])


In [None]:
import numpy as np

def make_sliding_cv(date_ids: np.ndarray, *, n_splits: int, gap_days: int = 5, train_to_val: int = 9):
    # ---- 构造唯一天轴 ----
    days = np.unique(date_ids)                 # 关键：按天计算窗口
    N, R, K, G = len(days), int(train_to_val), int(n_splits), int(gap_days)

    usable = N - G
    if usable <= 0 or K <= 0 or R <= 0:
        return []

    V_base, rem = divmod(usable, R + K)
    if V_base <= 0:
        return []

    T = R * V_base
    v_lens = [V_base + 1 if i < rem else V_base for i in range(K)]
    v_lo = T + G

    folds = []
    for V_i in v_lens:
        v_hi  = v_lo + V_i
        tr_hi = v_lo - G
        tr_lo = tr_hi - T
        if tr_lo < 0 or v_hi > N:
            break

        tr_days = days[tr_lo:tr_hi]
        va_days = days[v_lo:v_hi]

        tr_idx = np.flatnonzero(np.isin(date_ids, tr_days))
        va_idx = np.flatnonzero(np.isin(date_ids, va_days))

        # 保险丝（行不重叠、天不重叠、gap 按“天位置”生效）
        assert np.intersect1d(tr_idx, va_idx).size == 0, "row overlap"
        assert np.intersect1d(tr_days, va_days).size == 0, "day overlap"
        gap_pos = (np.searchsorted(days, va_days.min())
                   - np.searchsorted(days, tr_days.max()) - 1)
        assert gap_pos >= G, f"gap_days not enforced: {gap_pos} < {G}"

        folds.append((tr_idx, va_idx))
        v_lo = v_hi

    return folds

# 用法
d = np.memmap("/mnt/data/js/exp/v1/panel_mm/full_panel_v1.date.i32.mmap", dtype="int32", mode="r")
folds = make_sliding_cv(d, n_splits=3, gap_days=5, train_to_val=9)


In [None]:
folds

In [None]:
days_all = np.unique(d)
for i,(tr,va) in enumerate(folds,1):
    tr_days = np.unique(d[tr]); va_days = np.unique(d[va])
    row_ovl = np.intersect1d(tr,va).size
    day_ovl = np.intersect1d(tr_days,va_days).size
    gap_pos = (np.searchsorted(days_all, va_days.min())
               - np.searchsorted(days_all, tr_days.max()) - 1)
    print(f"fold{i}: row_ovl={row_ovl}, day_ovl={day_ovl}, gap_days={gap_pos}")


In [None]:
d

In [None]:
folds

In [None]:
# ---------- 2) 加载 memmap ----------
import json, numpy as np, lightgbm as lgb
prefix = f"/mnt/data/js/exp/v1/panel_mm/full_panel_v1"
with open(f"{prefix}.meta.json") as f:
    meta = json.load(f)
n_rows, n_feat = meta["n_rows"], meta["n_feat"]
feat_names = meta["features"]

X = np.memmap(f"{prefix}_X.float32.mmap", dtype="float32", mode="r", shape=(n_rows, n_feat))
y = np.memmap(f"{prefix}_y.float32.mmap", dtype="float32", mode="r", shape=(n_rows,))
w = np.memmap(f"{prefix}_w.float32.mmap", dtype="float32", mode="r", shape=(n_rows,))
# 你之前已定义：weighted_r2_zero_mean、lgb_wr2_eval


训练模型

In [None]:
# 估算 GPU “transfer to GPU” 的大头（经验值）

n_rows = (
    lc.filter(pl.col("date_id").is_between(DATE_LO, DATE_HI)) 
      .select(pl.len())
      .collect()
      .item()  # -> int
)

n_feat = len(feat_names)
dense_groups = int(n_feat)  # 按之前比例估
bytes_est = n_rows * 0.8* dense_groups         
gb_est = bytes_est / (1024**3)

print(f"rows≈{n_rows:,}, dense_groups≈{dense_groups}, est GPU load≈{gb_est:.2f} GiB")


In [None]:
ds_params = dict(
    max_bin=31,                    
    bin_construct_sample_cnt=100000,
    min_data_in_bin=3,
    data_random_seed=42,
)

# 1) 全集 Dataset
d_all = lgb.Dataset(
    X, label=y, weight=w,
    feature_name=feat_names,
    free_raw_data=True,
    params=ds_params,               # 让子集也继承这些设置
)

params = dict(
    objective="regression",
    metric="None",
    device_type="gpu",
    num_threads=16,
    learning_rate=0.08,
    num_leaves=31,
    max_depth=8,
    feature_fraction=0.60,
    bagging_fraction=0.60,
    bagging_freq=1,
    min_data_in_leaf=200,
    seed=42,
)

# 2) 多折训练 + 每折 wr2 + 汇总 gain_share（仅一张表）
import numpy as np, pandas as pd, os

fi = pd.DataFrame({"feature": feat_names})
scores = [] 

for k, (tr, va) in enumerate(folds, 1):
    dtrain = d_all.subset(tr, params=ds_params)    # 只构建本折的子集
    dvalid = d_all.subset(va, params=ds_params)

    bst = lgb.train(
        params, dtrain,
        valid_sets=[dvalid, dtrain],
        valid_names=["val", "train"],
        feval=lgb_wr2_eval,
        num_boost_round=4000,
        callbacks=[
            lgb.early_stopping(stopping_rounds=100, verbose=True),
            lgb.log_evaluation(period=100),
        ],
    )

    # 每折分数
    scores.append(bst.best_score["val"]["wr2"])   # or bst.best_score["val"]["wr2"]

    # 每折 gain_share → 作为一列加入
    g = bst.feature_importance(importance_type="gain", iteration=bst.best_iteration).astype(float)
    denom = g.sum()
    fi[f"fold{k}_gain_share"] = (g / denom) if denom > 0 else np.zeros_like(g, dtype=float)
    bst.free_dataset()                 # 释放 booster 里持有的 Dataset
    del dtrain, dvalid, bst; gc.collect()

In [None]:
# 汇总均值 + 排序 + 保存
fold_cols = [c for c in fi.columns if c.startswith("fold")]
fi["mean_gain_share"] = fi[fold_cols].mean(axis=1)
fi = fi.sort_values("mean_gain_share", ascending=False, ignore_index=True)

In [None]:
fi

# 模型评估

## 1.数据清洗,预处理

数据集：test + pad

## 2.特征工程