In [None]:
# V2版本，支持多个数据同时回测

In [None]:
import pandas as pd
import os
import glob
from functools import lru_cache  # 缓存装饰器
import logging
import traceback
from pathlib import Path

In [None]:
import sys, pathlib
# 把 tools 的绝对路径插到最前面
sys.path.insert(0, str(pathlib.Path('__file__').resolve().parent.parent / 'Quant_base'))
import Drawpic
import CLbasement
import Basement

In [None]:
# 测试票
testlist = ['510050','159819','512660','159928','512880']
#testlist = ['SH.600150']
# 配置参数集中管理
CONFIG = {
    'sdate': '', #开始日期，如果为空则为数据最早日期
    'edate': '', #结束日期，如果为空则为数据最晚日期
    #'data_path': r'E:\python_vm\quant_code\data\A股\天粒度', #数据路径
    'data_path': r'E:\python_vm\quant_code\data\ETF\day', 
    'stock_list':[],
    'initial_money': 100000, #初始资金
}
logging.basicConfig(
    filename='backtest.log',      # 日志文件名
    #filemode='w', #每次重新写入
    level=logging.INFO,      # 只记录 INFO 及以上级别
    format='%(asctime)s [%(levelname)s] %(message)s'
)

In [None]:
#数据加载模块
class DataLoader:
    def __init__(self, CONFIG):
        self.data_path = CONFIG['data_path']
        self.stock_list = [p.stem for p in Path(self.data_path).glob('*.csv')] if not CONFIG['stock_list'] else CONFIG['stock_list']
        self.sdate = CONFIG['sdate']
        self.edate = CONFIG['edate']
        # 预加载所有文件路径
        self.files_map = {
            os.path.basename(f).split('.csv')[0]: f 
            for f in glob.glob(os.path.join(self.data_path, "*.csv"))
        }
        # 初始化时立即加载天粒度数据（仅加载一次）
        self.data, self.trade_cal = self._load_data()
        logging.info(f"DataLoader初始化完成: 加载天粒度数据 {self.sdate} 至 {self.edate}, 共 {len(self.trade_cal)} 个交易日")

    def _load_data(self):
        """加载并处理指定日期范围内的所有天粒度数据"""
        stockdf = pd.DataFrame()
        for code in self.stock_list:
            filepath = self.files_map.get(code)
            if not filepath:
                print(f'{code}无数据')
                continue  # 跳过无数据的股票
            try:
                #tmpdf = pd.read_csv(filepath)
                tmpdf = Basement.read_data_local(filepath)
                stockdf = pd.concat([stockdf, tmpdf], axis=0, ignore_index=True)
                stockdf = stockdf[['code','datetime','open','close','low','high','volume','amount']]
                #stockdf.columns = [['code','datetime','open','close','low','high','volume','amount']]
            except Exception as e:
                logging.info(f"读取天粒度数据 {code} 出错: {str(e)}")

        if self.sdate != '':
            stockdf = stockdf[stockdf['datetime'] >= self.sdate]
        if self.edate != '':
            stockdf = stockdf[stockdf['datetime'] <= self.edate]
        
        trade_cal = sorted(stockdf['datetime'].drop_duplicates().tolist())
        
        return stockdf,trade_cal


