In [None]:
from __future__ import division
import pandas as pd
import warnings
import numpy as np
from datetime import datetime
import matplotlib.pyplot as plt
import talib as ta
from talib import abstract


from vnpy.trader.vtConstant import *
from vnpy.trader.app.ctaStrategy import CtaTemplate
import talib as ta

########################################################################
# 策略继承CtaTemplate
class MultiSignalStrategy(CtaTemplate):
    className = 'MultiSignalStrategy'
    author = 'Qiudaokai'

    # 策略参数
    barPeriod = 200
    short_roc_period = 7; roc_10_ma1_period = 6; roc_10_ma2_period = 18
    signalMaPeriod = 20; fastPeriod = 6; slowPeriod = 22
    roc_period = 40; roc_ma1_period = 11;roc_ma2_period = 27
    cciPeriod = 10; cciThrehold = 30
    
    filterCanTrade = 0; volPeriod = 60; lowVolThrehold = 0.001
    
    # 风控参数
    trailingPct = 0.04
    stopRatio = 0.04
    nPos = 0
    fixsize = 3
    lot = 3
    stopLossPct = 0.04; takeProfitPct = 0.3
    
    # 仓位管理
    addPct = 0.005; addMultipler = 1
    
    # CTALog参数
    short_roc_check = 0
    long_roc_check = 0
    maTrend_check = 0
    cci_check = 0

    # 策略变量
    transactionPrice = {} # 记录成交价格
    CCI = {};ROC_MA = {}
    stopProtect = {}
    shortStop = {};longStop = {}
    
    
    # 参数列表
    paramList = [
                'short_roc_period','roc_period','roc_10_ma1_period','roc_10_ma2_period',
                'fastPeriod','slowPeriod','cciPeriod','cciThrehold','roc_ma1_period','roc_ma2_period',
                'trailingPct', 'stopRatio', 'volPeriod', 'lowVolThrehold', 'signalMaPeriod']    
    
    # 变量列表
    varList = ['barPeriod','transactionPrice','intraTradeHighDict', 'intraTradeLowDict',
              'shortStop', 'longStop','CCI', 'ROC_MA','short_roc_check','long_roc_check','maTrend_check','cci_check']  
