In [9]:
import pandas as pd
import numpy as np
from collections import defaultdict
import pandas as pd

# 配置参数
CONFIG = {
    "columns": {
        "trade_id": 'dl_cd',
        "trade_time": "dl_tm",
        'amnt':'amnt',
        'acrd_intrst':'acrd_intrst',
        "trade_date": 'txn_dt',
        "settle_date": "stlmnt_dt",
        "nmnl_vol": 'nmnl_vol',  # Or nmnl_vol depending on the field
        "seller": "byr_trd_acnt_cn_shrt_nm",  # Could be different depending on the data
        "buyer": "slr_trd_acnt_cn_shrt_nm",  # Could be different depending on the data
        "seller_trader": "byr_trdr_nm",  # 买方交易员
        "buyer_trader": "slr_trdr_nm",  # 卖方交易员
        "price": "net_prc",  # 净价（元）
        "yield": 'yld_to_mrty',  # 到期收益率（%）
        "trade_method": "trdng_mthd_cd",  # 交易方式
        "bond_name": "bnds_nm",  # 债券名称
    },
    "date_format": "%Y-%m-%d",
    "datetime_format": "%Y/%m/%d",
    "datetime_format_full": "%Y-%m-%d hh:mm:ss",
    
    # 业务参数
    "find_threshold_rate_bond": 0.01,
    "find_threshold_other_bond": 0.2,
    "profit_loss_threshold": 5000,
    "remove_market_maker": True,
    "save_directory": "chain_result2",
    "market_price_threshold": 0.05,
    "price_ratio_min": 0.015,
    "single_price_boundary_min": 2,
    "target_price_ratio_min": 0.05,
    "path_length": 3
}

# 读取所有交易数据
def read_trade_data():
    """
    读取多个 CSV 文件并合并为一个 DataFrame。
    """
    data_df = pd.read_csv("./static/bond_dtl_10_08.csv")
    data_df.drop("Unnamed: 0", axis=1, inplace=True)
    data_df = data_df.reset_index(drop=True)
    return data_df

# 处理时间和日期字段
def process_time_and_date(all_trade_data):
    """
    处理交易时间和结算日期。
    """
    # 交易时间列：转换为 datetime 类型
    all_trade_data[CONFIG["columns"]["trade_time"]] = pd.to_datetime(
        all_trade_data[CONFIG["columns"]["trade_time"]].str[:19], format="%Y-%m-%d %H:%M:%S"
    )
    return all_trade_data

def select_and_rename_all_columns(all_trade_data):
    """
    提取和重命名所有指定的列。
    """
    selected_columns = {value: key for key, value in CONFIG["columns"].items()}
    return all_trade_data[list(selected_columns.keys())]

# 过滤交易方式
def filter_trade_method(all_trade_data):
    """
    过滤只保留 'RFQ' 或 'Negotiate' 方式的交易。
    """
    return all_trade_data[all_trade_data[CONFIG["columns"]["trade_method"]].isin(['RFQ', 'Negotiate'])]

def filter_strange_trade(df, threshold=0.1):
    """
    筛选偏离度较高的交易，偏离度基于 buyer 自身的市场价格。

    参数：
        df (pd.DataFrame): 包含交易数据的 DataFrame。
        threshold (float): 偏离度的阈值，默认是 10% (0.1)。

    返回：
        pd.DataFrame: 包含偏离度超过阈值的异常交易。
    """
    # 生成市场价格映射
    market_price_mapping = df.groupby(CONFIG["columns"]["buyer"])[CONFIG["columns"]["price"]].mean().to_dict()

    # 映射市场价格
    df['market_price'] = df[CONFIG["columns"]["buyer"]].map(market_price_mapping)

    # 计算价格偏离度
    df['price_dev'] = (df[CONFIG["columns"]["price"]] - df['market_price']) / df['market_price']

    # 筛选偏离度超出指定阈值的交易
    abnormal_trades = df[(df['price_dev'] > threshold) | (df['price_dev'] < -threshold)]

    return abnormal_trades

# 主函数：读取和处理交易数据
def main():
    # 读取数据
    all_trade_data = read_trade_data()
    
    # 处理时间和日期
    all_trade_data = process_time_and_date(all_trade_data)
    
    # 过滤交易方式
    all_trade_data = filter_trade_method(all_trade_data)
    
    # 提取和重命名列
    all_trade_data = select_and_rename_all_columns(all_trade_data)

    return all_trade_data

