In [15]:
def ATR_signal_cal(df,N,k1,k2):
    # 计算指标
    ohlc_df = df[['Open','High','Low','Close']].copy()
    # 计算TR
    ohlc_df['H-L'] = ohlc_df['High'] - ohlc_df['Low']
    ohlc_df['abs(H-C_1)'] = (ohlc_df['High'] - ohlc_df['Close'].shift(1)).abs()
    ohlc_df['abs(L-C_1)'] = (ohlc_df['Low'] - ohlc_df['Close'].shift(1)).abs()
    ohlc_df['TR'] = ohlc_df[['H-L','abs(H-C_1)','abs(L-C_1)']].max(axis=1)
    # 计算ATR
    ohlc_df['ATR'] = ohlc_df['TR'].rolling(N).mean()
    print(ohlc_df.head(20))
    print(ohlc_df.tail(20))
    # 计算上下轨
    ohlc_df['Upline'] = ohlc_df['Open'] + k1 * ohlc_df['ATR']
    ohlc_df['Downline'] = ohlc_df['Open'] - k2 * ohlc_df['ATR']
    print(ohlc_df.head(20))
    print(ohlc_df.tail(20))
    # 生成突破信号
    ohlc_df['signal'] = 0
    ohlc_df.loc[ohlc_df[ohlc_df['Close'] > ohlc_df['Upline']].index,'signal'] = 1
    ohlc_df.loc[ohlc_df[ohlc_df['Close'] < ohlc_df['Downline']].index,'signal'] = -1
    print(ohlc_df.head(20))
    print(ohlc_df.tail(20))

    return ohlc_df[['signal']]

In [16]:
# 交易规则函数
def my_position_func(symbol, md_data, current_pos, user_data):
    if 'bar_index' not in user_data:
        user_data['bar_index'] = 0
    else:
        user_data['bar_index'] += 1
    target_pos = current_pos

    # 空仓时
    if current_pos == 0:
        user_data['volume'] = 0
        # 当出现交易信号时，按对应信号方向开仓
        if md_data['signal'] != 0:
            target_pos = md_data['signal']
            user_data['open_bar'] = user_data['bar_index']
    # 持有仓位时
    else:
        # 当信号与当前仓位方向相反时，平仓
        if current_pos * md_data['signal'] < 0:
            target_pos = 0
            user_data['open_bar'] = user_data['bar_index']
    return target_pos

In [None]:
import pandas as pd
from dataclasses import dataclass
from datetime import date, datetime
from collections import defaultdict
from typing import List, Dict, Tuple
'''
@dataclass
class BacktestTrade(object):
    #symbol: str    # sp500
    is_open: bool  # 开仓
    is_long: bool  # 做多
    price: float
    volume: int
    trading_date: date
    #trade_time: datetime
'''
class BacktestTrade:
    """ 记录交易信息 """
    def __init__(self, is_open: bool, is_long: bool, price: float, volume: int, trading_day: str):
        self.is_open = is_open  # 是否开仓
        self.is_long = is_long  # 多空方向
        self.price = price  # 交易价格
        self.volume = volume  # 交易手数
        self.trading_day = trading_day  # 交易日期
'''
class DailyResult:
    def __init__(self, date: date):
        """"""
        self.date = date
        self.close_price_dict: Dict[str, float] = dict()
        self.pre_close_dict: Dict[str, float] = dict()
        self.trades: list[BacktestTrade] = []
        self.trade_count = 0
        self.start_pos_dict = dict()
        self.end_pos_dict = dict()
        self.turnover = 0
        self.commission = 0
        self.trading_pnl = 0
        self.holding_pnl = 0
        self.total_pnl = 0
        self.net_pnl = 0

    def calculate_pnl(self, vt_symbol: str, pre_close: float, start_pos: float):

        self.pre_close_dict[vt_symbol] = pre_close
        # Holding pnl is the pnl from holding position at day start
        self.start_pos_dict[vt_symbol] = start_pos
        self.end_pos_dict[vt_symbol] = start_pos
        self.holding_pnl += self.start_pos_dict[vt_symbol] *  (self.close_price_dict[vt_symbol] - self.pre_close_dict[vt_symbol])
        # Trading pnl is the pnl from new trade during the day
        self.trade_count = len(self.trades)

        for trade in self.trades:
            if trade.symbol != vt_symbol: continue
            if (trade.is_long and trade.is_open) or (not trade.is_long and not trade.is_open): pos_change = trade.volume  # 开多仓&平空仓
            else: pos_change = -trade.volume   # 平多仓&开空仓

            turnover = trade.price * trade.volume

            self.trading_pnl += pos_change * (self.close_price_dict[vt_symbol] - trade.price)
            self.end_pos_dict[vt_symbol] += pos_change
            self.turnover += turnover

        # Net pnl takes account of commission and slippage cost
        self.total_pnl = self.trading_pnl + self.holding_pnl
        self.net_pnl = self.total_pnl - self.commission   # commission = 0
'''
class DailyResult:
    """ 记录每日交易结果 """
    def __init__(self, trading_day: str):
        self.trading_day = trading_day  # 交易日
        self.trades: List[BacktestTrade] = []  # 当天所有交易
        self.close_price = None  # 收盘价
        self.pnl = 0  # 当天盈亏

    def calculate_pnl(self, position: int, open_price: float):
        """ 计算每日盈亏 """
        if self.close_price is not None:
            self.pnl = position * (self.close_price - open_price)  # 盈亏计算
        return self.pnl

