In [1]:
import numpy as np
import pandas as pd
pd.set_option('display.max_rows', 10)
pd.set_option('display.max_columns', 1000)
import datetime
import matplotlib.pylab as plt
%matplotlib inline
import seaborn as sns
sns.set_style('whitegrid')
import time
import os
import copy

#### 定义获取数据函数

In [2]:
def fix_data(path):
    tmp = pd.read_csv(path, encoding="gbk", engine='python')
    tmp.rename(columns={'Unnamed: 0':'trading_time'}, inplace=True)
    tmp['trading_point'] = pd.to_datetime(tmp.trading_time)
    del tmp['trading_time']
    tmp.set_index(tmp.trading_point, inplace=True)
    return tmp

def High_2_Low(tmp, freq):
    """处理从RiceQuant下载的分钟线数据，
    从分钟线数据合成低频数据
    2017-08-11    
    """
    # 分别处理bar数据
    tmp_open = tmp['open'].resample(freq).ohlc()
    tmp_open = tmp_open['open'].dropna()

    tmp_high = tmp['high'].resample(freq).ohlc()
    tmp_high = tmp_high['high'].dropna()

    tmp_low = tmp['low'].resample(freq).ohlc()
    tmp_low = tmp_low['low'].dropna()

    tmp_close = tmp['close'].resample(freq).ohlc()
    tmp_close = tmp_close['close'].dropna()

    tmp_price = pd.concat([tmp_open, tmp_high, tmp_low, tmp_close], axis=1)
    
    # 处理成交量
    tmp_volume = tmp['volume'].resample(freq).sum()
    tmp_volume.dropna(inplace=True)
    
    return pd.concat([tmp_price, tmp_volume], axis=1)

#### 处理数据

In [3]:
from Talib_calc import *

tmp = fix_data('hs300.csv')

# targets 1d 数据合成
tmp_1d = High_2_Low(tmp, '1d')
rolling = 88
targets = tmp_1d
targets['returns'] =  targets['close'].shift(-2) / targets['close'] - 1.0
targets['upper_boundary']= targets.returns.rolling(rolling).mean() + 0.5 * targets.returns.rolling(rolling).std()
targets['lower_boundary']= targets.returns.rolling(rolling).mean() - 0.5 * targets.returns.rolling(rolling).std()
targets.dropna(inplace=True)
targets['labels'] = 1
targets.loc[targets['returns']>=targets['upper_boundary'], 'labels'] = 2
targets.loc[targets['returns']<=targets['lower_boundary'], 'labels'] = 0

# factors 1d 数据合成
tmp_1d = High_2_Low(tmp, '1d')
Index = tmp_1d.index
High = tmp_1d.high.values
Low = tmp_1d.low.values
Close = tmp_1d.close.values
Open = tmp_1d.open.values
Volume = tmp_1d.volume.values
factors = get_factors(Index, Open, Close, High, Low, Volume, rolling = 26, drop=True)

factors = factors.loc[:targets.index[-1]]

tmp_factors_1 = factors.iloc[:12]
targets = targets.loc[tmp_factors_1.index[-1]:]

gather_list = np.arange(factors.shape[0])[11:]

print(tmp)

                 open   close    high     low      volume trading_point
trading_point                                                          
2014-01-02     3.1274  3.2937  3.3148  3.1227  51073320.0    2014-01-02
2014-01-03     3.2797  3.3406  3.3945  3.2633  61031636.0    2014-01-03
2014-01-06     3.2797  3.2117  3.3265  3.1860  46161948.0    2014-01-06
2014-01-07     3.1860  3.2235  3.2516  3.1860  24545424.0    2014-01-07
2014-01-08     3.2281  3.3242  3.3382  3.2281  37178264.0    2014-01-08
...               ...     ...     ...     ...         ...           ...
2016-12-26     4.8251  4.9330  4.9428  4.7368  47045690.0    2016-12-26
2016-12-27     5.0016  4.9526  5.0899  4.9330  79567159.0    2016-12-27
2016-12-28     4.9232  4.8741  5.0016  4.8349  58147107.0    2016-12-28
2016-12-29     4.8545  4.8251  4.8741  4.7859  31178446.0    2016-12-29
2016-12-30     4.8251  4.8153  4.8545  4.7760  30835028.0    2016-12-30