# 获取金额组合和对应的交易索引
def get_amount_dict(transaction_indices, trade_data):
    """
    根据交易索引获取金额组合和对应的交易索引。
    
    参数:
        transaction_indices (list): 交易索引列表。
        trade_data (pd.DataFrame): 包含交易数据的 DataFrame。
    
    返回:
        dict: 以金额为键，交易索引集合为值的字典。
    """
    amount_dict = defaultdict(set)
    if not set(transaction_indices).issubset(set(trade_data.index)):
        raise ValueError("Some indices in transaction_indices are not valid.")

    for idx in transaction_indices:
        amt = trade_data.at[idx, CONFIG["columns"]["amnt"]]
        amount_dict[amt].add(idx)
    return amount_dict

# 根据方向选择交易链的下一步
def choose_next_edge(index, amt, trade_data, direction):
    """
    根据方向选择交易链的下一步。
    
    参数:
        index (int): 当前交易的索引。
        amt (float): 当前交易的金额。
        trade_data (pd.DataFrame): 包含交易数据的 DataFrame。
        direction (str): 追踪方向，"forward" 或 "backward"。
    
    返回:
        list: 下一步交易的索引列表。
    """
    seller_col = CONFIG["columns"]["seller"]
    buyer_col = CONFIG["columns"]["buyer"]

    if direction == "forward":  # 正向选择
        condition = (
            (trade_data[CONFIG["columns"]["amnt"]] == amt) & 
            (trade_data[seller_col] == trade_data.at[index, buyer_col])
        )
    elif direction == "backward":  # 反向选择
        condition = (
            (trade_data[CONFIG["columns"]["amnt"]] == amt) & 
            (trade_data[buyer_col] == trade_data.at[index, seller_col])
        )
    else:
        raise ValueError("Invalid direction. Use 'forward' or 'backward'.")

    return trade_data[condition].index.tolist()

def trace_trade_chains(trade_data, current_trade_cd, direction, path=None, path_ids=None, max_depth=5, visited_trades=None):
    """
    递归追踪交易链。

    参数:
        trade_data (pd.DataFrame): 包含交易数据的 DataFrame。
        current_trade_cd (str): 当前交易的交易代码 (Trade CD)。
        direction (str): "forward" 或 "backward"。
        path (list): 当前路径。
        path_ids (list): 当前路径中的交易 ID。
        max_depth (int): 最大递归深度。
        visited_trades (set): 已访问的交易代码集合，用于避免重复计算。

    返回:
        tuple: 路径 (list) 和路径中的交易 ID (list)。
    """
    if path is None:
        path = []
    if path_ids is None:
        path_ids = []
    if visited_trades is None:
        visited_trades = set()

    # 如果当前交易已经被访问过，直接返回
    if current_trade_cd in visited_trades:
        return path, path_ids

    # 标记当前交易为已访问
    visited_trades.add(current_trade_cd)

    # 获取当前交易的机构信息
    seller_col = CONFIG["columns"]["seller"]
    buyer_col = CONFIG["columns"]["buyer"]
    trade_cd_col = CONFIG["columns"]["trade_id"]

    current_trade = trade_data[trade_data[trade_cd_col] == current_trade_cd]
    if current_trade.empty:
        raise ValueError(f"Trade with CD {current_trade_cd} not found.")

    current_seller = current_trade.iloc[0][seller_col]
    current_buyer = current_trade.iloc[0][buyer_col]

    # 根据方向添加路径信息
    if direction == "forward":
        current_inst = current_seller
    elif direction == "backward":
        current_inst = current_buyer
    else:
        raise ValueError("Invalid direction. Use 'forward' or 'backward'.")

    path.append(current_inst)
    path_ids.append(current_trade_cd)  # 添加当前交易代码到路径

    # 检查递归深度
    if len(path) >= max_depth:
        return path, path_ids

    if direction == "forward":
        relevant_trades = trade_data[trade_data[seller_col] == current_buyer]
    elif direction == "backward":
        relevant_trades = trade_data[trade_data[buyer_col] == current_seller]

    if relevant_trades.empty:  # 如果没有相关交易，结束递归
        return path, path_ids

    all_paths = []
    for idx, trade in relevant_trades.iterrows():
        next_trade_cd = trade[trade_cd_col]
        sub_path, sub_path_ids = trace_trade_chains(
            trade_data, next_trade_cd, direction, path.copy(), path_ids.copy(), max_depth, visited_trades.copy()
        )
        all_paths.append((sub_path, sub_path_ids))

    # 返回最长路径 (可根据业务需求更改逻辑)
    longest_path = max(all_paths, key=lambda x: len(x[0]))
    return longest_path

