# 三阶段股票预测框架教程

本教程展示如何使用三阶段架构进行股票预测：

## 架构概览

```
Stage1: 空间特征提取器
  ├─ 输入: 多股票横截面数据
  ├─ 学习: 股票间关系、市场结构
  └─ 输出: 关系embedding (降维)

Stage2: 残差提升 (可选)
  └─ 改进Stage1预测

Stage3: 时序预测器
  ├─ 输入: 目标股票时序 + 关系特征
  ├─ 模型: LSTM/GRU/TCN/TFT
  └─ 输出: 最终预测
```

## 核心优势

1. **资源节省**: 将多股票×多特征降维到低维关系特征
2. **信息丰富**: 保留市场全局信息
3. **可解释**: Attention权重可视化股票影响力

## 1. 环境准备

In [None]:
import sys
sys.path.append('..')

import torch
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from tqdm import tqdm

# 设置随机种子
torch.manual_seed(42)
np.random.seed(42)

# 设置绘图风格
plt.style.use('seaborn-v0_8-darkgrid')
sns.set_palette("husl")

# 检查GPU
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f"使用设备: {device}")

## 2. 数据加载

使用`stock_data_fetcher_akshare.ipynb`生成的数据

In [None]:
# 加载数据
df = pd.read_csv('../data/data.csv')

print(f"数据形状: {df.shape}")
print(f"\n列名示例:")
print(df.columns[:20].tolist())

# 提取股票代码和特征
# 假设列名格式: 股票代码_特征名
all_columns = df.columns.tolist()

# 解析股票代码
stock_codes = []
for col in all_columns:
    if '_' in col and not col.startswith('idx_'):
        code = col.split('_')[0]
        if code not in stock_codes:
            stock_codes.append(code)

# 解析指数代码
index_codes = []
for col in all_columns:
    if col.startswith('idx_'):
        parts = col.split('_')
        code = f"idx_{parts[1]}"
        if code not in index_codes:
            index_codes.append(code)

print(f"\n股票数量: {len(stock_codes)}")
print(f"指数数量: {len(index_codes)}")
print(f"\n股票代码: {stock_codes[:5]}...")
print(f"指数代码: {index_codes}")

## 3. 配置Pipeline

In [None]:
from src.three_stage_pipeline import ThreeStagePipeline

# 配置
target_stock = stock_codes[0]  # 选择第一只股票作为预测目标

# 技术指标列名 (根据实际数据调整)
feature_columns = [
    'close', 'open', 'high', 'low', 'volume',
    'MA5', 'MA10', 'MA20', 'MA60',
    'RSI', 'MACD', 'MACD_signal',
    'BB_upper', 'BB_middle', 'BB_lower'
    # 根据实际数据添加更多特征
]

# 创建pipeline
pipeline = ThreeStagePipeline(
    stock_codes=stock_codes[:10],  # 先使用10只股票
    index_codes=index_codes,
    target_stock=target_stock,
    feature_columns=feature_columns,
    relationship_dim=32,
    seq_len=60,
    device=device
)

print("✓ Pipeline配置完成")

## 4. Stage1: 训练空间特征提取器

In [None]:
# 构建Stage1模型
pipeline.build_stage1(
    d_model=128,
    nhead=8,
    num_layers=3
)

# 分割数据
train_size = int(len(df) * 0.8)
val_size = int(len(df) * 0.1)

train_df = df.iloc[:train_size]
val_df = df.iloc[train_size:train_size+val_size]
test_df = df.iloc[train_size+val_size:]

print(f"训练集大小: {len(train_df)}")
print(f"验证集大小: {len(val_df)}")
print(f"测试集大小: {len(test_df)}")

In [None]:
# 训练Stage1
pipeline.train_stage1(
    train_data=train_df,
    val_data=val_df,
    num_epochs=50,
    batch_size=64,
    lr=1e-4
)

## 5. 提取关系特征

In [None]:
# 构建关系特征提取器
pipeline.build_relationship_extractor(extractor_type='hybrid')

# 提取全部数据的关系特征
df_with_relationships = pipeline.extract_relationship_features(
    df,
    save_path='../data/data_with_relationships.csv'
)

print(f"\n添加关系特征后的数据形状: {df_with_relationships.shape}")
print(f"新增列: {[col for col in df_with_relationships.columns if 'relationship' in col][:5]}...")

## 6. 可视化关系特征

分析提取的关系特征与股票收益的关系

In [None]:
# 可视化关系特征的时序变化
fig, axes = plt.subplots(4, 2, figsize=(15, 12))
axes = axes.flatten()

for i in range(min(8, pipeline.relationship_dim)):
    axes[i].plot(df_with_relationships[f'relationship_{i}'])
    axes[i].set_title(f'Relationship Feature {i}')
    axes[i].set_xlabel('Time')
    axes[i].set_ylabel('Value')

plt.tight_layout()
plt.savefig('../visualizations/relationship_features_timeseries.png', dpi=300)
plt.show()

In [None]:
# 关系特征的相关性矩阵
relationship_cols = [f'relationship_{i}' for i in range(pipeline.relationship_dim)]
corr_matrix = df_with_relationships[relationship_cols].corr()

plt.figure(figsize=(12, 10))
sns.heatmap(corr_matrix, cmap='coolwarm', center=0, 
            square=True, linewidths=0.5, cbar_kws={"shrink": 0.8})
plt.title('Relationship Features Correlation Matrix')
plt.savefig('../visualizations/relationship_features_correlation.png', dpi=300)
plt.show()

## 7. Attention可视化

查看目标股票对其他股票的注意力分布

In [None]:
from models.relationship_extractors import visualize_attention_relationships

