# 股票预测模型工作流

--- 

本 Notebook 包含一个完整的机器学习工作流，分为两个主要阶段：

### **阶段一：数据准备与特征工程**
**目标**: 获取所有需要的原始数据，进行完整的特征计算，并将最终可用于建模的干净数据集保存到磁盘。
**运行频率**: 只有在股票池、数据周期或特征工程逻辑发生改变时才需要运行此阶段。

### **阶段二：模型训练与评估**
**目标**: 从磁盘加载已处理好的数据，进行超参数优化（可选）、模型训练、评估和结果可视化。
**运行频率**: 这是进行模型实验的主要区域，可以反复运行以测试不同的模型、参数或评估方法。

---

## 通用设置与导入

In [1]:
import sys
import json
import yaml
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from pathlib import Path
from tqdm.autonotebook import tqdm
import hashlib

# --- 设置 Matplotlib 样式 ---
plt.style.use('seaborn-v0_8-whitegrid')
plt.rcParams['font.sans-serif'] = ['SimHei']
plt.rcParams['axes.unicode_minus'] = False
print("INFO: Matplotlib and Seaborn styles configured.")

# --- 健壮的模块导入逻辑 ---
try:
    from model_builders.build_models import run_training_for_ticker
    from data_process.save_data import run_data_pipeline
    from data_process.get_data import _initialize_apis, bs
    from data_process.save_data import run_data_pipeline
    print("INFO: Project modules imported successfully.")
except ImportError:
    print("WARNNING: Standard import failed. Adding project root to sys.path.")
    project_root = str(Path().resolve())
    if project_root not in sys.path:
        sys.path.append(project_root)
    from model_builders.build_models import run_training_for_ticker
    from data_process.save_data import run_data_pipeline
    print("INFO: Project modules imported successfully after path adjustment.")

# --- 加载配置文件 ---
CONFIG_PATH = 'configs/config.yaml'
try:
    with open(CONFIG_PATH, 'r', encoding='utf-8') as f:
        config = yaml.safe_load(f)
    print(f"SUCCESS: Unified configuration file loaded from '{CONFIG_PATH}'.")
except FileNotFoundError:
    print(f"ERROR: Configuration file not found at '{CONFIG_PATH}'. Please check the path.")
    config = {}

# --- 提取核心配置块 ---
if config:
    run_data_pipeline(config_path=CONFIG_PATH)
    global_settings = config.get('global_settings', {})
    strategy_config = config.get('strategy_config', {})
    hpo_config = config.get('hpo_config', {})
    default_model_params = config.get('default_model_params', {})
    stocks_to_process = config.get('stocks_to_process', [])

# --- 核心修正：使用 try...finally 来确保 API 被正确关闭 ---
try:
    if config:
        # 1. 在开始前，初始化 API
        print("--- Initializing Data APIs (Baostock, Tushare) ---")
        _initialize_apis(config)
        
        # 2. 运行数据处理流水线
        run_data_pipeline(config_path=CONFIG_PATH)
        
        print("\n--- Stage 1 Finished: All data has been processed and saved. ---")
    else:
        print("ERROR: Config is empty. Cannot start data preparation.")

finally:
    # 3. 无论阶段一成功还是失败，最后都尝试登出 API
    print("\n" + "="*80)
    try:
        bs.logout()
        print("INFO: Baostock API has been logged out successfully from Stage 1.")
    except Exception as e:
        print(f"WARNNING: Failed to logout Baostock API from Stage 1: {e}")
    print("="*80)
# --- 修正结束 ---

  from tqdm.autonotebook import tqdm


INFO: Matplotlib and Seaborn styles configured.
INFO: Project modules imported successfully.
SUCCESS: Unified configuration file loaded from 'configs/config.yaml'.
开始执行数据管道协调任务...
将使用配置文件: configs/config.yaml

--- Starting Batch Feature Generation Process ---
Using config file: configs/config.yaml
login success!
INFO: Baostock API 登录成功。SDK版本: 00.8.90
INFO: 未在配置中提供有效的 Tushare Token。将跳过宏观数据获取。

--- Generating features for 贵州茅台 (600519.SH) ---
  - [1/7] 正在从本地缓存加载 sh.600519 的原始日线数据...
