In [1]:
import os
import tushare as ts
import pickle
import pandas as pd
import numpy as np

In [2]:
def get_data(tickers, START_DATE, END_DATE, filename='data.pk', token='f5f8a1adb5cae8459b6ef2c43a10a528fd040cb0dc1fa44da610a83b', ):
    """
    根据股票的ticker，生成包含所需信息的dict of dataframe
    :param 
        codes: list of str，股票代码，e.g. '600536.SH'
        filename: 默认为data.pk, 根据任务修改名称
        token: 赋予默认值，原则上不需要更新
    :body
        目前只根据有无对应ticker来判断，没有包括START_DATE
    :return: 包含股票信息的Data, dict (key: ticker, value: dataframe)
    """
    # 打开本地pickle数据
    Data = dict()
    if filename in os.listdir():
        with open(filename, 'rb') as f:
            Data = pickle.load(f)
    # 查询哪些数据需要下载
    old_keys = Data.keys()
    todo_keys = [ticker for ticker in tickers if ticker not in old_keys]
    print('To do: ', todo_keys)
    if len(todo_keys) == 0: return Data
    # 获取原始的股票信息dataframe
    ts.set_token(token)
    for code in todo_keys:
        try:
            info_df = ts.pro_bar(ts_code=code, adj=ADJ, start_date=START_DATE, end_date=END_DATE)
            info_df.index = pd.to_datetime(info_df['trade_date'])
            Data[code] = info_df.sort_index(ascending=True)
        except:
            print('Error in retrieving ', code, ' data...')
            continue
    # 创建或更新本地pickle数据
    with open(filename, 'wb') as f:
        pickle.dump(Data, f)
        
    return Data

def get_tradeDates(start_date, end_date):
    tradeCal = pro.trade_cal(exchange='', start_date=start_date, end_date=end_date)
    tradeDates = pd.to_datetime(tradeCal[tradeCal['is_open']==1]['cal_date']).reset_index(drop=True)
    return tradeDates

In [3]:
class TradeRequest:
    def __init__(self, timestamp, ticker, qty):
        self.timestamp = timestamp
        self.ticker = ticker
        self.qty = qty
        
    def __str__(self):
        return str(self.timestamp) + "  " + self.ticker + "  " + str(self.qty)
    
    def get(self):
        return (self.timestamp, self.ticker, self.qty)

In [4]:
class Instrument:
    def __init__(self, ticker):
        # 最基本的信息
        self.ticker = ticker
        self.position = 0
        self.lastPrice = 0
        self.value = 0
        self.timestamp = np.nan
        # 更多的信息
        #self.yesterday = {'O':np.nan,'H':np.nan,'L':np.nan,'C':np.nan}
        #self.today = {'O':np.nan,'H':np.nan,'L':np.nan,'C':np.nan}
        self.isTradable = True
        #self.cost = 0
        #self.unrealized_pnl = 0
        #self.realized_pnl = 0
        # 历史记录
        self.records = []

    def __update_value(self):
        self.value = self.position * self.lastPrice
    
    def update_lastPrice(self, val, timestamp, isTradable):
        '''
        used in WATCHLIST: update the latest price
        '''
        self.lastPrice = val
        self.timestamp = timestamp
        self.isTradable = isTradable
    
    def retrieve_from_watchlist(self, rhs):
        '''
        used in Portfolio class
        update the prices of instruments in the portfolio from the WATCHLIST
        '''
        self.lastPrice = rhs.lastPrice
        self.timestamp = rhs.timestamp
        self.isTradable = rhs.isTradable
        self.__update_value()
        
    def change_position(self, qty, datetime):
        '''
        Intended to be private; only called by Portfolio objects
        '''
        self.position = self.position + qty
        self.__update_value()
        self.records.append((datetime, qty, self.lastPrice))
        
    def snapshot(self):
        return pd.DataFrame({
            'ticker':[self.ticker],
            'position':[self.position],
            'lastPrice':[self.lastPrice],
            #'pnl':[self.unrealized_pnl]
        })

In [5]:
class WATCHLIST:
    def __init__(self, timestamp, src):
        self.timestamp = timestamp
        self.instruments = dict()
        self.src = src # src can be either dataframe (localfile) or url (another service)
        
    def add(self, ticker):
        self.instruments[ticker] = Instrument(ticker)
        
    def request_data(self, ticker, timestamp):
        # private，由update调用
        # 可以根据src来源，用不同的方法调取数据
        return self.src[ticker].loc[timestamp]
    
    def update(self, timestamp):
        # 理论上是可以选择直接抓取此时最新的数据，或输入时间选择历史数据
        # 现在就直接输入下一刻的时间
        self.timestamp = timestamp
        for ticker, ins in self.instruments.items():
            try:
                sub = self.request_data(ticker, timestamp)
                ins.update_lastPrice(sub.close, timestamp, isTradable=True)
            except:
                ins.update_lastPrice(ins.lastPrice, timestamp, isTradable=False)
    