class SimpleStrategy:
    def __init__(self) -> None:
        super().__init__()
        self._current_pos = 0  #当前持仓
        self._trades = []  #存储所有交易
        self._target_pos = 0  #目标持仓
        self._daily_results = {}  #记录每日交易数据
        self._user_data = {}
        self._my_position_func = None  #持仓决策函数
        self._volume = {}  #交易手数
        self._trade_num = 0  #交易次数

    def process_data(self, symbol, principal, md_data):
        trading_day = md_data.get('tradingDay', None)

        if md_data['tradingDay'] not in self._daily_results:
            self._daily_results[md_data['tradingDay']] = DailyResult(md_data['tradingDay'])
            self._daily_results[md_data['tradingDay']].close_price_dict[symbol] = md_data['close']

        while self._target_pos != self._current_pos:
            is_open = (self._target_pos < self._current_pos and self._current_pos <= 0) or (
                    self._target_pos > self._current_pos and self._current_pos >= 0)         # 开仓

            is_long = (self._target_pos < self._current_pos and self._current_pos > 0) or (
                    self._target_pos > self._current_pos and self._current_pos >= 0)         # 多空

            if is_open: self._volume[symbol] = max(principal // (md_data['open']),1)         # 最少交易1手
            new_trade = BacktestTrade(is_open, is_long, md_data['open'], self._volume, trading_day)
            self._trades.append(new_trade)
            self._daily_results[trading_day].trades.append(new_trade)
            if self._target_pos !=0: self._current_pos += self._target_pos
            else:self._current_pos = self._target_pos

        # calculate next bar target position direction
        self._target_pos = self._my_position_func(md_data, self._current_pos, self._user_data)

    def get_trades(self) -> List[BacktestTrade]:
        return self._trades

    def get_daily_results(self):
        return self._daily_results



class SimpleBacktest(object):
    def __init__(self) -> None:
        self._symbol_ohlc_data = dict()
        self._symbol_pred_data = dict()
        # self._symbol_commission = dict() # 不考虑手续费
        self._symbol_principal = dict()
        self._strategy = None

    def set_symbol_data(self, symbol: str, ohlc_data: pd.DataFrame, predict_data: pd.DataFrame) -> None:
        self._symbol_ohlc_data[symbol] = ohlc_data
        self._symbol_pred_data[symbol] = predict_data
        # set default commission
        # self._symbol_commission[symbol] = (0.0005, 0.0005)

    '''
    def set_symbol_commission(self, symbol: str, open_commission_rate: float, close_commission_rate: float):
        self._symbol_commission[symbol] = (open_commission_rate, close_commission_rate)
    '''

    def set_symbol_principal(self, symbol: str, principal_value: float):
        self._symbol_principal[symbol] = principal_value

    def set_strategy(self, strategy):
        self._strategy = strategy

    def run_backtests(self,my_position_func):
        self._strategy._my_position_func = my_position_func
        daily_df_dict = dict()
        for symbol in self._symbol_ohlc_data.keys():
            ohlc_data = self._symbol_ohlc_data[symbol]
            pred_data = self._symbol_pred_data[symbol]
            df = pd.merge(ohlc_data, pred_data, how='outer', left_index=True, right_index=True)

            for md_data in df.iterrows():
                self._strategy.process_data(symbol, self._symbol_principal[symbol], md_data)

            daily_results = self._strategy.get_daily_results(symbol)

            pre_close = defaultdict(float)
            start_pos = defaultdict(int)

            for daily_result in daily_results.values():
                daily_result.calculate_pnl(symbol, pre_close[symbol], start_pos[symbol])

                pre_close[symbol] = daily_result.close_price_dict[symbol]
                start_pos[symbol] = daily_result.end_pos_dict[symbol]

            # Generate dataframe
            results = defaultdict(list)

            for daily_result in daily_results.values():
                for key, value in daily_result.__dict__.items():
                    results[key].append(value)
            df = pd.DataFrame.from_dict(results).set_index("date")
            df.index = pd.to_datetime(df.index, format="%Y-%m-%d")
            df.loc[:, "balance"] = df["net_pnl"].cumsum() + self._symbol_principal[symbol]

            daily_df_dict[symbol] = df
        return daily_df_dict


# 框架调用函数（依次代入品种代码，行情数据，信号数据，回测本金, 交易规则函数）
# 返回日频回测结果
def call_framework_func(symbol, ohlc_df, signal_df, market_value, pos_func):
    simple_bt = SimpleBacktest()
    simple_strategy = SimpleStrategy()
    simple_bt.set_symbol_data(symbol, ohlc_df, signal_df)
    #simple_bt.set_symbol_commission(symbol, o_cms,c_cms) # 设置开仓和平仓手续费
    simple_bt.set_symbol_principal(symbol, market_value)  # 设置每次交易所用资金量（假设固定）
    simple_bt.set_strategy(simple_strategy)
    daily_df_dict = simple_bt.run_backtests(pos_func)     # 运行回测，获取回测结果

    return daily_df_dict

In [None]:
# 1. 准备示例数据
dates = pd.date_range(start="2023-01-01", periods=5)
ohlc_data = pd.DataFrame({
    'open': [100, 102, 105, 103, 108],
    'high': [101, 104, 106, 105, 110],
    'low': [99, 101, 104, 102, 107],
    'close': [101, 103, 105, 104, 109],
    'tradingDay': dates.date
}, index=dates)

signal_data = pd.DataFrame({
    'signal': [1, 1, -1, 0, 1]  # 1: 做多，-1: 做空，0: 平仓
}, index=dates)

# 2. 定义仓位决策函数
def position_strategy(symbol: str, md_time: datetime, md_data: dict,
                      current_pos: int, user_data: dict) -> int:
    """简单趋势策略"""
    # 使用信号列决定目标仓位
    return md_data['signal']  # 直接返回信号值作为目标仓位

# 3. 运行回测
results = call_framework_func(
    symbol="STOCK_A",
    ohlc_df=ohlc_data,
    signal_df=signal_data,
    market_value=10000,  # 初始资金1万元
    pos_func=my_position_func
)

# 4. 查看结果
result_df = results["STOCK_A"]
print(result_df[['trade_count', 'total_pnl', 'balance']])

TypeError: tuple indices must be integers or slices, not str

In [8]:
from collections import defaultdict
from typing import List, Dict

class BacktestTrade:
    """ 记录交易信息 """
    def __init__(self, is_open: bool, is_long: bool, price: float, volume: int, trading_day: str):
        self.is_open = is_open  # 是否开仓
        self.is_long = is_long  # 多空方向
        self.price = price  # 交易价格
        self.volume = volume  # 交易手数
        self.trading_day = trading_day  # 交易日期

class DailyResult:
    """ 记录每日交易结果 """
    def __init__(self, trading_day: str):
        self.trading_day = trading_day  # 交易日
        self.trades: List[BacktestTrade] = []  # 当天所有交易
        self.close_price = None  # 收盘价
        self.pnl = 0  # 当天盈亏

    def calculate_pnl(self, position: int, open_price: float):
        """ 计算每日盈亏 """
        if self.close_price is not None:
            self.pnl = position * (self.close_price - open_price)  # 盈亏计算
        return self.pnl

class SimpleBacktest:
    """ 简化回测框架，适用于单一指数 """
    def __init__(self, initial_cash: float):
        self.current_pos = 0  # 当前持仓
        self.trades = []  # 交易记录
        self.daily_results = {}  # 每日回测结果
        self.cash = initial_cash  # 账户资金
        self.target_pos = 0  # 目标持仓
        self.position_func = None  # 用户提供的持仓决策函数
        self.trade_size = 0  # 默认最小交易单位

    def process_data(self, md_data: Dict):
        """ 处理市场数据并执行交易 """
        trading_day = md_data.get('tradingDay', None)
        if trading_day is None:
            print("Error: md_data missing 'tradingDay'")
            return

        if trading_day not in self.daily_results:
            self.daily_results[trading_day] = DailyResult(trading_day)

        self.daily_results[trading_day].close_price = md_data['close']

        # 交易逻辑
        while self.target_pos != self.current_pos:
            is_open = (self.target_pos > self.current_pos and self.current_pos >= 0) or (
                    self.target_pos < self.current_pos and self.current_pos <= 0)         # 是否开仓
            is_long = (self.target_pos < self.current_pos and self.current_pos > 0) or (
                    self.target_pos > self.current_pos and self.current_pos >= 0)         # 多or空

            trade_volume = max(int(self.cash // md_data['open']), 1)
            new_trade = BacktestTrade(
                is_open=is_open, is_long=is_long, price=md_data['open'],
                volume=trade_volume, trading_day=trading_day
            )
            self.trades.append(new_trade)
            self.daily_results[trading_day].trades.append(new_trade)

            # 更新持仓
            if self.target_pos != 0:
                self.current_pos = self.target_pos
            else:
                self.current_pos = 0

        # 计算下一步的目标持仓
        if self.position_func:
            self.target_pos = self.position_func(md_data, self.current_pos)

    def run_backtest(self, market_data: List[Dict]):
        """ 运行回测 """
        for md_data in market_data:
            self.process_data(md_data)

    def get_trades(self) -> List[BacktestTrade]:
        return self.trades

    def get_daily_results(self) -> Dict[str, DailyResult]:
        return self.daily_results


In [10]:
import yfinance as yf

sp500 = yf.Ticker("^GSPC")

# Get data for a specific date range
data = sp500.history(start="2020-01-01", end="2024-01-03")

In [12]:
N, k1, k2 = 7, 0.8, 0.8

signal = ATR_signal_cal(data,N,k1,k2)

                                  Open         High          Low        Close  \
Date                                                                            
2020-01-02 00:00:00-05:00  3244.669922  3258.139893  3235.530029  3257.850098   
2020-01-03 00:00:00-05:00  3226.360107  3246.149902  3222.340088  3234.850098   
2020-01-06 00:00:00-05:00  3217.550049  3246.840088  3214.639893  3246.280029   
2020-01-07 00:00:00-05:00  3241.860107  3244.909912  3232.429932  3237.179932   
2020-01-08 00:00:00-05:00  3238.590088  3267.070068  3236.669922  3253.050049   
2020-01-09 00:00:00-05:00  3266.030029  3275.580078  3263.669922  3274.699951   
2020-01-10 00:00:00-05:00  3281.810059  3282.989990  3260.860107  3265.350098   
2020-01-13 00:00:00-05:00  3271.129883  3288.129883  3268.429932  3288.129883   
2020-01-14 00:00:00-05:00  3285.350098  3294.250000  3277.189941  3283.149902   
2020-01-15 00:00:00-05:00  3282.270020  3298.659912  3280.689941  3289.290039   
2020-01-16 00:00:00-05:00  3

In [9]:
# 1. 创建示例市场数据
import pandas as pd

dates = pd.date_range(start="2023-01-01", periods=5)
ohlc_data = pd.DataFrame({
    'open': [100, 102, 105, 103, 108],
    'high': [99, 101, 104, 102, 107],
    'low': [101, 103, 106, 104, 109],
    'close': [101, 103, 105, 104, 109],
    'tradingDay': dates.astype(str)  # 确保交易日为字符串
})

market_data = ohlc_data.to_dict(orient='records')  # 转换成字典列表

# 2. 定义持仓策略
def position_strategy(md_data: Dict, current_pos: int) -> int:
    """ 简单持仓策略：信号=1时做多，-1时做空，0时平仓 """
    return 1 if md_data['close'] > md_data['open'] else 0

# 3. 运行回测
backtest = SimpleBacktest(initial_cash=10000)
backtest.position_func = position_strategy  # 绑定持仓决策函数
backtest.run_backtest(market_data)

# 4. 查看回测结果
trades = backtest.get_trades()
daily_results = backtest.get_daily_results()

for trade in trades:
    print(f"Trade - Day: {trade.trading_day}, Price: {trade.price}, Volume: {trade.volume}, Long: {trade.is_long}")

for day, result in daily_results.items():
    print(f"Day: {day}, PnL: {result.calculate_pnl(backtest.current_pos, result.trades[0].price if result.trades else 0)}")


Trade - Day: 2023-01-02, Price: 102, Volume: 98, Long: True
Trade - Day: 2023-01-04, Price: 103, Volume: 97, Long: True
Trade - Day: 2023-01-05, Price: 108, Volume: 92, Long: True
Day: 2023-01-01, PnL: 101
Day: 2023-01-02, PnL: 1
Day: 2023-01-03, PnL: 105
Day: 2023-01-04, PnL: 1
Day: 2023-01-05, PnL: 1
