In [14]:
# 增加仓位管理：通过修改Broker类增加资金管理与仓位控制的机制
# 计算最终收益：根据交易历史与持有仓位计算回测结束时的累计收益

In [15]:
import pandas as pd
import numpy as np
from abc import ABC, abstractmethod
from typing import Type

In [16]:
class Strategy(ABC):
    def __init__(self, broker, data):
        self.broker = broker
        self.data = data

    @abstractmethod
    def init(self):
        pass
    
    @abstractmethod
    def next(self, bar):
        pass

class Broker:
    def __init__(self, cash: float):
        self.cash = cash # 现金
        self.position = 0 # 持仓仓位
        self.position_value = 0 # 持仓价值
        self.trades = [] # 交易记录

    def buy(self, price):
        if self.cash > 0:  # 简化处理，全部资金买入
            self.position = self.cash / price  # 计算买入仓位大小
            self.position_value = self.position * price  # 计算持仓价值

            self.cash = self.cash - self.position_value
            self.trades.append(('BUY', self.position, price))
    
    def sell(self, price):
        if self.position > 0:  # 持有仓位，则全部卖出
            self.cash = self.cash + self.position * price  # 计算卖出后的现金
            self.trades.append(('SELL', self.position, price))
            self.position = 0
            self.position_value = 0

    def update_position_value(self, price):
        # 持仓价值=现在的仓位*最新价格
        self.position_value = self.position * price

    def result(self):
        return self.cash + self.position_value

class Backtest:
    def __init__(self, data: pd.DataFrame, strategy: Type[Strategy], cash: float = 10_000):
        self.strategy = strategy
        self.data = data
        self.broker = Broker(cash)

    def run(self):
        strategy = self.strategy(self.broker, self.data)
        strategy.init()
        
        for i in range(len(self.data)):
            strategy.next(i)
            # 通过update_position_value每一天都更新一道持仓价值
            self.broker.update_position_value(self.data['Close'].iloc[i])
        
        return self.broker.result(), self.broker.trades

In [17]:
# 示例策略
class SmaCross(Strategy):
    def init(self):
        self.sma1 = self.data['Close'].rolling(10).mean()
        self.sma2 = self.data['Close'].rolling(20).mean()

    def next(self, bar):
        if bar == 0: return
        if self.sma1.iloc[bar] > self.sma2.iloc[bar] and self.sma1.iloc[bar-1] <= self.sma2.iloc[bar-1]:
            self.broker.buy(price=self.data["Close"].iloc[bar])
        elif self.sma1.iloc[bar] < self.sma2.iloc[bar] and self.sma1.iloc[bar-1] >= self.sma2.iloc[bar-1]:
            self.broker.sell(price=self.data["Close"].iloc[bar])

# 模拟数据
np.random.seed(42)  # 保证可复现性
data = pd.DataFrame({'Close': np.random.lognormal(mean=0.0, sigma=0.2, size=100) + 50})

bt = Backtest(data, SmaCross)
final_result, trades = bt.run()
print(f"Final Result: ${final_result:.2f}")
print("Trades:", trades)

Final Result: $9871.99
Trades: [('BUY', 196.87219019581974, 50.79437573205976), ('SELL', 196.87219019581974, 51.0401576073225), ('BUY', 196.4142518605342, 51.15915734683227), ('SELL', 196.4142518605342, 51.03486776600187), ('BUY', 195.64535861387165, 51.23543661898979), ('SELL', 195.64535861387165, 51.06696353222703), ('BUY', 195.02653995890992, 51.22899886184242), ('SELL', 195.02653995890992, 51.17645439420683), ('BUY', 194.85211396175487, 51.22226608137345), ('SELL', 194.85211396175487, 50.87895291356166), ('BUY', 193.0264669271383, 51.36016676464511), ('SELL', 193.0264669271383, 51.01852159102574), ('BUY', 194.3465260981272, 50.67198868061742), ('SELL', 194.3465260981272, 50.957016757180035), ('BUY', 193.77198055240058, 51.10810736851555), ('SELL', 193.77198055240058, 50.86900133386499), ('BUY', 193.5149420895431, 50.936568673982336), ('SELL', 193.5149420895431, 51.06101293980262), ('BUY', 193.74256317427475, 51.001023214455586), ('SELL', 193.74256317427475, 50.95416618301184)]