In [None]:
#仓位管理模块
class PositionManager:
    def __init__(self, CONFIG, ori_data):
        self.cash = CONFIG['initial_money']  # 当前现金
        self.positions = {} #格式为codename : [持仓量，买入价格]
        self.ori_data = ori_data
        self.record = pd.DataFrame(columns=['code','date','type','info','profit'])
        logging.info('持仓初始化完成')
        
    def buy(self,codename, price, volume,date):
        #"""执行买入"""
        price = round(price,2)
        cost = round(price * volume * 100,2)
        self.cash -= cost
        self.positions[codename] = [self.positions.get(codename, [0,0])[0] + volume,price]
        self.record.loc[len(self.record)] = [codename, date, '买入',f'买入{codename}共{volume}手，买入价格{price}','/']
        logging.info(f'交易信息：{date}买入{codename}共{volume}手，买入价格{price}')
        
    def sell(self,codename, price, volume,date):
        #"""执行卖出"""
        price = round(price,2)
        old_price = self.positions[codename][1]
        cost = round(price * volume * 100,2)
        profit = round((price - old_price) * volume * 100)
        self.cash += cost
        self.positions[codename] = [self.positions.get(codename, [0,0])[0] - volume,price]
        self.record.loc[len(self.record)] = [codename, date, '卖出',f'卖出{codename}共{volume}手，卖出价格{price}',profit]
        logging.info(f'交易信息：{date}卖出{codename}共{volume}手，卖出价格{price},此次盈利{profit}')

    def update_positions(self, date):
        #"""基于日期更新当前权益"""
        curdate_df = self.ori_data[self.ori_data['datetime'] == date]
        new_pos = {}
        for k,v in self.positions.items():
            volumn = v[0]
            new_close = float(curdate_df[curdate_df['code'] == k]['close'].values[0])
            new_pos[k] = [volumn,new_close]
        self.positions = new_pos
        #logging.info(f'已更新{date}的持仓价格，当前持仓{self.positions}')
    
    def get_equities(self):
        #计算最新权益
        cur_equities = self.cash
        for k,v in self.positions.items():
            cur_equities += v[0]*v[1]*100
        return cur_equities
            

In [None]:
#回测主函数
class Backtester:
    def __init__(self, config, data_loader):
        """
        初始化回测器
        参数:
            config: 配置字典
            data_loader: 已初始化的DataLoader实例（外部传入，共享数据）
        """
        self.config = config
        self.data_loader = data_loader
        self.cashlist = []
        if not self.data_loader.trade_cal:
            logging.info("警告：无有效交易日数据，回测可能无法正常运行")

    def run(self):
        # 检查数据有效性
        if not self.data_loader.trade_cal:
            logging.info("无法执行回测：无有效交易日数据")
            return

        # 重置回测状态（便于多次运行）
        self.position_manager = PositionManager(self.config,self.data_loader.data)
        self.strategy = Strategy(self.position_manager,self.config)
        self.windowsize = self.strategy.window_size
        self.buysell_point = {}
        logging.info('初始化完成，开始回测')
        # 回测主循环
        trade_cal = self.data_loader.trade_cal
        stockdata = self.data_loader.data
        for i in range(1,len(trade_cal)+1):
            buysignal_list = []
            sellsignal_list = []
            tmp_datelist = trade_cal[max(0, i-self.windowsize):i]
            tmp_sdate = tmp_datelist[0]
            tmp_edate = tmp_datelist[-1]
            tmpdf = stockdata[(stockdata['datetime']>=tmp_sdate)&(stockdata['datetime']<=tmp_edate)]

            #更新持仓价格
            #self.position_manager.update_positions(tmp_edate)
            
            code_list = tmpdf['code'].drop_duplicates().tolist()
            for code in code_list:
                tmpdf_code = tmpdf[tmpdf['code'] == code]
                #cur_close = float(tmpdf_code.tail(1)['close'].values[0])
                signal = self.strategy.generate_signals(tmpdf_code)
                if signal is not None:  # 只处理有效信号
                    if signal['type'] == 0:
                        sellsignal_list.append(signal)
                    elif signal['type'] == 1:
                        buysignal_list.append(signal)
            
            #先处理卖出信号
            for each_signal in sellsignal_list:
                
                self.position_manager.sell(
                    codename=each_signal['codename'],
                    price=each_signal['price'],
                    volume=each_signal['hands'],
                    date = tmp_edate
                )
                #往画图买卖点中新增数据
                cname = each_signal['codename']
                cclose = each_signal['price']
                dftmp = stockdata[stockdata['code'] == cname].reset_index(drop=True)
                indexs = int(dftmp[dftmp['datetime'] == tmp_edate].index[0])
                self.buysell_point.setdefault(cname, []).append([indexs,cclose,0])
                

            #再处理买入信号
            for each_signal in buysignal_list:
                self.position_manager.buy(
                    codename=each_signal['codename'],
                    price=each_signal['price'],
                    volume=each_signal['hands'],
                    date = tmp_edate
                )
                #往画图买卖点中新增数据
                cname = each_signal['codename']
                cclose = each_signal['price']
                dftmp = stockdata[stockdata['code'] == cname].reset_index(drop=True)
                indexs = int(dftmp[dftmp['datetime'] == tmp_edate].index[0])
                self.buysell_point.setdefault(cname, []).append([indexs,cclose,1])

            cur_equities = self.position_manager.get_equities()
            self.cashlist.append(cur_equities)
    
            logging.info(f'日期{tmp_edate}，当前权益{cur_equities}，其中现金{self.position_manager.cash},持仓{self.position_manager.positions}')