# 提取一个batch的attention权重
sample_data = train_df.iloc[:64]
X_sample, _ = pipeline._prepare_stage1_data(sample_data)
X_sample = X_sample.to(device)

with torch.no_grad():
    attention_weights = pipeline.stage1_model.get_attention_weights(X_sample)

# 平均所有层和batch
avg_attention = attention_weights.mean(dim=[0, 1]).cpu()

# 可视化
stock_names = stock_codes[:10] + index_codes
target_idx = 0

visualize_attention_relationships(
    avg_attention,
    stock_names,
    target_idx,
    save_path='../visualizations/attention_weights.png'
)

## 8. Stage3: 训练时序预测器

In [None]:
# 构建Stage3模型 (选择LSTM)
pipeline.build_stage3(
    model_type='lstm',
    hidden_dim=128,
    num_layers=2
)

# 训练Stage3
pipeline.train_stage3(
    df_with_relationships=df_with_relationships,
    target_column=f'{target_stock}_target_return_1d',
    train_ratio=0.8,
    num_epochs=100,
    batch_size=64,
    lr=1e-4
)

## 9. 模型评估

In [None]:
# 在测试集上预测
predictions, test_features = pipeline.predict(
    test_df,
    return_features=True
)

# 真实值
actual = test_df[f'{target_stock}_target_return_1d'].values[-len(predictions):]

# 计算指标
from sklearn.metrics import mean_squared_error, mean_absolute_error, r2_score

mse = mean_squared_error(actual, predictions)
mae = mean_absolute_error(actual, predictions)
r2 = r2_score(actual, predictions)

print("测试集性能:")
print(f"MSE: {mse:.6f}")
print(f"MAE: {mae:.6f}")
print(f"R²: {r2:.6f}")

In [None]:
# 可视化预测结果
fig, axes = plt.subplots(2, 1, figsize=(15, 10))

# 预测 vs 实际
axes[0].plot(actual, label='Actual', alpha=0.7)
axes[0].plot(predictions, label='Predicted', alpha=0.7)
axes[0].set_title('Predictions vs Actual Returns')
axes[0].set_xlabel('Time')
axes[0].set_ylabel('Return')
axes[0].legend()
axes[0].grid(True)

# 散点图
axes[1].scatter(actual, predictions, alpha=0.5)
axes[1].plot([actual.min(), actual.max()], 
             [actual.min(), actual.max()], 
             'r--', lw=2, label='Perfect Prediction')
axes[1].set_xlabel('Actual Return')
axes[1].set_ylabel('Predicted Return')
axes[1].set_title(f'Prediction Scatter Plot (R²={r2:.4f})')
axes[1].legend()
axes[1].grid(True)

plt.tight_layout()
plt.savefig('../visualizations/stage3_predictions.png', dpi=300)
plt.show()

## 10. 保存Pipeline

In [None]:
# 保存完整pipeline
pipeline.save_pipeline('../saved_models/three_stage_pipeline')

print("✓ Pipeline已保存")

## 11. 加载和使用已保存的Pipeline

In [None]:
# 加载pipeline
new_pipeline = ThreeStagePipeline(
    stock_codes=[],  # 会从配置文件加载
    index_codes=[],
    target_stock='',
    feature_columns=[]
)

new_pipeline.load_pipeline('../saved_models/three_stage_pipeline')

# 使用加载的模型预测
new_predictions = new_pipeline.predict(test_df)

print(f"预测结果: {new_predictions[:5]}")

## 12. 对比不同方案

对比三种Stage3模型: LSTM vs GRU vs TCN

In [None]:
results = {}

for model_type in ['lstm', 'gru', 'tcn']:
    print(f"\n训练 {model_type.upper()}...")
    
    # 重新构建pipeline
    test_pipeline = ThreeStagePipeline(
        stock_codes=stock_codes[:10],
        index_codes=index_codes,
        target_stock=target_stock,
        feature_columns=feature_columns,
        relationship_dim=32,
        seq_len=60,
        device=device
    )
    
    # 使用已训练的Stage1
    test_pipeline.stage1_model = pipeline.stage1_model
    test_pipeline.relationship_extractor = pipeline.relationship_extractor
    
    # 训练Stage3
    test_pipeline.build_stage3(model_type=model_type)
    test_pipeline.train_stage3(
        df_with_relationships,
        target_column=f'{target_stock}_target_return_1d',
        num_epochs=50
    )
    
    # 评估
    preds = test_pipeline.predict(test_df)
    r2 = r2_score(actual, preds)
    
    results[model_type] = {
        'r2': r2,
        'predictions': preds
    }

# 可视化对比
fig, ax = plt.subplots(figsize=(10, 6))

ax.plot(actual, label='Actual', linewidth=2, alpha=0.7)
for model_type, result in results.items():
    ax.plot(result['predictions'], 
            label=f"{model_type.upper()} (R²={result['r2']:.4f})",
            alpha=0.7)

ax.set_title('Model Comparison: LSTM vs GRU vs TCN')
ax.set_xlabel('Time')
ax.set_ylabel('Return')
ax.legend()
ax.grid(True)

plt.tight_layout()
plt.savefig('../visualizations/model_comparison.png', dpi=300)
plt.show()

## 总结

本教程展示了完整的三阶段股票预测流程:

1. **Stage1**: 学习股票间的空间关系
2. **关系特征提取**: 将高维市场信息压缩到低维向量
3. **Stage3**: 结合关系特征进行时序预测

### 下一步

- 尝试不同的关系特征维度
- 调整序列长度
- 添加更多技术指标
- 实现Stage2残差提升
- 尝试TFT模型