In [1]:

import os
# [修复] 解决OpenBLAS多线程导致的段错误 (Segmentation Fault)
os.environ['OPENBLAS_NUM_THREADS'] = '64'
os.environ['GOTO_NUM_THREADS'] = '64'
os.environ['OMP_NUM_THREADS'] = '64'

from auto_config import project_dir
os.environ["TABPFN_MODEL_CACHE_DIR"] = (project_dir/"data/pretrained").as_posix()
print(f"设置 TABPFN_MODEL_CACHE_DIR 为: {os.environ['TABPFN_MODEL_CACHE_DIR']}")

import pandas as pd
import numpy as np
import json
from pathlib import Path
import warnings
from autogluon.tabular import TabularDataset, TabularPredictor
from autogluon.core.constants import REGRESSION
from sklearn.model_selection import TimeSeriesSplit

warnings.filterwarnings('ignore')

#%%
# --- 1. 配置与环境准备 ---
print("--- [1/6] 加载配置 ---")

# [新增] 运行模式选择: 'QUICK_CV' 或 'FULL_TRAIN'
RUN_MODE = 'QUICK_CV' 
CV_FOLDS = 5

# 定义目标标签
TARGET_LABEL = '涨跌_shift' 
EVAL_METRIC = 'roc_auc'

# 路径配置
FEATURE_JSON_PATH = project_dir / "temp/stage2/feature_selection_results_vetted.json"
TRAIN_DATA_PATH = project_dir / "temp/stage3/train.pkl"
VALID_DATA_PATH = project_dir / "temp/stage3/valid.pkl"
TEST_DATA_PATH = project_dir / "temp/stage3/test.pkl"
MODEL_OUTPUT_BASE_PATH = project_dir / "models/stage4"
MODEL_OUTPUT_BASE_PATH.mkdir(parents=True, exist_ok=True)

# 加载特征选择JSON
with open(FEATURE_JSON_PATH, 'r', encoding='utf-8') as f:
    feature_config = json.load(f)

vetted_features = feature_config.get(TARGET_LABEL, {}).get('final_results', {}).get('vetted_features', [])
categorical_features = feature_config.get('categorical_features_to_keep', [])
features_to_use = vetted_features + categorical_features

if not vetted_features:
    raise ValueError(f"未能从特征选择文件 {FEATURE_JSON_PATH} 中为目标 {TARGET_LABEL} 找到'vetted_features'。请先运行stage2脚本。")

print(f"运行模式: {RUN_MODE}")
print(f"目标标签: {TARGET_LABEL}, 评估指标: {EVAL_METRIC}")
print(f"将使用 {len(features_to_use)} 个特征进行训练。")


#%%
# --- 2. 自定义模型实现 (Custom Models) ---
print("\n--- [2/6] 定义自定义模型 ---")
# 假设自定义模型已在 custom_ag 目录中定义好
try:
    from custom_ag.ag_svm import AgSVMModel
    from custom_ag.ag_nb import IntelligentNaiveBayesModel
    from custom_ag.ag_tabpfn import TabPFNModel
    from autogluon.tabular.models.lr.lr_model import LinearModel
    print("自定义模型已加载。")
except ImportError:
    print("未找到自定义模型，将使用AutoGluon默认模型。")
    AgSVMModel, IntelligentNaiveBayesModel, TabPFNModel, LinearModel = None, None, None, None

#%%
# --- 3. 数据加载与准备 ---
print("\n--- [3/6] 加载预划分的数据集 ---")
train_df = pd.read_pickle(TRAIN_DATA_PATH)
valid_df = pd.read_pickle(VALID_DATA_PATH)
test_df = pd.read_pickle(TEST_DATA_PATH)

# 数据类型修复
print("正在检查并修复数据类型以兼容AutoGluon...")
for df_ in [train_df, valid_df, test_df]:
    for col in df_.columns:
        if str(df_[col].dtype) == 'Int64':
            df_[col] = df_[col].astype('float32')

# 选择所需的特征和标签
final_cols = features_to_use + [TARGET_LABEL]
train_data_full = train_df[final_cols].dropna(subset=[TARGET_LABEL])
valid_data_full = valid_df[final_cols].dropna(subset=[TARGET_LABEL])
test_data = test_df[final_cols] 