# 双向追踪交易链
def trace_trade_chains_bidirection(trade_data, start_inst, max_depth=5):
    """
    双向追踪交易链。
    
    参数:
        trade_data (pd.DataFrame): 包含交易数据的 DataFrame。
        start_inst (str): 起始机构。
        max_depth (int): 最大递归深度。

    返回:
        dict: 包含正向和反向路径及对应交易 ID 的字典。
    """
    reverse_path, reverse_path_ids = trace_trade_chains(
        trade_data, start_inst, direction="backward", max_depth=max_depth
    )
    reverse_path = reverse_path[::-1]  # 反转路径，使其从起始机构向后排列

    forward_path, forward_path_ids = trace_trade_chains(
        trade_data, start_inst, direction="forward", max_depth=max_depth
    )
    
    combined_path = reverse_path[:-1] + forward_path  # 合并路径，去掉重复的中间节点
    combined_path_ids = reverse_path_ids + forward_path_ids

    return {
        "path": combined_path,
        "path_ids": combined_path_ids
    }

# 测试整合
if __name__ == "__main__":
    # 假设 main() 加载交易数据
    df = main()
    print(df[0:1000])

    # 测试起始机构dl_cd
    start_institution = "CBT20241008313725"

    forward_chain = trace_trade_chains(df[0:1000], start_institution, direction="forward", path=None, path_ids=None, max_depth=5)

    print(forward_chain)



  data_df = pd.read_csv("./static/bond_dtl_10_08.csv")


                  dl_cd               dl_tm         amnt  acrd_intrst  \
0     CBT20241008313979 2024-10-08 16:28:37   49693900.0      0.42989   
1     CBT20241008315298 2024-10-08 16:53:57   54461500.0      0.65014   
3     CBT20241008314807 2024-10-08 16:40:47   10086230.0      1.11079   
5     CBT20241008313725 2024-10-08 16:25:35   61442400.0      0.30815   
6     CBT20241008314330 2024-10-08 16:33:30  148013850.0      0.15586   
...                 ...                 ...          ...          ...   
1096  CBT20241008214904 2024-10-08 17:14:50   40140840.0      0.53074   
1097  CBT20241008214815 2024-10-08 17:02:25   99178500.0      1.11644   
1099  CBT20241008315341 2024-10-08 17:00:16   10055930.0      2.82869   
1100  CBT20241008214802 2024-10-08 16:59:44    9794100.0      0.09014   
1102  CBT20241008315334 2024-10-08 16:59:03  101974800.0      1.98904   

          txn_dt   stlmnt_dt     nmnl_vol           byr_trd_acnt_cn_shrt_nm  \
0     2024-10-08  2024-10-09   50000000.0   

In [None]:
import pandas as pd
import numpy as np

# 读取交易数据文件
# alldf = pd.read_excel("C:\\Users\\sys\\Downloads\\190210新.xlsx", encoding="GBK")
data_files = []
for i in range(1, 3):
    data_files.append(pd.read_csv(f"{i}.csv"))
all_trade_data = pd.concat(data_files)
data_files = []
all_trade_data.drop("Unnamed: 0", axis=1, inplace=True)
all_trade_data = all_trade_data.reset_index()
all_trade_data.drop("index", axis=1, inplace=True)
all_trade_data.columns = ["交易编号", "交易日期", "交易时间", "交易方式", "债券类型", "债券名称", 
                          "买方", "买方交易员", "卖方", "卖方交易员", "净价（元）", "到期收益率（%）", 
                          "券面总额（万元）", "交易金额（元）", "结算日", "结算金额（元）"]

# 参数设置
trade_id_col = '交易编号'  # 交易编号
trade_time_col = "交易时间"  # 交易时间
settle_date_col = "结算日"  # 结算日期
bond_type_col = "债券类型"  # 债券类型
trade_date_col = "成交日"  # 成交日期
face_value_col = '券面总额（万元）'  # 券面总额
face_value_field = '券面总额（万元）'  # 券面总额字段
seller_col = "卖方"  # 卖方
buyer_col = "买方"  # 买方
seller_trader_col = "卖出方交易员"  # 卖出方交易员
buyer_trader_col = "买入方交易员"  # 买入方交易员
price_col = "净价（元）"  # 净价
settle_amount_col = "结算金额（元）"  # 结算金额
yield_col = '到期收益率（%）'  # 到期收益率
trade_method_col = "交易方式"  # 交易方式
bond_name_col = "债券名称"  # 债券名称

date_format = "%Y-%m-%d"  # 日期格式
datetime_format = "%Y/%m/%d"  # 日期格式1
datetime_format_full = "%Y-%m-%d hh:mm:ss"  # 完整的时间格式

# 参数设置
find_threshold_rate_bond = 0.01  # 利率债价差阈值
find_threshold_other_bond = 0.2  # 非利率债价差阈值
profit_loss_threshold = 5000  # 盈利亏损阈值
remove_market_maker = True  # 是否去掉做市商
save_directory = "chain_result2"  # 结果存储路径

