In [1]:
%matplotlib inline

In [2]:
import myalgotrade

# sample strategy

双均线只做多策略，在10分钟线上进行开仓、平仓操作.
当前价格上穿均线开仓做多,下穿均线平仓.

In [3]:
# %load /Users/jnq/work/commodity/backtest_platform/myalgotrade/strategy/sample.py
from myalgotrade import strategy
from pyalgotrade.technical import ma
import multiprocessing


class SampleStrategy(strategy.StrategyBase):
    def __init__(self, bar_feed, log_path, params, cash):
        super(SampleStrategy, self).__init__(bar_feed, log_path, params, cash_or_brk=cash)
        self.instrument = bar_feed.getDefaultInstrument()  # 合约名
        short_days = int(params['ma_short'])  # 均线参数
        long_days = int(params['ma_long'])
        self.price_ds = bar_feed[self.instrument].getCloseDataSeries()  # 历史价格序列
        self.sma_short = ma.SMA(self.price_ds, short_days)  # 均线序列,随价格序列更新而更新
        self.sma_long = ma.SMA(self.price_ds, long_days)

    # 根据参数返回一个key, 用作log名
    @classmethod
    def get_log_key(cls, params):
        return '-'.join(str(params[key]) for key in sorted(params.keys()))

    # 策略开始
    def on_start(self):
        print 'start cash:', self.getBroker().getCash()

    # 策略结束
    def on_finish(self, bars):
        print 'end cash:', self.getBroker().getCash()

    # 开仓成功
    def on_enter_ok(self, position):
        entry_order = position.getEntryOrder()  # 该仓位的开仓orde
        output = '\t'.join(
            str(i) for i in ('enter', position.getInstrument(), position.getShares(), entry_order.getAvgFillPrice()))
        # self.info(output)

    # 开仓失败
    def on_enter_canceled(self, position):
        print 'enter canceled!'

    # 平仓成功
    def on_exit_ok(self, position):
        exit_order = position.getExitOrder()
        output = '\t'.join(
            str(i) for i in ('exit', position.getInstrument(), position.getShares(), exit_order.getAvgFillPrice()))
        # self.info(output)

    # 平仓失败
    def on_exit_canceled(self, position):
        print 'exit caneled!'
        position.exitMarket()  # 重新平仓

    # 订单状态更新
    def on_order_updated(self, order):
        pass

    # 每个新bar数据调用一次, bars包含相同时间内所有品种的bar
    def on_bars(self, bars):
        # print 'price:', bars[self.instrument].getClose()

        # 数据太少,无法计算均线
        if self.sma_long[-1] is None:
            # print 'not enough bars for sma, skipped'
            return

        # 获取所有仓位
        positions = list(self.getActivePositions())
        if len(positions) > 1:
            raise Exception('we should at most have one position in this strategy.')

        # 价格位于均线之上 且 空仓, 开多仓
        if self.sma_short[-1] > self.sma_long[-1]:
            if len(positions) == 0:
                shares = 1
                self.enterLong(self.instrument, shares, True)  # 市价开仓

        # 价格位于均线之下 且 有仓位, 平仓
        elif self.sma_short[-1] < self.sma_long[-1]:
            if len(positions) == 1:
                position = positions[0]
                if not position.exitActive():
                    position.exitMarket()  # 市价平仓




## 回测运行示例 sql

In [4]:
# coding=utf-8
from datetime import datetime
from myalgotrade.util import dbutil
from myalgotrade.feed import Frequency, feed_manager
from myalgotrade import strategy
import multiprocessing
import pprint

