In [None]:
import os
import collections
import multiprocessing
from typing import List, Tuple, Dict, Set
from tqdm import tqdm

# --- 核心判断逻辑 (不变) ---
def is_complex(pdb_id: str, pdb_to_chains_map: Dict[str, Set[str]]) -> bool:
    return len(pdb_to_chains_map.get(pdb_id, set())) > 1

# --- 阶段一：并行读取文件的工作函数 (不变) ---
def read_file_chunk(args: Tuple[str, int, int]) -> Tuple[Dict[str, List[str]], Dict[str, Set[str]]]:
    filename, start, end = args
    local_clusters = collections.defaultdict(list)
    local_pdb_map = collections.defaultdict(set)
    with open(filename, 'r', encoding='utf-8') as f:
        f.seek(start)
        if start != 0:
            f.readline()
        while f.tell() < end:
            line = f.readline()
            if not line: break
            line = line.strip()
            if not line: continue
            try:
                center, member = line.split('\t')
                local_clusters[center].append(member)
                center_pdb_id, center_chain_id = center.split('_', 1)
                member_pdb_id, member_chain_id = member.split('_', 1)
                local_pdb_map[center_pdb_id].add(center_chain_id)
                local_pdb_map[member_pdb_id].add(member_chain_id)
            except ValueError:
                pass
    return local_clusters, local_pdb_map

# --- 阶段二：并行过滤的工作函数 (已修正) ---
worker_pdb_map = None

def init_filter_worker(pdb_map: Dict[str, Set[str]]):
    global worker_pdb_map
    worker_pdb_map = pdb_map

def filter_single_cluster(cluster_item: Tuple[str, List[str]]) -> List[Tuple[str, str]]:
    """
    处理单个聚类的函数 (已应用修正逻辑)。
    """
    center, members = cluster_item
    global worker_pdb_map
    
    # 步骤1: 判断这是否是一个“复合物聚类”
    cluster_has_complex = False
    for member in members:
        pdb_id = member.split('_', 1)[0]
        if is_complex(pdb_id, worker_pdb_map):
            cluster_has_complex = True
            break
    
    # 步骤2: 如果是，则只返回其中也属于复合物的成员
    if cluster_has_complex:
        # 【核心修正】在返回结果时，对每个成员再次进行is_complex判断
        return [(center, member) for member in members if is_complex(member.split('_', 1)[0], worker_pdb_map)]
    else:
        return []

# --- 主流程 (不变) ---
def run_parallel_pipeline(input_file: str, output_file: str, num_workers: int):
    if not os.path.exists(input_file):
        print(f"错误: 输入文件 '{input_file}' 未找到。")
        return

    # 阶段一: 并行读取和合并
    print("--- 阶段一: 并行读取文件 ---")
    file_size = os.path.getsize(input_file)
    chunk_size = file_size // num_workers
    chunk_args = [(input_file, i * chunk_size, (i + 1) * chunk_size if i < num_workers - 1 else file_size) for i in range(num_workers)]
    
    clusters = collections.defaultdict(list)
    pdb_to_chains_map = collections.defaultdict(set)

    with multiprocessing.Pool(num_workers) as pool:
        results_iterator = pool.imap(read_file_chunk, chunk_args)
        print("正在合并读取结果...")
        for local_clusters, local_pdb_map in tqdm(results_iterator, total=len(chunk_args), desc="读取文件块"):
            for center, members in local_clusters.items():
                clusters[center].extend(members)
            for pdb_id, chains in local_pdb_map.items():
                pdb_to_chains_map[pdb_id].update(chains)
    print(f"文件读取与合并完成。共找到 {len(clusters)} 个原始聚类。")

    # 阶段二: 并行过滤
    print("\n--- 阶段二: 并行过滤聚类 ---")
    final_results = []
    with multiprocessing.Pool(num_workers, initializer=init_filter_worker, initargs=(pdb_to_chains_map,)) as pool:
        filter_iterator = pool.imap_unordered(filter_single_cluster, clusters.items())
        for result_list in tqdm(filter_iterator, total=len(clusters), desc="过滤聚类"):
            if result_list:
                final_results.extend(result_list)

    # 阶段三: 写入文件
    print("\n--- 阶段三: 写入最终结果 ---")
    print(f"正在将 {len(final_results)} 条复合物聚类条目写入到: {output_file}...")
    with open(output_file, 'w', encoding='utf-8') as f:
        # 对结果进行排序，让输出更规整 (可选)
        final_results.sort() 
        for center, member in tqdm(final_results, desc="写入文件"):
            f.write(f"{center}\t{member}\n")
            
    print("\n所有处理已完成！")