#     
    
    # 同步列表，保存了需要保存到数据库的变量名称
    syncList = ['posDict', 'eveningDict']

    #----------------------------------------------------------------------
    def __init__(self, ctaEngine, setting):
        # 首先找到策略的父类（就是类CtaTemplate），然后把DoubleMaStrategy的对象转换为类CtaTemplate的对象
        super().__init__(ctaEngine, setting)
      
    #----------------------------------------------------------------------
    def onInit(self):
        """初始化策略"""
        self.writeCtaLog(u'策略初始化')
        self.transactionPrice = {s:0 for s in self.symbolList}
        self.intraTradeHighDict = {s:0 for s in self.symbolList}
        self.intraTradeLowDict = {s:999999 for s in self.symbolList}
        self.ROC_MA = {s:0 for s in self.symbolList}
        self.CCI = {s:0 for s in self.symbolList}
        self.transactionPrice = {s:0 for s in self.symbolList} # 生成成交价格的字典
        self.longStop = {s:0 for s in self.symbolList}
        self.shortStop = {s:999999 for s in self.symbolList}
        self.mail('策略启动')
        self.putEvent()

    #----------------------------------------------------------------------
    def onStart(self):
        """启动策略（必须由用户继承实现）"""
        self.writeCtaLog(u'策略启动')
        self.putEvent()
    
    #----------------------------------------------------------------------
    def onStop(self):
        """停止策略"""
        self.writeCtaLog(u'策略停止')
        self.putEvent()
        
    #----------------------------------------------------------------------
    def onTick(self, tick):
        """收到行情TICK推送"""
        pass
        
    #----------------------------------------------------------------------

    def onBar(self, bar):
        """收到Bar推送（必须由用户继承实现）"""
        symbol = bar.vtSymbol
        self.writeCtaLog('MultiSignalStrategy####5S####posDict:%s####'%(self.posDict))
        self.onBarRiskControl(bar)
        self.onBarPosition(bar)
        
    def onBarRiskControl(self, bar):
        symbol = bar.vtSymbol
        lastOrder = self.transactionPrice[symbol]
        buyTakeProfitPrice = self.transactionPrice[symbol] * (1 + self.takeProfitPct)
        sellTakeProfitPrice = self.transactionPrice[symbol] * (1 - self.takeProfitPct)
        
        if self.posDict[symbol+'_LONG'] == 0 and self.posDict[symbol+'_SHORT'] == 0:
            self.intraTradeHighDict[symbol] = 0
            self.intraTradeLowDict[symbol] = 999999
            self.longStop = 0
            self.shortStop = 999999
        elif self.posDict[symbol+'_LONG'] >0:
            self.intraTradeHighDict[symbol] = max(self.intraTradeHighDict[symbol], bar.high)
            self.longStop = self.intraTradeHighDict[symbol] * (1 - self.trailingPct)
            if bar.close <= self.longStop:
                self.writeCtaLog("LONG_stopLoss")
                self.mail('多头止损，已平多')
                self.cancelAll()
                self.sell(symbol, bar.close*0.98, self.posDict[symbol+'_LONG'])
            elif bar.close > buyTakeProfitPrice:
                self.writeCtaLog("LONG_takeProfit")
                self.mail('多头止盈，已平多')
                self.cancelAll()
                self.sell(symbol, bar.close*0.98, self.posDict[symbol+'_LONG'])
            self.writeCtaLog('longStop%s'%(self.longStop))
        elif self.posDict[symbol+'_SHORT'] >0:
            self.intraTradeLowDict[symbol] = min(self.intraTradeLowDict[symbol], bar.low)
            self.shortStop = self.intraTradeLowDict[symbol] * (1 + self.trailingPct)
            if bar.close >= self.shortStop:
                self.writeCtaLog("SHORT_stopLoss")
                self.mail('空头止损，已平空')
                self.cancelAll()
                self.cover(symbol, bar.close*1.02, self.posDict[symbol+'_SHORT'])
            elif bar.close < sellTakeProfitPrice:
                self.writeCtaLog("SHORT_takeProfit")
                self.mail('空头止盈，已平空')
                self.cancelAll()
                self.cover(symbol, bar.close*1.02, self.posDict[symbol+'_SHORT'])
            self.writeCtaLog('shortStop%s'%(self.shortStop))
            
        self.putEvent()
        
    def onBarPosition(self,bar):
        symbol = bar.vtSymbol
        lastOrder = self.transactionPrice[symbol]
        
        if (self.posDict[symbol+'_LONG']==0) and (self.posDict[symbol + "_SHORT"]==0):
            self.nPos=0
        elif (self.posDict[symbol+'_LONG'] > 0 and self.nPos < 1):
            if (bar.close/lastOrder-1) >= self.addPct:
                self.nPos += 1
                self.buy(symbol, bar.close*1.02, int(self.lot*(self.addMultipler**self.nPos)))
                self.writeCtaLog('LongAddPosition')
                self.mail('多头加仓')
        elif (self.posDict[symbol + "_SHORT"] > 0 and self.nPos < 1):
            if (lastOrder/bar.close-1) >= self.addPct:
                self.nPos += 1
                self.short(symbol, bar.close*0.98, int(self.lot*(self.addMultipler**self.nPos)))
                self.writeCtaLog('ShortAddPosition')
                self.mail('空头加仓')
                
        self.putEvent()
    
    def on5MinBar(self, bar):
        symbol = bar.vtSymbol
        am5 = self.getArrayManager(symbol, "5m")

        if not am5.inited:
            return

        std = ta.STDDEV(am5.close, self.volPeriod)
        atr = ta.ATR(am5.high, am5.low, am5.close, self.volPeriod)
        rangeHL = ta.MAX(am5.high, self.volPeriod) - ta.MIN(am5.low, self.volPeriod)
        minVol = min(std[-1], atr[-1], rangeHL[-1])
        lowFilterRange = am5.close[-1] * self.lowVolThrehold
        
        if (minVol >= lowFilterRange):
            self.filterCanTrade = 1
        else:
            self.filterCanTrade = -1
        
        self.writeCtaLog(u'Vol_filter_check: %s'%(self.filterCanTrade))
        self.putEvent()
      
    #----------------------------------------------------------------------
        
    def on15MinBar(self, bar):
        symbol = bar.vtSymbol
        am15 = self.getArrayManager(symbol, "15m")
        
        if not am15.inited:
            return
        
        ROC_10 = ta.ROC(am15.close, self.short_roc_period)
        roc_10_ma1 = ta.MA(ROC_10, self.roc_10_ma1_period)
        roc_10_ma2 = ta.MA(ROC_10, self.roc_10_ma2_period)
        
        if (roc_10_ma1[-1] > roc_10_ma2[-1]) and (roc_10_ma1[-2] <= roc_10_ma2[-2]):
            self.ROC_MA[symbol] = 1
            self.short_roc_check = 1
        elif (roc_10_ma1[-1] < roc_10_ma2[-1]) and (roc_10_ma1[-2] >= roc_10_ma2[-2]):
            self.ROC_MA[symbol] = -1
            self.short_roc_check = -1
        else:
            self.ROC_MA[symbol] = 0
            self.short_roc_check = 0
        
        self.writeCtaLog(u'short_roc_ma1: %s, short_roc_ma2: %s, short_roc_check: %s'%(roc_10_ma1[-2:],roc_10_ma2[-2:],self.short_roc_check))