设置 TABPFN_MODEL_CACHE_DIR 为: /home/ye_canming/repos/novelties/ts/comp/AutoCute/data/pretrained
--- [1/6] 加载配置 ---
运行模式: QUICK_CV
目标标签: 涨跌_shift, 评估指标: roc_auc
将使用 241 个特征进行训练。

--- [2/6] 定义自定义模型 ---
自定义模型已加载。

--- [3/6] 加载预划分的数据集 ---
正在检查并修复数据类型以兼容AutoGluon...


In [2]:
test_data

Unnamed: 0,Amplitude,BETA10,BETA20,BETA30,BETA5,CORD20,CORD30,CORD5,CORD60,CORR10,...,龙虎_lag_2,龙虎_lag_20,龙虎_lag_27,龙虎_lag_3,龙虎_lag_4,龙虎_lag_5,龙虎_lag_55,龙虎_lag_6,龙虎_lag_83,涨跌_shift
2376,0.525845,0.000611,-0.001017,-0.002259,-0.000272,-0.638000,-0.530505,-0.581657,-0.263277,-0.004186,...,False,,,False,False,,,,,
4531,2.816834,0.000507,0.000989,-0.001840,-0.005106,-0.099982,-0.009147,-0.109185,0.352380,0.667087,...,False,,,False,False,,,,,
6666,1.085451,0.000450,-0.003073,-0.005869,-0.004649,-0.005840,-0.030570,0.762490,0.310700,1.208069,...,False,,,False,False,,,,,
8741,2.067789,0.000921,-0.004829,-0.007329,-0.000980,-0.255986,-0.210786,0.428631,-0.040913,-0.270790,...,False,,,False,False,,,,,
11117,1.276773,0.000344,0.000277,-0.003606,-0.003513,-0.298207,-0.247041,1.232518,0.216621,0.552119,...,False,,,False,False,,,,,
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
602581,1.367177,0.001520,0.004382,0.000339,-0.000890,-0.485367,-0.456172,-0.396987,-0.165561,-0.081815,...,False,,,False,False,,,,,
602976,2.833823,-0.000111,-0.003327,-0.006432,0.002595,-0.635307,-0.530505,-0.219527,-0.301469,0.830333,...,False,,,False,False,,,,,
603472,4.483582,0.014891,0.010021,0.008823,0.007935,0.160840,0.356190,0.493865,0.563182,0.181876,...,True,,,True,False,,,,,
604595,1.689246,-0.007177,-0.014748,-0.015336,-0.008781,-0.542053,-0.361545,-0.603217,-0.035770,0.035030,...,False,,,False,False,,,,,


In [3]:
test_df

Unnamed: 0,timestamp,item_id,KMID,KLEN,KMID2,KUP,KUP2,KLOW,KLOW2,KSFT,...,timestamp_day_of_week,timestamp_day_of_year,timestamp_month_end,timestamp_quarter_end,timestamp_month_sin,timestamp_month_cos,timestamp_day_of_week_sin,timestamp_day_of_week_cos,timestamp_day_of_month_sin,timestamp_day_of_month_cos
2376,2025-04-25,1,-0.002717,0.005435,-0.498351,0.000906,0.420538,0.001811,0.615485,-1.811532e-03,...,4,115,0,0,0.866025,-0.5,-0.866025,-0.5,-0.937752,0.347305
4531,2025-04-25,2,-0.005642,0.031034,-0.181581,0.025348,1.130286,0.000000,0.000000,-3.102450e-02,...,4,115,0,0,0.866025,-0.5,-0.866025,-0.5,-0.937752,0.347305
6666,2025-04-25,63,-0.001579,0.011368,-0.138746,0.008836,1.079914,0.000947,0.292846,-9.472185e-03,...,4,115,0,0,0.866025,-0.5,-0.866025,-0.5,-0.937752,0.347305
8741,2025-04-25,100,0.012404,0.022335,0.557559,0.007441,0.615482,0.002481,0.339841,7.444398e-03,...,4,115,0,0,0.866025,-0.5,-0.866025,-0.5,-0.937752,0.347305
11117,2025-04-25,157,-0.006712,0.013424,-0.498356,0.000000,0.000000,0.006709,0.785401,6.400499e-08,...,4,115,0,0,0.866025,-0.5,-0.866025,-0.5,-0.937752,0.347305
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
602581,2025-04-25,688396,0.002763,0.014454,0.191440,0.000000,0.000000,0.011681,1.118270,1.445383e-02,...,4,115,0,0,0.866025,-0.5,-0.866025,-0.5,-0.937752,0.347305
602976,2025-04-25,688472,-0.010697,0.031021,-0.344015,0.020295,0.943175,0.000000,0.000000,-3.101095e-02,...,4,115,0,0,0.866025,-0.5,-0.866025,-0.5,-0.937752,0.347305
603472,2025-04-25,688506,-0.025011,0.051034,-0.488414,0.000000,0.000000,0.025978,0.795405,1.021055e-03,...,4,115,0,0,0.866025,-0.5,-0.866025,-0.5,-0.937752,0.347305
604595,2025-04-25,688599,-0.002266,0.018129,-0.124885,0.012076,0.955317,0.003776,0.473986,-1.057338e-02,...,4,115,0,0,0.866025,-0.5,-0.866025,-0.5,-0.937752,0.347305