In [None]:
# #单均线策略
# class Strategy:
#     def __init__(self,position_manager,CONFIG):
#         self.window_size = 50
#         self.initial_money = CONFIG['initial_money']
#         self.position_manager = position_manager
#         self.high_price = {}
#         self.max_hold = 5 #最大持仓数量
#         logging.info('策略初始化完成')

#     def calculate_moving_average(self, prices_list, x):
#         """
#         简单移动平均（SMA）
#         :param prices_list: 按时间顺序的收盘价列表 [float]
#         :param x: 周期
#         :return: 与 prices_list 等长，前 x-1 个元素保持原值，之后为对应均线
#         """
#         if x <= 0:
#             raise ValueError('周期 x 必须为正整数')
    
#         n = len(prices_list)
#         ma = prices_list.copy()          # 先复制一份，前 x-1 个位置保留原值
    
#         for i in range(x, n + 1):        # 从第 x 个元素开始算
#             window = prices_list[i - x:i]
#             ma[i - 1] = round(sum(window) / x, 2)
#         return ma
    
#     def generate_signals(self, df):
#         if len(df) < self.window_size:
#             return None
#         else:
#             codename = df['code'].drop_duplicates().tolist()[0] #名称
#             #curdate = df.tail(1)['datetime'].values[0]
#             pricelist = df['close'].tolist()
#             curprice = pricelist[-1] #最新价位
#             moving_average = self.calculate_moving_average(pricelist,20)
#             A = pricelist[-5:]
#             B = moving_average[-5:]
#             current_hold = self.position_manager.positions.get(codename, [0, 0])[0] #当前持仓
            
#             #当前有持仓,需要判断出场逻辑
#             if  current_hold > 0:
#                 hprice = self.high_price.get(codename, 0)
#                 if curprice < hprice * 0.90:
#                     logging.info(f'触发出场信号，价格{curprice},最高价{self.high_price}')
#                     sellcount = self.position_manager.positions[codename][0]
#                     return {'codename':codename,'type':0,'hands':sellcount,'price':curprice}
#                 else:
#                     self.high_price[codename] = max(curprice,self.high_price[codename])
            
#             #当前没有持仓
#             elif current_hold == 0:
#                 #入场条件满足
#                 if (all(a < b for a, b in zip(A[:-1], B[:-1]))) and (A[-1] > B[-1]):
#                     logging.info(f'触发均线入场信号，价格{str(A)},均线{str(B)}')
#                     #计算购买量
#                     buycount = self.money_management(curprice)
#                     if buycount > 0:
#                         self.high_price[codename] = curprice #记录买入以来最高价格
#                         #购买量大于0，买入
#                         return {'codename':codename,'type':1,'hands':buycount,'price':curprice}
#             else:
#                 return None
            