INFO: No macroeconomic data to merge. Proceeding without it.
INFO: Starting feature calculation pipeline...
  - [Calculating Features] Running: Technical Indicators...
    - Calculated: ema with params {'length': 10}
    - Calculated: ema with params {'length': 30}
    - Calculated: rsi with params {'length': 14}
    - Calculated: macd with params {'fast': 12, 'slow': 26, 'signal': 9}
    - Calculated: bbands with params {'length': 20, 'std': 2}
  - [Calculating Features] Running: Calendar Features...
  - [Ca

top-level pandera module will be **removed in a future version of pandera**.
If you're using pandera to validate pandas objects, we highly recommend updating
your import:

```
# old import
import pandera as pa

# new import
import pandera.pandas as pa
```

If you're using pandera to validate objects from other compatible libraries
like pyspark or polars, see the supported libraries section of the documentation
for more information on how to import pandera:

https://pandera.readthedocs.io/en/stable/supported_libraries.html


```
```



WARNNING: Found 5 large gaps in time series index. Max gap: 17 days 00:00:00
SUCCESS: Data schema and time index validation passed.
--- SUCCESS: Features generated for 平安银行. Shape: (4952, 23) ---

--- Generating features for 寒武纪-U (688256.SH) ---
  - [1/7] 正在从本地缓存加载 sh.688256 的原始日线数据...
INFO: No macroeconomic data to merge. Proceeding without it.
INFO: Starting feature calculation pipeline...
  - [Calculating Features] Running: Technical Indicators...
    - Calculated: ema with params {'length': 10}
    - Calculated: ema with params {'length': 30}
    - Calculated: rsi with params {'length': 14}
    - Calculated: macd with params {'fast': 12, 'slow': 26, 'signal': 9}
    - Calculated: bbands with params {'length': 20, 'std': 2}
  - [Calculating Features] Running: Calendar Features...
  - [Calculating Features] INFO: No candlestick patterns specified in config. Skipping Candlestick Patterns.
INFO: Feature calculation pipeline finished.
  - [1/7] 正在从本地缓存加载 sh.000300 的原始日线数据...
  - [1/7] 

# **阶段一：数据准备与特征工程**

运行下面的单元格将为 `config.yaml` 中定义的所有股票生成特征数据，并将其保存到 `data/processed` 目录（可在配置中修改）。

In [2]:
# --- 这是一个辅助函数，用于根据配置定位保存好的数据文件 ---
def get_processed_data_path(stock_info: dict, config: dict) -> Path:
    relevant_config = {**config.get('global_settings', {}), **config.get('strategy_config', {}), **stock_info}
    relevant_config.pop('keyword', None)
    relevant_config.pop('ticker', None)
    relevant_config.pop('tushare_api_token', None)
    config_string = json.dumps(relevant_config, sort_keys=True)
    config_hash = hashlib.sha256(config_string.encode('utf-8')).hexdigest()[:12]

    output_dir_base = config.get('global_settings', {}).get('output_dir', 'data/processed')
    start_date = config.get('strategy_config', {}).get('start_date')
    end_date = config.get('strategy_config', {}).get('end_date')
    date_range_str = f"{start_date}_to_{end_date}"

    target_dir = Path(output_dir_base) / stock_info['ticker'] / date_range_str
    return target_dir / f"features_{config_hash}.pkl"

# --- 运行数据准备流水线 ---
print("--- Starting Stage 1: Data Preparation and Feature Engineering ---")
if config:
    # 调用 data_process/save_data.py 中的主函数来协调整个流程
    # 注意：save_data.py 内部会调用 get_data.py
    # 我们需要一个协调器来处理保存逻辑，而不是直接调用 get_data
    from data_process.save_data import run_data_pipeline
    run_data_pipeline(config_path=CONFIG_PATH)
    print("\n--- Stage 1 Finished: All data has been processed and saved. ---")
else:
    print("ERROR: Config is empty. Cannot start data preparation.")


--- Starting Stage 1: Data Preparation and Feature Engineering ---
开始执行数据管道协调任务...
将使用配置文件: configs/config.yaml

--- Starting Batch Feature Generation Process ---
Using config file: configs/config.yaml
INFO: 未在配置中提供有效的 Tushare Token。将跳过宏观数据获取。

--- Generating features for 贵州茅台 (600519.SH) ---
  - [1/7] 正在从本地缓存加载 sh.600519 的原始日线数据...
INFO: No macroeconomic data to merge. Proceeding without it.
INFO: Starting feature calculation pipeline...
  - [Calculating Features] Running: Technical Indicators...
    - Calculated: ema with params {'length': 10}
    - Calculated: ema with params {'length': 30}
    - Calculated: rsi with params {'length': 14}
    - Calculated: macd with params {'fast': 12, 'slow': 26, 'signal': 9}
    - Calculated: bbands with params {'length': 20, 'std': 2}
  - [Calculating Features] Running: Calendar Features...
  - [Calculating Features] INFO: No candlestick patterns specified in config. Skipping Candlestick Patterns.
INFO: Feature calculation pipeline finished.
  - 

# **阶段二：模型训练与评估**

在运行此阶段之前，**请确保阶段一已成功运行**。

此阶段将**从本地加载数据**，然后执行模型训练和评估。

### 2.1 (可选) 超参数优化

In [3]:
# --- 定义配置文件路径 ---
CONFIG_PATH = 'configs/config.yaml'

# --- 加载配置文件 ---
try:
    with open(CONFIG_PATH, 'r', encoding='utf-8') as f:
        config = yaml.safe_load(f)
    print(f"SUCCESS: Unified configuration file loaded from '{CONFIG_PATH}'.")
except FileNotFoundError:
    print(f"ERROR: Configuration file not found at '{CONFIG_PATH}'. Please check the path.")
    config = {}

# --- 提取核心配置块 ---
# 这样做可以让后续代码更清晰
if config:
    global_settings = config.get('global_settings', {})
    strategy_config = config.get('strategy_config', {})
    hpo_config = config.get('hpo_config', {})
    default_model_params = config.get('default_model_params', {})
    stocks_to_process = config.get('stocks_to_process', [])

SUCCESS: Unified configuration file loaded from 'configs/config.yaml'.


### 2.2 模型训练

In [4]:
# 是否从头训练
FORCE_RETRAIN = True
all_ic_history = []

if config and stocks_to_process:
    models_to_train = global_settings.get('models_to_train', ['lgbm', 'lstm'])
    stock_iterator = tqdm(stocks_to_process, desc="Processing Stocks")

    for stock_info in stock_iterator:
        ticker = stock_info.get('ticker')
        keyword = stock_info.get('keyword', ticker)
        stock_iterator.set_description(f"Processing {keyword}")

        if not ticker:
            continue
        
        # --- 核心修正：从本地加载数据，而不是重新生成 ---
        data_path = get_processed_data_path(stock_info, config)
        if not data_path.exists():
            print(f"\nERROR: Processed data for {keyword} not found at {data_path}. Please run Stage 1 first. Skipping.")
            continue
        
        try:
            df = pd.read_pickle(data_path)
            print(f"\nINFO: Successfully loaded processed data for {keyword} from {data_path}. Shape: {df.shape}")
        except Exception as e:
            print(f"\nERROR: Failed to load data for {keyword} from {data_path}: {e}. Skipping.")
            continue

        for model_type in models_to_train:
            # ... (这部分与之前相同，只是数据源变了) ...
            # run_training_for_ticker 函数需要稍微修改，使其接受 DataFrame 而不是重新调用 get_data
            # 为了保持模块化，我们暂时不修改 run_training_for_ticker，而是让它内部的 get_data 从缓存加载
            # 一个更优的重构是让 run_training_for_ticker 直接接收 df，但目前的缓存机制已经实现了高效加载
            # 因此，我们保持调用不变
            base_config = {**global_settings, **strategy_config, **default_model_params, **stock_info}
            final_run_config = {'global_settings': base_config, 'stocks_to_process': [stock_info]}

            ic_history = run_training_for_ticker(
                ticker=ticker,
                model_type=model_type,
                config=final_run_config, 
                force_retrain=FORCE_RETRAIN,
                keyword=keyword
            )
            
            if ic_history is not None and not ic_history.empty:
                all_ic_history.append(ic_history)
else:
    print("ERROR: Config is empty or 'stocks_to_process' list is missing/empty. Cannot start training.")

Processing TCL科技: 100%|██████████| 7/7 [00:00<00:00, 1585.66it/s]


ERROR: Processed data for 贵州茅台 not found at data\processed\600519.SH\2000-01-01_to_2025-08-31\features_29c486410003.pkl. Please run Stage 1 first. Skipping.

ERROR: Processed data for 平安银行 not found at data\processed\000001.SZ\2000-01-01_to_2025-08-31\features_a7a139023ee9.pkl. Please run Stage 1 first. Skipping.

ERROR: Processed data for 寒武纪-U not found at data\processed\688256.SH\2000-01-01_to_2025-08-31\features_a9103f5aa238.pkl. Please run Stage 1 first. Skipping.

ERROR: Processed data for 长城军工 not found at data\processed\601606.SH\2000-01-01_to_2025-08-31\features_1539dfeccccd.pkl. Please run Stage 1 first. Skipping.

ERROR: Processed data for 视觉中国 not found at data\processed\000681.SZ\2000-01-01_to_2025-08-31\features_97f71e6bbd04.pkl. Please run Stage 1 first. Skipping.

ERROR: Processed data for 长白山 not found at data\processed\603099.SH\2000-01-01_to_2025-08-31\features_e4093a246eef.pkl. Please run Stage 1 first. Skipping.

ERROR: Processed data for TCL科技 not found at data\p




### 2.3 结果聚合、评估与可视化

这部分与之前完全相同，用于分析训练结果。

#### 2.3.1 聚合并保存IC历史

训练完成后，所有模型在各自验证集上的表现（IC）被收集起来。我们将它们合并成一个大的 DataFrame，并保存到磁盘，以便进行后续的统一分析和作为模型融合策略的输入。

In [5]:
if all_ic_history:
    full_ic_df = pd.concat(all_ic_history)
    
    # 定义保存路径
    output_dir = Path(config.get('global_settings', {}).get('model_dir', 'models'))
    output_dir.mkdir(parents=True, exist_ok=True)
    ic_output_path = output_dir / 'full_ic_history.csv'
    
    # 保存到 CSV
    full_ic_df.to_csv(ic_output_path)
    
    print(f"SUCCESS: Aggregated IC history saved to '{ic_output_path}'.")
    print("\n--- Aggregated IC History (first 5 rows) ---")
    display(full_ic_df.head())
else:
    print("WARNNING: No IC history was generated during training. Skipping aggregation and evaluation.")

WARNNING: No IC history was generated during training. Skipping aggregation and evaluation.


#### 2.3.2 评估模型

我们使用聚合后的 IC 历史数据来计算每个模型/股票组合的性能指标：

- **IC Mean (IC均值)**: 预测方向准确性的平均度量。越高越好。
- **IC Std (IC标准差)**: 表现的稳定性。越低越好。
- **ICIR (信息比率)**: `IC Mean / IC Std`，综合了准确性和稳定性，是衡量策略质量的核心指标。通常认为大于 0.5 即为不错的策略。

In [6]:
if 'full_ic_df' in locals() and not full_ic_df.empty:
    evaluation_summary = full_ic_df.groupby(['ticker', 'model_type'])['rank_ic'].agg(['mean', 'std'])
    evaluation_summary['icir'] = evaluation_summary['mean'] / evaluation_summary['std']
    
    print("--- Model Performance Evaluation Summary ---")
    display(evaluation_summary)
else:
    print("INFO: No data available for evaluation.")

INFO: No data available for evaluation.


#### 2.3.3 评估结果可视化

##### 2.3.3.1 评估指标汇总表

In [7]:
if 'evaluation_summary' in locals():
    display(evaluation_summary.style.format({
        'mean': '{:.4f}',
        'std': '{:.4f}',
        'icir': '{:.4f}'
    }).background_gradient(cmap='viridis', subset=['icir']))
else:
    print("INFO: No evaluation summary to display.")

INFO: No evaluation summary to display.


##### 2.3.3.2 7.2 ICIR 对比图

该图直观地比较了不同模型在各个股票上的 ICIR 表现。

In [8]:
if 'evaluation_summary' in locals():
    plt.figure(figsize=(12, 6))
    sns.barplot(data=evaluation_summary.reset_index(), x='ticker', y='icir', hue='model_type')
    plt.title('模型信息比率 (ICIR) 对比', fontsize=16)
    plt.xlabel('股票代码', fontsize=12)
    plt.ylabel('ICIR (信息比率)', fontsize=12)
    plt.axhline(0, color='grey', linestyle='--')
    plt.axhline(0.5, color='red', linestyle='--', label='ICIR=0.5 (良好)')
    plt.legend()
    plt.show()
else:
    print("INFO: No data to plot ICIR comparison.")

INFO: No data to plot ICIR comparison.


##### 2.3.3.3 累积 IC 曲线图

该图展示了模型 IC 随时间的累积变化。一个稳定有效的模型应该呈现出一条持续向上倾斜的曲线。

In [9]:
if 'full_ic_df' in locals() and not full_ic_df.empty:
    # 创建用于绘图的副本
    plot_df = full_ic_df.copy()
    plot_df['date'] = pd.to_datetime(plot_df['date'])
    plot_df.set_index('date', inplace=True)
    
    # 计算累积IC
    plot_df['cumulative_ic'] = plot_df.groupby(['ticker', 'model_type'])['rank_ic'].cumsum()
    plot_df.reset_index(inplace=True)
    
    plt.figure(figsize=(14, 8))
    sns.lineplot(data=plot_df, x='date', y='cumulative_ic', hue='ticker', style='model_type')
    
    plt.title('模型累积 Rank IC 曲线', fontsize=16)
    plt.xlabel('日期', fontsize=12)
    plt.ylabel('累积 Rank IC', fontsize=12)
    plt.legend(title='股票/模型')
    plt.show()
else:
    print("INFO: No data to plot cumulative IC.")

INFO: No data to plot cumulative IC.