In [4]:
train_data_full

Unnamed: 0,Amplitude,BETA10,BETA20,BETA30,BETA5,CORD20,CORD30,CORD5,CORD60,CORR10,...,龙虎_lag_2,龙虎_lag_20,龙虎_lag_27,龙虎_lag_3,龙虎_lag_4,龙虎_lag_5,龙虎_lag_55,龙虎_lag_6,龙虎_lag_83,涨跌_shift
0,4.507924,0.004848,-0.006496,-0.010828,-0.022094,0.160210,0.347773,0.786340,0.256858,0.521272,...,,,,,,,,,,False
1,2.859289,0.001759,-0.005303,-0.011115,-0.031714,0.156045,0.341971,1.049609,0.270261,0.542126,...,,,,,,,,,,False
3,2.328427,-0.007293,-0.004000,-0.010665,-0.002925,0.172696,0.311374,0.933282,0.271253,0.636488,...,,,,False,False,False,,,,False
4,2.172364,-0.005237,-0.003348,-0.009756,0.000000,0.153358,0.175880,1.167320,0.274004,0.569079,...,,,,,False,False,,False,,True
5,1.786681,-0.011456,-0.002336,-0.008956,-0.003855,0.167270,0.168308,0.969668,0.275008,1.241004,...,False,,,,,False,,False,,False
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
605683,2.397492,-0.002873,-0.004408,-0.004905,0.013085,-0.163380,-0.073449,-0.454506,0.252113,-0.195584,...,False,False,False,False,,,True,,True,True
605685,1.848464,0.006370,-0.002541,-0.003746,0.018176,-0.052097,-0.050839,1.022719,0.269758,0.003303,...,,False,False,False,False,False,False,False,False,False
605686,2.120131,0.008132,-0.002031,-0.003659,0.001626,-0.045951,-0.052143,1.017065,0.271288,-0.037460,...,,False,False,,False,False,False,False,False,True
605687,2.172364,0.008371,-0.001026,-0.003209,-0.000446,-0.045415,-0.045976,0.944557,0.283053,-0.152639,...,False,False,False,,,False,False,False,True,False


In [18]:
valid_data_full