#         # 发出状态更新事件
        self.putEvent()
    
    def on30minBar(self, bar):
        symbol = bar.vtSymbol
        am30 = self.getArrayManager(symbol, "30m")
        
        if not am30.inited:
            return
        
        cci = ta.CCI(am30.high, am30.low, am30.close, self.cciPeriod)
        
        if cci[-1] > self.cciThrehold:
            self.CCI[symbol] = 1
            self.cci_check = 1
        elif cci[-1] < -self.cciThrehold:
            self.CCI[symbol] = -1
            self.cci_check = -1
        else:
            self.CCI[symbol] = 0
            self.cci_check = 0
        
        self.writeCtaLog(u'cci: %s, cci_check: %s'%(cci[-1], self.cci_check))
        
        ROC = ta.ROC(am30.close, self.roc_period)
        
        fastMa = ta.MA(am30.close, self.fastPeriod)
        slowMa = ta.MA(am30.close, self.slowPeriod)
        
        if (fastMa[-1] > slowMa[-1]) and (fastMa[-2] < slowMa[-2]):
            self.ROC_MA[symbol] += 1
            self.maTrend_check = 1
        elif (fastMa[-1] < slowMa[-1]) and (fastMa[-2] > slowMa[-2]):
            self.ROC_MA[symbol] += -1
            self.maTrend_check = -1
        else:
            self.ROC_MA[symbol] += 0
            self.maTrend_check = 0
        
        self.writeCtaLog(u'fastMa: %s, slowMa: %s, maTrend_check: %s'%(fastMa[-2:],slowMa[-2:],self.maTrend_check))
        
        roc_ma1 = ta.MA(ROC, self.roc_ma1_period)
        roc_ma2 = ta.MA(ROC, self.roc_ma2_period)
        
        # 现象条件
        if (roc_ma1[-1] > roc_ma2[-1]) and (roc_ma1[-2] <= roc_ma2[-2]):
            self.long_roc_check = 1
        elif (roc_ma1[-1] < roc_ma2[-1]) and (roc_ma1[-2] >= roc_ma2[-2]):
            self.long_roc_check = -1
        else:
            self.long_roc_check = 0
        
        breakUp = (roc_ma1[-1] > roc_ma2[-1]) and (roc_ma1[-2] <= roc_ma2[-2])
        breakDn = (roc_ma1[-1] < roc_ma2[-1]) and (roc_ma1[-2] >= roc_ma2[-2])
        
        self.writeCtaLog(u'long_roc_ma1: %s, long_roc_ma2: %s, long_roc_check: %s'%(roc_ma1[-2:],roc_ma2[-2:],self.long_roc_check))
        
        Signal = self.ROC_MA[symbol] + self.CCI[symbol]
        
        # 进出场条件
        if Signal >= 2 and breakUp and self.filterCanTrade == 1 and self.posDict[symbol + "_LONG"] == 0:
            if self.posDict[symbol + "_SHORT"] == 0:
                self.buy(symbol, bar.close * 1.02, self.lot)
                self.mail('开多')
            elif self.posDict[symbol + "_SHORT"] > 0:
                self.cancelAll()
                self.cover(symbol, bar.close * 1.02, self.posDict[symbol + "_SHORT"])
                self.buy(symbol, bar.close * 1.02, self.lot)
                self.mail('平空开多')
        elif self.CCI[symbol] == -1 and breakDn and self.filterCanTrade == 1 and self.posDict[symbol + "_SHORT"] == 0:
            if self.posDict[symbol + "_LONG"] == 0:
                self.short(symbol, bar.close * 0.98, self.lot)
                self.mail('开空')
            elif self.posDict[symbol + "_LONG"] > 0:
                self.cancelAll()
                self.sell(symbol, bar.close * 0.98, self.posDict[symbol + "_LONG"])
                self.short(symbol, bar.close *0.98, self.lot)
                self.mail('平多开空')
        elif self.filterCanTrade == -1:
            if self.posDict[symbol+'_LONG'] > 0:
                self.cancelAll()
                self.nPos = 0
                self.sell(symbol, bar.close*0.98, self.posDict[symbol+'_LONG'])
                self.mail('平多')
            elif self.posDict[symbol+'_SHORT'] > 0:
                self.cancelAll()
                self.nPos = 0
                self.cover(symbol, bar.close*1.02, self.posDict[symbol+'_SHORT'])
                self.mail('平空')
        
        # 发出状态更新事件
        self.putEvent()
        
    #----------------------------------------------------------------------
    def onOrder(self, order):
        """收到委托变化推送"""
        # 对于无需做细粒度委托控制的策略，可以忽略onOrder
        pass
    
    #----------------------------------------------------------------------
    def onTrade(self, trade):
        """收到成交推送"""
        symbol = trade.vtSymbol
        if trade.offset == OFFSET_OPEN:  # 判断成交订单类型
            self.transactionPrice[symbol] = trade.price # 记录成交价格
            self.mail('已成交，成交价格为：%s' % trade.price)
    #----------------------------------------------------------------------
    def onStopOrder(self, so):
        """停止单推送"""
        pass