[733 rows x 6 columns]


#### 转换数据

In [4]:
print(factors.shape)

inputs = np.array(factors).reshape(-1, 1, factors.shape[1])

def dense_to_one_hot(labels_dense):
    """标签 转换one hot 编码
    输入labels_dense 必须为非负数
    2016-11-21
    """
    num_classes = len(np.unique(labels_dense)) # np.unique 去掉重复函数
    raws_labels = labels_dense.shape[0]
    index_offset = np.arange(raws_labels) * num_classes
    labels_one_hot = np.zeros((raws_labels, num_classes))
    labels_one_hot.flat[index_offset + labels_dense.ravel()] = 1
    return labels_one_hot  

targets = dense_to_one_hot(targets['labels'])
targets = np.expand_dims(targets, axis=1)

print(factors.shape[1])
print(inputs.shape)

NameError: name 'inputs' is not defined

#### 训练模型

In [None]:
from Classifier_PonderDNC_BasicLSTM_L3 import *

op1 = Classifier_PonderDNC_BasicLSTM_L3(
    inputs= inputs, 
    targets= targets, 
    gather_list= gather_list, 
    hidden_size= 50, 
    memory_size= 50, 
    pondering_coefficient= 1e-2, 
    learning_rate= 1e-2)

In [None]:
op1.fit(training_iters = 100,
        display_step = 10,
        save_path = "model/ResidualPonderDNC_1.ckpt")

In [None]:
op1.close()

#### second

In [None]:
tf.reset_default_graph()
op2 = Classifier_PonderDNC_BasicLSTM_L3(
    inputs= inputs, 
    targets= targets, 
    gather_list= gather_list, 
    hidden_size= 50, 
    memory_size= 50, 
    pondering_coefficient= 1e-2, 
    learning_rate= 1e-3)

op2.fit(training_iters = 100,
        display_step = 10,
        save_path = "model/ResidualPonderDNC_1.ckpt",
        restore_path = "model/ResidualPonderDNC_1.ckpt")

#### third

In [None]:
tf.reset_default_graph()
op3 = Classifier_PonderDNC_BasicLSTM_L3(
    inputs= inputs, 
    targets= targets, 
    gather_list= gather_list, 
    hidden_size= 50, 
    memory_size= 50, 
    pondering_coefficient= 1e-2, 
    learning_rate= 1e-3)

op3.fit(training_iters = 100,
        display_step = 10,
        save_path = "model/ResidualPonderDNC_1.ckpt",
        restore_path = "model/ResidualPonderDNC_1.ckpt")

In [None]:
tf.reset_default_graph()
op4 = Classifier_PonderDNC_BasicLSTM_L3(
    inputs= inputs, 
    targets= targets, 
    gather_list= gather_list, 
    hidden_size= 50, 
    memory_size= 50, 
    pondering_coefficient= 1e-2, 
    learning_rate= 1e-3)

op4.fit(training_iters = 50,
        display_step = 10,
        save_path = "model/ResidualPonderDNC_1.ckpt",
        restore_path = "model/ResidualPonderDNC_1.ckpt")

In [None]:
op4.close()

In [None]:
tf.reset_default_graph()
op5 = Classifier_PonderDNC_BasicLSTM_L3(
    inputs= inputs, 
    targets= targets, 
    gather_list= gather_list, 
    hidden_size= 50, 
    memory_size= 50, 
    pondering_coefficient= 1e-1, 
    learning_rate= 1e-3)

op5.fit(training_iters = 50,
        display_step = 10,
        save_path = "model/ResidualPonderDNC_2.ckpt",
        restore_path = "model/ResidualPonderDNC_1.ckpt")

In [None]:
op5.close()

In [None]:
tf.reset_default_graph()
op6 = Classifier_PonderDNC_BasicLSTM_L3(
    inputs= inputs, 
    targets= targets, 
    gather_list= gather_list, 
    hidden_size= 50, 
    memory_size= 50, 
    pondering_coefficient= 1e-1, 
    learning_rate= 1e-4)

op6.fit(training_iters = 100,
        display_step = 10,
        save_path = "model/ResidualPonderDNC_3.ckpt",
        restore_path = "model/ResidualPonderDNC_2.ckpt")

