In [None]:
from __future__ import division

from vnpy.trader.vtObject import VtBarData
from vnpy.trader.vtConstant import EMPTY_STRING
from vnpy.trader.app.ctaStrategy.ctaTemplate import (CtaTemplate,
                                                     BarGenerator,
                                                     ArrayManager)
import talib as ta


########################################################################
class RegressionStrategy(CtaTemplate):
    className = 'RegressionStrategy'
    author = 'ChannelCMT'

    # 策略交易标的的列表
    symbolList = []  # 初始化为空
    posDict = {}  # 初始化仓位字典

    # 多空仓位
    Longpos = EMPTY_STRING  # 多头品种仓位
    Shortpos = EMPTY_STRING  # 空头品种仓位

    # 策略参数
    amWindow = 100
    trendFastWindow = 10  # 布林通道窗口数
    trendSlowWindow = 40  # 布林通道的偏差
    volatilityWindow = 10  # CCI窗口数
    volatilityThresholdWindow = 10
    SignalFastWindow = 30
    revertPercent = 0.04
    stopRatio = 0.04  # 止损比例
    profitMultiplier = 0.7
    initDays = 2  # 初始化数据所用的天数
    fixedSize = 1  # 每次交易的数量

    # 策略变量
    #     intraTradeHigh = 0                  # 持仓期内的最高点
    #     intraTradeLow = 0                   # 持仓期内的最低点
    #     longStop = 0                        # 多头止损
    #     shortStop = 0                       # 空头止损
    Trend = 0
    Volatility = 0
    transactionPrice = 0

    # 参数列表，保存了参数的名称
    paramList = ['name',
                 'className',
                 'author',
                 'symbolList',
                 'amWindow',
                 'trendFastWindow',
                 'trendSlowWindow',
                 'volatilityWindow',
                 'volatilityThresholdWindow',
                 'SignalFastWindow',
                 'revertPercent',
                 'stopRatio',
                 'profitMultiplier',
                 'initDays',
                 'fixedSize']

    # 变量列表，保存了变量的名称
    varList = ['inited',
               'trading',
               'posDict',
               'Volatility',
               'Trend']

    # 同步列表，保存了需要保存到数据库的变量名称
    syncList = ['posDict',
                'intraTradeHigh',
                'intraTradeLow']

    # ----------------------------------------------------------------------
    def __init__(self, ctaEngine, setting):
        """Constructor"""
        super(RegressionStrategy, self).__init__(ctaEngine, setting)

        # 生成仓位记录的字典
        symbol = self.symbolList[0]
        self.Longpos = symbol.replace('.', '_') + "_LONG"
        self.Shortpos = symbol.replace('.', '_') + "_SHORT"

        self.bg60Dict = {
            sym: BarGenerator(self.onBar, 60, self.on60minBar)
            for sym in self.symbolList
        }

        self.bg30Dict = {
            sym: BarGenerator(self.onBar, 30, self.on30minBar)
            for sym in self.symbolList
        }

        self.bg15Dict = {
            sym: BarGenerator(self.onBar, 15, self.on15minBar)
            for sym in self.symbolList
        }

        self.am60Dict = {
            sym: ArrayManager(size=self.amWindow)
            for sym in self.symbolList
        }

        self.am30Dict = {
            sym: ArrayManager(size=self.amWindow)
            for sym in self.symbolList
        }

        self.am15Dict = {
            sym: ArrayManager(size=self.amWindow)
            for sym in self.symbolList
        }

    # ----------------------------------------------------------------------
    def onInit(self):
        """初始化策略（必须由用户继承实现）"""
        self.writeCtaLog(u'%s策略初始化' % self.name)
        self.ctaEngine.initPosition(self)
        # 载入历史数据，并采用回放计算的方式初始化策略数值
        initData = self.loadBar(self.initDays)
        for bar in initData:
            self.onBar(bar)

        self.putEvent()

    # ----------------------------------------------------------------------
    def onStart(self):
        """启动策略（必须由用户继承实现）"""
        self.writeCtaLog(u'%s策略启动' % self.name)
        self.putEvent()

    # ----------------------------------------------------------------------
    def onStop(self):
        """停止策略（必须由用户继承实现）"""
        self.writeCtaLog(u'%s策略停止' % self.name)
        self.putEvent()

    # ----------------------------------------------------------------------
    def onTick(self, tick):
        """收到行情TICK推送（必须由用户继承实现）"""
        self.bgDict[tick.vtSymbol].updateTick(tick)

    # ----------------------------------------------------------------------
    def onBar(self, bar):
        """收到Bar推送（必须由用户继承实现）"""
        symbol = bar.vtSymbol

        bg60 = self.bg60Dict[symbol]
        bg60.updateBar(bar)

        bg30 = self.bg30Dict[symbol]
        bg30.updateBar(bar)

        bg15 = self.bg15Dict[symbol]
        bg15.updateBar(bar)

        # 洗价器
        if (self.posDict[self.Longpos] > 0):
            if (bar.close < self.transactionPrice * (1 - self.stopRatio)) or (
                bar.close > self.transactionPrice * (1 + self.profitMultiplier * self.stopRatio)):
                self.cancelAll()
                self.sell(symbol, bar.close * 0.9, 1)
        elif (self.posDict[self.Shortpos] > 0):
            if (bar.close > self.transactionPrice * (1 + self.stopRatio)) or (
                bar.close < self.transactionPrice * (1 - self.profitMultiplier * self.stopRatio)):
                self.cancelAll()
                self.cover(symbol, bar.close * 1.1, 1)

    # ----------------------------------------------------------------------
    # Regression
    def on60minBar(self, bar):
        symbol = bar.vtSymbol

        # 保存K线数据
        am60 = self.am60Dict[symbol]
        am60.updateBar(bar)

        if not am60.inited:
            return

        trendFastArray = ta.LINEARREG_SLOPE(am60.close, self.trendFastWindow)
        trendSlowArray = ta.LINEARREG_SLOPE(am60.close, self.trendSlowWindow)

        if trendFastArray[-1] > trendSlowArray[-1] or trendFastArray[-1] > 0:
            self.Trend = 1

        elif trendFastArray[-1] <= trendSlowArray[-1] or trendSlowArray[-1] < 0:
            self.Trend = -1
        else:
            self.Trend = 0

    ## ATR_Slope
    def on30minBar(self, bar):
        symbol = bar.vtSymbol
        # 保存K线数据
        am30 = self.am30Dict[symbol]
        am30.updateBar(bar)
        if not am30.inited:
            return

        volatilityArray = ta.ATR(am30.high, am30.low, am30.close, self.volatilityWindow)
        volatilityThreshold = ta.LINEARREG_SLOPE(volatilityArray, self.volatilityThresholdWindow)
        if volatilityThreshold[-1] > 0:
            self.Volatility = 1
        elif volatilityThreshold[-1] <= 0:
            self.Volatility = 0

    ## Reverting
    def on15minBar(self, bar):
        symbol = bar.vtSymbol
        am15 = self.am15Dict[symbol]
        am15.updateBar(bar)
        if not am15.inited:
            return

        signalMeanArray = ta.MA(am15.close, self.SignalFastWindow)
        signalUpperArray = signalMeanArray * (1 + self.revertPercent)
        signalLowerArray = signalMeanArray * (1 - self.revertPercent)

        if self.posDict[self.Longpos] == 0 and self.posDict[self.Shortpos] == 0:
            if self.Trend == 1 and self.Volatility == 1:
                if am15.close[-1] >= signalLowerArray[-2] and am15.close[-2] < signalLowerArray[-2]:
                    self.buy(symbol, bar.close * 1.1, self.fixedSize)
            if self.Trend == -1 and self.Volatility == 1:
                if am15.close[-1] <= signalUpperArray[-2] and am15.close[-2] > signalUpperArray[-2]:
                    self.short(symbol, bar.close * 0.9, self.fixedSize)

                    #         self.saveSyncData()
        # 发出状态更新事件
        self.putEvent()

        # ----------------------------------------------------------------------

    def onOrder(self, order):
        """收到委托变化推送（必须由用户继承实现）"""
        pass

    # ----------------------------------------------------------------------
    def onTrade(self, trade):
        self.transactionPrice = trade.price
        self.putEvent()

    # ----------------------------------------------------------------------
    def onStopOrder(self, so):
        """停止单推送"""
        pass