market_price_threshold = 0.05  # 市场价格阈值
price_ratio_min = 0.015  # 最小价格差比率
single_price_boundary_min = 2  # 单个价格界限最小值
target_price_ratio_min = 0.05  # 目标价格比率最小值
path_length = 3  # 路径长度


# 处理成交时间和结算日期
all_trade_data[trade_date_col] = all_trade_data["交易时间"].str[0:4] + all_trade_data["交易时间"].str[5:7] + all_trade_data["交易时间"].str[8:10]
all_trade_data[trade_time_col] = pd.to_datetime(all_trade_data["交易时间"].str[0:19], format="%Y-%m-%d %H:%M:%S")
all_trade_data = all_trade_data[all_trade_data[trade_method_col].isin(['RFQ', 'Negotiate'])]


In [1]:
import pandas as pd
from collections import defaultdict

# 获取金额组合和对应的交易索引
def get_amount_dict(transaction_indices, trade_data):
    amount_dict = defaultdict(set)
    for idx in transaction_indices:
        amt = trade_data.at[idx, '交易金额']
        amount_dict[amt].add(idx)
    return amount_dict

# 根据方向选择交易链
def choose_next_edge(index, amt, trade_data, direction):
    next_edges = []
    if direction == "f":  # 正向选择
        next_edges = trade_data[(trade_data['交易金额'] == amt) & (trade_data['卖方'] == trade_data.at[index, '买方'])].index.tolist()
    elif direction == "b":  # 反向选择
        next_edges = trade_data[(trade_data['交易金额'] == amt) & (trade_data['买方'] == trade_data.at[index, '卖方'])].index.tolist()
    return next_edges

# 追踪交易链，递归执行
def trace_trade_chains(trade_data, start_inst, direction, path=None, path_id=None, max_depth=5):
    if path is None:
        path = []
    if path_id is None:
        path_id = []
    
    path.append(start_inst)
    
    # 找到当前机构的相关交易
    current_trades = trade_data[trade_data['卖方'] == start_inst] if direction == "f" else trade_data[trade_data['买方'] == start_inst]
    import pandas as pd
from collections import defaultdict

# 获取金额组合和对应的交易索引
def get_amount_dict(transaction_indices, trade_data):
    """
    根据交易索引获取每个金额对应的交易索引集合。
    """
    amount_dict = defaultdict(set)
    for idx in transaction_indices:
        amt = trade_data.at[idx, '交易金额']
        amount_dict[amt].add(idx)
    return amount_dict

# 根据交易方向选择下一步
def choose_next_edge(index, amt, trade_data, direction):
    """
    根据交易金额和方向选择下一步交易索引。
    """
    if direction == "f":  # 正向
        return trade_data[(trade_data['交易金额'] == amt) & 
                          (trade_data['卖方'] == trade_data.at[index, '买方'])].index.tolist()
    elif direction == "b":  # 反向
        return trade_data[(trade_data['交易金额'] == amt) & 
                          (trade_data['买方'] == trade_data.at[index, '卖方'])].index.tolist()
    return []


# 双向追踪交易链
def trace_trade_chains_bidirection(trade_data, start_inst, max_depth=5):
    """
    双向追踪交易链，包括正向和反向。
    """
    # 反向追踪
    reverse_paths, reverse_path_ids = trace_trade_chains_iterative(trade_data, start_inst, "b", max_depth)
    # 正向追踪
    forward_paths, forward_path_ids = trace_trade_chains_iterative(trade_data, start_inst, "f", max_depth)

    # 合并路径
    combined_paths = reverse_paths + forward_paths
    combined_path_ids = reverse_path_ids + forward_path_ids

    return combined_paths, combined_path_ids

# 示例调用
# trade_data = pd.DataFrame(...)  # 交易数据 DataFrame
# start_inst = "机构A"  # 起始机构
# paths, path_ids = trace_trade_chains_bidirection(trade_data, start_inst)
# print("交易链路径:", paths)
# print("交易链ID:", path_ids)

    if current_trades.empty or len(path) > max_depth:
        return path, path_id
    
    for idx, trade in current_trades.iterrows():
        next_inst = trade['买方'] if direction == "f" else trade['卖方']
        path, path_id = trace_trade_chains(trade_data, next_inst, direction, path.copy(), path_id.copy(), max_depth)
        
        # 在每条路径中加入唯一标识符
        path_id.append(trade['交易编号'])
    
    return path, path_id

# 双向追踪交易链
def trace_trade_chains_bidirection(trade_data, start_inst, max_depth=5):
    # 反向追踪
    reverse_path, reverse_path_id = trace_trade_chains(trade_data, start_inst, "b", max_depth=max_depth)
    
    # 正向追踪
    forward_path, forward_path_id = trace_trade_chains(trade_data, start_inst, "f", max_depth=max_depth)
    
    # 合并反向和正向路径
    return reverse_path + forward_path, reverse_path_id + forward_path_id



