In [6]:
import numpy as np
import pandas as pd
import os
import tushare as ts
from datetime import datetime
import json
from loguru import logger
from typing import List, Tuple, Dict, Any, Callable
pd.set_option("display.unicode.east_asian_width", True)

TOKEN_PATH = os.path.expanduser('~/.tushare.token')

with open(TOKEN_PATH, 'r') as f:
    token = f.read().strip()
    ts.set_token(token)
    pro = ts.pro_api()

from datetime import datetime
from sklearn.linear_model import LinearRegression

TICKER = "600970.SH"
START = "20210101"
END   = "20241023"

data = pro.daily_basic(
    ts_code=TICKER,
    start_date=START,
    end_date=END,
    fields="trade_date,turnover_rate_f,pe_ttm,pb,ps,close,vol"
)

data["trade_date"] = pd.to_datetime(data["trade_date"], format=r"%Y%m%d")
data.rename(columns={"trade_date": "date"}, inplace=True)
data.set_index("date", drop=True, inplace=True)
data.sort_index(ascending=True, inplace=True)

future_period = 20
sum_period = 10

data[f"returns-window-{future_period}"] = data["close"].pct_change(periods=future_period) + 1
# print(data[f"returns-window-{period}"])
data[f"future-returns-{future_period}"] = data[f"returns-window-{future_period}"].shift(-future_period)

data[f"trf-sum-{sum_period}"] = data["turnover_rate_f"].rolling(sum_period).sum()
data[f"vol-sum-{sum_period}"] = data["vol"].rolling(sum_period).sum()

data.dropna(inplace=True)
print(data)

def zscore(ser: pd.Series) -> pd.Series:
    return (ser - ser.mean()) / ser.std()

columns = [
    f"future-returns-{future_period}",
    # "turnover_rate_f",
    f"trf-sum-{sum_period}",
    f"vol-sum-{sum_period}",
    "pe_ttm",
    "pb",
    "ps",
]

import matplotlib.pyplot as plt
import seaborn as sns
input = np.array([zscore(data[col]) for col in columns])
    # zscore(data["close-future"]),
    # zscore(data["turnover_rate_f"]),
    # zscore(data["pe_ttm"]),
    # zscore(data["pb"]),
    # zscore(data["ps"]),
# ])

# 计算相关系数矩阵
correlation_matrix = np.corrcoef(input)

# 设置图形大小
plt.figure(figsize=(10, 8))

# 使用 seaborn 库绘制热图
ax = sns.heatmap(correlation_matrix, annot=True, fmt=".2f", cmap='coolwarm', square=True, cbar_kws={"shrink": .5})

# 添加标题和轴标签
plt.title("Correlation Matrix Heatmap")
ax.set_xticklabels(columns, rotation=45, ha="center")
ax.set_yticklabels(columns, rotation=0)

# 显示图形
plt.show()

KeyError: 'vol'