Unnamed: 0,Amplitude,BETA10,BETA20,BETA30,BETA5,CORD20,CORD30,CORD5,CORD60,CORR10,...,龙虎_lag_2,龙虎_lag_20,龙虎_lag_27,龙虎_lag_3,龙虎_lag_4,龙虎_lag_5,龙虎_lag_55,龙虎_lag_6,龙虎_lag_83,涨跌_shift
2372,1.349130,0.003087,-0.002376,-0.002862,0.002904,-0.641068,-0.446182,-0.723995,-0.272394,-0.337083,...,,False,False,False,False,False,False,False,,True
2373,0.705392,0.002686,-0.002106,-0.002815,0.000362,-0.641068,-0.453824,-0.633408,-0.281520,-0.131777,...,,False,False,,False,False,False,False,,False
2374,0.696001,0.001784,-0.001767,-0.002764,-0.002180,-0.639535,-0.448944,-0.548233,-0.277814,0.156168,...,False,False,False,,,False,False,False,,True
2375,0.620643,0.001346,-0.001407,-0.002515,-0.002810,-0.636922,-0.530505,-0.523619,-0.276859,0.256472,...,False,,False,False,,,False,False,,False
4527,1.167699,0.004215,-0.000056,-0.002342,0.009054,-0.047299,0.022666,0.702937,0.322791,0.515579,...,,False,False,False,True,False,False,False,,True
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
604594,2.189751,-0.006668,-0.016885,-0.014744,-0.006727,-0.514830,-0.356997,-0.595523,-0.031742,0.247598,...,False,,False,False,,,False,False,,False
605690,1.653707,0.003867,0.001036,-0.001983,-0.000453,-0.029470,-0.096039,0.053565,0.275658,-0.196814,...,,True,False,False,False,False,False,False,,False
605691,1.185924,-0.000764,0.001364,-0.001744,-0.005855,-0.024274,-0.097001,-0.146422,0.286538,0.195273,...,,False,False,,False,False,False,False,,True
605692,0.928986,-0.002214,0.001442,-0.001425,-0.005329,-0.031099,-0.092377,-0.470496,0.275927,0.629978,...,False,False,False,,,False,False,False,,False


In [None]:
leaderboard_explore = predictor_explore.leaderboard(valid_data_full)
pred_proba = predictor_explore.predict_proba(test_data)
feature_importance = predictor_explore.feature_importance(test_data)

#%%
with open(project_dir/"temp/stage4/train_updown_results_summary.txt", "w") as f:

    # 保存 leaderboard_explore
    f.write("===== Leaderboard Explore =====\n")
    f.write(leaderboard_explore.to_string(index=False))
    f.write("\n\n")

    # 保存 pred_proba
    f.write("===== Predicted Probabilities =====\n")
    f.write(pred_proba.to_string())
    f.write("\n\n")

    # 保存 feature_importance
    f.write("===== Feature Importance =====\n")
    f.write(feature_importance.to_string(index=False))
    f.write("\n")

In [30]:
import pandas as pd
from pathlib import Path
from auto_config import project_dir

output_file = project_dir / "temp/lag158.pkl"
df_lag158 = pd.read_pickle(output_file)
# OUTPUT_DIR = project_dir / "temp/stage3"
# OUTPUT_TEST_PATH = OUTPUT_DIR / "train.pkl"
# df_lag158 = pd.read_pickle(OUTPUT_TEST_PATH )


In [31]:
df_lag158

Unnamed: 0,timestamp,item_id,KMID,KLEN,KMID2,KUP,KUP2,KLOW,KLOW2,KSFT,...,涨跌幅排名,涨跌正负,涨跌,龙虎,收盘_shift,涨跌幅_shift,涨跌幅排名_shift,涨跌正负_shift,涨跌_shift,龙虎_shift
0,2015-07-15,1,-0.007831,0.052586,-0.148775,0.027916,0.817336,0.016761,0.600351,-0.019014,...,-0.908883,False,True,False,8.890000,0.23,0.034349,True,False,False
1,2015-07-16,1,0.000000,0.031501,0.000000,0.014610,0.749653,0.016855,0.821144,0.002250,...,0.034349,True,False,False,9.070000,2.02,0.536411,True,False,False
2,2015-07-17,1,0.014538,0.038038,0.383346,0.011178,0.573204,0.012295,0.605041,0.015661,...,0.536411,True,False,False,8.890000,-1.98,0.208044,False,False,False
3,2015-07-20,1,-0.018770,0.025390,-0.735745,0.000000,0.000000,0.006620,0.536063,-0.012140,...,0.208044,False,False,False,8.870000,-0.22,-0.090278,False,False,False
4,2015-07-21,1,0.005668,0.023812,0.238497,0.005667,0.509741,0.012462,0.809215,0.012472,...,-0.090278,False,False,False,8.820000,-0.56,-0.305364,False,True,False
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
605690,2025-04-21,688981,-0.009088,0.017738,-0.510623,0.004488,0.527157,0.004159,0.505591,-0.009415,...,1.344507,False,False,False,90.010002,-0.55,0.684920,False,False,False
605691,2025-04-22,688981,-0.005415,0.012487,-0.432372,0.004198,0.618603,0.002872,0.500288,-0.006740,...,0.684920,False,False,False,90.070000,0.07,-0.138453,True,True,False
605692,2025-04-23,688981,-0.000887,0.009651,-0.091893,0.006320,0.943174,0.002440,0.526911,-0.004770,...,-0.138453,True,True,False,89.099998,-1.08,1.115323,False,False,False
605693,2025-04-24,688981,-0.011650,0.018859,-0.615218,0.000555,0.172343,0.006653,0.636132,-0.005546,...,1.115323,False,False,False,88.419998,-0.76,1.255134,False,False,False