# 执行双向追踪

paths, path_ids = trace_trade_chains_bidirection(trade_data, start_inst)

print("交易链路径:", paths)
print("交易链ID:", path_ids)


交易链路径: ['A', 'D', 'C', 'B', 'A', 'D', 'A', 'B', 'A', 'B', 'C', 'D', 'A', 'B', 'C', 'D', 'C']
交易链ID: [4, 1, 2, 5, 3, 6, 4, 1, 5, 4, 3, 2, 6, 1, 5]


In [None]:
import pandas as pd
from collections import defaultdict

# 获取金额组合和对应的交易索引
def get_amount_dict(transaction_indices, trade_data):
    amount_dict = defaultdict(set)
    for idx in transaction_indices:
        amt = trade_data.at[idx, '交易金额']
        amount_dict[amt].add(idx)
    return amount_dict

# 根据方向选择交易链
def get_next_edges(index, amt, trade_data, direction):
    """获取下一个边，基于方向和交易金额"""
    if direction == "f":  # 正向选择
        return trade_data[(trade_data['交易金额'] == amt) & (trade_data['卖方'] == trade_data.at[index, '买方'])].index.tolist()
    elif direction == "b":  # 反向选择
        return trade_data[(trade_data['交易金额'] == amt) & (trade_data['买方'] == trade_data.at[index, '卖方'])].index.tolist()
    return []

# 追踪交易链（通用函数，支持正向和反向）
def trace_trade_chain_for_inst(trade_data, start_inst, direction, path=None, path_id=None, max_depth=5):
    """
    追踪某个机构的交易链，根据方向递归追踪交易路径。
    direction: "f" 表示正向，"b" 表示反向
    """
    if path is None:
        path = []
    if path_id is None:
        path_id = []
    
    path.append(start_inst)
    
    # 获取当前机构相关的交易
    current_trades = trade_data[trade_data['卖方'] == start_inst] if direction == "f" else trade_data[trade_data['买方'] == start_inst]
    
    if current_trades.empty or len(path) > max_depth:
        return path, path_id
    
    # 递归追踪
    for idx, trade in current_trades.iterrows():
        next_inst = trade['买方'] if direction == "f" else trade['卖方']
        new_path, new_path_id = trace_trade_chain_for_inst(
            trade_data, next_inst, direction, path.copy(), path_id.copy(), max_depth)
        
        # 在每条路径中加入唯一标识符
        new_path_id.append(trade['交易编号'])
        path, path_id = new_path, new_path_id
    
    return path, path_id

# 双向追踪交易链（统一调用正向和反向追踪）
def trace_trade_chains_bidirection(trade_data, start_inst, max_depth=5):
    """
    进行双向追踪，返回正向和反向的交易链
    """
    # 反向追踪
    reverse_path, reverse_path_id = trace_trade_chain_for_inst(trade_data, start_inst, "b", max_depth=max_depth)
    
    # 正向追踪
    forward_path, forward_path_id = trace_trade_chain_for_inst(trade_data, start_inst, "f", max_depth=max_depth)
    
    # 合并反向和正向路径
    return reverse_path + forward_path, reverse_path_id + forward_path_id

# 示例数据框
data = {
    '卖方': ['A', 'B', 'C', 'D', 'A', 'B'],
    '买方': ['B', 'C', 'D', 'A', 'C', 'D'],
    '交易金额': [100, 100, 100, 100, 200, 200],
    '交易编号': [1, 2, 3, 4, 5, 6]
}
trade_data = pd.DataFrame(data)

# 执行双向追踪
start_inst = 'A'
paths, path_ids = trace_trade_chains_bidirection(trade_data, start_inst)

print("交易链路径:", paths)
print("交易链ID:", path_ids)


In [None]:
import pandas as pd
import numpy as np

# 读取机构数据，若需要
if remove_maker:
    institution_df = pd.read_csv("maker.csv", encoding="GBK")
    institution_set = set(institution_df["机构名称"].values)
else:
    institution_set = set()

# 数据清理：去除空值并筛选特定交易方式
alldf = alldf.dropna()
alldf = alldf[alldf[transaction_method_column].isin(['RFQ', 'Negotiate'])]

# 初始化统计信息
result_data = []
summary_stats1 = []
summary_stats2 = []
new_group_id = 0
max_pair_count = 0

# 定义数据列名称
data_columns = [
    transaction_number_column, date_column, transaction_time_column, amount_column, 
    sender_column, receiver_column, sender_account_column, receiver_account_column, 
    price_column, transaction_method_column, trade_date_column, trade_price_column, rate_column
]

