In [11]:
import pickle
import torch
import numpy as np
from pathlib import Path

def analyze_pkl_file_jupyter(pkl_file_path, show_all=False, output_txt='analyze_result.txt'):
    """分析PKL文件的内容和结构，适用于Jupyter Notebook，并将结果写入txt文件。"""
    result_lines = []
    def _p(s):
        print(s)
        result_lines.append(s)
        
    if not Path(pkl_file_path).exists():
        _p(f"错误: PKL文件不存在: {pkl_file_path}")
        with open(output_txt, 'w', encoding='utf-8') as f:
            f.write('\n'.join(result_lines))
        return

    try:
        _p(f"正在加载PKL文件: {pkl_file_path}")
        with open(pkl_file_path, 'rb') as f:
            data = pickle.load(f)

        _p(f"成功加载 {len(data)} 个复合物")
        _p(f"数据类型: {type(data)}")

        # 分析每个复合物
        for idx, complex_data in enumerate(data):
            _p('\n' + '='*80)
            _p(f"复合物 {idx + 1}: {complex_data.get('pdbid', 'Unknown')}")
            _p('='*80)

            _p("\n=== 所有键值对和形状统计 ===")

            # 按类别分组显示
            basic_info = {}
            graph_data = {}
            masif_features = {}
            spatial_edges = {}
            scores = {}
            other_data = {}

            for key, value in complex_data.items():
                # 分类键值对
                if key in ['pdbid', 'smiles', 'pk', 'rmsd']:
                    basic_info[key] = value
                elif key.startswith('masif_'):
                    masif_features[key] = value
                elif 'spatial' in key:
                    spatial_edges[key] = value
                elif key in ['node_feat', 'edge_feat', 'edge_index', 'coords']:
                    graph_data[key] = value
                elif key in ['rfscore', 'gbscore']:
                    scores[key] = value
                else:
                    other_data[key] = value

            # 打印基本信息
            if basic_info:
                _p("\n--- 基本信息 ---")
                for key, value in basic_info.items():
                    _p(f"  {key}: {type(value).__name__} - {value}")

            # 打印图数据
            if graph_data:
                _p("\n--- 图结构数据 ---")
                for key, value in graph_data.items():
                    if isinstance(value, (np.ndarray, torch.Tensor)):
                        _p(f"  {key}: {type(value).__name__} shape {value.shape}")
                        if hasattr(value, 'dtype'):
                            _p(f"    数据类型: {value.dtype}")
                    else:
                        _p(f"  {key}: {type(value).__name__} - {value}")

            # 打印MaSIF特征
            if masif_features:
                _p("\n--- MaSIF特征 ---")
                for key, value in masif_features.items():
                    if isinstance(value, (np.ndarray, torch.Tensor)):
                        _p(f"  {key}: {type(value).__name__} shape {value.shape}")
                        if hasattr(value, 'dtype'):
                            _p(f"    数据类型: {value.dtype}")
                        # 检查NaN值
                        if hasattr(value, 'isnan'):
                            nan_count = torch.isnan(value).sum().item() if isinstance(value, torch.Tensor) else np.isnan(value).sum()
                            _p(f"    包含NaN值: {nan_count} 个")
                    else:
                        _p(f"  {key}: {type(value).__name__} - {value}")

                # 打印masif_desc_straight的值（如果存在，前30项）
                if 'masif_desc_straight' in masif_features:
                    desc = masif_features['masif_desc_straight']
                    _p("\n+++ masif_desc_straight 的值展示 +++")
                    if isinstance(desc, (np.ndarray, torch.Tensor)):
                        arr = desc.cpu().numpy() if isinstance(desc, torch.Tensor) else desc
                        _p(f"  masif_desc_straight 前30项: {arr.flatten()[:30]}")
                        _p(f"  masif_desc_straight 总形状: {arr.shape}")
                    else:
                        _p(f"  masif_desc_straight: {desc}")

            # 打印空间边数据
            if spatial_edges:
                _p("\n--- 空间边数据 ---")
                for key, value in spatial_edges.items():
                    if isinstance(value, (np.ndarray, torch.Tensor)):
                        _p(f"  {key}: {type(value).__name__} shape {value.shape}")
                        if hasattr(value, 'dtype'):
                            _p(f"    数据类型: {value.dtype}")
                    else:
                        _p(f"  {key}: {type(value).__name__} - {value}")

            # 打印评分数据
            if scores:
                _p("\n--- 评分特征 ---")
                for key, value in scores.items():
                    if isinstance(value, (np.ndarray, torch.Tensor)):
                        _p(f"  {key}: {type(value).__name__} shape {value.shape}")
                        if hasattr(value, 'dtype'):
                            _p(f"    数据类型: {value.dtype}")
                    else:
                        _p(f"  {key}: {type(value).__name__} - {value}")

            # 打印其他数据
            if other_data:
                _p("\n--- 其他数据 ---")
                for key, value in other_data.items():
                    if isinstance(value, (np.ndarray, torch.Tensor)):
                        _p(f"  {key}: {type(value).__name__} shape {value.shape}")
                        if hasattr(value, 'dtype'):
                            _p(f"    数据类型: {value.dtype}")
                    else:
                        _p(f"  {key}: {type(value).__name__} - {value}")

            # 统计总结
            total_keys = len(complex_data.keys())
            tensor_keys = sum(1 for v in complex_data.values() if isinstance(v, (np.ndarray, torch.Tensor)))

            _p(f"\n--- 统计总结 ---")
            _p(f"  总键数: {total_keys}")
            _p(f"  张量/数组键数: {tensor_keys}")
            _p(f"  标量/字符串键数: {total_keys - tensor_keys}")

            # 只分析第一个复合物的详细信息
            if not show_all and idx == 0 and len(data) > 1:
                _p(f"\n注意: 只显示第一个复合物的详细信息，共有 {len(data)} 个复合物")
                break

        _p("\n分析完成！")

    except Exception as e:
        _p(f"处理过程中出现错误: {e}")
        import traceback
        tb_str = traceback.format_exc()
        _p(tb_str)
    
    # 保存结果到txt文件
    with open(output_txt, 'w', encoding='utf-8') as f:
        f.write('\n'.join(result_lines))
    print(f"\n分析结果已保存到: {output_txt}")

