In [14]:
pip install mysql-connector-python

Collecting mysql-connector-python
  Downloading mysql_connector_python-9.1.0-cp312-cp312-macosx_13_0_arm64.whl.metadata (6.0 kB)
Downloading mysql_connector_python-9.1.0-cp312-cp312-macosx_13_0_arm64.whl (15.1 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m15.1/15.1 MB[0m [31m18.6 MB/s[0m eta [36m0:00:00[0ma [36m0:00:01[0m
[?25hInstalling collected packages: mysql-connector-python
Successfully installed mysql-connector-python-9.1.0
Note: you may need to restart the kernel to use updated packages.


In [15]:
import mysql.connector
conn = mysql.connector.connect(
    host="localhost",
    database="stock_history",
    user="username",
    password="password"
)
cursor = conn.cursor()
cursor.execute("SELECT date, open, high, low, close FROM stock_data WHERE symbol='AAPL'")

DatabaseError: 2003 (HY000): Can't connect to MySQL server on 'localhost:3306' (61)

In [7]:
from rich.progress import Progress
from datetime import timedelta 
import pandas as pd
import numpy as np
import tushare as ts
pro = ts.pro_api('0384cfaf0d0c27fdc6308cdc3a077f46e8475a716cc8dc3280f94fff')

In [8]:
# 给定一个交易日，返回该日满足条件的A股股票列表
def get_stocklist(date: str, num: int):

    start = str(pd.to_datetime(date)-timedelta(30))
    start = start[0:4]+start[5:7]+start[8:10]
    df1 = pro.index_weight(index_code='000002.SH',
                           start_date=start, end_date=date)  # 交易日当天的股票列表
    codes = list(df1['con_code'])
    codes = codes[0:1000]  # 在每个截面期只选取1000只股票

    return codes

# 给定日期区间的端点，输出期间的定长采样交易日列表


def get_datelist(start: str, end: str, interval: int):

    df = pro.index_daily(ts_code='399300.SZ', start_date=start, end_date=end)
    date_list = list(df.iloc[::-1]['trade_date'])
    sample_list = []
    for i in range(len(date_list)):
        if i % interval == 0:
            sample_list.append(date_list[i])

    return sample_list

# 返回两个，一个是前30个交易日的9个指标面板（9*30），一个是未来10天的收益率


def get_x_y(instrument_id: str, timestamp: datetime, look_back: int, look_forward: int, 
            df: pd.DataFrame):
    """Get features (X) and target (y) for a given timestamp"""
    # Get past data for features
    past_data = df[df['trade_time'] <= timestamp].sort_values('trade_time', ascending=False)
    # Get future data for returns
    future_data = df[df['trade_time'] > timestamp].sort_values('trade_time')
    
    if (past_data.shape[0] >= look_back) and (future_data.shape[0] >= look_forward):
        # Get price and volume features
        features = past_data.iloc[0:look_back][[
            'last_price',          # 1
            'highest_price',       # 2
            'lowest_price',        # 3
            'cum_volume',          # 4
            'cum_turnover',        # 5
            'bid_price1',          # 6
            'bid_volume1',         # 7
            'ask_price1',          # 8
            'ask_volume1'          # 9
        ]].fillna(0)
        
        # Calculate future return
        future_prices = future_data['last_price'].iloc[0:look_forward]
        if len(future_prices) >= look_forward:
            ret = future_prices.iloc[-1]/future_prices.iloc[0] - 1
            return features.iloc[::-1].T.values, ret
            
    return None, None


def get_length(date: str, pass_day: int, future_day: int):
    start = str(pd.to_datetime(date)-timedelta(pass_day*2))
    start = start[0:4]+start[5:7]+start[8:10]
    end = str(pd.to_datetime(date)+timedelta(future_day*2))
    end = end[0:4]+end[5:7]+end[8:10]
    len_1 = pro.index_daily(ts_code='399300.SZ',
                            start_date=start, end_date=date).shape[0]
    len_2 = pro.index_daily(ts_code='399300.SZ',
                            start_date=date, end_date=end).shape[0]
    return len_1, len_2

# 构造数据集的函数：输入一个时间区间的端点，得到该区间内采样交易日期的所有数据


def get_dataset(num: int, start: str, end: str, interval: int, pass_day: int, future_day: int):
    X_train = []
    y_train = []
    trade_date_list = get_datelist(start, end, interval)
    # 添加进度条
    with Progress() as progress:
        task_date = progress.add_task(
            "[red]Date...", total=len(trade_date_list))
        for date in trade_date_list:
            # 更新进度条
            progress.update(task_date, advance=1)
            stock_list = get_stocklist(date, num)
            len1, len2 = get_length(date, pass_day, future_day)
            task_stock = progress.add_task(
                "[green]Stock...", total=len(range(len(stock_list))))
            for i in range(len(stock_list)):
                # 更新进度条
                progress.update(task_stock, advance=1)
                code = stock_list[i]
                x, y = get_x_y(code, date, pass_day, future_day, len1, len2)
                try:
                    if (x.shape[0] == 9) & (x.shape[1] == pass_day):
                        X_train.append(x)
                        y_train.append(y)
                except Exception:
                    continue
    return X_train, y_train

In [10]:
# 参数设定：使用过去30天的数据预测未来10天的收益率，回归问题
X_train, y_train = get_dataset(
    num=1000, start='20220101', end='20220630', interval=10, pass_day=30, future_day=10)
X_test, y_test = get_dataset(num=1000, start='20220931',
                             end='20221231', interval=10, pass_day=30, future_day=10)
print("there are in total", len(X_train), "training samples")
print("there are in total", len(X_test), "testing samples")
# 将数据保存到本地供离线训练
Xa = np.array(X_train)
ya = np.array(y_train)
Xe = np.array(X_test)
ye = np.array(y_test)
np.save('./X_train.npy', Xa)
np.save('./y_train.npy', ya)
np.save('./X_test.npy', Xe)
np.save('./y_test.npy', ye)

Exception: 抱歉，您没有接口访问权限，权限的具体详情访问：https://tushare.pro/document/1?doc_id=108。