In [None]:
tf.reset_default_graph()
op7 = Classifier_PonderDNC_BasicLSTM_L3(
    inputs= inputs, 
    targets= targets, 
    gather_list= gather_list, 
    hidden_size= 50, 
    memory_size= 50, 
    pondering_coefficient= 1e-1, 
    learning_rate= 1e-4)

op7.fit(training_iters = 100,
        display_step = 10,
        save_path = "model/ResidualPonderDNC_4.ckpt",
        restore_path = "model/ResidualPonderDNC_3.ckpt")

In [None]:
model = op7

#### 设置回测框架

In [None]:
import rqalpha.api as rqa
from rqalpha import run_func

def init(context):
    context.contract = '601933.XSHG'
    context.BarSpan = 200
    context.TransactionRate = '1d'
    context.DataFields = ['datetime', 'open', 'close','high', 'low', 'volume']
    context.DefineQuantity = 5 
    context.func_get_factors = get_factors
    context.model_classifier = model

In [None]:
def handle_bar(context, bar_dict):
    
    # 合约池代码 
    contract = context.contract
    #rqa.logger.info('------------------------------------')
    #timepoint = rqa.history_bars(contract, 1, '1d', 'datetime')[0]  
    #timepoint = pd.to_datetime(str(timepoint))
    #timepoint = rqa.get_next_trading_date(timepoint)
    #rqa.logger.info (timepoint)    
    
    # 获取合约报价
    Quotes = rqa.history_bars(
        order_book_id= contract, 
        bar_count= context.BarSpan, 
        frequency= context.TransactionRate,
        fields= context.DataFields)
    Quotes = pd.DataFrame(Quotes)
    #print(Quotes)

    # 计算技术分析指标
    tmp_factors = context.func_get_factors(
        index= pd.to_datetime(Quotes['datetime']), 
        Open= Quotes['open'].values, 
        Close= Quotes['close'].values, 
        High= Quotes['high'].values, 
        Low= Quotes['low'].values, 
        Volume=Quotes['volume'].values,
        drop=True)   
    inputs = np.expand_dims(np.array(tmp_factors), axis=1)    
    
    # 模型预测
    probability, classification = context.model_classifier.pred(inputs)
    flag = classification[-1][0]
    rqa.logger.info(str(flag))
    #print (flag)
    
    # 绘制估计概率
    rqa.plot("估空概率", probability[-1][0][0])
    rqa.plot("振荡概率", probability[-1][0][1])
    rqa.plot("估多概率", probability[-1][0][2])
    
    # 获取仓位
    cur_position = context.portfolio.accounts['STOCK'].positions

    # 卖出
    if flag == 0:
        rqa.logger.info ('沽空')
        rqa.order_target_percent(contract, 0)            
            
    # 买入
    if flag == 2:
        rqa.logger.info ('沽多')
        rqa.order_target_percent(contract, 1)      
            
    if flag == 1:
        '''
        rqa.logger.info ('振荡区间')
        if tmp_sell_quantity > 0:
            rqa.buy_close(contract, tmp_sell_quantity)
            rqa.logger.info ('平空单')
        if tmp_buy_quantity > 0:
            rqa.sell_close(contract, tmp_buy_quantity)
            rqa.logger.info ('平多单')
        else:
            rqa.logger.info ('空仓规避')
        '''

In [None]:
start_date = '2016-01-01'
end_date = '2017-01-01'
accounts = {'stock':1e5}

config = {
    'base':{'start_date':start_date, 'end_date':end_date, 'accounts':accounts},
    'extra':{'log_level':'info'},
    'mod':{'sys_analyser':{'enabled':True, 'plot':True}}
}

results = run_func(init=init, handle_bar=handle_bar, config=config)

In [None]:
start_date = '2017-01-01'
end_date = '2017-8-27'
accounts = {'stock':1e5}

config = {
    'base':{'start_date':start_date, 'end_date':end_date, 'accounts':accounts},
    'extra':{'log_level':'info'},
    'mod':{'sys_analyser':{'enabled':True, 'plot':True}}
}

results = run_func(init=init, handle_bar=handle_bar, config=config)