In [7]:
#1.调整数据结构：将data调整为支持多个股票的格式。最简单的方式是使用Pandas的DataFrame，以股票代码作为一级列索引，时间作为行索引。
#2.修改Broker类：支持针对特定标的的买卖操作。
#3.修改Strategy类：同时考虑多个股票标的的决策。
#4.调整交易记录：确保每一笔交易都记录了具体的股票标的。

In [8]:
import pandas as pd
import numpy as np
from abc import ABC, abstractmethod
from typing import Type

In [9]:
class Strategy(ABC):
    def __init__(self, broker, data):
        self.broker = broker
        self.data = data

    @abstractmethod
    def init(self):
        pass

    @abstractmethod
    def next(self, bar):
        pass

class Broker:
    def __init__(self, cash: float):
        self.cash = cash
        self.positions = {}
        self.position_values = {}
        self.trades = []

    def buy(self, symbol, price, quantity):
        # 如果要买的symbol不在我们持仓里，我们先为它开辟空间
        if symbol not in self.positions:
            self.positions[symbol] = 0
        # 计算买入成本
        cost = price * quantity
        
        if self.cash >= cost:
            self.cash -= cost
            self.positions[symbol] += quantity
            self.trades.append((symbol, 'BUY', quantity, price))

    def sell(self, symbol, price, quantity):
        # 先检查是否有足够的持仓
        if symbol in self.positions and self.positions[symbol] >= quantity:
            self.positions[symbol] -= quantity
            self.cash += price * quantity
            self.trades.append((symbol, 'SELL', quantity, price))
    
    def update_position_value(self, symbol, price):
        """更新给定标的的持仓市值"""
        if symbol in self.positions:
            self.position_values[symbol] = self.positions[symbol] * price

    def result(self):
        total_position_value = sum(self.position_values.values())
        return self.cash + total_position_value

class Backtest:
    def __init__(self, data: pd.DataFrame, strategy: Type[Strategy], cash: float = 10_000):
        self.strategy = strategy
        self.data = data
        self.broker = Broker(cash)
        self.strategy = strategy(self.broker, self.data)
        self.strategy.symbols = list(set(column.split('-')[0] for column in data.columns))

    def run(self):
        self.strategy.init()
        
        for i in range(len(self.data)):
            self.strategy.next(i)

            for symbol in self.strategy.symbols:  # 对于策略涉及的每个标的
                close_price = self.data[symbol + '-Close'].iloc[i]  # 获取收盘价
                self.broker.update_position_value(symbol, close_price)  # 更新持仓市值        
        
        return self.broker.result(), self.broker.trades

class MultiAssetStrategy(Strategy):
    def init(self):
        pass

    def next(self, bar):
        for symbol in self.symbols:
            close = self.data[symbol + '-Close'].iloc[bar]
            # 示例策略：如果当前收盘价高于前一天，买入；如果低于前一天，卖出
            if bar > 0 and close > self.data[symbol + '-Close'].iloc[bar - 1]:
                self.broker.buy(symbol, close, 10)
            elif bar > 0 and close < self.data[symbol + '-Close'].iloc[bar - 1]:
                self.broker.sell(symbol, close, 10)


In [10]:
# 使用示例
# 假设data是一个DataFrame，包含两个股票的Close价格
data = pd.DataFrame({
    'AAPL-Close': np.random.lognormal(mean=0.0, sigma=0.2, size=60) + 100,
    'MSFT-Close': np.random.lognormal(mean=0.0, sigma=0.2, size=60) + 200,
})

bt = Backtest(data, MultiAssetStrategy)
final_result, trades = bt.run()
print(f"Final Result: ${final_result:.2f}")
print("Trades:", trades)

Final Result: $9898.03
Trades: [('MSFT', 'BUY', 10, 201.07258533138418), ('MSFT', 'SELL', 10, 200.95623331425992), ('AAPL', 'BUY', 10, 100.91435427754512), ('AAPL', 'SELL', 10, 100.66964820330193), ('MSFT', 'BUY', 10, 201.565785876715), ('AAPL', 'BUY', 10, 101.2343479099103), ('MSFT', 'SELL', 10, 200.96515925971437), ('AAPL', 'SELL', 10, 100.8866327308534), ('MSFT', 'BUY', 10, 201.2045456485376), ('AAPL', 'BUY', 10, 101.21917446633296), ('MSFT', 'BUY', 10, 201.45336590620923), ('AAPL', 'SELL', 10, 100.9750599081865), ('MSFT', 'SELL', 10, 200.88973168317483), ('MSFT', 'SELL', 10, 200.8136801707499), ('AAPL', 'BUY', 10, 100.83090611384793), ('MSFT', 'BUY', 10, 200.9511509823493), ('AAPL', 'BUY', 10, 100.96748711054508), ('MSFT', 'BUY', 10, 201.06037828863288), ('AAPL', 'BUY', 10, 101.07586415231123), ('MSFT', 'BUY', 10, 201.08554479920693), ('AAPL', 'SELL', 10, 101.06359537647239), ('MSFT', 'SELL', 10, 200.8882676949117), ('AAPL', 'SELL', 10, 101.02758154324353), ('MSFT', 'SELL', 10, 200

In [11]:
data

Unnamed: 0,AAPL-Close,MSFT-Close
0,101.247147,200.834918
1,100.942935,201.072585
2,100.776026,200.956233
3,100.914354,200.902298
4,100.669648,200.886202
5,101.234348,201.565786
6,100.886633,200.965159
7,101.219174,201.204546
8,100.97506,201.453366
9,100.594899,200.889732