# ------------------------------------------------------------------------
# 只需修改下面这两行即可分析你的PKL文件
your_pkl = '/xcfhome/zncao02/affincraft/data/2aco-VCA/output/2aco-VCA_features_with_masif.pkl'  # ← 换成你的PKL路径
output_txt = 'analyze_result.txt'    # 输出TXT文件名
# show_all=True显示全部复合物，False只显示第一个
analyze_pkl_file_jupyter(your_pkl, show_all=False, output_txt=output_txt)

正在加载PKL文件: /xcfhome/zncao02/affincraft/data/2aco-VCA/output/2aco-VCA_features_with_masif.pkl
成功加载 1 个复合物
数据类型: <class 'list'>

复合物 1: 2aco-VCA

=== 所有键值对和形状统计 ===

--- 基本信息 ---
  smiles: str - CCCCCC/C=C\CCCCCCCCCC(=O)[O-]
  rmsd: float - 0.0
  pk: float - 6.62
  pdbid: str - 2aco-VCA

--- 图结构数据 ---
  edge_index: ndarray shape (2, 1224)
    数据类型: int64
  edge_feat: ndarray shape (1224, 4)
    数据类型: float32
  node_feat: ndarray shape (118, 9)
    数据类型: int64
  coords: ndarray shape (118, 3)
    数据类型: float64

--- MaSIF特征 ---
  masif_input_feat: ndarray shape (3444, 100, 5)
    数据类型: float64
  masif_rho_wrt_center: ndarray shape (3444, 100)
    数据类型: float64
  masif_theta_wrt_center: ndarray shape (3444, 100)
    数据类型: float64
  masif_mask: ndarray shape (3444, 100)
    数据类型: float64
  masif_desc_straight: ndarray shape (3444, 80)
    数据类型: float32
  masif_desc_flipped: ndarray shape (3444, 80)
    数据类型: float32

