# 模型增量更新与每日预测引擎 (批量版)

---

### **目标**
本 Notebook 是项目的**每日自动化运行脚本**。它会遍历配置文件中定义的所有股票，为每一只成功训练过模型的股票执行在线学习（热更新），并生成新的交易信号。

### **工作流程**
1.  **环境设置**: 导入库，加载配置。
2.  **主循环 (遍历股票)**: 对股票池中的每一只股票，执行以下步骤：
    a. **加载模型**: 加载该股票的基础模型 (LGBM, LSTM) 和融合模型 (Fuser)。如果任何模型缺失，则跳过该股票。
    b. **获取增量数据**: 智能地获取从上次批量训练至今的所有新数据。
    c. **执行增量训练**: 批量生成“历史预测”与“真实标签”，对融合模型进行 `partial_train`。
    d. **生成今日新预测**: 使用更新后的融合模型，生成用于次日交易的决策建议。

## 1. 环境设置与导入

In [3]:
import sys, yaml, pandas as pd, joblib, torch, torch.nn as nn, numpy as np, json
from pathlib import Path
from tqdm.autonotebook import tqdm

# --- 模块导入 ---
try:
    from data_process.get_data import initialize_apis, shutdown_apis, get_full_feature_df
    from model_builders.model_fuser import ModelFuser
    from model_builders.lstm_builder import LSTMModel
    print("INFO: Project modules imported successfully.")
except ImportError as e:
    print(f"WARNNING: Import failed: {e}. Adding project root...")
    project_root = str(Path().resolve()); sys.path.append(project_root) if project_root not in sys.path else None
    from data_process.get_data import initialize_apis, shutdown_apis, get_full_feature_df
    from model_builders.model_fuser import ModelFuser
    from model_builders.lstm_builder import LSTMModel
    print("INFO: Re-imported successfully.")

# --- 加载配置文件 ---
CONFIG_PATH = 'configs/config.yaml'
try:
    with open(CONFIG_PATH, 'r', encoding='utf-8') as f: config = yaml.safe_load(f)
    print(f"SUCCESS: Config loaded from '{CONFIG_PATH}'.")
except FileNotFoundError:
    print(f"ERROR: Config file not found."); config = {}

# --- 提取核心配置块 ---
if config:
    global_settings = config.get('global_settings', {})
    stocks_to_process = config.get('stocks_to_process', [])
    models_to_train = global_settings.get('models_to_train', ['lgbm', 'lstm'])

INFO: Project modules imported successfully.
SUCCESS: Config loaded from 'configs/config.yaml'.


## 2. 批量增量更新与预测主循环

