In [32]:
import pandas as pd
import numpy as np
from pathlib import Path

class BondTransactionAnalyzer:
    def __init__(self, config):
        self.config = {
            "input_files": ["1.csv", "2.csv"],
            "result_folder": "chain_result2",
            "interest_rate_bond_threshold": 0.01,
            "other_bond_threshold": 0.2,
            "profit_threshold": 5000,  # 万元
            "remove_market_makers": False,
            "market_maker_list": "maker.csv",
            "required_columns": {
                "trade_id": "dl_cd",
                "settlement_date": "stlmnt_dt",
                "trade_time": "dl_tm",
                "trade_type": "trdng_md_cd",
                "bond_type": "bond_indctr_list",
                "bond_name": "bnds_nm",
                "seller": "slr_instn_cn_shrt_nm",
                "buyer": "byr_instn_cn_shrt_nm",
                "seller_trader": "slr_trdr_nm",
                "buyer_trader": "byr_trdr_nm",
                "clean_price": "net_prc",
                "yield": "yld_to_mrty",
                "nominal_amount": "nmnl_vol",
                "settlement_amount": "stlmnt_amnt",
                "trade_date": "txn_dt"
            }
        }
        self.config.update(config)
        
        self.all_transactions = None
        self.market_makers = set()
        Path(self.config["result_folder"]).mkdir(exist_ok=True)

    def load_data(self):
        if self.config["remove_market_makers"]:
            maker_df = pd.read_csv(self.config["market_maker_list"], encoding="GBK")
            self.market_makers = set(maker_df["机构名称"].values)
        
        dfs = []
        for f in self.config["input_files"]:
            try:
                df = pd.read_csv(f, encoding='utf-8')
            except UnicodeDecodeError:
                df = pd.read_csv(f, encoding='GB18030')  # 处理中文编码
            dfs.append(df)
        combined_df = pd.concat(dfs, ignore_index=True)
        self.all_transactions = combined_df
        
        self._preprocess_data()
        # print("all_transactions", self.all_transactions)
        
    def _preprocess_data(self):
        # 列重命名
        self.all_transactions.rename(columns={
            v: k for k, v in self.config["required_columns"].items()
        }, inplace=True)
        
        # 处理时间格式（带时区信息）
        self.all_transactions["trade_time"] = pd.to_datetime(
            self.all_transactions["trade_time"].str[:19]
        )
        
        # 转换金额单位为万元（假设原始单位是元）
        self.all_transactions["nominal_amount"] = self.all_transactions["nominal_amount"] / 10000
        self.all_transactions["settlement_amount"] = self.all_transactions["settlement_amount"] / 10000
        
        # 过滤有效交易类型
        valid_types = ['NDM']  # 根据数据中的trdng_md_cd列值
        self.all_transactions = self.all_transactions[
            self.all_transactions["trade_type"].isin(valid_types)
        ]
        
        if self.config["remove_market_makers"]:
            self.all_transactions = self.all_transactions[
                ~self.all_transactions["seller"].isin(self.market_makers) &
                ~self.all_transactions["buyer"].isin(self.market_makers)
            ]

    def analyze_transaction_chains(self):
        grouped = self.all_transactions.groupby("trade_date")
        for date_str, daily_transactions in grouped:
            print("Analyzing transaction chains", daily_transactions)
            date_folder = Path(self.config["result_folder"]) / date_str
            date_folder.mkdir(exist_ok=True)
            
            bond_groups = daily_transactions.groupby("bond_name")
            # print("Analyzing transaction chains", daily_transactions)
            # for bond_name, bond_transactions in bond_groups:
            #     self._process_bond_transactions(
            #         bond_name, bond_transactions, date_folder
            #     )

    def _process_bond_transactions(self, bond_name, transactions, output_dir):
        """处理单只债券的交易"""
        # 确定价格阈值
        bond_type = transactions["bond_type"].iloc[0]
        price_threshold = (
            self.config["interest_rate_bond_threshold"] 
            if bond_type in ["国债", "政策性金融债"] 
            else self.config["other_bond_threshold"]
        )
        
        # 识别可疑机构
        suspicious_institutions = self._find_suspicious_institutions(transactions)
        if not suspicious_institutions:
            return
        
        # 合并相同交易对手的交易
        consolidated_transactions = self._consolidate_duplicate_trades(transactions)
        
        # 追踪交易链
        results = []
        for institution in suspicious_institutions:
            chain_results = self._trace_institution_chains(
                institution, consolidated_transactions, price_threshold
            )
            results.extend(chain_results)
        
        # 分批保存结果
        if results:
            output_path = output_dir / f"{bond_name}_results.csv"
            pd.DataFrame(results).to_csv(output_path, index=False)

    def _find_suspicious_institutions(self, transactions):
        """识别存在可疑净头寸的机构"""
        # 计算各机构的净交易量和损益
        buyers = transactions.groupby("byr_instn_cn_shrt_nm").agg({
            "nmnl_vol": "sum",  # 券面总额
            "stlmnt_amnt": "sum"  # 结算金额
        })
        sellers = transactions.groupby("slr_instn_cn_shrt_nm").agg({
            "nmnl_vol": "sum",  # 券面总额
            "stlmnt_amnt": "sum"  # 结算金额
        })
        
        net_position = buyers.join(sellers, how="outer", lsuffix="_buy", rsuffix="_sell").fillna(0)
        net_position["net_amount"] = net_position["nmnl_vol_sell"] - net_position["nmnl_vol_buy"]
        net_position["net_profit"] = net_position["stlmnt_amnt_sell"] - net_position["stlmnt_amnt_buy"]
        
        # 筛选可疑机构
        threshold = self.config["profit_threshold"] * 10000  # 转换为元
        suspicious = net_position[
            (abs(net_position["net_profit"]) > threshold) &
            (net_position["net_amount"] != 0)
        ].index.tolist()
        
        return suspicious

    def _consolidate_duplicate_trades(self, transactions):
        """合并相同交易对手的重复交易"""
        group_keys = [
            "slr_instn_cn_shrt_nm",  # 卖方机构
            "byr_instn_cn_shrt_nm",  # 买方机构
            "slr_trdr_nm",  # 卖方交易员
            "byr_trdr_nm",  # 买方交易员
            "yld_to_mrty"  # 到期收益率
        ]
        grouped = transactions.groupby(group_keys)
        
        consolidated = []
        for (seller, buyer, s_trader, b_trader, yld), group in grouped:
            if len(group) > 1:
                combined = pd.Series({
                    "dl_cd": "|".join(group["dl_cd"]),  # 成交编号
                    "nmnl_vol": group["nmnl_vol"].sum(),  # 券面总额
                    "stlmnt_amnt": group["stlmnt_amnt"].sum(),  # 结算金额
                    "net_prc": group["net_prc"].mean(),  # 净价
                    **dict(zip(group_keys, [seller, buyer, s_trader, b_trader, yld]))
                })
                consolidated.append(combined)
            else:
                consolidated.append(group.iloc[0])
        
        return pd.DataFrame(consolidated)

    def _trace_institution_chains(self, institution, transactions, price_threshold):
        """追踪单个机构的交易链条"""
        # 正向追踪（上游）
        upstream_chains = self._trace_transaction_chain(
            institution, transactions, direction="upstream", 
            price_threshold=price_threshold
        )
        
        # 反向追踪（下游）
        downstream_chains = self._trace_transaction_chain(
            institution, transactions, direction="downstream",
            price_threshold=price_threshold
        )
        
        # 合并并分析链条
        combined_chains = self._analyze_combined_chains(
            upstream_chains, downstream_chains, institution
        )
        
        return combined_chains

    def _trace_transaction_chain(self, start_institution, transactions, 
                                direction="upstream", price_threshold=0.1):
        """追踪单个方向的交易链"""
        # 实现核心的链条追踪逻辑（此处需要根据原代码逻辑补充实现细节）
        # 返回追踪到的交易链信息
        return []  # 示例返回

    def _analyze_combined_chains(self, upstream, downstream, institution):
        """分析合并后的交易链条"""
        # 实现链条合并和损益计算逻辑
        return []  # 示例返回
    # 使用示例