output_columns = [
    transaction_number_column, date_column, transaction_time_column, amount_column, 
    receiver_column, sender_column, receiver_account_column, sender_account_column, 
    price_column, transaction_method_column, trade_price_column, rate_column, 
    '路径标识', '路径端点', '方向', '传递交易量', '端点交易量', '母路径', '目标交易员', '目标机构', 
    '目标债券', '目标日期', '目标分组', '交易损益'
]

# 按日期分组数据
date_groups = alldf.groupby(date_column)

# 遍历每一天的交易数据
for current_date, daily_data in date_groups:
    print(f"开始处理 {str(current_date)[:10]} 数据 .....")
    
    # 将日期格式转为正确格式
    daily_data[date_column] = pd.to_datetime(daily_data[date_column], format="%Y-%m-%d")
    daily_data[trade_date_column] = pd.to_datetime(daily_data[trade_date_column], format=date_format)
    
    # 按债券类型分组
    bond_groups = daily_data.groupby(bond_column)

    # 遍历每种债券的交易数据
    for bond_name, bond_data in bond_groups:
        bond_data = bond_data.copy()
        bond_type = bond_data.head(1)[bond_type_column].values[0]

        # 根据债券类型设置阈值
        threshold = low_interest_rate_threshold if bond_type in ["国债", "政策性金融债"] else other_bond_threshold
        
        # 只保留需要的列
        bond_data = bond_data[data_columns]

        # 获取同时作为发送方和接收方的机构交集，并排除指定的机构
        intersecting_institutions = set(bond_data[sender_column].unique()).intersection(bond_data[receiver_column].unique()) - institution_set
        
        # 按发送方和接收方分组
        buy_groups = bond_data.groupby([sender_column])
        sell_groups = bond_data.groupby([receiver_column])

        transaction_institutions = []
        net_positions = []

        # 计算每个机构的净交易量和价格差
        for institution in intersecting_institutions:
            buy_data = buy_groups.get_group(institution)
            sell_data = sell_groups.get_group(institution)
            total_buy_amount = buy_data[amount_column].sum()
            total_sell_amount = sell_data[amount_column].sum()
            
            # 如果买入量等于卖出量，计算价格差
            if total_buy_amount == total_sell_amount:
                net_position = buy_data[trade_price_column].sum() - sell_data[trade_price_column].sum()
                
                # 如果价格差超过设定的阈值，则记录此机构
                if abs(net_position) > amount_threshold:
                    transaction_institutions.append(institution)
                    net_positions.append(net_position)

        # 如果有符合条件的机构，进一步处理数据
        if transaction_institutions:
            daily_data_copy = bond_data.copy()
            grouped_data = daily_data_copy.groupby([sender_column, receiver_column, sender_account_column, receiver_account_column, rate_column])

            additional_data = []
            data_set = []

            # 计算每个组合的统计数据
            for group_key, group_data in grouped_data:
                if len(group_data) > 1:
                    summary_data = group_data.head(1).copy()
                    summary_data[transaction_number_column] = "-".join(group_data[transaction_number_column].values)
                    summary_data[amount_column] = group_data[amount_column].sum()
                    summary_data[trade_price_column] = group_data[trade_price_column].sum()
                    data_set += list(group_data.index)
                    additional_data.append(np.array(summary_data.values))
            data_set = set(data_set)

            # 如果有需要合并的数据
            if additional_data:
                additional_data_frame = pd.DataFrame(np.concatenate(additional_data), columns=data_columns)

            # 处理符合条件的每个机构
            for institution_index, institution in enumerate(transaction_institutions):
                net_position = net_positions[institution_index]
                
                # 找到该机构对应的数据
                institution_data_set = set(daily_data_copy[(daily_data_copy[sender_column] == institution) | (daily_data_copy[receiver_column] == institution)].index).intersection(data_set)
                
                # 删除该机构的数据并合并其他数据
                daily_data_copy.drop(institution_data_set, inplace=True)
                if additional_data:
                    daily_data_copy = pd.concat([daily_data_copy, additional_data_frame])
                
                daily_data_copy.reset_index(drop=True, inplace=True)

                # 进行双向交易链追踪
                trade_data, pair_count, stats1, stats2 = trace_trade_chains_bidirection(daily_data_copy, institution, max_pair_count)
                max_pair_count = pair_count

                # 添加目标信息
                trade_data["目标机构"] = institution
                trade_data["目标债券"] = bond_name
                trade_data["目标日期"] = current_date
                trade_data["目标分组"] = new_group_id
                trade_data["交易损益"] = net_position
                
                # 更新统计信息
                stats1["目标机构"] = institution
                stats1["目标债券"] = bond_name
                stats1["目标日期"] = current_date
                stats1["目标分组"] = new_group_id
                
                stats2["目标机构"] = institution
                stats2["目标债券"] = bond_name
                stats2["目标日期"] = current_date
                stats2["目标分组"] = new_group_id
                
                # 更新组ID
                new_group_id += 1
                result_data.append(trade_data)
                summary_stats1.append(stats1)
                summary_stats2.append(stats2)

    # 输出每个日期的结果
    if result_data:
        pd.concat(result_data)[output_columns].to_csv(f"{save_folder}/{str(current_date)[:10]}_result.csv")
        pd.concat(summary_stats1).to_csv(f"{save_folder}/{str(current_date)[:10]}_s1.csv")
        pd.concat(summary_stats2).to_csv(f"{save_folder}/{str(current_date)[:10]}_s2.csv")