if __name__ == "__main__":
    # --- 配置 ---
    INPUT_FILENAME = "cluster.tsv"
    OUTPUT_FILENAME = "complex_cluster.tsv"
    
    # 设定工作进程数，通常等于CPU核心数
    # 对于I/O密集型任务，可以适当增加
    NUM_WORKERS = 64

    print(f"输入文件: {INPUT_FILENAME}")
    print(f"输出文件: {OUTPUT_FILENAME}")
    print(f"将使用 {NUM_WORKERS} 个工作进程。")
    print("-" * 30)

    # 运行修正后的流水线
    run_parallel_pipeline(INPUT_FILENAME, OUTPUT_FILENAME, NUM_WORKERS)

输入文件: cluster.tsv
输出文件: complex_cluster.tsv
将使用 64 个工作进程。
------------------------------
--- 阶段一: 并行读取文件 ---
正在合并读取结果...


读取文件块: 100%|██████████| 64/64 [00:01<00:00, 52.57it/s]


文件读取与合并完成。共找到 27081 个原始聚类。

--- 阶段二: 并行过滤聚类 ---


过滤聚类: 100%|██████████| 27081/27081 [00:01<00:00, 21389.29it/s]



--- 阶段三: 写入最终结果 ---
正在将 151342 条复合物聚类条目写入到: complex_cluster.tsv...


写入文件: 100%|██████████| 151342/151342 [00:00<00:00, 4363640.06it/s]


所有处理已完成！





In [3]:
! top

[?1h=[H[2J[mtop - 15:23:47 up 76 days,  4:26, 11 users,  load average: 4.05, 3.45, 3.57[m[m[m[m[K
Tasks:[m[m[1m 1749 [m[mtotal,[m[m[1m   4 [m[mrunning,[m[m[1m 1745 [m[msleeping,[m[m[1m   0 [m[mstopped,[m[m[1m   0 [m[mzombie[m[m[m[m[K
%Cpu(s):[m[m[1m  2.0 [m[mus,[m[m[1m  0.8 [m[msy,[m[m[1m  0.0 [m[mni,[m[m[1m 97.1 [m[mid,[m[m[1m  0.0 [m[mwa,[m[m[1m  0.0 [m[mhi,[m[m[1m  0.0 [m[msi,[m[m[1m  0.0 [m[mst[m[m[m[m[K
MiB Mem :[m[m[1m 1031675.+[m[mtotal,[m[m[1m 924562.4 [m[mfree,[m[m[1m  43600.3 [m[mused,[m[m[1m  63513.0 [m[mbuff/cache[m[m[m[m[K
MiB Swap:[m[m[1m      0.0 [m[mtotal,[m[m[1m      0.0 [m[mfree,[m[m[1m      0.0 [m[mused.[m[m[1m 964285.0 [m[mavail Mem [m[m[m[m[K
[K
[7m    PID USER      PR  NI    VIRT    RES    SHR S  %CPU  %MEM     TIME+ COMMAND  [m[m[K
[m[1m1762345 renju     20   0  108.9g   9.9g 224480 R 105.6   1.0 353:20.70 python3  [m[m[