# 股票价格预测数据探索

In [None]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import yfinance as yf

# 配置参数
plt.style.use('seaborn')
sns.set(font_scale=1.2)
pd.set_option('display.max_columns', None)

## 数据获取

In [None]:
def fetch_stock_data(tickers, start_date, end_date):
    stock_data = {}
    for ticker in tickers:
        data = yf.download(ticker, start=start_date, end=end_date)
        stock_data[ticker] = data
    return stock_data

tickers = ['AAPL', 'GOOGL', 'MSFT', 'AMZN']
start_date = '2018-01-01'
end_date = '2023-06-30'

stock_datasets = fetch_stock_data(tickers, start_date, end_date)

## 数据预处理

In [None]:
def preprocess_data(stock_datasets):
    processed_data = {}
    for ticker, data in stock_datasets.items():
        # 计算技术指标
        data['MA_20'] = data['Close'].rolling(window=20).mean()
        data['MA_50'] = data['Close'].rolling(window=50).mean()
        data['RSI'] = compute_rsi(data['Close'], 14)
        
        # 计算价格变化率
        data['Price_Change'] = data['Close'].pct_change()
        data['Target'] = data['Close'].shift(-1)
        
        processed_data[ticker] = data.dropna()
    return processed_data

def compute_rsi(price_series, periods=14):
    delta = price_series.diff()
    gain = delta.clip(lower=0)
    loss = -delta.clip(upper=0)
    
    avg_gain = gain.rolling(window=periods).mean()
    avg_loss = loss.rolling(window=periods).mean()
    
    relative_strength = avg_gain / avg_loss
    rsi = 100.0 - (100.0 / (1.0 + relative_strength))
    return rsi

processed_stocks = preprocess_data(stock_datasets)

## 数据可视化

In [None]:
plt.figure(figsize=(15, 10))
for i, (ticker, data) in enumerate(processed_stocks.items(), 1):
    plt.subplot(2, 2, i)
    plt.plot(data.index, data['Close'], label='收盘价')
    plt.plot(data.index, data['MA_20'], label='20日移动平均线')
    plt.plot(data.index, data['MA_50'], label='50日移动平均线')
    plt.title(f'{ticker} 股价走势')
    plt.xlabel('日期')
    plt.ylabel('价格')
    plt.legend()

plt.tight_layout()
plt.show()

## 相关性分析

In [None]:
def correlation_analysis(processed_stocks):
    combined_data = pd.concat([data[['Close', 'MA_20', 'MA_50', 'RSI', 'Price_Change']] for data in processed_stocks.values()], axis=1)
    correlation_matrix = combined_data.corr()
    
    plt.figure(figsize=(10, 8))
    sns.heatmap(correlation_matrix, annot=True, cmap='coolwarm', center=0)
    plt.title('特征相关性分析')
    plt.show()

correlation_analysis(processed_stocks)