In [None]:
import pandas as pd
import numpy as np

def read_institution_data(remove_maker, institution_file="maker.csv"):
    """读取机构数据，若需要"""
    if remove_maker:
        institution_df = pd.read_csv(institution_file, encoding="GBK")
        return set(institution_df["机构名称"].values)
    return set()

def clean_data(df, transaction_method_column, valid_methods=['RFQ', 'Negotiate']):
    """数据清理：去除空值并筛选特定交易方式"""
    df = df.dropna()
    df = df[df[transaction_method_column].isin(valid_methods)]
    return df

def get_threshold(bond_type, low_interest_rate_threshold, other_bond_threshold):
    """根据债券类型返回阈值"""
    return low_interest_rate_threshold if bond_type in ["国债", "政策性金融债"] else other_bond_threshold

def calculate_net_positions(buy_groups, sell_groups, amount_column, trade_price_column, amount_threshold):
    """计算每个机构的净交易量和价格差"""
    transaction_institutions = []
    net_positions = []

    for institution in buy_groups.keys():
        buy_data = buy_groups.get_group(institution)
        sell_data = sell_groups.get_group(institution)
        total_buy_amount = buy_data[amount_column].sum()
        total_sell_amount = sell_data[amount_column].sum()
        
        # 如果买入量等于卖出量，计算价格差
        if total_buy_amount == total_sell_amount:
            net_position = buy_data[trade_price_column].sum() - sell_data[trade_price_column].sum()
            
            # 如果价格差超过设定的阈值，则记录此机构
            if abs(net_position) > amount_threshold:
                transaction_institutions.append(institution)
                net_positions.append(net_position)

    return transaction_institutions, net_positions

def process_grouped_data(grouped_data, data_columns):
    """处理分组后的数据，计算每个组合的统计数据"""
    additional_data = []
    data_set = []

    for group_key, group_data in grouped_data:
        if len(group_data) > 1:
            summary_data = group_data.head(1).copy()
            summary_data[transaction_number_column] = "-".join(group_data[transaction_number_column].values)
            summary_data[amount_column] = group_data[amount_column].sum()
            summary_data[trade_price_column] = group_data[trade_price_column].sum()
            data_set += list(group_data.index)
            additional_data.append(np.array(summary_data.values))

    data_set = set(data_set)
    
    if additional_data:
        additional_data_frame = pd.DataFrame(np.concatenate(additional_data), columns=data_columns)
        return additional_data_frame, data_set
    return None, data_set

def process_institution_data(daily_data_copy, transaction_institutions, additional_data_frame, data_set, max_pair_count, new_group_id, bond_name, current_date):
    """处理符合条件的每个机构的数据"""
    result_data = []
    summary_stats1 = []
    summary_stats2 = []

    for institution_index, institution in enumerate(transaction_institutions):
        net_position = net_positions[institution_index]

        # 找到该机构对应的数据
        institution_data_set = set(daily_data_copy[(daily_data_copy[sender_column] == institution) | 
                                                   (daily_data_copy[receiver_column] == institution)].index).intersection(data_set)
        
        # 删除该机构的数据并合并其他数据
        daily_data_copy.drop(institution_data_set, inplace=True)
        if additional_data_frame is not None:
            daily_data_copy = pd.concat([daily_data_copy, additional_data_frame])

        daily_data_copy.reset_index(drop=True, inplace=True)

        # 进行双向交易链追踪
        trade_data, pair_count, stats1, stats2 = trace_trade_chains_bidirection(daily_data_copy, institution, max_pair_count)
        max_pair_count = pair_count

        # 添加目标信息
        trade_data["目标机构"] = institution
        trade_data["目标债券"] = bond_name
        trade_data["目标日期"] = current_date
        trade_data["目标分组"] = new_group_id
        trade_data["交易损益"] = net_position
        
        # 更新统计信息
        stats1["目标机构"] = institution
        stats1["目标债券"] = bond_name
        stats1["目标日期"] = current_date
        stats1["目标分组"] = new_group_id
        
        stats2["目标机构"] = institution
        stats2["目标债券"] = bond_name
        stats2["目标日期"] = current_date
        stats2["目标分组"] = new_group_id

        # 更新组ID
        new_group_id += 1
        result_data.append(trade_data)
        summary_stats1.append(stats1)
        summary_stats2.append(stats2)

    return result_data, summary_stats1, summary_stats2, max_pair_count

