在沙盒环境，使用机器学习构建时序涨跌幅预测模型

In [1]:
import os,pdb,itertools,copy,datetime
os.environ['ULTRON_DATA'] = 'keim'

In [2]:
import pandas as pd
import numpy as np
from ultron.env import *
from ultron.optimize.model.treemodel import XGBTrainer

/var/log/ultron/2022-09-27.log


In [3]:
enable_example_env()

2022-09-27 17:31:37,056 - [env.py:67] - ultron - INFO - enable example env will only read /home/kerry/ultron/rom/sandbox/keim


#### 加载行情数据

In [4]:
market_data = pd.read_csv(os.path.join(g_project_data, 'market_data.csv'), index_col=0)
market_data['trade_date'] = pd.to_datetime(market_data['trade_date'])
market_data.head()

Unnamed: 0,trade_date,code,openPrice,highestPrice,lowestPrice,closePrice,turnoverVol
0,2017-10-27,A,4462.578191,4463.801485,4413.646412,4435.665713,158774
1,2017-10-27,AL,15625.658581,15658.904663,15430.931529,15449.92929,293630
2,2017-10-27,BU,3310.339921,3336.950371,3283.729472,3302.356787,461826
3,2017-10-27,C,2009.751001,2014.561895,2001.331936,2002.53466,375480
4,2017-10-27,CF,20517.496003,20531.174333,20408.069357,20449.104349,84032


#### 选择中因子

In [5]:
sel_factor = pd.read_csv(os.path.join(g_project_data, 'sel_factor.csv'), index_col=0)
sel_factor.head()

Unnamed: 0,factor,window,weekday,bins
0,BM_MainFar_80D,23,5,5
1,BM_MainFar_80D,25,5,5
2,BM_MainFar_80D,27,5,5
3,BM_RecentFar_20D,5,1,5
4,BM_RecentFar_40D,3,1,3


#### 读取因子

In [6]:
total_data = pd.read_csv(os.path.join(g_project_data, 'factor.csv'), index_col=0)
factor_data = total_data[['trade_date','code'] + sel_factor['factor'].unique().tolist()]
factor_data['trade_date'] = pd.to_datetime(factor_data['trade_date'])
factor_data.head()

Unnamed: 0,trade_date,code,BM_MainFar_80D,BM_RecentFar_20D,BM_RecentFar_40D,BM_RecentFar_80D,BM_RecentSecond_20D,BM_RecentSecond_40D,B_FarSpot,B_MainSpot,...,TS_MainFar,TS_RecentFar,TS_RecentSecond,T_DnIntraday_5D,T_DnVolatility_1_10D,T_DnVolatility_2_20D,WeightNetIntTotalChg5D,WeightShortVolRelTotIntChg,inventory,profitratio
0,2017-10-27,A,-0.033259,-0.026646,-0.019436,-0.041974,-0.023047,-0.013509,-0.042729,0.002378,...,-0.05761,-0.079619,-0.104757,-0.00835,-0.007715,-0.002168,-0.000633,-0.037579,,
1,2017-10-27,AL,-0.001423,0.001697,-0.000937,0.000587,0.001133,-0.000539,-0.076121,-0.084726,...,-0.069381,-0.068413,-0.067663,-0.005843,-0.008381,0.000165,-0.000352,-0.012891,-173.600006,-0.005896
2,2017-10-27,BU,-0.016537,0.059635,-0.032271,-0.034618,0.069999,-0.027086,-0.124574,-0.321128,...,-0.102761,-0.124225,-0.159247,-0.005098,-0.009538,0.001268,0.002481,0.275875,,0.055222
3,2017-10-27,C,0.007939,-0.005224,-0.014003,0.025361,-0.001541,-0.007955,-0.006522,0.172635,...,-0.063351,-0.072537,-0.092543,-0.002866,-0.003349,-0.000979,0.002547,0.245555,-369.799988,
4,2017-10-27,CF,-0.02396,0.002346,-0.028774,-0.008043,0.004838,-0.009858,0.05664,0.311288,...,-0.025782,-0.024471,-0.003359,-0.003975,-0.004373,-0.00142,-0.000643,-0.131799,-189.899994,0.123392


In [7]:
features = [col for col in factor_data.columns if col not in ['trade_date','code','inventory','profitratio']]

#### 构建训练集Y值

##### 通过ATR进行涨跌幅判断

