In [2]:
import pandas as pd
import numpy as np

class BreakoutLabeler:
    """
    突破事件标注器
    上破=1, 下破=2, 其余=0
    """
    
    def __init__(self, 
                 donchian_period=30,    # 唐奇安通道周期
                 atr_period=10,         # ATR周期
                 atr_buffer=0.1,        # ATR缓冲倍数
                 bb_period=10,          # 布林带周期
                 bb_std=1.5,            # 布林带标准差
                 squeeze_lookback=100,  # 挤压分位数回看周期
                 squeeze_threshold=0.2, # 挤压分位数阈值
                 squeeze_bars=3):       # 挤压连续根数
        
        self.donchian_period = donchian_period
        self.atr_period = atr_period
        self.atr_buffer = atr_buffer
        self.bb_period = bb_period
        self.bb_std = bb_std
        self.squeeze_lookback = squeeze_lookback
        self.squeeze_threshold = squeeze_threshold
        self.squeeze_bars = squeeze_bars
    
    def _calc_atr(self, df):
        """计算ATR (Average True Range)"""
        high = df['high']
        low = df['low']
        prev_close = df['close'].shift(1)
        
        tr = pd.concat([
            high - low,
            (high - prev_close).abs(),
            (low - prev_close).abs()
        ], axis=1).max(axis=1)
        
        return tr.ewm(span=self.atr_period, adjust=False).mean()
    
    def _calc_bollinger_bandwidth(self, close):
        """计算布林带宽度"""
        mid = close.rolling(self.bb_period, min_periods=self.bb_period).mean()
        std = close.rolling(self.bb_period, min_periods=self.bb_period).std()
        
        upper = mid + self.bb_std * std
        lower = mid - self.bb_std * std
        bandwidth = (upper - lower) / mid.abs()
        
        return bandwidth
    
    def _check_squeeze(self, bbw):
        """检查是否满足挤压条件"""
        # 计算历史分位数
        bbw_shifted = bbw.shift(1)
        quantile = bbw_shifted.rolling(
            self.squeeze_lookback, 
            min_periods=self.squeeze_lookback
        ).quantile(self.squeeze_threshold)
        
        # 判断是否处于挤压状态
        in_squeeze = (bbw <= quantile) & quantile.notna()
        
        # 计算连续挤压根数
        consecutive = self._count_consecutive(in_squeeze)
        
        # 向前shift，确保使用历史数据
        return consecutive.shift(1).fillna(0) >= self.squeeze_bars
    
    def _count_consecutive(self, bool_series):
        """计算连续True的个数"""
        result = np.zeros(len(bool_series), dtype=int)
        count = 0
        
        for i, val in enumerate(bool_series):
            count = count + 1 if val else 0
            result[i] = count
            
        return pd.Series(result, index=bool_series.index)
    
    def _calc_donchian(self, df):
        """计算唐奇安通道（使用历史数据）"""
        high_channel = df['high'].shift(1).rolling(
            self.donchian_period, 
            min_periods=self.donchian_period
        ).max()
        
        low_channel = df['low'].shift(1).rolling(
            self.donchian_period,
            min_periods=self.donchian_period
        ).min()
        
        return high_channel, low_channel
    
    def _detect_breakout(self, df, atr, high_channel, low_channel, squeeze_ok):
        """检测突破事件"""
        atr_buffer = atr.shift(1) * self.atr_buffer
        
        # 上破和下破条件
        up_break = (df['close'] >= high_channel + atr_buffer) & squeeze_ok
        down_break = (df['close'] <= low_channel - atr_buffer) & squeeze_ok
        
        # 处理冲突（同时上破和下破，标记为0）
        both_break = up_break & down_break
        
        # 生成标签
        label = np.zeros(len(df), dtype=int)
        label[up_break & ~both_break] = 1
        label[down_break & ~both_break] = 2
        
        return label
    
    def label(self, df):
        """
        主函数：为K线数据打标签
        
        参数:
            df: DataFrame, 必须包含 time, open, high, low, close, volume 列
        
        返回:
            DataFrame, 增加了 label, ATR, don_high, don_low, BBW, squeeze_ok 列
        """
        df = df.copy().sort_values('time').reset_index(drop=True)
        
        # 1. 计算技术指标
        df['ATR'] = self._calc_atr(df)
        df['BBW'] = self._calc_bollinger_bandwidth(df['close'])
        
        # 2. 判断挤压状态
        df['squeeze_ok'] = self._check_squeeze(df['BBW'])
        
        # 3. 计算唐奇安通道
        df['don_high'], df['don_low'] = self._calc_donchian(df)
        
        # 4. 检测突破并生成标签
        df['label'] = self._detect_breakout(
            df, 
            df['ATR'], 
            df['don_high'], 
            df['don_low'], 
            df['squeeze_ok']
        )
        
        return df


def main():
    # 读取数据
    csv_path = "/home/mengxiaosen/mxs/workspace/Quantrade/data/NEIRO/NEIROUSDT_2h_data.csv"
    df = pd.read_csv(csv_path)
    
    # 标准化列名
    df.columns = [c.strip().lower() for c in df.columns]
    df['time'] = pd.to_datetime(df['time'])
    
    # 创建标注器并执行标注
    labeler = BreakoutLabeler()
    
    df_labeled = labeler.label(df)
    
    # 查看突破事件
    events = df_labeled[df_labeled['label'] != 0]
    print("\n=== 突破事件 ===")
    print(events[['time', 'close', 'don_high', 'don_low', 'ATR', 'label']])
    
    # 统计
    print(f"\n总K线数: {len(df_labeled)}")
    print(f"上破事件: {(df_labeled['label']==1).sum()}")
    print(f"下破事件: {(df_labeled['label']==2).sum()}")
    
    # 保存结果
    output_path = "2h_kline_labeled.csv"
    df_labeled.to_csv(output_path, index=False)
    print(f"\n已保存到: {output_path}")


if __name__ == "__main__":
    main()



=== 突破事件 ===
                    time     close  don_high   don_low       ATR  label
114  2024-09-26 06:00:00  0.001080  0.001376  0.001114  0.000070      2
242  2024-10-06 22:00:00  0.001147  0.001083  0.000915  0.000053      1
572  2024-11-03 10:00:00  0.001337  0.001698  0.001366  0.000050      2
593  2024-11-05 04:00:00  0.001210  0.001473  0.001228  0.000048      2
654  2024-11-10 06:00:00  0.002595  0.002510  0.002022  0.000150      1
780  2024-11-20 18:00:00  0.001885  0.002300  0.001954  0.000094      2
1069 2024-12-14 20:00:00  0.001601  0.001811  0.001622  0.000039      2
1106 2024-12-17 22:00:00  0.001473  0.001677  0.001515  0.000046      2
1211 2024-12-26 16:00:00  0.000951  0.001085  0.000971  0.000024      2
1409 2025-01-12 04:00:00  0.000832  0.000829  0.000751  0.000018      1
1718 2025-02-06 22:00:00  0.000327  0.000404  0.000341  0.000015      2
2227 2025-03-21 08:00:00  0.000273  0.000272  0.000243  0.000006      1
2290 2025-03-26 14:00:00  0.000306  0.000296  0.00