In [32]:
df_lag158[['item_id','timestamp',"涨跌_shift",'涨跌']]

Unnamed: 0,item_id,timestamp,涨跌_shift,涨跌
0,1,2015-07-15,False,True
1,1,2015-07-16,False,False
2,1,2015-07-17,False,False
3,1,2015-07-20,False,False
4,1,2015-07-21,True,False
...,...,...,...,...
605690,688981,2025-04-21,False,False
605691,688981,2025-04-22,True,False
605692,688981,2025-04-23,False,True
605693,688981,2025-04-24,False,False


In [None]:
import pandas as pd
from pathlib import Path
from auto_config import project_dir
MODEL_OUTPUT_BASE_PATH = project_dir / "models/stage4"
path= MODEL_OUTPUT_BASE_PATH / "updown_classfication_explore"

output_file1 = project_dir / "temp/stage3/train.pkl"
stock_finance = pd.read_pickle(output_file1)



In [36]:
stock_finance

Unnamed: 0,timestamp,item_id,KMID,KLEN,KMID2,KUP,KUP2,KLOW,KLOW2,KSFT,...,timestamp_day_of_week,timestamp_day_of_year,timestamp_month_end,timestamp_quarter_end,timestamp_month_sin,timestamp_month_cos,timestamp_day_of_week_sin,timestamp_day_of_week_cos,timestamp_day_of_month_sin,timestamp_day_of_month_cos
0,2015-07-15,2,0.041134,0.129692,0.318139,0.060836,0.755970,0.027404,0.478197,0.007623,...,2,196,0,0,-0.500000,-0.866025,8.660254e-01,-0.5,0.101168,-0.994869
1,2015-07-16,2,-0.001464,0.065917,-0.022219,0.024867,0.661925,0.039458,0.886077,0.013179,...,3,197,0,0,-0.500000,-0.866025,1.224647e-16,-1.0,-0.101168,-0.994869
2,2015-07-17,2,0.043514,0.085684,0.509674,0.020304,0.508791,0.021748,0.528478,0.045017,...,4,198,0,0,-0.500000,-0.866025,-8.660254e-01,-0.5,-0.299363,-0.954139
3,2015-07-20,2,0.053554,0.081853,0.657094,0.004231,0.229438,0.023950,0.572090,0.073406,...,0,201,0,0,-0.500000,-0.866025,0.000000e+00,1.0,-0.790776,-0.612106
4,2015-07-21,2,-0.035201,0.036546,-0.959101,0.000000,0.000000,0.001353,0.193661,-0.033816,...,1,202,0,0,-0.500000,-0.866025,8.660254e-01,0.5,-0.897805,-0.440394
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
499526,2025-04-14,688981,-0.016047,0.019789,-0.807982,0.000428,0.147580,0.003315,0.421740,-0.013153,...,0,104,0,0,0.866025,-0.500000,0.000000e+00,1.0,0.299363,-0.954139
499527,2025-04-15,688981,-0.015363,0.023098,-0.663127,0.000000,0.000000,0.007731,0.617146,-0.007625,...,1,105,0,0,0.866025,-0.500000,8.660254e-01,0.5,0.101168,-0.994869
499528,2025-04-16,688981,0.010987,0.023411,0.470517,0.000000,0.000000,0.012410,0.815933,0.023413,...,2,106,0,0,0.866025,-0.500000,8.660254e-01,-0.5,-0.101168,-0.994869
499529,2025-04-17,688981,0.005476,0.015117,0.362952,0.006351,0.705343,0.003285,0.485053,0.002410,...,3,107,0,0,0.866025,-0.500000,1.224647e-16,-1.0,-0.299363,-0.954139