In [8]:
def ATR(high, low, close, N=14):
    tmp1 = high - low
    tmp2 = (high - close.shift()).abs()
    tmp3 = (low - close.shift()).abs()
    cond1 = tmp1 > tmp2
    tmp2[cond1] = tmp1[cond1]
    cond2 = tmp2 > tmp3
    tmp3[cond2] = tmp2[cond2]
    TR = tmp3
    ATR = TR.rolling(N).mean()
    return ATR

In [9]:
def create_lable(market_data):
    market_data = market_data.sort_values(by=['trade_date', 'code'],
                                              ascending=True)
    highest_price = market_data.set_index(['trade_date', 'code'
                                               ])['highestPrice'].unstack()
    lowest_price = market_data.set_index(['trade_date', 'code'
                                              ])['lowestPrice'].unstack()
    close_price = market_data.set_index(['trade_date',
                                             'code'])['closePrice'].unstack()
    atr = ATR(high=highest_price,
                       low=lowest_price,
                       close=close_price,
                       N=20)
    signal_dt = (close_price - close_price.shift(1)) / atr
    signal_dt = signal_dt.dropna()
    signal_dt = signal_dt.stack()
    signal_dt.name = 'value'
    signal_dt = signal_dt.reset_index()
    ## 信号判断
    ## 0 横盘  1 上涨 2 下跌
    ## 本例子中采用二分法
    #signal_dt['signal'] = np.where(
    #        signal_dt['value'] > 0.3, 1,
    #        np.where(signal_dt['value'] < -0.3, 2, 0))
    signal_dt['signal'] = np.where(
        signal_dt['value'] > 0, 1, 0
    )

    signal_dt['trade_date'] = pd.to_datetime(signal_dt['trade_date'])
    return signal_dt

In [10]:
lable_data = create_lable(copy.deepcopy(market_data))

In [11]:
#### 偏移1天即 T-1的因子簇对应T的涨跌
lable_data['signal'] = lable_data['signal'].shift(-1)
factor_data = factor_data.merge(lable_data,
                                          on=['trade_date', 'code'])

factor_data = factor_data.dropna(subset=['signal'])

#### 保持数据集 用于其他例子

In [12]:
factor_data.to_csv(os.path.join(g_project_data, 'train_datas.csv'),encoding='UTF-8')

##### 重型机器学习模型，目前框架中包含两类 1. XGBTrainer  2. LGBMTrainer 本例子使用 XGBTrainer

In [13]:
model = XGBTrainer(features=features,
                           objective='binary:logistic',
                           booster='gbtree',
                           tree_method='hist',
                           n_estimators=200)

In [14]:
model.fit(factor_data, factor_data['signal'].values)

[17:32:06] Tree method is selected to be 'hist', which uses a single updater grow_fast_histmaker.
[17:32:06] /workspace/src/tree/updater_prune.cc:74: tree pruning end, 1 roots, 14 extra nodes, 0 pruned nodes, max_depth=3
[17:32:06] /workspace/src/tree/updater_prune.cc:74: tree pruning end, 1 roots, 14 extra nodes, 0 pruned nodes, max_depth=3
[17:32:06] /workspace/src/tree/updater_prune.cc:74: tree pruning end, 1 roots, 14 extra nodes, 0 pruned nodes, max_depth=3
[17:32:06] /workspace/src/tree/updater_prune.cc:74: tree pruning end, 1 roots, 14 extra nodes, 0 pruned nodes, max_depth=3
[17:32:06] /workspace/src/tree/updater_prune.cc:74: tree pruning end, 1 roots, 14 extra nodes, 0 pruned nodes, max_depth=3
[17:32:06] /workspace/src/tree/updater_prune.cc:74: tree pruning end, 1 roots, 14 extra nodes, 0 pruned nodes, max_depth=3
[17:32:06] /workspace/src/tree/updater_prune.cc:74: tree pruning end, 1 roots, 14 extra nodes, 0 pruned nodes, max_depth=3
[17:32:06] /workspace/src/tree/updater_pr