#     def money_management(self,curprice):
#         cost = curprice * 100
#         cur_equities = self.position_manager.get_equities()
#         avalible_cash = min(cur_equities/self.max_hold, self.position_manager.cash)
#         curaval_hands = int(avalible_cash /cost)
#         logging.info(f'当前权益为{cur_equities},每支票分配权益为{cur_equities/self.max_hold}，可用现金为{self.position_manager.cash},价格{curprice},一手花费{cost},可买手数为{curaval_hands}')
#         return curaval_hands
        
        

In [None]:
#均线 +中枢策略 小时粒度
class Strategy:
    def __init__(self,position_manager,CONFIG):
        self.window_size = 200
        self.initial_money = CONFIG['initial_money']
        self.position_manager = position_manager
        self.high_price = {}
        self.max_hold = 1 #最大持仓数量
        logging.info('策略初始化完成')

    def calculate_moving_average(self, prices_list, x):
        """
        简单移动平均（SMA）
        :param prices_list: 按时间顺序的收盘价列表 [float]
        :param x: 周期
        :return: 与 prices_list 等长，前 x-1 个元素保持原值，之后为对应均线
        """
        if x <= 0:
            raise ValueError('周期 x 必须为正整数')
    
        n = len(prices_list)
        ma = prices_list.copy()          # 先复制一份，前 x-1 个位置保留原值
    
        for i in range(x, n + 1):        # 从第 x 个元素开始算
            window = prices_list[i - x:i]
            ma[i - 1] = round(sum(window) / x, 2)
        return ma
    
    def money_management(self,curprice):
        cost = curprice * 100
        cur_equities = self.position_manager.get_equities()
        avalible_cash = min(cur_equities/self.max_hold, self.position_manager.cash)
        curaval_hands = int(avalible_cash /cost)
        logging.info(f'当前权益为{cur_equities},每支票分配权益为{cur_equities/self.max_hold}，可用现金为{self.position_manager.cash},价格{curprice},一手花费{cost},可买手数为{curaval_hands}')
        return curaval_hands

    def if_uptrend_judge(self, df_window_L1):
        original_kline,draw,markarea = CLbasement.cl_base(df_window_L1)
        #最后一个中枢
        lastzs = markarea[-1]
        #最后一个中枢开始结束时间
        lastzs_sdatetime = df_window_L1.iloc[lastzs[0]['xAxis']]['datetime']
        lastzs_edatetime = df_window_L1.iloc[lastzs[1]['xAxis']]['datetime']
        #最后一个中枢对应的df
        lastzs_df = df_window_L1[(df_window_L1['datetime'] >= lastzs_sdatetime) & (df_window_L1['datetime'] <= lastzs_edatetime)]
        #最后一个中枢对应的笔
        spos = next((i for i, (a, _) in enumerate(draw) if a == lastzs[0]['xAxis']), None)
        epos = next((i for i, (a, _) in enumerate(draw) if a == lastzs[1]['xAxis']), None)
        lastzs_draw = draw[spos:epos+1]
        #最后一笔的起始时间
        last_draw_starttime = df_window_L1.iloc[draw[-2][0]]['datetime']
        
        #最后一个中枢相关指标：
        zs_top = lastzs[0]['yAxis'] #中枢上沿
        zs_bottom = lastzs[1]['yAxis'] #中枢下沿
        zs_high = lastzs_df.iloc[1:]['high'].max() #中枢高点
        zs_low  = lastzs_df.iloc[1:]['low'].min() #中枢低点
        zs_avgvol = lastzs_df['volume'].mean() #中枢平均成交量
        zs_highvol = lastzs_df.iloc[2:]['volume'].max() #中枢最高成交量
        zs_length = len(lastzs_draw)

        #当前相关指标：
        cur_date = df_window_L1.tail(1)['datetime'].values[0]
        curopen = float(df_window_L1.tail(1)['open'].values[0])
        curhigh = float(df_window_L1.tail(1)['high'].values[0])
        curlow = float(df_window_L1.tail(1)['low'].values[0])
        curclose = float(df_window_L1.tail(1)['close'].values[0])
        

        if (lastzs_edatetime == last_draw_starttime) and (curclose > zs_top) and (zs_length > 3):
            return True
        else:
            return False        

    def generate_signals(self, df):
        if len(df) < self.window_size:
            return None
        else:
            codename = df['code'].drop_duplicates().tolist()[0] #名称
            pricelist = df['close'].tolist()
            curprice = pricelist[-1] #最新价位
            moving_average = self.calculate_moving_average(pricelist,20)
            A = pricelist[-5:]
            B = moving_average[-5:]
            current_hold = self.position_manager.positions.get(codename, [0, 0])[0] #当前持仓

            #判断是否是上升趋势
            uptrend_flag = self.if_uptrend_judge(df)
            
            #当前有持仓,需要判断出场逻辑
            if  current_hold > 0:
                hprice = self.high_price.get(codename, 0)
                if curprice < hprice * 0.90:
                    logging.info(f'触发出场信号，价格{curprice},最高价{self.high_price}')
                    sellcount = self.position_manager.positions[codename][0]
                    return {'codename':codename,'type':0,'hands':sellcount,'price':curprice}
                else:
                    self.high_price[codename] = max(curprice,self.high_price[codename])
            
            #当前没有持仓
            elif current_hold == 0:
                #入场条件满足
                if uptrend_flag:
                    logging.info(f'触发中枢入场信号，价格{str(A)},均线{str(B)}')
                    #计算购买量
                    buycount = self.money_management(curprice)
                    if buycount > 0:
                        self.high_price[codename] = curprice #记录买入以来最高价格
                        #购买量大于0，买入
                        return {'codename':codename,'type':1,'hands':buycount,'price':curprice}
            else:
                return None
            


        
        