+++ masif_desc_straight 的值展示 +++
  masif_desc_straight 前30项: [nan nan nan n

In [1]:
#!/usr/bin/env python3  
"""  
PKL文件详细分析脚本 - 专门用于分析AffinCraft生成的特征文件  
"""  
  
import pickle  
import torch  
import numpy as np  
from pathlib import Path  
import sys  
  
def analyze_edge_types_detailed(edge_attr):  
    """详细分析边的类型分布，包括所有PLIP相互作用类型"""  
    edge_types = {}  
    edge_type_stats = {}  
      
    # 处理numpy数组或tensor  
    if isinstance(edge_attr, torch.Tensor):  
        edge_attr = edge_attr.numpy()  
      
    for i, edge_feature in enumerate(edge_attr):  
        if len(edge_feature) >= 3:  
            # 前3维是边类型编码  
            edge_type_code = tuple(edge_feature[:3].astype(int))  
              
            # 根据边类型映射进行分类  
            if edge_type_code == (4, 0, 0):  
                edge_type = "SPATIAL_EDGE"  
                edge_category = "空间边"  
            elif edge_type_code == (5, 1, 0):  
                edge_type = "HYDROGEN_BOND"  
                edge_category = "PLIP相互作用边"  
            elif edge_type_code == (5, 2, 0):  
                edge_type = "HYDROPHOBIC_CONTACT"  
                edge_category = "PLIP相互作用边"  
            elif edge_type_code == (5, 3, 0):  
                edge_type = "PI_STACKING"  
                edge_category = "PLIP相互作用边"  
            elif edge_type_code == (5, 4, 0):  
                edge_type = "PI_CATION"  
                edge_category = "PLIP相互作用边"  
            elif edge_type_code == (5, 5, 0):  
                edge_type = "SALT_BRIDGE"  
                edge_category = "PLIP相互作用边"  
            elif edge_type_code == (5, 6, 0):  
                edge_type = "WATER_BRIDGE"  
                edge_category = "PLIP相互作用边"  
            elif edge_type_code == (5, 7, 0):  
                edge_type = "HALOGEN_BOND"  
                edge_category = "PLIP相互作用边"  
            elif edge_type_code == (5, 8, 0):  
                edge_type = "METAL_COMPLEX"  
                edge_category = "PLIP相互作用边"  
            elif edge_type_code == (5, 9, 0):  
                edge_type = "OTHERS"  
                edge_category = "PLIP相互作用边"  
            else:  
                # 可能是结构边（化学键）  
                edge_type = f"STRUCTURAL_BOND_{edge_type_code}"  
                edge_category = "结构边"  
              
            if edge_type not in edge_types:  
                edge_types[edge_type] = []  
                edge_type_stats[edge_type] = {  
                    'category': edge_category,  
                    'code': edge_type_code,  
                    'count': 0,  
                    'distances': []  
                }  
              
            edge_type_stats[edge_type]['count'] += 1  
              
            # 如果有第4维，那是距离信息  
            if len(edge_feature) > 3:  
                distance = edge_feature[3]  
                edge_types[edge_type].append(distance)  
                edge_type_stats[edge_type]['distances'].append(distance)  
            else:  
                edge_types[edge_type].append("N/A")  
      
    return edge_types, edge_type_stats  
  
def print_detailed_analysis(pkl_file_path):  
    """打印PKL文件的详细分析"""  
      
    # 检查文件是否存在  
    if not Path(pkl_file_path).exists():  
        print(f"错误: PKL文件不存在: {pkl_file_path}")  
        return  
      
    try:  
        # 加载pkl文件  
        print(f"正在加载PKL文件: {pkl_file_path}")  
        with open(pkl_file_path, 'rb') as f:  
            graphs = pickle.load(f)  
          
        print(f"成功加载 {len(graphs)} 个复合物")  
        print(f"数据类型: {type(graphs[0])}")  
          
        # 分析每个复合物  
        for idx, graph in enumerate(graphs):  
            print(f"\n{'='*80}")  
            print(f"复合物 {idx + 1}: {graph.get('pdbid', 'Unknown')}")  
            print(f"{'='*80}")  
              
            # 基本信息  
            print("\n=== 基本信息 ===")  
            print(f"PDB ID: {graph.get('pdbid', 'N/A')}")  
            print(f"结合亲和力 (pK): {graph.get('pk', 'N/A')}")  
            print(f"SMILES: {graph.get('smiles', 'N/A')}")  
            print(f"RMSD: {graph.get('rmsd', 'N/A')}")  
              
            # 数据结构概览  
            print("\n=== 数据结构概览 ===")  
            for key, value in graph.items():  
                if isinstance(value, (np.ndarray, torch.Tensor)):  
                    print(f"  {key}: {type(value).__name__} shape {value.shape}")  
                else:  
                    print(f"  {key}: {type(value).__name__} - {value}")  
              
            # 图结构详细信息  
            print("\n=== 图结构详细信息 ===")  
            node_feat = graph.get('node_feat', np.array([]))  
            edge_index = graph.get('edge_index', np.array([]))  
            edge_feat = graph.get('edge_feat', np.array([]))  
            coords = graph.get('coords', np.array([]))  
              
            print(f"节点数量: {node_feat.shape[0] if len(node_feat.shape) > 0 else 0}")  
            print(f"边数量: {edge_index.shape[1] if len(edge_index.shape) > 1 else 0}")  
            print(f"节点特征维度: {node_feat.shape[1] if len(node_feat.shape) > 1 else 0}")  
            print(f"边特征维度: {edge_feat.shape[1] if len(edge_feat.shape) > 1 else 0}")  
              
            # 分子组成统计  
            if 'num_node' in graph and 'num_edge' in graph:  
                print("\n=== 分子组成统计 ===")  
                num_node = graph['num_node']  
                num_edge = graph['num_edge']  
                print(f"配体节点数: {num_node[0]}")  
                print(f"蛋白质节点数: {num_node[1]}")  
                if len(num_edge) >= 5:  
                    print(f"配体结构边数: {num_edge[0]}")  
                    print(f"蛋白质结构边数: {num_edge[1]}")  
                    print(f"配体-蛋白相互作用边数: {num_edge[2]}")  
                    print(f"配体空间边数: {num_edge[3]}")  
                    print(f"蛋白质空间边数: {num_edge[4]}")  
              
            # 详细边类型分析  
            if len(edge_feat.shape) > 1:  
                print("\n=== 详细边类型分析 ===")  
                edge_types, edge_type_stats = analyze_edge_types_detailed(edge_feat)  
                  
                # 总体统计  
                total_edges = sum(stats['count'] for stats in edge_type_stats.values())  
                print(f"总边数: {total_edges}")  
                  
                # 按类别分组显示  
                categories = {}  
                for edge_type, stats in edge_type_stats.items():  
                    category = stats['category']  
                    if category not in categories:  
                        categories[category] = []  
                    categories[category].append((edge_type, stats))  
                  
                for category, edges in categories.items():  
                    category_total = sum(stats['count'] for _, stats in edges)  
                    print(f"\n{category}: {category_total} 条边 ({category_total/total_edges*100:.1f}%)")  
                      
                    for edge_type, stats in sorted(edges, key=lambda x: x[1]['count'], reverse=True):  
                        count = stats['count']  
                        percentage = count / total_edges * 100  
                        print(f"  {edge_type}:")  
                        print(f"    数量: {count} 条 ({percentage:.1f}%)")  
                        print(f"    编码: {stats['code']}")  
                          
                        if stats['distances'] and len([d for d in stats['distances'] if d != "N/A"]) > 0:  
                            distances = [d for d in stats['distances'] if d != "N/A"]  
                            print(f"    距离范围: {min(distances):.3f} - {max(distances):.3f} Å")  
                            print(f"    平均距离: {np.mean(distances):.3f} Å")  
                            print(f"    距离标准差: {np.std(distances):.3f} Å")  
              
            # 3D坐标信息  
            if len(coords.shape) > 1:  
                print("\n=== 3D坐标信息 ===")  
                print(f"坐标形状: {coords.shape}")  
                print(f"坐标范围:")  
                print(f"  X: {coords[:, 0].min():.3f} - {coords[:, 0].max():.3f}")  
                print(f"  Y: {coords[:, 1].min():.3f} - {coords[:, 1].max():.3f}")  
                print(f"  Z: {coords[:, 2].min():.3f} - {coords[:, 2].max():.3f}")  
              
            # 分子相互作用评分  
            print("\n=== 分子相互作用评分 ===")  
            if 'rfscore' in graph:  
                rfscore = graph['rfscore']  
                print(f"RF-Score 维度: {len(rfscore)}")  
                print(f"RF-Score 前10个值: {rfscore[:10].tolist()}")  
            if 'gbscore' in graph:  
                gbscore = graph['gbscore']  
                print(f"GB-Score 维度: {len(gbscore)}")  
                print(f"GB-Score 前10个值: {gbscore[:10].tolist()}")  
            if 'ecif' in graph:  
                ecif = graph['ecif']  
                print(f"ECIF 维度: {len(ecif)}")  
                print(f"ECIF 非零元素数: {np.count_nonzero(ecif)}")  
              
            # MaSIF特征（如果存在）  
            if 'masif_features' in graph:  
                masif_features = graph['masif_features']  
                print(f"\n=== MaSIF特征 ===")  
                print(f"MaSIF特征形状: {masif_features.shape}")  
                print(f"MaSIF特征范围: {masif_features.min():.3f} - {masif_features.max():.3f}")  
          
        print(f"\n分析完成！")  
          
    except Exception as e:  
        print(f"处理过程中出现错误: {e}")  
        import traceback  
        traceback.print_exc()  
  
if __name__ == "__main__":  
    # 您的PKL文件路径  
    pkl_file_path = "/xcfhome/zncao02/affincraft/data/6w1i-G4P/output/6w1i-G4P_features_with_masif.pkl"  
      
    print_detailed_analysis(pkl_file_path)

正在加载PKL文件: /xcfhome/zncao02/affincraft/data/6w1i-G4P/output/6w1i-G4P_features_with_masif.pkl
成功加载 1 个复合物
数据类型: <class 'dict'>

复合物 1: 6w1i-G4P

=== 基本信息 ===
PDB ID: 6w1i-G4P
结合亲和力 (pK): 6.62
SMILES: NC1=NC2=C(/N=C\N2C2OC(COP(=O)([O-])OP(=O)([O-])[O-])C(OP(=O)([O-])OP(=O)([O-])[O-])C2O)C(=O)N1
RMSD: 1.4

=== 数据结构概览 ===
  edge_index: ndarray shape (2, 2266)
  edge_feat: ndarray shape (2266, 4)
  node_feat: ndarray shape (170, 9)
  coords: ndarray shape (170, 3)
  pro_name: ndarray shape (612,)
  AA_name: ndarray shape (612,)
  smiles: str - NC1=NC2=C(/N=C\N2C2OC(COP(=O)([O-])OP(=O)([O-])[O-])C(OP(=O)([O-])OP(=O)([O-])[O-])C2O)C(=O)N1
  rmsd: float - 1.4
  rfscore: ndarray shape (100,)
  gbscore: ndarray shape (400,)
  pk: float - 6.62
  pdbid: str - 6w1i-G4P
  num_node: ndarray shape (2,)
  num_edge: ndarray shape (5,)
  lig_spatial_edge_index: ndarray shape (2, 434)
  lig_spatial_edge_attr: ndarray shape (434, 4)
  pro_spatial_edge_index: ndarray shape (2, 1498)
  pro_spatial_edge_attr:

In [None]:
"""  
PKL文件边类型分离分析脚本 - 分别分析共价边、内部空间边和相互作用边  
"""  
  
import pickle  
import torch  
import numpy as np  
from pathlib import Path  
  
def analyze_edges_by_category(edge_index, edge_feat, num_ligand_atoms):  
    """  
    根据边的连接模式和类型编码分离三种边类型  
      
    Args:  
        edge_index: 边索引 [2, num_edges]  
        edge_feat: 边特征 [num_edges, feat_dim]  
        num_ligand_atoms: 配体原子数量  
      
    Returns:  
        dict: 包含三种边类型的详细信息  
    """  
    if isinstance(edge_feat, torch.Tensor):  
        edge_feat = edge_feat.numpy()  
    if isinstance(edge_index, torch.Tensor):  
        edge_index = edge_index.numpy()  
      
    # 初始化三种边类型的存储  
    structural_bonds = {'indices': [], 'features': [], 'stats': {}}  
    internal_spatial = {'ligand': {'indices': [], 'features': [], 'stats': {}},   
                       'protein': {'indices': [], 'features': [], 'stats': {}}}  
    interaction_edges = {'indices': [], 'features': [], 'stats': {}}  
      
    for i in range(edge_index.shape[1]):  
        src_idx, tgt_idx = edge_index[0, i], edge_index[1, i]  
        edge_feature = edge_feat[i]  
          
        if len(edge_feature) >= 3:  
            edge_type_code = tuple(edge_feature[:3].astype(int))  
              
            # 判断边的类型和位置  
            src_is_ligand = src_idx < num_ligand_atoms  
            tgt_is_ligand = tgt_idx < num_ligand_atoms  
              
            # 分类边类型  
            if edge_type_code[0] in [0, 1]:  # 结构边（共价键）  
                structural_bonds['indices'].append([src_idx, tgt_idx])  
                structural_bonds['features'].append(edge_feature)  
                  
                bond_type = f"LIGAND_BOND_{edge_type_code}" if src_is_ligand and tgt_is_ligand else f"PROTEIN_BOND_{edge_type_code}"  
                if bond_type not in structural_bonds['stats']:  
                    structural_bonds['stats'][bond_type] = {'count': 0, 'distances': []}  
                structural_bonds['stats'][bond_type]['count'] += 1  
                if len(edge_feature) > 3:  
                    structural_bonds['stats'][bond_type]['distances'].append(edge_feature[3])  
                      
            elif edge_type_code[0] == 4:  # 原始空间边  
                if src_is_ligand and tgt_is_ligand:  
                    # 配体内部空间边  
                    internal_spatial['ligand']['indices'].append([src_idx, tgt_idx])  
                    internal_spatial['ligand']['features'].append(edge_feature)  
                elif not src_is_ligand and not tgt_is_ligand:  
                    # 蛋白质内部空间边  
                    internal_spatial['protein']['indices'].append([src_idx, tgt_idx])  
                    internal_spatial['protein']['features'].append(edge_feature)  
                else:  
                    # 蛋白-配体相互作用边  
                    interaction_edges['indices'].append([src_idx, tgt_idx])  
                    interaction_edges['features'].append(edge_feature)  
                      
            elif edge_type_code[0] == 5:  # PLIP相互作用边  
                # 根据原子位置判断是内部空间边还是相互作用边  
                if src_is_ligand and tgt_is_ligand:  
                    # 配体内部的PLIP类型空间边  
                    internal_spatial['ligand']['indices'].append([src_idx, tgt_idx])  
                    internal_spatial['ligand']['features'].append(edge_feature)  
                      
                    edge_type = get_plip_edge_type_name(edge_type_code)  
                    if edge_type not in internal_spatial['ligand']['stats']:  
                        internal_spatial['ligand']['stats'][edge_type] = {'count': 0, 'distances': []}  
                    internal_spatial['ligand']['stats'][edge_type]['count'] += 1  
                    if len(edge_feature) > 3:  
                        internal_spatial['ligand']['stats'][edge_type]['distances'].append(edge_feature[3])  
                          
                elif not src_is_ligand and not tgt_is_ligand:  
                    # 蛋白质内部的PLIP类型空间边  
                    internal_spatial['protein']['indices'].append([src_idx, tgt_idx])  
                    internal_spatial['protein']['features'].append(edge_feature)  
                      
                    edge_type = get_plip_edge_type_name(edge_type_code)  
                    if edge_type not in internal_spatial['protein']['stats']:  
                        internal_spatial['protein']['stats'][edge_type] = {'count': 0, 'distances': []}  
                    internal_spatial['protein']['stats'][edge_type]['count'] += 1  
                    if len(edge_feature) > 3:  
                        internal_spatial['protein']['stats'][edge_type]['distances'].append(edge_feature[3])  
                          
                else:  
                    # 蛋白-配体PLIP相互作用边  
                    interaction_edges['indices'].append([src_idx, tgt_idx])  
                    interaction_edges['features'].append(edge_feature)  
                      
                    edge_type = get_plip_edge_type_name(edge_type_code)  
                    if edge_type not in interaction_edges['stats']:  
                        interaction_edges['stats'][edge_type] = {'count': 0, 'distances': []}  
                    interaction_edges['stats'][edge_type]['count'] += 1  
                    if len(edge_feature) > 3:  
                        interaction_edges['stats'][edge_type]['distances'].append(edge_feature[3])  
      
    return {  
        'structural_bonds': structural_bonds,  
        'internal_spatial': internal_spatial,  
        'interaction_edges': interaction_edges  
    }  
  
def get_plip_edge_type_name(edge_type_code):  
    """根据边类型编码返回PLIP相互作用类型名称"""  
    plip_type_map = {  
        (5, 1, 0): "HYDROGEN_BOND",  
        (5, 2, 0): "HYDROPHOBIC_CONTACT",   
        (5, 3, 0): "PI_STACKING",  
        (5, 4, 0): "PI_CATION",  
        (5, 5, 0): "SALT_BRIDGE",  
        (5, 6, 0): "WATER_BRIDGE",  
        (5, 7, 0): "HALOGEN_BOND",  
        (5, 8, 0): "METAL_COMPLEX",  
        (5, 9, 0): "OTHERS"  
    }  
    return plip_type_map.get(edge_type_code, f"UNKNOWN_{edge_type_code}")  
  
def print_edge_category_analysis(category_name, edge_data):  
    """打印特定边类别的详细分析"""  
    print(f"\n=== {category_name} ===")  
      
    if isinstance(edge_data, dict) and 'stats' in edge_data:  
        # 单一类别（如结构边或相互作用边）  
        total_edges = sum(stats['count'] for stats in edge_data['stats'].values())  
        print(f"总边数: {total_edges}")  
          
        for edge_type, stats in sorted(edge_data['stats'].items(), key=lambda x: x[1]['count'], reverse=True):  
            count = stats['count']  
            percentage = count / total_edges * 100 if total_edges > 0 else 0  
            print(f"  {edge_type}: {count} 条 ({percentage:.1f}%)")  
              
            if stats['distances']:  
                distances = stats['distances']  
                print(f"    距离范围: {min(distances):.3f} - {max(distances):.3f} Å")  
                print(f"    平均距离: {np.mean(distances):.3f} Å")  
                  
    elif isinstance(edge_data, dict) and 'ligand' in edge_data:  
        # 内部空间边（包含配体和蛋白质）  
        ligand_total = sum(stats['count'] for stats in edge_data['ligand']['stats'].values())  
        protein_total = sum(stats['count'] for stats in edge_data['protein']['stats'].values())  
          
        print(f"配体内部空间边: {ligand_total} 条")  
        for edge_type, stats in sorted(edge_data['ligand']['stats'].items(), key=lambda x: x[1]['count'], reverse=True):  
            count = stats['count']  
            percentage = count / ligand_total * 100 if ligand_total > 0 else 0  
            print(f"  {edge_type}: {count} 条 ({percentage:.1f}%)")  
            if stats['distances']:  
                distances = stats['distances']  
                print(f"    距离范围: {min(distances):.3f} - {max(distances):.3f} Å")  
                print(f"    平均距离: {np.mean(distances):.3f} Å")  
          
        print(f"\n蛋白质内部空间边: {protein_total} 条")  
        for edge_type, stats in sorted(edge_data['protein']['stats'].items(), key=lambda x: x[1]['count'], reverse=True):  
            count = stats['count']  
            percentage = count / protein_total * 100 if protein_total > 0 else 0  
            print(f"  {edge_type}: {count} 条 ({percentage:.1f}%)")  
            if stats['distances']:  
                distances = stats['distances']  
                print(f"    距离范围: {min(distances):.3f} - {max(distances):.3f} Å")  
                print(f"    平均距离: {np.mean(distances):.3f} Å")  
  
def analyze_pkl_by_edge_types(pkl_file_path):  
    """按边类型分析PKL文件"""  
      
    if not Path(pkl_file_path).exists():  
        print(f"错误: PKL文件不存在: {pkl_file_path}")  
        return  
      
    try:  
        print(f"正在加载PKL文件: {pkl_file_path}")  
        with open(pkl_file_path, 'rb') as f:  
            graphs = pickle.load(f)  
          
        print(f"成功加载 {len(graphs)} 个复合物")  
          
        for idx, graph in enumerate(graphs):  
            print(f"\n{'='*80}")  
            print(f"复合物 {idx + 1}: {graph.get('pdbid', 'Unknown')}")  
            print(f"{'='*80}")  
              
            # 获取基本信息  
            edge_index = graph.get('edge_index', np.array([]))  
            edge_feat = graph.get('edge_feat', np.array([]))  
            num_node = graph.get('num_node', [0, 0])  
            num_ligand_atoms = num_node[0] if len(num_node) > 0 else 0  
              
            print(f"配体原子数: {num_ligand_atoms}")  
            print(f"蛋白质原子数: {num_node[1] if len(num_node) > 1 else 0}")  
            print(f"总边数: {edge_index.shape[1] if len(edge_index.shape) > 1 else 0}")  
              
            if len(edge_index.shape) > 1 and len(edge_feat.shape) > 1:  
                # 分析边类型  
                edge_analysis = analyze_edges_by_category(edge_index, edge_feat, num_ligand_atoms)  
                  
                # 打印三种边类型的详细分析  
                print_edge_category_analysis("共价边（结构边）", edge_analysis['structural_bonds'])  
                print_edge_category_analysis("内部空间边", edge_analysis['internal_spatial'])  
                print_edge_category_analysis("蛋白-配体相互作用边", edge_analysis['interaction_edges'])  
          
        print(f"\n分析完成！")  
          
    except Exception as e:  
        print(f"处理过程中出现错误: {e}")  
        import traceback  
        traceback.print_exc()  
  
if __name__ == "__main__":  
    # 您的PKL文件路径  
    pkl_file_path = "/xcfhome/zncao02/affincraft/data/test.pkl"  
      
    analyze_pkl_by_edge_types(pkl_file_path)

正在加载PKL文件: /xcfhome/zncao02/affincraft/data/test.pkl
成功加载 1 个复合物

复合物 1: 6w1i-G4P
配体原子数: 36
蛋白质原子数: 134
总边数: 2266

=== 共价边（结构边） ===
总边数: 290
  PROTEIN_BOND_(0, 0, 0): 136 条 (46.9%)
    距离范围: 1.411 - 1.543 Å
    平均距离: 1.502 Å
  LIGAND_BOND_(0, 0, 0): 60 条 (20.7%)
    距离范围: 1.321 - 1.616 Å
    平均距离: 1.477 Å
  PROTEIN_BOND_(0, 0, 1): 42 条 (14.5%)
    距离范围: 1.251 - 1.341 Å
    平均距离: 1.323 Å
  PROTEIN_BOND_(1, 0, 1): 36 条 (12.4%)
    距离范围: 1.226 - 1.252 Å
    平均距离: 1.235 Å
  LIGAND_BOND_(1, 0, 0): 16 条 (5.5%)
    距离范围: 1.248 - 1.515 Å
    平均距离: 1.431 Å

=== 内部空间边 ===
配体内部空间边: 434 条
  OTHERS: 434 条 (100.0%)
    距离范围: 2.178 - 4.991 Å
    平均距离: 3.555 Å

蛋白质内部空间边: 1498 条
  OTHERS: 1338 条 (89.3%)
    距离范围: 2.198 - 4.996 Å
    平均距离: 3.848 Å
  HYDROGEN_BOND: 110 条 (7.3%)
    距离范围: 2.205 - 3.947 Å
    平均距离: 2.974 Å
  HYDROPHOBIC_CONTACT: 40 条 (2.7%)
    距离范围: 2.383 - 3.920 Å
    平均距离: 2.745 Å
  PI_CATION: 10 条 (0.7%)
    距离范围: 3.395 - 4.762 Å
    平均距离: 4.225 Å

=== 蛋白-配体相互作用边 ===
总边数: 32
  HYDROGEN

In [3]:
"""  
PKL文件边类型分离分析脚本 - 分别分析共价边、内部空间边和相互作用边  
"""  
  
import pickle  
import torch  
import numpy as np  
from pathlib import Path  
  
def analyze_edges_by_category(edge_index, edge_feat, num_ligand_atoms):  
    """  
    根据边的连接模式和类型编码分离三种边类型  
      
    Args:  
        edge_index: 边索引 [2, num_edges]  
        edge_feat: 边特征 [num_edges, feat_dim]  
        num_ligand_atoms: 配体原子数量  
      
    Returns:  
        dict: 包含三种边类型的详细信息  
    """  
    if isinstance(edge_feat, torch.Tensor):  
        edge_feat = edge_feat.numpy()  
    if isinstance(edge_index, torch.Tensor):  
        edge_index = edge_index.numpy()  
      
    # 初始化三种边类型的存储  
    structural_bonds = {'indices': [], 'features': [], 'stats': {}}  
    internal_spatial = {'ligand': {'indices': [], 'features': [], 'stats': {}},   
                       'protein': {'indices': [], 'features': [], 'stats': {}}}  
    interaction_edges = {'indices': [], 'features': [], 'stats': {}}  
      
    for i in range(edge_index.shape[1]):  
        src_idx, tgt_idx = edge_index[0, i], edge_index[1, i]  
        edge_feature = edge_feat[i]  
          
        if len(edge_feature) >= 3:  
            edge_type_code = tuple(edge_feature[:3].astype(int))  
              
            # 判断边的类型和位置  
            src_is_ligand = src_idx < num_ligand_atoms  
            tgt_is_ligand = tgt_idx < num_ligand_atoms  
              
            # 分类边类型  
            if edge_type_code[0] in [0, 1]:  # 结构边（共价键）  
                structural_bonds['indices'].append([src_idx, tgt_idx])  
                structural_bonds['features'].append(edge_feature)  
                  
                bond_type = f"LIGAND_BOND_{edge_type_code}" if src_is_ligand and tgt_is_ligand else f"PROTEIN_BOND_{edge_type_code}"  
                if bond_type not in structural_bonds['stats']:  
                    structural_bonds['stats'][bond_type] = {'count': 0, 'distances': []}  
                structural_bonds['stats'][bond_type]['count'] += 1  
                if len(edge_feature) > 3:  
                    structural_bonds['stats'][bond_type]['distances'].append(edge_feature[3])  
                      
            elif edge_type_code[0] == 4:  # 原始空间边  
                if src_is_ligand and tgt_is_ligand:  
                    # 配体内部空间边  
                    internal_spatial['ligand']['indices'].append([src_idx, tgt_idx])  
                    internal_spatial['ligand']['features'].append(edge_feature)  
                elif not src_is_ligand and not tgt_is_ligand:  
                    # 蛋白质内部空间边  
                    internal_spatial['protein']['indices'].append([src_idx, tgt_idx])  
                    internal_spatial['protein']['features'].append(edge_feature)  
                else:  
                    # 蛋白-配体相互作用边  
                    interaction_edges['indices'].append([src_idx, tgt_idx])  
                    interaction_edges['features'].append(edge_feature)  
                      
            elif edge_type_code[0] == 5:  # PLIP相互作用边  
                # 根据原子位置判断是内部空间边还是相互作用边  
                if src_is_ligand and tgt_is_ligand:  
                    # 配体内部的PLIP类型空间边  
                    internal_spatial['ligand']['indices'].append([src_idx, tgt_idx])  
                    internal_spatial['ligand']['features'].append(edge_feature)  
                      
                    edge_type = get_plip_edge_type_name(edge_type_code)  
                    if edge_type not in internal_spatial['ligand']['stats']:  
                        internal_spatial['ligand']['stats'][edge_type] = {'count': 0, 'distances': []}  
                    internal_spatial['ligand']['stats'][edge_type]['count'] += 1  
                    if len(edge_feature) > 3:  
                        internal_spatial['ligand']['stats'][edge_type]['distances'].append(edge_feature[3])  
                          
                elif not src_is_ligand and not tgt_is_ligand:  
                    # 蛋白质内部的PLIP类型空间边  
                    internal_spatial['protein']['indices'].append([src_idx, tgt_idx])  
                    internal_spatial['protein']['features'].append(edge_feature)  
                      
                    edge_type = get_plip_edge_type_name(edge_type_code)  
                    if edge_type not in internal_spatial['protein']['stats']:  
                        internal_spatial['protein']['stats'][edge_type] = {'count': 0, 'distances': []}  
                    internal_spatial['protein']['stats'][edge_type]['count'] += 1  
                    if len(edge_feature) > 3:  
                        internal_spatial['protein']['stats'][edge_type]['distances'].append(edge_feature[3])  
                          
                else:  
                    # 蛋白-配体PLIP相互作用边  
                    interaction_edges['indices'].append([src_idx, tgt_idx])  
                    interaction_edges['features'].append(edge_feature)  
                      
                    edge_type = get_plip_edge_type_name(edge_type_code)  
                    if edge_type not in interaction_edges['stats']:  
                        interaction_edges['stats'][edge_type] = {'count': 0, 'distances': []}  
                    interaction_edges['stats'][edge_type]['count'] += 1  
                    if len(edge_feature) > 3:  
                        interaction_edges['stats'][edge_type]['distances'].append(edge_feature[3])  
      
    return {  
        'structural_bonds': structural_bonds,  
        'internal_spatial': internal_spatial,  
        'interaction_edges': interaction_edges  
    }  
  
def get_plip_edge_type_name(edge_type_code):  
    """根据边类型编码返回PLIP相互作用类型名称"""  
    plip_type_map = {  
        (5, 1, 0): "HYDROGEN_BOND",  
        (5, 2, 0): "HYDROPHOBIC_CONTACT",   
        (5, 3, 0): "PI_STACKING",  
        (5, 4, 0): "PI_CATION",  
        (5, 5, 0): "SALT_BRIDGE",  
        (5, 6, 0): "WATER_BRIDGE",  
        (5, 7, 0): "HALOGEN_BOND",  
        (5, 8, 0): "METAL_COMPLEX",  
        (5, 9, 0): "OTHERS"  
    }  
    return plip_type_map.get(edge_type_code, f"UNKNOWN_{edge_type_code}")  
  
def print_edge_category_analysis(category_name, edge_data):  
    """打印特定边类别的详细分析"""  
    print(f"\n=== {category_name} ===")  
      
    if isinstance(edge_data, dict) and 'stats' in edge_data:  
        # 单一类别（如结构边或相互作用边）  
        total_edges = sum(stats['count'] for stats in edge_data['stats'].values())  
        print(f"总边数: {total_edges}")  
          
        for edge_type, stats in sorted(edge_data['stats'].items(), key=lambda x: x[1]['count'], reverse=True):  
            count = stats['count']  
            percentage = count / total_edges * 100 if total_edges > 0 else 0  
            print(f"  {edge_type}: {count} 条 ({percentage:.1f}%)")  
              
            if stats['distances']:  
                distances = stats['distances']  
                print(f"    距离范围: {min(distances):.3f} - {max(distances):.3f} Å")  
                print(f"    平均距离: {np.mean(distances):.3f} Å")  
                  
    elif isinstance(edge_data, dict) and 'ligand' in edge_data:  
        # 内部空间边（包含配体和蛋白质）  
        ligand_total = sum(stats['count'] for stats in edge_data['ligand']['stats'].values())  
        protein_total = sum(stats['count'] for stats in edge_data['protein']['stats'].values())  
          
        print(f"配体内部空间边: {ligand_total} 条")  
        for edge_type, stats in sorted(edge_data['ligand']['stats'].items(), key=lambda x: x[1]['count'], reverse=True):  
            count = stats['count']  
            percentage = count / ligand_total * 100 if ligand_total > 0 else 0  
            print(f"  {edge_type}: {count} 条 ({percentage:.1f}%)")  
            if stats['distances']:  
                distances = stats['distances']  
                print(f"    距离范围: {min(distances):.3f} - {max(distances):.3f} Å")  
                print(f"    平均距离: {np.mean(distances):.3f} Å")  
          
        print(f"\n蛋白质内部空间边: {protein_total} 条")  
        for edge_type, stats in sorted(edge_data['protein']['stats'].items(), key=lambda x: x[1]['count'], reverse=True):  
            count = stats['count']  
            percentage = count / protein_total * 100 if protein_total > 0 else 0  
            print(f"  {edge_type}: {count} 条 ({percentage:.1f}%)")  
            if stats['distances']:  
                distances = stats['distances']  
                print(f"    距离范围: {min(distances):.3f} - {max(distances):.3f} Å")  
                print(f"    平均距离: {np.mean(distances):.3f} Å")  
  
def analyze_pkl_by_edge_types(pkl_file_path):  
    """按边类型分析PKL文件"""  
      
    if not Path(pkl_file_path).exists():  
        print(f"错误: PKL文件不存在: {pkl_file_path}")  
        return  
      
    try:  
        print(f"正在加载PKL文件: {pkl_file_path}")  
        with open(pkl_file_path, 'rb') as f:  
            graphs = pickle.load(f)  
          
        print(f"成功加载 {len(graphs)} 个复合物")  
          
        for idx, graph in enumerate(graphs):  
            print(f"\n{'='*80}")  
            print(f"复合物 {idx + 1}: {graph.get('pdbid', 'Unknown')}")  
            print(f"{'='*80}")  
              
            # 获取基本信息  
            edge_index = graph.get('edge_index', np.array([]))  
            edge_feat = graph.get('edge_feat', np.array([]))  
            num_node = graph.get('num_node', [0, 0])  
            num_ligand_atoms = num_node[0] if len(num_node) > 0 else 0  
              
            print(f"配体原子数: {num_ligand_atoms}")  
            print(f"蛋白质原子数: {num_node[1] if len(num_node) > 1 else 0}")  
            print(f"总边数: {edge_index.shape[1] if len(edge_index.shape) > 1 else 0}")  
              
            if len(edge_index.shape) > 1 and len(edge_feat.shape) > 1:  
                # 分析边类型  
                edge_analysis = analyze_edges_by_category(edge_index, edge_feat, num_ligand_atoms)  
                  
                # 打印三种边类型的详细分析  
                print_edge_category_analysis("共价边（结构边）", edge_analysis['structural_bonds'])  
                print_edge_category_analysis("内部空间边", edge_analysis['internal_spatial'])  
                print_edge_category_analysis("蛋白-配体相互作用边", edge_analysis['interaction_edges'])  
          
        print(f"\n分析完成！")  
          
    except Exception as e:  
        print(f"处理过程中出现错误: {e}")  
        import traceback  
        traceback.print_exc()  
  
if __name__ == "__main__":  
    # 您的PKL文件路径  
    pkl_file_path = "/xcfhome/zncao02/affincraft/data/2aco.pkl"  
      
    analyze_pkl_by_edge_types(pkl_file_path)

正在加载PKL文件: /xcfhome/zncao02/affincraft/data/2aco.pkl
成功加载 1 个复合物

复合物 1: 2aco-VCA
配体原子数: 20
蛋白质原子数: 98
总边数: 1224

=== 共价边（结构边） ===
总边数: 146
  PROTEIN_BOND_(0, 0, 0): 72 条 (49.3%)
    距离范围: 1.423 - 1.602 Å
    平均距离: 1.503 Å
  LIGAND_BOND_(0, 0, 0): 34 条 (23.3%)
    距离范围: 1.251 - 1.536 Å
    平均距离: 1.497 Å
  PROTEIN_BOND_(1, 0, 1): 20 条 (13.7%)
    距离范围: 1.219 - 1.334 Å
    平均距离: 1.243 Å
  PROTEIN_BOND_(0, 0, 1): 16 条 (11.0%)
    距离范围: 1.249 - 1.396 Å
    平均距离: 1.329 Å
  LIGAND_BOND_(1, 0, 0): 4 条 (2.7%)
    距离范围: 1.250 - 1.527 Å
    平均距离: 1.389 Å

=== 内部空间边 ===
配体内部空间边: 102 条
  OTHERS: 102 条 (100.0%)
    距离范围: 2.200 - 4.978 Å
    平均距离: 3.434 Å

蛋白质内部空间边: 886 条
  OTHERS: 624 条 (70.4%)
    距离范围: 2.192 - 4.983 Å
    平均距离: 3.868 Å
  HYDROPHOBIC_CONTACT: 118 条 (13.3%)
    距离范围: 2.373 - 3.991 Å
    平均距离: 2.851 Å
  PI_STACKING: 72 条 (8.1%)
    距离范围: 2.136 - 4.924 Å
    平均距离: 3.648 Å
  HYDROGEN_BOND: 42 条 (4.7%)
    距离范围: 2.233 - 3.703 Å
    平均距离: 2.829 Å
  PI_CATION: 26 条 (2.9%)
    距离范围: 3.164