[17:32:06] /workspace/src/tree/updater_prune.cc:74: tree pruning end, 1 roots, 14 extra nodes, 0 pruned nodes, max_depth=3
[17:32:06] /workspace/src/tree/updater_prune.cc:74: tree pruning end, 1 roots, 14 extra nodes, 0 pruned nodes, max_depth=3
[17:32:06] /workspace/src/tree/updater_prune.cc:74: tree pruning end, 1 roots, 14 extra nodes, 0 pruned nodes, max_depth=3
[17:32:06] /workspace/src/tree/updater_prune.cc:74: tree pruning end, 1 roots, 12 extra nodes, 0 pruned nodes, max_depth=3
[17:32:06] /workspace/src/tree/updater_prune.cc:74: tree pruning end, 1 roots, 14 extra nodes, 0 pruned nodes, max_depth=3
[17:32:06] /workspace/src/tree/updater_prune.cc:74: tree pruning end, 1 roots, 14 extra nodes, 0 pruned nodes, max_depth=3
[17:32:06] /workspace/src/tree/updater_prune.cc:74: tree pruning end, 1 roots, 14 extra nodes, 0 pruned nodes, max_depth=3
[17:32:06] /workspace/src/tree/updater_prune.cc:74: tree pruning end, 1 roots, 14 extra nodes, 0 pruned nodes, max_depth=3
[17:32:06] /work

[17:32:06] /workspace/src/tree/updater_prune.cc:74: tree pruning end, 1 roots, 14 extra nodes, 0 pruned nodes, max_depth=3
[17:32:06] /workspace/src/tree/updater_prune.cc:74: tree pruning end, 1 roots, 12 extra nodes, 0 pruned nodes, max_depth=3
[17:32:06] /workspace/src/tree/updater_prune.cc:74: tree pruning end, 1 roots, 14 extra nodes, 0 pruned nodes, max_depth=3
[17:32:06] /workspace/src/tree/updater_prune.cc:74: tree pruning end, 1 roots, 14 extra nodes, 0 pruned nodes, max_depth=3
[17:32:06] /workspace/src/tree/updater_prune.cc:74: tree pruning end, 1 roots, 14 extra nodes, 0 pruned nodes, max_depth=3
[17:32:06] /workspace/src/tree/updater_prune.cc:74: tree pruning end, 1 roots, 14 extra nodes, 0 pruned nodes, max_depth=3


##### 预测结果

##### 随机构造测试集 用于预测

In [15]:
test_data = pd.DataFrame(np.random.randn(1000, len(features)), columns=features)
test_data.head()

Unnamed: 0,BM_MainFar_80D,BM_RecentFar_20D,BM_RecentFar_40D,BM_RecentFar_80D,BM_RecentSecond_20D,BM_RecentSecond_40D,B_FarSpot,B_MainSpot,B_RecentSpot,B_SecondSpot,...,R_UpVolatility_1_40D,R_UpVolatility_1_60D,TS_MainFar,TS_RecentFar,TS_RecentSecond,T_DnIntraday_5D,T_DnVolatility_1_10D,T_DnVolatility_2_20D,WeightNetIntTotalChg5D,WeightShortVolRelTotIntChg
0,-0.030111,-0.85986,0.959281,0.949809,0.487641,0.481177,0.104874,0.725514,-1.609181,-0.021068,...,-1.642276,1.231477,-1.463811,0.339812,-0.703375,0.972938,1.701776,0.175061,-0.437382,0.144127
1,-1.341278,0.365956,0.078692,-0.1718,0.091617,-0.36801,-0.003588,-0.296222,0.671802,0.82208,...,-0.845307,0.965778,0.604448,1.417201,-0.318443,0.830627,0.946477,0.442881,-0.321635,-0.611773
2,-0.824769,-1.955698,0.562405,0.460463,-0.095581,-0.42646,-1.040424,-0.359388,-0.556199,0.348545,...,0.632305,-0.348812,-0.726851,0.744127,-0.468048,1.312655,0.307666,-1.016679,-0.10503,0.251106
3,0.184538,0.995819,0.835242,0.270242,-1.445475,-0.395513,0.013711,0.973669,-1.310532,0.358438,...,-0.283124,-0.881857,0.583492,-1.306986,-0.282947,0.619602,-0.599103,1.682886,0.471341,-0.511075
4,-0.280706,-0.196826,-0.008669,-0.293083,-1.82322,-0.215796,-0.267172,1.535816,-1.365632,-0.124178,...,-1.131305,-0.433448,0.224725,1.456106,-0.794028,-0.018542,0.056244,0.493308,1.587679,0.057096


In [16]:
y = model.predict(test_data)