In [1]:
import pandas as pd
import numpy as np
from abc import ABC, abstractmethod
from typing import List

In [2]:

class Strategy(ABC):
    """
    抽象策略类,用户需要继承并实现init和next方法。
    """
    def __init__(self, broker, data):
        self.broker = broker
        self.data = data

    @abstractmethod
    def init(self):
        pass
    
    @abstractmethod
    def next(self, bar):
        pass

class Broker:
    """
    用于模拟订单执行和记录交易情况。
    """
    def __init__(self):
        self.orders = [] # 订单列表，执行后清空
        self.trades = [] # 交易列表

    def buy(self):
        self.orders.append('BUY')

    def sell(self):
        self.orders.append('SELL')

    # 先让strategy来释放order，再通过回测引擎来执行
    def execute_orders(self, bar, price):
        # 如果有订单，则执行
        if self.orders:
            for order in self.orders:
                self.trades.append({'action': order, 'bar': bar, 'price': price})
            self.orders = []

class Backtest:
    """
    回测引擎，执行回测并输出结果。
    """
    def __init__(self, data: pd.DataFrame, strategy: type[Strategy]):
        self.strategy = strategy
        self.data = data
        self.broker = Broker()

    def run(self):
        strategy = self.strategy(broker=self.broker, data=self.data)
        strategy.init()
        
        for i in range(len(self.data)):
            strategy.next(i)  # 模拟每个时间点的行为
            # 以收盘价作为交易执行价格
            self.broker.execute_orders(bar=i, price=self.data['Close'].iloc[i])
        
        return self.broker.trades

def print_trades(trades):
    for trade in trades:
        print(f"Bar: {trade['bar']}, Action: {trade['action']}, Price: {trade['price']}")

In [3]:
# 示例使用
class SmaCross(Strategy):
    def init(self):
        self.sma1 = self.data['Close'].rolling(10).mean()
        self.sma2 = self.data['Close'].rolling(20).mean()

    def next(self, bar):
        if self.sma1.iloc[bar] > self.sma2.iloc[bar]:
            self.broker.buy()
        elif self.sma1.iloc[bar] < self.sma2.iloc[bar]:
            self.broker.sell()

# 模拟数据
data = pd.DataFrame({'Close': np.random.rand(100) * 100})

bt = Backtest(data, SmaCross)
trades = bt.run()
print_trades(trades)

Bar: 19, Action: SELL, Price: 41.41488393249385
Bar: 20, Action: SELL, Price: 82.81114246973213
Bar: 21, Action: SELL, Price: 32.911436066778435
Bar: 22, Action: SELL, Price: 85.22245491295378
Bar: 23, Action: SELL, Price: 23.642008483523448
Bar: 24, Action: BUY, Price: 27.785269437748795
Bar: 25, Action: SELL, Price: 24.73569886436363
Bar: 26, Action: SELL, Price: 49.39416330426308
Bar: 27, Action: BUY, Price: 80.73980866606848
Bar: 28, Action: SELL, Price: 17.650641447649317
Bar: 29, Action: BUY, Price: 39.804100325097934
Bar: 30, Action: SELL, Price: 79.20478594806059
Bar: 31, Action: SELL, Price: 15.536802049229692
Bar: 32, Action: SELL, Price: 66.19285752951359
Bar: 33, Action: SELL, Price: 90.62197920629195
Bar: 34, Action: SELL, Price: 47.42345821690552
Bar: 35, Action: BUY, Price: 91.31897881402922
Bar: 36, Action: BUY, Price: 67.8983415354752
Bar: 37, Action: BUY, Price: 33.24345987429746
Bar: 38, Action: BUY, Price: 96.59433171298087
Bar: 39, Action: BUY, Price: 49.8328530097