In [None]:
trade_record = {}
trade_cashlist = {}
trade_callist = {}
trade_buysellpoint = {}
trade_odata = {}
for eachcode in testlist:
    config = CONFIG
    config['stock_list'] = [eachcode]
    dloader = DataLoader(config)
    backtester = Backtester(config,dloader)
    backtester.run()

    trade_record[eachcode] = backtester.position_manager.record
    trade_cashlist[eachcode] = backtester.cashlist
    trade_callist[eachcode] = dloader.trade_cal
    trade_buysellpoint.update(backtester.buysell_point)
    trade_odata[eachcode] = dloader.data

In [None]:
longest = 0
longestname = ''
for each in trade_callist:
    if len(trade_callist[each]) > longest:
        longest = len(trade_callist[each])
        longestname = each
    
for each in trade_callist:
    newcashlist = [CONFIG['initial_money']] * (longest - len(trade_callist[each]))
    trade_cashlist[each] = newcashlist + trade_cashlist[each]

In [None]:
key_list = list(trade_cashlist)
fname = key_list[0]
linechart = Drawpic.generate_line_chart(trade_cashlist[fname],trade_callist[longestname],fname)
if len(key_list) > 1:
    for fname in key_list[1:]:
        tmpchart = Drawpic.generate_line_chart(trade_cashlist[fname],trade_callist[longestname],fname)
        linechart.overlap(tmpchart)
avg = [sum(col) for col in zip(*trade_cashlist.values())]
avgchart = Drawpic.generate_line_chart(avg,trade_callist[longestname],'avg')
linechart.overlap(avgchart)
linechart.load_javascript()

In [None]:
linechart.render_notebook()

In [None]:
codename = 512880

In [None]:
trade_record[str(codename)]

In [None]:
datas = trade_odata[str(codename)]
#datas = dloader.data[dloader.data['code'] == codename]
buysell_point = trade_buysellpoint[codename]

In [None]:
grid = Drawpic.generate_pic_by_df(datas,buysell_point,[10,20])
grid.load_javascript()

In [None]:
grid.render_notebook()