if __name__ == "__main__":
    config = {
        "input_files": ["bond_2005496_2006_2402.csv"],
        "result_folder": "analysis_results",
        "profit_threshold": 1000,
        "remove_market_makers": False,
        "market_maker_list": "market_makers.csv"
    }
    
    analyzer = BondTransactionAnalyzer(config)
    analyzer.load_data()
    analyzer.analyze_transaction_chains()

Analyzing transaction chains      Unnamed: 0           trade_id  dt_cnfrm  trade_date          trade_time  \
108         108  CBT20200601301687       NaN  2020-06-01 2020-06-01 14:03:49   
109         109  CBT20200601301390       NaN  2020-06-01 2020-06-01 13:52:37   
110         110  CBT20200601303563       NaN  2020-06-01 2020-06-01 15:23:16   
111         111  CBT20200601302578       NaN  2020-06-01 2020-06-01 14:39:45   
112         112  CBT20200601301038       NaN  2020-06-01 2020-06-01 13:38:38   
113         113  CBT20200601302178       NaN  2020-06-01 2020-06-01 14:24:08   
114         114  CBT20200601301470       NaN  2020-06-01 2020-06-01 13:55:45   
115         115  CBT20200601302590       NaN  2020-06-01 2020-06-01 14:40:04   
116         116  CBT20200601302177       NaN  2020-06-01 2020-06-01 14:24:08   
117         117  CBT20200601302283       NaN  2020-06-01 2020-06-01 14:28:19   

                       bsns_tm  qt_rqst_cd  trd_rcrd_cd prdct_cd trade_type  \
108  2020-0