In [1]:
import joblib
import pandas as pd
from sklearn.preprocessing import MinMaxScaler
from sktime.split import temporal_train_test_split
from sqlalchemy import create_engine

## 读取数据集


In [6]:
db_connection_str = "mysql+pymysql://root:123456@localhost:3306/A_stock?charset=utf8"
db_connection = create_engine(db_connection_str)

# 读取上证指数

sh_data = pd.read_sql(
    "SELECT date, open, high, low, close, volume, amount FROM stockindexhistory WHERE `symbol`='sh000001'",
    con=db_connection,
    index_col="date",
)
# 读取沪深300指数
hs300_data = pd.read_sql(
    "SELECT date, open, high, low, close, volume, amount FROM stockindexhistory WHERE `symbol`='sh000300'",
    con=db_connection,
    index_col="date",
)
# 读取深证成指
sz_data = pd.read_sql(
    "SELECT date, open, high, low, close, volume, amount FROM stockindexhistory WHERE `symbol`='sz399001'",
    con=db_connection,
    index_col="date",
)
# 读取A股指数
a_data = pd.read_sql(
    "SELECT date, open, high, low, close, volume, amount FROM stockindexhistory WHERE `symbol`='sh000002'",
    con=db_connection,
    index_col="date",
)

In [7]:
sh_data.describe()

Unnamed: 0,open,high,low,close,volume,amount
count,8149.0,8149.0,8149.0,8149.0,8149.0,8149.0
mean,2116.190994,2135.031015,2096.073904,2117.553727,107758700.0,120604800000.0
std,1091.388919,1100.54587,1081.390828,1092.602956,131109800.0,163575500000.0
min,96.05,99.98,95.79,99.98,15.0,6000.0
25%,1226.92,1239.33,1210.72,1225.49,6440180.0,4927910000.0
50%,2074.56,2088.63,2061.65,2075.48,56893400.0,47911500000.0
75%,3029.93,3051.36,3008.73,3031.64,164960000.0,177817000000.0
max,6057.43,6124.04,6040.71,6092.06,857133000.0,1309920000000.0


In [8]:
# 规范 volume  为十万单位，amount 为亿单位
sh_data["volume"] = sh_data["volume"] / 100000
sh_data["amount"] = sh_data["amount"] / 100000000
sh_data.describe()

Unnamed: 0,open,high,low,close,volume,amount
count,8149.0,8149.0,8149.0,8149.0,8149.0,8149.0
mean,2116.190994,2135.031015,2096.073904,2117.553727,1077.587053,1206.047754
std,1091.388919,1100.54587,1081.390828,1092.602956,1311.098173,1635.755027
min,96.05,99.98,95.79,99.98,0.00015,6e-05
25%,1226.92,1239.33,1210.72,1225.49,64.4018,49.2791
50%,2074.56,2088.63,2061.65,2075.48,568.934,479.115
75%,3029.93,3051.36,3008.73,3031.64,1649.6,1778.17
max,6057.43,6124.04,6040.71,6092.06,8571.33,13099.2


In [9]:
stock_market = [sh_data, hs300_data, sz_data, a_data]
features_one = ["open", "high", "low", "volume"]
stock_name = ["sh000001", "sh000300", "sz399001", "sh000002"]

In [8]:
for name, stock in zip(stock_name, stock_market):
    X_train, X_test = temporal_train_test_split(stock[features_one], test_size=0.2)
    y_train, y_test = temporal_train_test_split(stock["close"], test_size=0.2)
    X_scaler = MinMaxScaler()
    X_train = X_scaler.fit(X_train.to_numpy())
    y_scaler = MinMaxScaler()
    y_train = y_scaler.fit(y_train.to_numpy().reshape(-1, 1))

    # 保存模型
    joblib.dump(X_scaler, f"./scalers/{name}_X_scaler.pkl")
    joblib.dump(y_scaler, f"./scalers/{name}_y_scaler.pkl")

In [4]:
df = stock_market[0]

In [5]:
df.head()

Unnamed: 0_level_0,open,high,low,close,volume,amount
date,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1
1990-12-19,96.05,99.98,95.79,99.98,1260.0,494000.0
1990-12-20,104.3,104.39,99.98,104.39,197.0,84000.0
1990-12-21,109.07,109.13,103.73,109.13,28.0,16000.0
1990-12-24,113.57,114.55,109.13,114.55,32.0,31000.0
1990-12-25,120.09,120.25,114.55,120.25,15.0,6000.0


In [10]:
X_train, X_test = temporal_train_test_split(df[features_one], test_size=0.2)

In [11]:
X_scaler = MinMaxScaler()
X_train = X_scaler.fit_transform(X_train)
X_train

array([[0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 1.45251673e-06],
       [1.38390775e-03, 7.32064422e-04, 7.04803429e-04, 2.12335779e-07],
       [2.18405805e-03, 1.51890917e-03, 1.33559409e-03, 1.51668414e-08],
       ...,
       [5.32373377e-01, 5.28304831e-01, 5.28247647e-01, 2.80802383e-01],
       [5.27037364e-01, 5.22096393e-01, 5.22287937e-01, 3.06793683e-01],
       [5.21689609e-01, 5.21254768e-01, 5.23177772e-01, 2.22074040e-01]])