In [2]:
import backtrader as bt
import pandas as pd
import numpy as np
import datetime
from copy import deepcopy

# 一、读取日度行情表

表内字段就是 Backtrader 默认情况下要求输入的 7 个字段： 'datetime' 、'open'、'high'、'low'、'close'、'volume'、'openinterest'，外加一个 'sec_code' 股票代码字段。

In [3]:
daily_price = pd.read_csv("./data/daily_price.csv", parse_dates=['datetime'])
daily_price

Unnamed: 0,datetime,sec_code,open,high,low,close,volume,openinterest
0,2019-01-02,600466.SH,33.064891,33.496709,31.954503,32.386321,10629352,0
1,2019-01-02,603228.SH,50.660230,51.458513,50.394136,51.120778,426147,0
2,2019-01-02,600315.SH,148.258423,150.480132,148.258423,149.558935,2138556,0
3,2019-01-02,000750.SZ,49.512579,53.154883,48.715825,51.561375,227557612,0
4,2019-01-02,002588.SZ,36.608672,36.608672,35.669988,35.763857,2841517,0
...,...,...,...,...,...,...,...,...
255967,2021-01-28,600717.SH,121.489201,122.011736,120.705400,120.966667,6022213,0
255968,2021-01-28,300558.SZ,134.155888,137.600704,130.700970,131.569750,5330301,0
255969,2021-01-28,600171.SH,39.774873,39.830040,38.864630,38.947380,12354183,0
255970,2021-01-28,600597.SH,47.190201,49.243025,46.250355,46.423484,32409940,0


In [4]:
# 筛选 600466.SH 和 603228.SH 2只股票的数据集
data1 = daily_price.query(f"sec_code=='600466.SH'").set_index('datetime').drop(columns=['sec_code'])
data2 = daily_price.query(f"sec_code=='603228.SH'").set_index('datetime').drop(columns=['sec_code'])
data2

Unnamed: 0_level_0,open,high,low,close,volume,openinterest
datetime,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1
2019-01-02,50.660230,51.458513,50.394136,51.120778,426147,0
2019-01-03,50.609059,51.049137,50.107573,50.639762,492071,0
2019-01-04,50.199683,51.171950,49.278588,50.455543,665486,0
2019-01-07,50.864918,51.417575,50.353199,50.967262,689444,0
2019-01-08,51.110544,52.174920,50.250855,50.527183,931211,0
...,...,...,...,...,...,...
2021-01-22,59.955312,59.955312,57.478496,58.174461,6410959,0
2021-01-25,58.113052,58.133522,57.110045,57.498966,3445027,0
2021-01-26,57.498966,57.498966,54.305716,55.124498,7340180,0
2021-01-27,55.595298,55.943280,54.960742,55.595298,2665407,0


In [6]:
# 导入指标库
import backtrader.indicators as btind # 导入策略分析模块

# 一、 建议在 __init__() 中提前计算指标

In [None]:
class MyStrategy(bt.Strategy):
	# 先在 __init__ 中提前算好指标
    def __init__(self):
        sma1 = btind.SimpleMovingAverage(self.data)
        ema1 = btind.ExponentialMovingAverage()
        close_over_sma = self.data.close > sma1
        close_over_ema = self.data.close > ema1
        sma_ema_diff = sma1 - ema1
        # 生成交易信号
        buy_sig = bt.And(close_over_sma, close_over_ema, sma_ema_diff > 0)
    # 在 next 中直接调用计算好的指标
    def next(self):
        if buy_sig:
            self.buy()

# 二、计算指标时的各种简写形式

In [7]:
class TestStrategy(bt.Strategy):
    def __init__(self):
        # 最简方式：直接省略指向的数据集
        self.sma1 = btind.SimpleMovingAverage(period=5)
        # 只指定第一个数据表格
        self.sma2 = btind.SMA(self.data, period=5)
        # 指定第一个数据表格的close 线
        self.sma3 = btind.SMA(self.data.close, period=5)
        # 完整写法
        self.sma4 = btind.SMA(self.datas[0].lines[0], period=5)
        # 指标函数也支持简写 SimpleMovingAverage → SMA
        
    def next(self):
        # 提取当前时间点
        print('datetime', self.datas[0].datetime.date(0))
        # 打印当日、昨日、前日的均线
        print('sma1',self.sma1.get(ago=0, size=3))
        print('sma2',self.sma2.get(ago=0, size=3))
        print('sma3',self.sma3.get(ago=0, size=3))
        print('sma4',self.sma4.get(ago=0, size=3))
        