In [None]:
if config and stocks_to_process:
    stock_iterator = tqdm(stocks_to_process, desc="批量更新模型")
    all_predictions_summary = []

    for stock_info in stock_iterator:
        ticker = stock_info.get('ticker')
        keyword = stock_info.get('keyword', ticker)
        stock_iterator.set_description(f"Processing {keyword}")

        if not ticker: continue

        # ======================================================================
        # 2.1 加载所有模型
        # ======================================================================
        base_models, base_scalers = {}, {}
        all_models_loaded = True
        
        # 直接从顶层的 global_settings 字典中获取 model_dir
        model_dir = Path(config.get('global_settings', {}).get('model_dir', 'models')) / ticker

        # 动态确定需要加载的模型
        models_to_load_for_this_stock = ['lgbm']
        use_lstm_for_this_stock = stock_info.get('use_lstm')
        if use_lstm_for_this_stock is None: use_lstm_for_this_stock = global_settings.get('use_lstm_globally', True)
        if use_lstm_for_this_stock: models_to_load_for_this_stock.append('lstm')
        
        for model_type in models_to_load_for_this_stock:
            model_files = sorted(model_dir.glob(f"{model_type}_model_*.p*t"))
            if not model_files:
                print(f"WARNNING: 在路径 '{model_dir}' 下未找到 {keyword} 的 {model_type.upper()} 模型文件。")
                all_models_loaded = False; break
            
            latest_model_file = model_files[-1]
            version_timestamp = latest_model_file.stem.split('_')[-1]
            latest_scaler_file = model_dir / f"{model_type}_scaler_{version_timestamp}.pkl"

            try:
                if model_type == 'lgbm':
                    base_models[model_type] = joblib.load(latest_model_file)
                elif model_type == 'lstm':
                    # input_size 将在稍后获取数据时动态确定
                    base_models[model_type] = {'path': latest_model_file, 'config': stock_info}
                
                base_scalers[model_type] = joblib.load(latest_scaler_file)
            except Exception as e:
                print(f"ERROR: 加载 {keyword} 的 {model_type.upper()} 构件失败: {e}")
                all_models_loaded = False; break
        
        fuser = ModelFuser(ticker, config)
        if not fuser.load(): all_models_loaded = False

        if not all_models_loaded:
            print(f"WARNNING: {keyword} 的模型构件不完整，跳过增量更新。")
            continue

        # ======================================================================
        # 2.2 获取增量数据
        # ======================================================================
        incremental_df = None
        new_data_for_update = None
        try:
            initialize_apis(config)
            
            last_train_date = None
            if fuser.meta_path.exists():
                with open(fuser.meta_path, 'r', encoding='utf-8') as f: meta_info = json.load(f)
                last_train_date = pd.to_datetime(meta_info.get('trained_at')).date()
            
            if last_train_date is None: last_train_date = pd.Timestamp.now().date() - pd.DateOffset(months=3)

            inc_config = config.copy()
            inc_config['strategy_config']['earliest_start_date'] = last_train_date.strftime('%Y-%m-%d')
            inc_config['strategy_config']['end_date'] = pd.Timestamp.now().strftime('%Y-%m-%d')
            
            incremental_df = get_full_feature_df(ticker, inc_config, keyword=keyword, prediction_mode=False) # 使用 False 以确保日期被尊重
        finally:
            shutdown_apis()
        
        if incremental_df is not None and len(incremental_df) > 1:
            label_col = config.get('global_settings', {}).get('label_alpha', 'label_return')
            feature_cols = [c for c in incremental_df.columns if c != label_col and not c.startswith('future_')]
            
            # 动态加载 LSTM 模型
            if 'lstm' in base_models:
                lstm_model_info = base_models['lstm']
                lstm_cfg = {**config.get('default_model_params',{}), **lstm_model_info['config']}.get('lstm_params',{})
                model_instance = LSTMModel(input_size=len(feature_cols), hidden_size_1=lstm_cfg.get('units_1',64), hidden_size_2=lstm_cfg.get('units_2',32), dropout=lstm_cfg.get('dropout',0.2))
                model_instance.load_state_dict(torch.load(lstm_model_info['path'])); model_instance.eval()
                base_models['lstm'] = model_instance
            
            X_incremental = incremental_df[feature_cols]
            y_true_incremental = incremental_df[label_col]
            
            preds_lgbm = base_scalers['lgbm'].transform(X_incremental)
            preds_lgbm = base_models['lgbm']['q_0.5'].predict(preds_lgbm)
            
            preds_dict_for_update = {'pred_lgbm': preds_lgbm}

            if 'lstm' in base_models:
                preds_lstm = []
                seq_len = lstm_cfg.get('sequence_length', 60)
                for i in range(len(X_incremental)):
                    pseudo_sequence_df = pd.concat([X_incremental.iloc[i:i+1]] * seq_len, ignore_index=True)
                    X_scaled_lstm = base_scalers['lstm'].transform(pseudo_sequence_df)
                    X_tensor_lstm = torch.from_numpy(X_scaled_lstm).unsqueeze(0).float()
                    with torch.no_grad(): preds_lstm.append(base_models['lstm'](X_tensor_lstm).item())
                preds_dict_for_update['pred_lstm'] = preds_lstm

            new_data_for_update = pd.DataFrame(preds_dict_for_update, index=X_incremental.index)
            new_data_for_update['y_true'] = y_true_incremental
            new_data_for_update.dropna(inplace=True)
            
            if new_data_for_update.empty: new_data_for_update = None
        
        # ======================================================================
        # 2.3 执行增量训练
        # ======================================================================
        if new_data_for_update is not None and not new_data_for_update.empty:
            fuser.partial_train(new_data_for_update)
        else:
            print(f"INFO: {keyword} 没有可用的新数据进行增量训练。")

        # ======================================================================
        # 2.4 生成今日新预测
        # ======================================================================
        if incremental_df is not None and len(incremental_df) > 0:
            today_features = incremental_df.iloc[-1:]
            X_today = today_features[feature_cols]
            
            X_scaled_lgbm = base_scalers['lgbm'].transform(X_today)
            pred_lgbm_today = base_models['lgbm']['q_0.5'].predict(X_scaled_lgbm)[0]
            
            preds_today = {'pred_lgbm': pred_lgbm_today}

            if 'lstm' in base_models:
                lstm_cfg = {**config.get('default_model_params',{}), **stock_info}.get('lstm_params',{})
                seq_len = lstm_cfg.get('sequence_length', 60)
                pseudo_sequence = pd.concat([X_today] * seq_len, ignore_index=True)
                X_scaled_lstm = base_scalers['lstm'].transform(pseudo_sequence)
                X_tensor_lstm = torch.from_numpy(X_scaled_lstm).unsqueeze(0).float()
                with torch.no_grad(): pred_lstm_today = base_models['lstm'](X_tensor_lstm).item()
                preds_today['pred_lstm'] = pred_lstm_today
            
            final_signal = fuser.predict(preds_today)
            direction = '看涨 (BUY)' if final_signal > 0 else '看跌 (SELL)'
            
            all_predictions_summary.append({'股票': keyword, '代码': ticker, '信号方向': direction, '信号强度': final_signal})
            
            # (可选) 打印单只股票的详细报告
            print(f"--- 为 {keyword} 生成的明日交易建议 ---")
            print(f"核心观点: {direction}, 信号强度: {final_signal:.4%}")

    # --- 循环结束后，显示所有预测的汇总 ---
    if all_predictions_summary:
        print("\n" + "="*80)
        print("--- 所有股票当日预测信号汇总 ---")
        summary_df = pd.DataFrame(all_predictions_summary)
        display(summary_df.style.format({'信号强度': '{:+.4%}'}).set_index(['股票', '代码']))
        print("="*80)
    else:
        print("\nWARNNING: 未能为任何股票生成有效预测。")