def save_data(result_data, summary_stats1, summary_stats2, save_folder, current_date, output_columns):
    """保存数据到文件"""
    if result_data:
        pd.concat(result_data)[output_columns].to_csv(f"{save_folder}/{str(current_date)[:10]}_result.csv")
        pd.concat(summary_stats1).to_csv(f"{save_folder}/{str(current_date)[:10]}_s1.csv")
        pd.concat(summary_stats2).to_csv(f"{save_folder}/{str(current_date)[:10]}_s2.csv")

# 主函数
def main(alldf, remove_maker, save_folder, amount_threshold, low_interest_rate_threshold, other_bond_threshold, date_format, bond_column, bond_type_column, date_column, transaction_method_column, amount_column, trade_date_column, trade_price_column, sender_column, receiver_column, sender_account_column, receiver_account_column, rate_column):
    institution_set = read_institution_data(remove_maker)

    # 数据清理
    alldf = clean_data(alldf, transaction_method_column)

    # 初始化统计信息
    result_data = []
    summary_stats1 = []
    summary_stats2 = []
    new_group_id = 0
    max_pair_count = 0

    # 定义数据列名称
    data_columns = [
        transaction_number_column, date_column, transaction_time_column, amount_column, 
        sender_column, receiver_column, sender_account_column, receiver_account_column, 
        price_column, transaction_method_column, trade_date_column, trade_price_column, rate_column
    ]

    output_columns = [
        transaction_number_column, date_column, transaction_time_column, amount_column, 
        receiver_column, sender_column, receiver_account_column, sender_account_column, 
        price_column, transaction_method_column, trade_price_column, rate_column, 
        '路径标识', '路径端点', '方向', '传递交易量', '端点交易量', '母路径', '目标交易员', '目标机构', 
        '目标债券', '目标日期', '目标分组', '交易损益'
    ]

    # 按日期分组数据
    date_groups = alldf.groupby(date_column)

    # 遍历每一天的交易数据
    for current_date, daily_data in date_groups:
        print(f"开始处理 {str(current_date)[:10]} 数据 .....")
        
        # 将日期格式转为正确格式
        daily_data[date_column] = pd.to_datetime(daily_data[date_column], format="%Y-%m-%d")
        daily_data[trade_date_column] = pd.to_datetime(daily_data[trade_date_column], format=date_format)
        
        # 按债券类型分组
        bond_groups = daily_data.groupby(bond_column)

        # 遍历每种债券的交易数据
        for bond_name, bond_data in bond_groups:
            bond_data = bond_data.copy()
            bond_type = bond_data.head(1)[bond_type_column].values[0]

            # 根据债券类型设置阈值
            threshold = get_threshold(bond_type, low_interest_rate_threshold, other_bond_threshold)

            # 只保留需要的列
            bond_data = bond_data[data_columns]

            # 获取同时作为发送方和接收方的机构交集，并排除指定的机构
            intersecting_institutions = set(bond_data[sender_column].unique()).intersection(bond_data[receiver_column].unique()) - institution_set
            
            # 按发送方和接收方分组
            buy_groups = bond_data.groupby([sender_column])
            sell_groups = bond_data.groupby([receiver_column])

            # 计算净交易量和价格差
            transaction_institutions, net_positions = calculate_net_positions(buy_groups, sell_groups, amount_column, trade_price_column, amount_threshold)

            # 处理符合条件的机构数据
            if transaction_institutions:
                daily_data_copy = bond_data.copy()
                grouped_data = daily_data_copy.groupby([sender_column, receiver_column, sender_account_column, receiver_account_column, rate_column])

                additional_data_frame, data_set = process_grouped_data(grouped_data, data_columns)

                trade_data, stats1, stats2, max_pair_count = process_institution_data(daily_data_copy, transaction_institutions, additional_data_frame, data_set, max_pair_count, new_group_id, bond_name, current_date)

                result_data.extend(trade_data)
                summary_stats1.extend(stats1)
                summary_stats2.extend(stats2)

        # 输出每个日期的结果
        save_data(result_data, summary_stats1, summary_stats2, save_folder, current_date, output_columns)