In [6]:
class Portfolio:
    def __init__(self, name, initCash):
        self.name = name
        self.cash = initCash
        self.instruments = dict()
        self.totalValue = initCash
        self.totalValues = []
        self.datetimes = []
    
    def update_instruments(self, rhs):
        '''
        A wrapper function that calls instrument.retrieve_from_watchlist(...)
        Also updates totalValue
        '''
        self.totalValue = self.cash
        for ticker, ins in self.instruments.items():
            self.instruments[ticker].retrieve_from_watchlist(rhs[ticker])
            self.totalValue += self.instruments[ticker].value
        self.totalValues.append(self.totalValue)
        
    def update_position(self, tradeRequest, rhs):
        '''
        input a trade request, assume it's successful
        A wrapper function that call instrument.change_position(qty, timestamp)
        '''
        timestamp, ticker, qty = tradeRequest.get()
        if self.instruments.get(ticker) == None:
            self.instruments[ticker] = rhs[ticker]
        if not self.instruments[ticker].isTradable:
            print('The instrument is not tradable now:', timestamp, ',', ticker, ',', qty)
            return False
        self.instruments[ticker].change_position(qty, timestamp)
        self.cash -= qty * self.instruments[ticker].lastPrice
        return True
        
    def add_instrument(self, ticker, rhs, position=0):
        '''
        Add an instrument;
        Intended to have position 0 when selecting the stock and waiting for the best timing
        Also allows to have nonzero position, meaning owning a stock from other sources
        '''
        ins = Instrument(ticker)
        ins.retrieve_from_watchlist(rhs[ticker])
        self.instruments[ticker] = ins
        

In [7]:
def rebalance_EW(portfolio, datetime, rhs):
    # 把总资金分配给每一个股票
    each = portfolio.totalValue / len(portfolio.instruments)
    # 计算得出目标要达到多少股
    targets = np.array([int(each / ins.lastPrice // 100 * 100) for ins in portfolio.instruments.values()])
    # 计算要买/卖多少股
    targets -= np.array([ins.position for ins in portfolio.instruments.values()])
    todo_list = []
    # for range in iterables
    # targets 也是从portfolio.instruments 得出来的，确保长度一致
    for ((ticker, _), target) in list(zip(portfolio.instruments.items(), targets)):
        if target != 0:
            isSuccess = p1.update_position(TradeRequest(today, ticker, target), rhs)
            if not isSuccess:
                todo_list.append(TradeRequest(today, ticker, target))
    return todo_list # 如果非交易日，portfolio调用instrument的函数时失败，返回待完成的交易留到明天

In [8]:
def simple_trade(portfolio, tradeRequest, rhs):
    return portfolio.update_position(tradeRequest, rhs)

In [9]:
pd.set_option('display.max_columns', None)  # 显示所有列

START_DATE = '20160101'  # 开始日期
END_DATE = '20211010'  # 结束日期
ADJ = 'qfq'  # 前复权
pro = ts.pro_api()

In [10]:
# 名称、code等数据
stock_basic = pro.stock_basic(exchange='', list_status='L', fields='ts_code,symbol,name,area,industry,list_date')
# 手动选择10个股票
code_list = ['600096', '600328', '000852', '601118', '603019', '600792', '300369', '300188', '600536', '000878',
             '600456', '600683']
StockList_basic = stock_basic[stock_basic['symbol'].apply(lambda s: s in code_list)].reset_index(drop=True)
TS_CODE_POOL = list(StockList_basic.ts_code)

In [11]:
DATA = get_data(TS_CODE_POOL, START_DATE, END_DATE)
TRADE_DAYS = get_tradeDates(START_DATE, END_DATE)

To do:  []


In [12]:
#INIT_AMOUNT = 5
#INIT_CASH = 200000 / INIT_AMOUNT

#JUDGE_TRADE_DAY_DF_LIST = []

In [13]:
# 时间推移，第一时间更新watchlist里instruments的价格；并且标示今日是否可交易
TODOLIST = [] # 昨日未完成的tradeRequest

p1 = Portfolio('buy and hold', 100000)
watchlist = WATCHLIST(TRADE_DAYS[0], src=DATA)

for ticker in TS_CODE_POOL:
    watchlist.add(ticker)

In [14]:
for today in TRADE_DAYS:
    watchlist.update(today)
    p1.update_instruments(watchlist.instruments)
    
    # 先解决昨日不成功的交易请求
    tmpTODO = []
    while len(TODOLIST) > 0:
        tradeRequest = TODOLIST.pop()
        isSuccessful = simple_trade(p1, tradeRequest, rhs=WATCHLIST_dict)
        if not isSuccessful: 
            tmpTODO.append(tradeRequest)
    TODOLIST.extend(tmpTODO)
    
    # buy and hold
    if today == TRADE_DAYS[0]:
        #for ticker in list(DATA.keys()):
        #    p1.add_instrument(ticker, rhs=watchlist.instruments)
        p1.add_instrument('300188.SZ', rhs=watchlist.instruments)
        res = rebalance_EW(p1, today, watchlist.instruments)
        TODOLIST.extend(res)

In [15]:
# sanity check
# 只有一支股票，ew带来的是满仓，cash应该剩下很少
p1.cash

21.160000000003492

In [16]:
p1.instruments['300188.SZ'].records

[(Timestamp('2016-01-04 00:00:00'), 3600, 27.7719)]

In [17]:
pd.Series(p1.totalValues)

0       100000.00
1        90030.88
2        92155.96
3        82953.64
4        78637.24
          ...    
1396     60249.16
1397     60177.16
1398     58701.16
1399     59385.16
1400     60393.16
Length: 1401, dtype: float64

In [18]:
pd.Series(3600 * DATA['300188.SZ'].close.values + p1.cash)

0       100000.00
1        90030.88
2        92155.96
3        82953.64
4        78637.24
          ...    
1395     59385.16
1396     60393.16
1397     60933.16
1398     60141.16
1399     60393.16
Length: 1400, dtype: float64