In [127]:
import pandas as pd
import numpy as np
import pymysql
import datetime
from pyecharts.charts import Kline, Candlestick, Line
from pyecharts import options as opts

In [2]:
class MySQLDataHandler(object):
    """
    MySQLDataHandler is designed to read data from Mysql and provide an interface
    to obtain the "latest" bar in a manner identical to a live
    trading interface. 
    """
    def __init__(self, symbol_list, host, user, passwd, db):
        """
        Initialises the Mysql data handler by connecing
        MySQL database.

        Parameters:
        events - The Event Queue.
        host - Mysql host.
        user - Mysql username.
        passwd - Mysql password.
        db - Mysql databse.
        """
        #self.events = events
        # connect Mysql database
        mysql = pymysql.connect(host, user, passwd, db)
        self.cursor = mysql.cursor()
        self.symbol_list = symbol_list
        
        self.symbol_data = {}
        self.latest_symbol_data = {}
        self.continue_backtest = True 
    def _query_convert_data(self):
        comb_index = None
        sql = "SELECT * FROM CANDLE60S"
        try:
            self.cursor.execute(sql)
            res = self.cursor.fetchall()
        except Exception as e:
            print(e)

        symbol_data = pd.DataFrame(res, columns=['id','timestamp', 'open', 'high', 'low', 'close', 'volume', 'instrument_id'])
        symbol_data.set_index('timestamp', inplace=True)
        symbol_data.drop(['id'], axis=1, inplace=True)

        for s in self.symbol_list:
            self.symbol_data[s] = symbol_data[symbol_data['instrument_id']==s]

            # Combine the index to pad forward values
            if comb_index is None:
                comb_index = self.symbol_data[s].index
            else:
                comb_index.union(self.symbol_data[s].index)
            # Set the latest symbol_data to None
            self.latest_symbol_data[s] = []

        # Reindex the dataframes
        for s in self.symbol_list:
            self.symbol_data[s] = self.symbol_data[s].reindex(index=comb_index, method='pad')

In [3]:
bars = MySQLDataHandler(['BTC-USDT'], '127.0.0.1', 'root', 'Mouyu0407', 'okex')

In [4]:
bars._query_convert_data()

### talib计算均线

In [61]:
import talib as ta

In [63]:
x_axis_duplicated = bars.symbol_data['BTC-USDT'].index.duplicated(keep='last')

In [149]:
data = bars.symbol_data['BTC-USDT'][~x_axis_duplicated]

In [65]:
x = data.copy()
x.loc[:,'MA_5'] = ta.MA(x['close'], timeperiod=5)

In [150]:
ma5 = data['close'].rolling(5).mean().values.tolist()
ma13 = data['close'].rolling(13).mean().values.tolist()

---------

In [78]:
bars.symbol_data['BTC-USDT'] = bars.symbol_data['BTC-USDT'][['open','close','low','high']]

In [152]:
x_axis = bars.symbol_data['BTC-USDT'].index.drop_duplicates(keep='last').values.tolist()
x_axis = [datetime.datetime.strptime(t, '%Y-%m-%dT%H:%M:%S.%fZ').strftime('%m-%d %H:%M') for t in x_axis]
x_axis_duplicated = bars.symbol_data['BTC-USDT'].index.duplicated(keep='last')
data = bars.symbol_data['BTC-USDT'][~x_axis_duplicated].values.tolist()

In [153]:
kline = (
    Kline()
    .add_xaxis(x_axis)
    .add_yaxis(
        series_name='candle60s', 
        y_axis=data, 
    )
    .set_series_opts(
        markpoint_opts=opts.MarkPointOpts(
            data=[
                opts.MarkPointItem(coord=['05-08 10:29',9900], value='Buy')
            ]
        )
    )
    .set_global_opts(
        title_opts=opts.TitleOpts(title="回测结果", pos_left="0"),
        xaxis_opts=opts.AxisOpts(
            type_="category",
            is_scale=True,
            boundary_gap=False,
        ),
        yaxis_opts=opts.AxisOpts(
            is_scale=True,
            #splitline_opts=opts.SplitLineOpts(
            #    is_show=True, linestyle_opts=opts.LineStyleOpts(width=1)
            #),
            splitarea_opts=opts.SplitAreaOpts(
                is_show=True, areastyle_opts=opts.AreaStyleOpts(opacity=1)
            ),
        ),
        datazoom_opts=[opts.DataZoomOpts(type_="slider",range_start=0,range_end=20)],
    )
)
line = (
    Line()
    .add_xaxis(x_axis)
    .add_yaxis(
        series_name="MA5",
        y_axis=ma5,
        is_smooth=True,
        is_hover_animation=False,
        linestyle_opts=opts.LineStyleOpts(width=3, opacity=0.5),
        label_opts=opts.LabelOpts(is_show=False),
    )
    .add_yaxis(
        series_name="MA13",
        y_axis=ma13,
        is_smooth=True,
        is_hover_animation=False,
        linestyle_opts=opts.LineStyleOpts(width=3, opacity=0.5),
        label_opts=opts.LabelOpts(is_show=False),
    )
    .set_global_opts(
        xaxis_opts=opts.AxisOpts(type_="category"),
        yaxis_opts=opts.AxisOpts(is_scale=True)
    )
)
overlap_kline_line = kline.overlap(line)
#kline.render_notebook()
#line.render_notebook()
overlap_kline_line.render_notebook()