else:
    print("ERROR: 配置文件或股票池为空。")

批量更新模型:   0%|          | 0/11 [00:00<?, ?it/s]

WARNNING: 在路径 'models\000001.SZ' 下未找到 平安银行 的 LGBM 模型文件。
SUCCESS: 融合构件已加载。
WARNNING: 平安银行 的模型构件不完整，跳过增量更新。
WARNNING: 在路径 'models\601606.SH' 下未找到 长城军工 的 LGBM 模型文件。
INFO: 未找到已训练的融合模型文件，将使用新初始化的模型。
WARNNING: 长城军工 的模型构件不完整，跳过增量更新。
WARNNING: 在路径 'models\000100.SZ' 下未找到 TCL科技 的 LGBM 模型文件。
SUCCESS: 融合构件已加载。
WARNNING: TCL科技 的模型构件不完整，跳过增量更新。
WARNNING: 在路径 'models\000426.SZ' 下未找到 兴业矿业 的 LGBM 模型文件。
INFO: 未找到已训练的融合模型文件，将使用新初始化的模型。
WARNNING: 兴业矿业 的模型构件不完整，跳过增量更新。
WARNNING: 在路径 'models\002083.SZ' 下未找到 孚日股份 的 LGBM 模型文件。
SUCCESS: 融合构件已加载。
WARNNING: 孚日股份 的模型构件不完整，跳过增量更新。
WARNNING: 在路径 'models\000150.SZ' 下未找到 宜华健康 的 LGBM 模型文件。
SUCCESS: 融合构件已加载。
WARNNING: 宜华健康 的模型构件不完整，跳过增量更新。
WARNNING: 在路径 'models\300013.SZ' 下未找到 新宁物流 的 LGBM 模型文件。
INFO: 未找到已训练的融合模型文件，将使用新初始化的模型。
WARNNING: 新宁物流 的模型构件不完整，跳过增量更新。
WARNNING: 在路径 'models\300242.SZ' 下未找到 佳云科技 的 LGBM 模型文件。
SUCCESS: 融合构件已加载。
WARNNING: 佳云科技 的模型构件不完整，跳过增量更新。
WARNNING: 在路径 'models\600301.SH' 下未找到 ST南化 的 LGBM 模型文件。
SUCCESS: 融合构件已加载。
WARNNING: ST南化 的模型构件不完整，跳过增量更新。
WA

: 