cerebro = bt.Cerebro()
st_date = datetime.datetime(2019,1,2)
end_date = datetime.datetime(2021,1,28)
datafeed1 = bt.feeds.PandasData(dataname=data1, fromdate=st_date, todate=end_date)
cerebro.adddata(datafeed1, name='600466.SH')       
datafeed2 = bt.feeds.PandasData(dataname=data2, fromdate=st_date, todate=end_date)
cerebro.adddata(datafeed2, name='603228.SH')  
cerebro.addstrategy(TestStrategy)
rasult = cerebro.run()

datetime 2019-01-08
sma1 array('d', [nan, nan, 33.015540696])
sma2 array('d', [nan, nan, 33.015540696])
sma3 array('d', [nan, nan, 33.015540696])
sma4 array('d', [nan, nan, 33.015540696])
datetime 2019-01-09
sma1 array('d', [nan, 33.015540696, 33.286968908])
sma2 array('d', [nan, 33.015540696, 33.286968908])
sma3 array('d', [nan, 33.015540696, 33.286968908])
sma4 array('d', [nan, 33.015540696, 33.286968908])
datetime 2019-01-10
sma1 array('d', [33.015540696, 33.286968908, 33.62008535])
sma2 array('d', [33.015540696, 33.286968908, 33.62008535])
sma3 array('d', [33.015540696, 33.286968908, 33.62008535])
sma4 array('d', [33.015540696, 33.286968908, 33.62008535])
datetime 2019-01-11
sma1 array('d', [33.286968908, 33.62008535, 33.546059473999996])
sma2 array('d', [33.286968908, 33.62008535, 33.546059473999996])
sma3 array('d', [33.286968908, 33.62008535, 33.546059473999996])
sma4 array('d', [33.286968908, 33.62008535, 33.546059473999996])
datetime 2019-01-14
sma1 array('d', [33.62008535, 33

# 三、调用指标时的各种简写形式

In [8]:
class TestStrategy(bt.Strategy):
    
    def __init__(self):
        self.sma5 = btind.SimpleMovingAverage(period=5) # 5日均线
        self.sma10 = btind.SimpleMovingAverage(period=10) # 10日均线
        self.buy_sig = self.sma5 > self.sma10 # 5日均线上穿10日均线
      
    def next(self):
        # 提取当前时间点
        print('datetime', self.datas[0].datetime.date(0))
        # 打印当前值
        print('close', self.data.close[0], self.data.close)
        print('sma5', self.sma5[0], self.sma5)
        print('sma10', self.sma10[0], self.sma10)
        print('buy_sig', self.buy_sig[0], self.buy_sig)
        # 比较收盘价与均线的大小
        if self.data.close > self.sma5:
            print('------收盘价上穿5日均线------')
        if self.data.close[0] > self.sma10:
            print('------收盘价上穿10日均线------')
        if self.buy_sig:
            print('------ buy ------')
        
cerebro2 = bt.Cerebro()
st_date = datetime.datetime(2019,1,2)
end_date = datetime.datetime(2021,1,28)
datafeed1 = bt.feeds.PandasData(dataname=data1, fromdate=st_date, todate=end_date)
cerebro2.adddata(datafeed1, name='600466.SH')       
cerebro2.addstrategy(TestStrategy)
rasult = cerebro2.run()


datetime 2019-01-15
close 33.06489128 <backtrader.linebuffer.LineBuffer object at 0x7f4dc71a5100>
sma5 33.18826774 <backtrader.indicators.sma.SimpleMovingAverage object at 0x7f4dc71a5760>
sma10 33.101904218 <backtrader.indicators.sma.SimpleMovingAverage object at 0x7f4dc71a5070>
buy_sig 1.0 <backtrader.linebuffer.LinesOperation object at 0x7f4dc7199340>
------ buy ------
datetime 2019-01-16
close 32.63307367 <backtrader.linebuffer.LineBuffer object at 0x7f4dc71a5100>
sma5 32.966190112 <backtrader.indicators.sma.SimpleMovingAverage object at 0x7f4dc71a5760>
sma10 33.12657951 <backtrader.indicators.sma.SimpleMovingAverage object at 0x7f4dc71a5070>
buy_sig 0.0 <backtrader.linebuffer.LinesOperation object at 0x7f4dc7199340>
datetime 2019-01-17
close 32.0778796 <backtrader.linebuffer.LineBuffer object at 0x7f4dc71a5100>
sma5 32.682424254 <backtrader.indicators.sma.SimpleMovingAverage object at 0x7f4dc71a5760>
sma10 33.151254802 <backtrader.indicators.sma.SimpleMovingAverage object at 0x7f4d

# 四、好用的运算指标

In [11]:
import backtrader.indicators as btind # 导入策略分析模块
class TestStrategy(bt.Strategy):
    
    def __init__(self):
        self.sma5 = btind.SimpleMovingAverage(period=5) # 5日均线
        self.sma10 = btind.SimpleMovingAverage(period=10) # 10日均线
        # bt.And 中所有条件都满足时返回 1；有一个条件不满足就返回 0
        self.And = bt.And(self.data>self.sma5, self.data>self.sma10, self.sma5>self.sma10)
        # bt.Or 中有一个条件满足时就返回 1；所有条件都不满足时返回 0
        self.Or = bt.Or(self.data>self.sma5, self.data>self.sma10, self.sma5>self.sma10)
        # bt.If(a, b, c) 如果满足条件 a，就返回 b，否则返回 c
        self.If = bt.If(self.data>self.sma5,1000, 5000)
        # bt.All,同 bt.And
        self.All = bt.All(self.data>self.sma5, self.data>self.sma10, self.sma5>self.sma10)
        # bt.Any，同 bt.Or
        self.Any = bt.Any(self.data>self.sma5, self.data>self.sma10, self.sma5>self.sma10)
        # bt.Max，返回同一时刻所有指标中的最大值
        self.Max = bt.Max(self.data, self.sma10, self.sma5)
        # bt.Min，返回同一时刻所有指标中的最小值
        self.Min = bt.Min(self.data, self.sma10, self.sma5)
        # bt.Sum，对同一时刻所有指标进行求和
        self.Sum = bt.Sum(self.data, self.sma10, self.sma5)
        # bt.Cmp(a,b), 如果 a>b ，返回 1；否则返回 -1
        self.Cmp = bt.Cmp(self.data, self.sma5)

        
    def next(self):
        print('---------- datetime',self.data.datetime.date(0), '------------------')
        print('close:', self.data[0], 'ma5:', self.sma5[0], 'ma10:', self.sma10[0])
        print('close>ma5',self.data>self.sma5, 'close>ma10',self.data>self.sma10, 'ma5>ma10', self.sma5>self.sma10)
        print('self.And', self.And[0], self.data>self.sma5 and self.data>self.sma10 and self.sma5>self.sma10)
        print('self.Or', self.Or[0], self.data>self.sma5 or self.data>self.sma10 or self.sma5>self.sma10)
        print('self.If', self.If[0], 1000 if self.data>self.sma5 else 5000)
        print('self.All',self.All[0], self.data>self.sma5 and self.data>self.sma10 and self.sma5>self.sma10)
        print('self.Any', self.Any[0], self.data>self.sma5 or self.data>self.sma10 or self.sma5>self.sma10)
        print('self.Max',self.Max[0], max([self.data[0], self.sma10[0], self.sma5[0]]))
        print('self.Min', self.Min[0], min([self.data[0], self.sma10[0], self.sma5[0]]))
        print('self.Sum', self.Sum[0], sum([self.data[0], self.sma10[0], self.sma5[0]]))
        print('self.Cmp', self.Cmp[0], 1 if self.data>self.sma5 else -1)  
        
cerebro3 = bt.Cerebro()
st_date = datetime.datetime(2019,1,2)
ed_date = datetime.datetime(2021,1,28)
datafeed1 = bt.feeds.PandasData(dataname=data1, fromdate=st_date, todate=ed_date)
cerebro3.adddata(datafeed1, name='600466.SH')

cerebro3.addstrategy(TestStrategy)
rasult = cerebro3.run()

---------- datetime 2019-01-15 ------------------
close: 33.06489128 ma5: 33.18826774 ma10: 33.101904218
close>ma5 False close>ma10 False ma5>ma10 True
self.And 0.0 False
self.Or 1.0 True
self.If 5000.0 5000
self.All 0.0 False
self.Any 1.0 True
self.Max 33.18826774 33.18826774
self.Min 33.06489128 33.06489128
self.Sum 99.355063238 99.355063238
self.Cmp -1.0 -1
---------- datetime 2019-01-16 ------------------
close: 32.63307367 ma5: 32.966190112 ma10: 33.12657951
close>ma5 False close>ma10 False ma5>ma10 False
self.And 0.0 False
self.Or 0.0 False
self.If 5000.0 5000
self.All 0.0 False
self.Any 0.0 False
self.Max 33.12657951 33.12657951
self.Min 32.63307367 32.63307367
self.Sum 98.72584329200001 98.72584329200001
self.Cmp -1.0 -1
---------- datetime 2019-01-17 ------------------
close: 32.0778796 ma5: 32.682424254 ma10: 33.151254802
close>ma5 False close>ma10 False ma5>ma10 False
self.And 0.0 False
self.Or 0.0 False
self.If 5000.0 5000
self.All 0.0 False
self.Any 0.0 False
self.Max 33.1

# 五、如何对齐不同周期的指标

In [None]:
# self.data0 是日度行情、self.data1 是月度行情
self.month = btind.xxx(self.data1) # 计算返回的 self.month 指标也是月度的
# 选择指标对象中的第一条 line 进行对齐
self.sellsignal = self.data0.close < self.month.lines[0]()
# 对齐整个指标对象
self.month_ = self.month() 
self.signal = self.data0.close < self.month_.lines[0] 

cerebro.run(runonce=False)

# 六、在 Backtrader 中调用 TA-Lib 库

In [None]:
class TALibStrategy(bt.Strategy):
    def __init__(self):
        # 计算 5 日均线
        bt.talib.SMA(self.data.close, timeperiod=5)
        bt.indicators.SMA(self.data, period=5)
        # 计算布林带
        bt.talib.BBANDS(self.data, timeperiod=25)
        bt.indicators.BollingerBands(self.data, period=25)

# 七、自定义新指标

In [None]:
class MyInd(bt.Indicator):
    lines = (xxx,xxx, ) # 最后一个 “,” 别省略
    params = ((xxx, n),) # 最后一个 “,” 别省略
    
    def __init__(self):
        '''可选'''
        pass
    
    def next(self):
        '''可选'''
        pass
    
    def once(self):
        '''可选'''
        pass 
    
    plotinfo = dict(...)
    plotlines = dict(...)
    ...

In [None]:
class DummyInd(bt.Indicator):
    # 将计算的指标命名为 'dummyline'，后面调用这根 line 的方式有： 
    # self.lines.dummyline ↔ self.l.dummyline ↔ self.dummyline
    lines = ('dummyline',)
    # 定义参数，后面调用这个参数的方式有：
    # self.params.xxx ↔ self.p.xxx
    params = (('value', 5),)
    
    def __init__(self):
        self.l.dummyline = bt.Max(0.0, self.p.value)
    
    def next(self):
        self.l.dummyline[0] = max(0.0, self.p.value)
   
    def once(self, start, end):
        dummy_array = self.l.dummyline.array
        for i in xrange(start, end):
            dummy_array[i] = max(0.0, self.p.value)

In [None]:
class My_MACD(bt.Indicator):
    lines = ('macd', 'signal', 'histo')
    params = (('period_me1',12), 
              ('period_me2', 26), 
              ('period_signal', 9),)
    def __init__(self):
        me1 = EMA(self.data, period=self.p.period_me1)
        me2 = EMA(self.data, period=self.p.period_me2)
        self.l.macd = me1 - me2
        self.l.signal = EMA(self.l.macd, period=self.p.period_signal)
        self.l.histo = self.l.macd - self.l.signal