# 参数定义见下个cell
def run_sample_sql(experiment_key, strategy_class, instrument, param, start, end, frequency, before_days=0):

    # 获取该品种的主力合约时间段
    feed_infos = dbutil.get_dominant_contract_infos(instrument, frequency, start, end, before_days) 
    print 'feed infos:'
    pprint.pprint(feed_infos)

    feed_mng = feed_manager.DataServerFeedManager(feed_infos) # feed管理器
    feeds_dict = feed_mng.get_feeds_by_range(start, end) #获取feed
    print 'feeds:'
    pprint.pprint(feeds_dict)

    # log key 用来标识每次回测的log文件，这里采用 experiment_key + instrument
    log_key = strategy.log_path_delimiter.join((experiment_key, instrument))
    result = strategy.run_strategy(strategy_class, feeds_dict, log_key, param, initial_cash=1000000, use_previous_cash=False)

#     process = multiprocessing.Process(target=result.analyze_result.plotEquityCurve, args=(log_key,))
#     process.start()  #notebook中无法开多进程，在其他ide中跑

    return result, log_key

In [5]:

args = dict(
    experiment_key = 'tutorial',            # log标识符
    strategy_class = SampleStrategy,        # 回测的策略类
    param = {'ma_short': 5, 'ma_long':40},  # 策略参数 
    start = datetime(2014, 1, 1),           # 开始时间
    end = datetime(2014, 6, 1),             # 结束时间
    frequency = Frequency.MINUTE * 10,      # 输入bar的频率，这里是10分钟
    before_days = 0,                        # 成为主力合约前多少天开始取数据
    #instrument = 'SR',                     # 商品品种
)

result, log_key = run_sample_sql(instrument='SR', **args)

{'afterday': 0,
 'beforday': 0,
 'commodity': 'SR',
 'dataName': 'domInfo',
 'end': datetime.datetime(2014, 6, 1, 0, 0),
 'start': datetime.datetime(2014, 1, 1, 0, 0)}
feed infos:
{'SR1405': (600,
            datetime.datetime(2014, 1, 2, 0, 0),
            (datetime.datetime(2014, 6, 10, 0, 0), )),
 'SR1409': (600,
            datetime.datetime(2014, 2, 21, 0, 0),
            (datetime.datetime(2014, 6, 10, 0, 0), ))}


TypeError: can't compare datetime.datetime to pyodbc.Row

In [None]:
pprint.pprint(result)

In [None]:
result.show_result()

In [None]:
'''day_summary_log是每日结算, 每天的累计利润
   trade_log是交易记录
   analyze_result是分析结果对象
   sub_records是子记录，这里有两个合约分别的记录
'''
type(result.analyze_result)

In [None]:
result.analyze_result.plotEquityCurve(log_key)

In [None]:
# 挑出其中一个合约
sub_result = result.sub_records['SR1409']
sub_result

In [None]:
# 进行结果分析
sub_result.analyze()
sub_result

In [None]:
sub_result.show_result()

In [None]:
sub_result.analyze_result.plotEquityCurve()

In [None]:
!head $result.trade_log
!head $result.day_summary_log

In [None]:
# 在多个品种上回测， 最好别在notebook中运行，log太长太多，会在以后优化选项
for instrument in ['SR', 'L', 'P', 'M', 'RB', 'RU']:
    result, log_key = run_sample_sql(instrument = str.upper(instrument), **args)
    print log_key
    result = result.analyze_result
    result.plotEquityCurve(log_key)  # notebook中直接画图
#     process = multiprocessing.Process(target=result.plotEquityCurve, args=(log_key,)) # 不是notebook中用多进程画图
#     process.start()
print 'done'



In [None]:
# 查看多品种合成的结果
def get_combine_result(experiment_key, param_key, combine_set):
    results = {}
    for instrument in combine_set:
            log_key = strategy.log_path_delimiter.join((experiment_key, instrument, param_key))
            results[instrument] = strategy.StrategyRecord.construct_by_log_name(log_key)
    result = strategy.combine_result(results, strategy.log_path_delimiter.join((experiment_key, param_key))).analyze_result
    print experiment_key
    print combine_set
    return result

In [None]:
param_key = SampleStrategy.get_log_key(args['param'])
combine_set = ('SR', 'L', 'M', 'P')  #  要合成的品种                               
result = get_combine_result(args['experiment_key'], param_key, combine_set)
result.show_result()
result.plotEquityCurve()