# GCPNet 结合位点三图特征提取流水线（使用完整版 gcpnet）

目标：
- 从 PDB 复合物结构 + `binding_sites.csv` 中的关键残基信息出发；
- 为每个 (PDB, 配体) 样本构建三张图：
  - 蛋白 binding 残基图（Cα 节点）；
  - 配体原子图；
  - 蛋白–配体相互作用图；
- 使用你新导入的完整版 `gcpnet` 及其 YAML 配置 (`config_gcpnet_encoder.yaml`) 初始化原版 `GCPNetModel`；
- 对三张图分别编码并拼接，最终导出 `binding_embeddings_triplet.csv`。

In [1]:
# 基础依赖 & GCPNet 特征模块
import sys
from pathlib import Path
from collections import defaultdict
from typing import Dict, List, Tuple

import math
import csv

import numpy as np
import pandas as pd

# PDB 解析
from Bio.PDB import PDBParser

# 数学/图构建
import torch
from torch_geometric.data import Data, Batch

# 添加 gcpnet 父目录到 Python 路径
BASE_DIR = Path(r"c:\Users\Administrator\Desktop\IGEM\stage1\notebook-lab")
if str(BASE_DIR) not in sys.path:
    sys.path.insert(0, str(BASE_DIR))

# gcpnet 特征模块
from gcpnet.features.factory import ProteinFeaturiser
from gcpnet.models.graph_encoders.gcpnet import GCPNetModel

print(f"✓ 成功导入 gcpnet 模块")


  from .autonotebook import tqdm as notebook_tqdm


✓ 成功导入 gcpnet 模块


## Part 1：蛋白口袋图构建与 GCPNet 特征（已实现部分的小结）

这一部分沿用前面已经实现的逻辑：
- 从 PDB 构建蛋白 Cα 图（节点=残基，边=KNN）；
- 根据 binding_sites 标记 binding 残基；
- 使用 ProteinFeaturiser 提取氨基酸 one-hot 与几何边特征；

在后续的高级版本中，我们会在此基础上再接入 GCPNet encoder，对节点进行多层消息传播，然后对节点 embedding 做 pooling 得到蛋白侧的口袋 embedding。

In [2]:
import math
import csv

import numpy as np
import pandas as pd

In [3]:
from pathlib import Path
from collections import defaultdict

import math
import csv

import numpy as np
import pandas as pd

In [4]:
# 路径与全局参数
BASE_DIR = Path(r"c:\Users\Administrator\Desktop\IGEM\stage1\notebook-lab")

PDB_DIR = BASE_DIR / "complex-20251129T063258Z-1-001" / "complex"
BINDING_CSV = BASE_DIR / "binding_sites.csv"

# 输出的结合嵌入结果
BINDING_EMBEDDINGS_CSV = BASE_DIR / "binding_embeddings.csv"

# 图构建参数
K_NEIGHBORS = 16    # KNN 图的 K

# 简单氨基酸 3 字母到整数编码
AA3_TO_ID = {
    "ALA": 0, "CYS": 1, "ASP": 2, "GLU": 3, "PHE": 4,
    "GLY": 5, "HIS": 6, "ILE": 7, "LYS": 8, "LEU": 9,
    "MET": 10, "ASN": 11, "PRO": 12, "GLN": 13, "ARG": 14,
    "SER": 15, "THR": 16, "VAL": 17, "TRP": 18, "TYR": 19,
}
UNKNOWN_AA_ID = len(AA3_TO_ID)

In [5]:
# 读取 binding_sites.csv，并按 (pdb_id, ligand) 分组

def load_binding_sites(csv_path: Path):
    """按 (pdb_id, ligand_resname, ligand_chain, ligand_resnum) 分组 binding 记录。"""
    df = pd.read_csv(csv_path)

    groups = defaultdict(list)
    for _, row in df.iterrows():
        key = (
            str(row["pdb_id"]),
            str(row["ligand_resname"]),
            str(row["ligand_chain"]),
            int(row["ligand_resnum"]),
        )
        groups[key].append(row)

    print(f"共 {len(df)} 条 binding 记录，{len(groups)} 个 (pdb, ligand) 组合。")
    return groups

binding_groups = load_binding_sites(BINDING_CSV)
list(binding_groups.keys())[:5]

共 25626 条 binding 记录，3139 个 (pdb, ligand) 组合。


[('0', 'FAD', 'B', 1),
 ('1', 'UNL', ' ', 1),
 ('100', 'UNL', ' ', 1),
 ('1000', 'UNL', ' ', 1),
 ('1001', 'UNL', ' ', 1)]

In [6]:
# 从 PDB 构建 Cα 图

parser = PDBParser(QUIET=True)


def build_ca_graph_from_pdb(pdb_path: Path, protein_chains=None):
    """从 PDB 构建一个 Cα 图：返回链 ID、残基号、Cα 坐标、氨基酸类型 ID、KNN 边。"""
    structure = parser.get_structure(pdb_path.stem, str(pdb_path))
    model = next(structure.get_models())  # 取第一个 model

    ca_coords = []
    chain_ids = []
    resnums = []
    res_types = []

    for chain in model:
        chain_id = chain.id
        if protein_chains and chain_id not in protein_chains:
            continue

        for residue in chain:
            hetfield = residue.id[0]
            if hetfield.strip():  # 排除 HETATM
                continue

            resname = residue.get_resname().strip()
            if "CA" not in residue:
                continue
            ca = residue["CA"]

            ca_coords.append(ca.coord)
            chain_ids.append(chain_id)
            resnums.append(residue.id[1])
            res_types.append(AA3_TO_ID.get(resname, UNKNOWN_AA_ID))

    if not ca_coords:
        return [], [], None, None, None

    coords = torch.tensor(ca_coords, dtype=torch.float32)  # [N,3]
    residue_type_ids = torch.tensor(res_types, dtype=torch.long)  # [N]

    # KNN 图
    N = coords.shape[0]
    if N == 1:
        edge_index = torch.zeros((2, 0), dtype=torch.long)
    else:
        dist_mat = torch.cdist(coords, coords)
        knn = min(K_NEIGHBORS, N - 1)
        _, knn_idx = torch.topk(-dist_mat, k=knn + 1, dim=-1)  # +1 包含自己

        rows, cols = [], []
        for i in range(N):
            for j in knn_idx[i].tolist():
                if i == j:
                    continue
                rows.append(i)
                cols.append(j)
        edge_index = torch.tensor([rows, cols], dtype=torch.long)

    return chain_ids, resnums, coords, residue_type_ids, edge_index

In [7]:
# 根据 binding_sites 标注图，并构建 PyG Data


def build_pyg_data_for_group(pdb_dir: Path, group_key, group_rows, protein_chains=None):
    """对一个 (pdb, ligand) group：构建 Cα 图并在节点上打 binding 标签。"""
    pdb_id, lig_resname, lig_chain, lig_resnum = group_key
    pdb_path = pdb_dir / f"{pdb_id}.pdb"
    if not pdb_path.exists():
        print(f"PDB 文件不存在：{pdb_path}")
        return None

    node_chain_ids, node_resnums, coords, res_type_ids, edge_index = build_ca_graph_from_pdb(
        pdb_path, protein_chains=protein_chains
    )
    if coords is None:
        print(f"在 {pdb_path} 中未找到蛋白 Cα 节点。")
        return None

    N = coords.shape[0]

    # (chain, resnum) -> node_idx
    index_map = {
        (cid, int(rnum)): i
        for i, (cid, rnum) in enumerate(zip(node_chain_ids, node_resnums))
    }

    # 生成 binding 标签
    y = torch.zeros(N, dtype=torch.long)
    for row in group_rows:
        cid = str(row["protein_chain"])
        rnum = int(row["protein_resnum"])
        idx = index_map.get((cid, rnum))
        if idx is not None:
            y[idx] = 1

    data = Data()
    data.pos = coords               # [N,3]
    data.residue_type = res_type_ids  # [N]
    data.edge_index = edge_index    # [2,E]
    data.y = y                      # [N]

    data.pdb_id = pdb_id
    data.ligand_resname = lig_resname
    data.ligand_chain = lig_chain
    data.ligand_resnum = int(lig_resnum)

    return data

In [8]:
# 将多个 Data 合并为 Batch，并构造 coords/seq_pos 供 ProteinFeaturiser 使用


def to_batch_for_featuriser(data_list):
    """合并为 Batch，并填充 coords / seq_pos 字段。"""
    batch = Batch.from_data_list(data_list)

    pos = batch.pos  # [N,3]
    zeros = torch.zeros_like(pos)
    coords = torch.stack([zeros, pos], dim=1)  # [N,2,3]，index 1 作为 Cα

    batch.coords = coords
    batch.seq_pos = torch.arange(pos.size(0), dtype=torch.long)

    return batch

In [9]:
# “完整版” featuriser 配置：节点 one-hot + 边距离/方向

featuriser = ProteinFeaturiser(
    representation="CA",  # 使用 coords[:,1,:] 作为 Cα
    scalar_node_features=[
        "amino_acid_one_hot",
        # 如需序列位置编码，可解开下一行
        # "sequence_positional_encoding",
    ],
    vector_node_features=[],
    edge_types=["knn_16"],
    scalar_edge_features=["edge_distance"],
    vector_edge_features=["edge_vectors"],
)

In [10]:
# 计算每个 (pdb, ligand) 的 binding embedding 并导出


def compute_binding_embeddings(
    pdb_dir: Path,
    binding_groups: dict,
    max_groups: int | None = 50,
    protein_chains=None,
):
    keys = list(binding_groups.keys())
    if max_groups is not None:
        keys = keys[:max_groups]

    data_list = []
    meta_list = []

    for key in keys:
        group_rows = binding_groups[key]
        data = build_pyg_data_for_group(pdb_dir, key, group_rows, protein_chains)
        if data is None:
            continue
        data_list.append(data)
        meta_list.append(key)

    if not data_list:
        print("没有成功构建的样本。")
        return None

    batch = to_batch_for_featuriser(data_list)
    batch = featuriser(batch)

    x = batch.x       # [总节点数, F]
    y = batch.y       # [总节点数]
    graph_idx = batch.batch  # [总节点数]

    h_list = []
    meta_records = []

    num_graphs = int(graph_idx.max().item()) + 1
    for g in range(num_graphs):
        mask_g = (graph_idx == g)
        x_g = x[mask_g]
        y_g = y[mask_g]

        mask_binding = (y_g > 0)
        if mask_binding.any():
            h_binding = x_g[mask_binding].mean(dim=0)
        else:
            h_binding = x_g.mean(dim=0)

        pdb_id, lig_resname, lig_chain, lig_resnum = meta_list[g]
        meta_records.append({
            "pdb_id": pdb_id,
            "ligand_resname": lig_resname,
            "ligand_chain": lig_chain,
            "ligand_resnum": lig_resnum,
        })
        h_list.append(h_binding.detach().cpu().numpy())

    if not h_list:
        print("未得到任何 embedding。")
        return None

    H = np.stack(h_list, axis=0)
    df_meta = pd.DataFrame(meta_records)
    df_emb = pd.DataFrame(H, columns=[f"feat_{i}" for i in range(H.shape[1])])
    df_out = pd.concat([df_meta, df_emb], axis=1)

    df_out.to_csv(BINDING_EMBEDDINGS_CSV, index=False)
    print(f"已保存 {len(df_out)} 条样本到 {BINDING_EMBEDDINGS_CSV}")

    return df_out


# 实际运行（先只跑前 50 组）
df_binding_emb = compute_binding_embeddings(PDB_DIR, binding_groups, max_groups=50)
df_binding_emb.head()


Creating a tensor from a list of numpy.ndarrays is extremely slow. Please consider converting the list to a single numpy.ndarray with numpy.array() before converting to a tensor. (Triggered internally at C:\actions-runner\_work\pytorch\pytorch\pytorch\torch\csrc\utils\tensor_new.cpp:256.)



已保存 50 条样本到 c:\Users\Administrator\Desktop\IGEM\stage1\notebook-lab\binding_embeddings.csv


Unnamed: 0,pdb_id,ligand_resname,ligand_chain,ligand_resnum,feat_0,feat_1,feat_2,feat_3,feat_4,feat_5,...,feat_13,feat_14,feat_15,feat_16,feat_17,feat_18,feat_19,feat_20,feat_21,feat_22
0,0,FAD,B,1,0.184211,0.026316,0.052632,0.026316,0.052632,0.131579,...,0.026316,0.0,0.078947,0.026316,0.078947,0.0,0.026316,0.0,0.0,0.0
1,1,UNL,,1,0.0,0.0,0.0,0.0,0.0,0.0,...,0.2,0.0,0.0,0.0,0.0,0.0,0.2,0.0,0.0,0.0
2,100,UNL,,1,0.4,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.1,0.0,0.0,0.0,0.1,0.0,0.0,0.0
3,1000,UNL,,1,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.2,0.2,0.0,0.0,0.2,0.0,0.0,0.0
4,1001,UNL,,1,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.166667,0.0,0.166667,0.0,0.0,0.0,0.0,0.0,0.0


In [11]:
# 基础依赖
from pathlib import Path
from collections import defaultdict

import math
import csv

import numpy as np
import pandas as pd

# PDB 解析
from Bio.PDB import PDBParser

# 数学/图构建
import torch
from torch_geometric.data import Data, Batch

# gcpnet 特征模块
from gcpnet.features.factory import ProteinFeaturiser

In [12]:
BASE_DIR = Path(r"c:\Users\Administrator\Desktop\IGEM\stage1\notebook-lab")

PDB_DIR = BASE_DIR / "complex-20251129T063258Z-1-001" / "complex"
BINDING_CSV = BASE_DIR / "binding_sites.csv"

# 输出的 “结合嵌入” 文件
BINDING_EMBEDDINGS_CSV = BASE_DIR / "binding_embeddings.csv"

# 图构建参数
K_NEIGHBORS = 16    # KNN 图的 K
DIST_CUTOFF = 4.0   # （已有 binding_sites 已经用过，此处构图主要看邻近关系）

# 简单的氨基酸 3 字母到整数编码（不全，但够常见）
AA3_TO_ID = {
    "ALA": 0, "CYS": 1, "ASP": 2, "GLU": 3, "PHE": 4,
    "GLY": 5, "HIS": 6, "ILE": 7, "LYS": 8, "LEU": 9,
    "MET": 10, "ASN": 11, "PRO": 12, "GLN": 13, "ARG": 14,
    "SER": 15, "THR": 16, "VAL": 17, "TRP": 18, "TYR": 19,
}

UNKNOWN_AA_ID = len(AA3_TO_ID)

In [13]:
def load_binding_sites(csv_path: Path):
    """
    将 binding_sites.csv 按 (pdb_id, ligand_resname, ligand_chain, ligand_resnum) 分组。

    返回：
    - groups: dict[key -> list[rows]]
      key = (pdb_id, lig_resname, lig_chain, lig_resnum)
      每个元素 row 为 dict，包含蛋白残基的链 / 号等。
    """
    df = pd.read_csv(csv_path)

    groups = defaultdict(list)
    for _, row in df.iterrows():
        key = (
            str(row["pdb_id"]),
            str(row["ligand_resname"]),
            str(row["ligand_chain"]),
            int(row["ligand_resnum"]),
        )
        groups[key].append(row)

    print(f"共 {len(df)} 条 binding 记录，"
          f"分成 {len(groups)} 个 (pdb, ligand) 组。")
    return groups

binding_groups = load_binding_sites(BINDING_CSV)
list(binding_groups.keys())[:5]

共 25626 条 binding 记录，分成 3139 个 (pdb, ligand) 组。


[('0', 'FAD', 'B', 1),
 ('1', 'UNL', ' ', 1),
 ('100', 'UNL', ' ', 1),
 ('1000', 'UNL', ' ', 1),
 ('1001', 'UNL', ' ', 1)]

从 PDB 构建 Cα 图（Code）

In [14]:
parser = PDBParser(QUIET=True)

def build_ca_graph_from_pdb(pdb_path: Path, protein_chains=None):
    """
    从 PDB 构建一个 Cα 图：
    - 节点：指定链上的每个残基（Cα）
    - 特征：暂时只保留 residue_type（以后由 ProteinFeaturiser 补全）
    - 坐标：Cα 的 3D 坐标
    - 边：KNN 邻接（在坐标空间下）

    返回：
    - node_chain_ids: list of chain id
    - node_resnums:  list of residue number
    - coords: FloatTensor [N, 3]
    - residue_type_ids: LongTensor [N]
    - edge_index: LongTensor [2, E]
    """
    structure = parser.get_structure(pdb_path.stem, str(pdb_path))
    model = next(structure.get_models())  # 取第一个 model

    ca_coords = []
    chain_ids = []
    resnums = []
    res_types = []

    for chain in model:
        chain_id = chain.id
        if protein_chains and chain_id not in protein_chains:
            continue

        for residue in chain:
            hetfield = residue.id[0]
            if hetfield.strip():  # 排除 HETATM
                continue

            resname = residue.get_resname().strip()
            # 找到 Cα 原子
            if "CA" not in residue:
                continue
            ca = residue["CA"]

            ca_coords.append(ca.coord)
            chain_ids.append(chain_id)
            resnums.append(residue.id[1])

            res_types.append(AA3_TO_ID.get(resname, UNKNOWN_AA_ID))

    if not ca_coords:
        return [], [], None, None, None

    coords = torch.tensor(ca_coords, dtype=torch.float32)  # [N,3]
    residue_type_ids = torch.tensor(res_types, dtype=torch.long)  # [N]

    # KNN 图
    N = coords.shape[0]
    if N == 1:
        edge_index = torch.zeros((2, 0), dtype=torch.long)
    else:
        # 简单的 O(N^2) 距离，数据规模不大时可以接受
        dist_mat = torch.cdist(coords, coords)  # [N,N]
        knn = min(K_NEIGHBORS, N - 1)
        _, knn_idx = torch.topk(-dist_mat, k=knn+1, dim=-1)  # +1 包含自己

        rows = []
        cols = []
        for i in range(N):
            for j in knn_idx[i].tolist():
                if i == j:
                    continue
                rows.append(i)
                cols.append(j)
        edge_index = torch.tensor([rows, cols], dtype=torch.long)  # [2,E]

    return chain_ids, resnums, coords, residue_type_ids, edge_index

把图打包成 PyG Data，并标注 binding 节点（Code）

In [15]:
def build_pyg_data_for_group(pdb_dir: Path, group_key, group_rows, protein_chains=None):
    """
    对一个 (pdb, ligand) group：
    - 从对应 PDB 构建 Cα 图；
    - 根据 binding_sites 标记 binding 残基；
    - 返回 PyG Data（暂时只包含 pos、residue_type、y）。
    """
    pdb_id, lig_resname, lig_chain, lig_resnum = group_key
    pdb_path = pdb_dir / f"{pdb_id}.pdb"
    if not pdb_path.exists():
        print(f"PDB 文件不存在：{pdb_path}")
        return None

    node_chain_ids, node_resnums, coords, res_type_ids, edge_index = build_ca_graph_from_pdb(
        pdb_path, protein_chains=protein_chains
    )
    if coords is None:
        print(f"在 {pdb_path} 中未找到蛋白 Cα 节点。")
        return None

    N = coords.shape[0]

    # 构建一个 (chain, resnum) -> node_idx 的映射
    index_map = {}
    for i, (cid, rnum) in enumerate(zip(node_chain_ids, node_resnums)):
        index_map[(cid, int(rnum))] = i

    # 生成 binding 标签
    y = torch.zeros(N, dtype=torch.long)
    for row in group_rows:
        cid = str(row["protein_chain"])
        rnum = int(row["protein_resnum"])
        idx = index_map.get((cid, rnum))
        if idx is not None:
            y[idx] = 1

    data = Data()
    data.pos = coords               # [N,3]
    data.residue_type = res_type_ids  # [N]
    data.edge_index = edge_index    # [2,E]
    data.y = y                      # [N], 0/1 binding 标签

    # 额外存一份 metadata（可选）
    data.pdb_id = pdb_id
    data.ligand_resname = lig_resname
    data.ligand_chain = lig_chain
    data.ligand_resnum = int(lig_resnum)

    return data

Cell 7：构建 ProteinFeaturiser 并特征化（Code）
注意：这个精简版 gcpnet 的 
ProteinFeaturiser
 期望有 coords，并通过 representation="CA" 把 coords 的第 1 个原子作为 Cα。我们目前只有 pos（Cα）。为了简单起步，可以先绕过 coords，直接用 pos + 自己写的简单几何特征；或者我们给 coords 伪造一个形状 [N, 2, 3]，让 index 1 是 pos。

这里给出“简单兼容方案”：构造假的 coords，只放一个坐标

In [16]:
def to_batch_for_featuriser(data_list):
    """
    将多个 Data 合并成 Batch，并构造 coords / seq_pos 等字段，让 ProteinFeaturiser 可以工作。

    - coords: [N, 2, 3]，index 1 = Cα 坐标，index 0 = 零向量
    - pos   : [N, 3]    暂时用 Cα 坐标（会被 representation='CA' 再次覆盖）
    - seq_pos: [N]      简单用 0..N-1，当作序列位置（跨图时按 concat 之后的索引）
    """
    batch = Batch.from_data_list(data_list)   # 自动合并 pos, residue_type, edge_index, y

    # 用 batch.pos 作为 Cα 坐标
    pos = batch.pos                           # [N,3]
    zeros = torch.zeros_like(pos)
    coords = torch.stack([zeros, pos], dim=1)  # [N,2,3]，index 1 是 Cα

    batch.coords = coords
    # 这里也可以给一个简单的序列位置索引
    batch.seq_pos = torch.arange(pos.size(0), dtype=torch.long)

    return batch

## Part 2：构建配体图 ligand_graph（原子级图）

目标：
- 对于每个 (pdb, ligand) 样本，从 PDB 中提取该配体残基的所有原子；
- 构建一个原子级别的图：
  - 节点：配体原子（坐标=原子坐标；类型=原子类型 ID）；
  - 边：基于原子坐标的 KNN 图；
- 该图将作为 GCPNet encoder 的输入，用于获得配体 embedding。

In [17]:
# 配体原子类型映射（简单版本）
ATOM_SYMBOL_TO_ID = {
    "C": 0,
    "N": 1,
    "O": 2,
    "S": 3,
    "P": 4,
    "F": 5,
    "Cl": 6,
    "Br": 7,
    "I": 8,
}
LIG_UNKNOWN_ATOM_ID = len(ATOM_SYMBOL_TO_ID)


def build_ligand_graph_from_pdb(
    pdb_path: Path,
    lig_resname: str,
    lig_chain: str,
    lig_resnum: int,
    k_neighbors: int = 8,
) -> Data | None:
    """从 PDB 中为指定配体构建原子级图。

    - 节点：该配体残基中的所有原子；
    - 特征：atom_type（整数 ID）；
    - 坐标：原子在 PDB 中的 3D 坐标；
    - 边：基于坐标的 KNN 图。
    """
    structure = parser.get_structure(pdb_path.stem, str(pdb_path))
    model = next(structure.get_models())

    if lig_chain.strip() == "" or lig_chain.lower() == "nan":
        lig_chain_id = " "
    else:
        lig_chain_id = lig_chain

    atom_coords = []
    atom_types = []

    for chain in model:
        if chain.id != lig_chain_id:
            continue
        for residue in chain:
            hetfield, resseq, icode = residue.id
            if not hetfield.strip():  # 只看 HETATM
                continue
            if residue.get_resname().strip() != lig_resname.strip():
                continue
            if int(resseq) != int(lig_resnum):
                continue

            # 命中了目标配体残基
            for atom in residue:
                coord = atom.coord
                atom_coords.append(coord)

                # 原子符号通常为 atom.element，如果为空则用名字首字母
                symbol = getattr(atom, "element", "").strip()
                if not symbol:
                    name = atom.get_name().strip()
                    symbol = name[0].upper() if name else "C"

                atom_types.append(ATOM_SYMBOL_TO_ID.get(symbol, LIG_UNKNOWN_ATOM_ID))

    if not atom_coords:
        return None

    coords = torch.tensor(np.stack(atom_coords, axis=0), dtype=torch.float32)  # [N,3]
    atom_type_ids = torch.tensor(atom_types, dtype=torch.long)

    N = coords.shape[0]
    if N == 1:
        edge_index = torch.zeros((2, 0), dtype=torch.long)
    else:
        dist_mat = torch.cdist(coords, coords)
        k = min(k_neighbors, N - 1)
        _, knn_idx = torch.topk(-dist_mat, k=k + 1, dim=-1)

        rows, cols = [], []
        for i in range(N):
            for j in knn_idx[i].tolist():
                if i == j:
                    continue
                rows.append(i)
                cols.append(j)
        edge_index = torch.tensor([rows, cols], dtype=torch.long)

    lig = Data()
    lig.pos = coords
    lig.atom_type = atom_type_ids
    lig.edge_index = edge_index

    return lig

## Part 3：构建蛋白–配体相互作用图 interaction_graph

目标：
- 在同一张图中同时包含：
  - binding 残基的 Cα（蛋白侧节点）；
  - 配体的原子（配体侧节点）；
- 边包括：
  - 残基–残基（binding 子图的 KNN）；
  - 残基–配体（每个 binding 残基连接到若干最近的配体原子）；
  - （可选）配体–配体（沿用配体图中的边）。

这张图将捕获真正的“蛋白–配体对”的几何关系。

In [18]:
def build_interaction_graph(
    protein_data: Data,
    ligand_data: Data,
    k_residue_residue: int = 8,
    k_residue_ligand: int = 4,
) -> Data:
    """构建蛋白–配体相互作用图。

    - 节点：binding 残基 Cα + 配体原子；
    - 特征：
      - 对于蛋白节点：使用 residue_type 作为类型 ID；
      - 对于配体节点：使用 atom_type 作为类型 ID（平移到同一编号空间）；
    - 边：残基–残基 KNN、残基–配体最近邻、配体–配体 KNN（复用 ligand_graph）。
    """
    # 1) 取蛋白侧 binding 残基节点
    pos_prot = protein_data.pos  # [N_res,3]
    res_type = protein_data.residue_type  # [N_res]
    y = protein_data.y  # [N_res]

    binding_mask = (y > 0)
    if not binding_mask.any():
        # 如果没有 binding 残基，退化为仅配体图
        inter = ligand_data.clone()
        inter.node_type = torch.ones(inter.pos.size(0), dtype=torch.long)  # 全部视为配体节点
        return inter

    pos_bind = pos_prot[binding_mask]        # [Nb,3]
    res_type_bind = res_type[binding_mask]   # [Nb]

    # 2) 配体原子
    pos_lig = ligand_data.pos               # [Nl,3]
    atom_type = ligand_data.atom_type       # [Nl]

    # 3) 统一节点特征空间
    #   简单做法：蛋白残基类型 ID 保持原样，配体原子类型 ID 平移一个偏移量，
    #   之后在 GNN 中用 one-hot 编码。
    residue_type_offset = 0
    atom_type_offset = res_type_bind.max().item() + 1

    node_type_ids = torch.cat(
        [
            res_type_bind + residue_type_offset,
            atom_type + atom_type_offset,
        ],
        dim=0,
    )  # [Nb + Nl]

    # 4) 组合坐标
    pos_all = torch.cat([pos_bind, pos_lig], dim=0)  # [Nb+Nl,3]
    Nb = pos_bind.size(0)
    Nl = pos_lig.size(0)

    # 5) 残基–残基 KNN
    if Nb > 1:
        dist_rr = torch.cdist(pos_bind, pos_bind)
        k_rr = min(k_residue_residue, Nb - 1)
        _, knn_idx_rr = torch.topk(-dist_rr, k=k_rr + 1, dim=-1)
        rr_rows, rr_cols = [], []
        for i in range(Nb):
            for j in knn_idx_rr[i].tolist():
                if i == j:
                    continue
                rr_rows.append(i)
                rr_cols.append(j)
        edge_rr = torch.tensor([rr_rows, rr_cols], dtype=torch.long)
    else:
        edge_rr = torch.zeros((2, 0), dtype=torch.long)

    # 6) 残基–配体最近邻
    if Nb > 0 and Nl > 0:
        dist_rl = torch.cdist(pos_bind, pos_lig)  # [Nb, Nl]
        k_rl = min(k_residue_ligand, Nl)
        _, knn_idx_rl = torch.topk(-dist_rl, k=k_rl, dim=-1)
        rl_rows, rl_cols = [], []
        for i in range(Nb):
            for j in knn_idx_rl[i].tolist():
                rl_rows.append(i)
                rl_cols.append(Nb + j)  # 配体节点索引偏移 Nb
        edge_rl = torch.tensor([rl_rows, rl_cols], dtype=torch.long)
    else:
        edge_rl = torch.zeros((2, 0), dtype=torch.long)

    # 7) 配体–配体边：直接复用 ligand_data.edge_index，并整体索引偏移 Nb
    if ligand_data.edge_index.numel() > 0:
        edge_ll = ligand_data.edge_index + Nb
    else:
        edge_ll = torch.zeros((2, 0), dtype=torch.long)

    # 合并所有边
    edge_index = torch.cat([edge_rr, edge_rl, edge_ll], dim=1)

    inter = Data()
    inter.pos = pos_all
    inter.node_type_id = node_type_ids
    inter.edge_index = edge_index

    return inter

## Part 4：使用完整版 GCPNet encoder 对三类图进行编码

这里我们**不再裁剪 gcpnet 源码**，而是：

- 使用你提供的 `config_gcpnet_encoder.yaml` 中的 `encoder.module_cfg/model_cfg/layer_cfg`；
- 用 `OmegaConf` 读取这些配置，初始化原版 `GCPNetModel`；
- 按照 GCPNet 预期的字段组织三种图的 `Batch`：
  - `h` / `chi`（节点标量/向量特征）；
  - `e` / `xi`（边标量/向量特征）；
  - `pos`（节点坐标）、`edge_index`（边列表）、`batch`（图索引）；
- 分别定义：
  - `encode_protein_graph`：对蛋白图进行编码，并在 binding 残基上 pooling；
  - `encode_ligand_graph`：对配体图进行编码，并在所有原子上 pooling；
  - `encode_interaction_graph`：对蛋白–配体相互作用图进行编码，并在所有节点上 pooling。

In [29]:
from omegaconf import OmegaConf
from pprint import pprint

CFG_PATH = BASE_DIR / "config_gcpnet_encoder.yaml"
cfg = OmegaConf.load(str(CFG_PATH))
pprint(cfg)


{'features': {'module': 'models.gcpnet.features.factory.ProteinFeaturiser', 'kwargs': {'representation': 'CA', 'scalar_node_features': ['amino_acid_one_hot', 'sequence_positional_encoding', 'alpha', 'kappa', 'dihedrals'], 'vector_node_features': ['orientation'], 'edge_types': ['knn_16'], 'scalar_edge_features': ['edge_distance'], 'vector_edge_features': ['edge_vectors']}}, 'task': {'transform': None}, 'encoder': {'module': 'models.gcpnet.models.graph_encoders.gcpnet.GCPNetModel', 'kwargs': {'num_layers': 6, 'emb_dim': 128, 'node_s_emb_dim': 128, 'node_v_emb_dim': 16, 'edge_s_emb_dim': 32, 'edge_v_emb_dim': 4, 'r_max': 10.0, 'num_rbf': 8, 'activation': 'silu', 'pool': 'sum', 'module_cfg': {'norm_pos_diff': True, 'scalar_gate': 0, 'vector_gate': True, 'scalar_nonlinearity': 'silu', 'vector_nonlinearity': 'silu', 'nonlinearities': ['silu', 'silu'], 'r_max': 10.0, 'num_rbf': 8, 'bottleneck': 4, 'vector_linear': True, 'vector_identity': True, 'default_bottleneck': 4, 'predict_node_positions

In [31]:
# 4.1 从 YAML 加载完整版 GCPNet encoder 配置并初始化模型

from omegaconf import OmegaConf
from types import SimpleNamespace

CFG_PATH = BASE_DIR / "config_gcpnet_encoder.yaml"

cfg = OmegaConf.load(str(CFG_PATH))

encoder_cfg = cfg.encoder

# 先转换为普通字典避免循环引用
enc_kwargs_dict = OmegaConf.to_container(encoder_cfg.kwargs, resolve=True)


def dict_to_namespace(d):
    """递归地将字典转换为 SimpleNamespace，支持属性访问但避免 OmegaConf 的循环引用问题"""
    if isinstance(d, dict):
        return SimpleNamespace(**{k: dict_to_namespace(v) for k, v in d.items()})
    elif isinstance(d, list):
        return [dict_to_namespace(item) for item in d]
    else:
        return d


# 将嵌套配置转换为 SimpleNamespace（支持 .attribute 访问，但避免递归问题）
enc_kwargs = {
    'num_layers': enc_kwargs_dict['num_layers'],
    'emb_dim': enc_kwargs_dict['emb_dim'],
    'node_s_emb_dim': enc_kwargs_dict['node_s_emb_dim'],
    'node_v_emb_dim': enc_kwargs_dict['node_v_emb_dim'],
    'edge_s_emb_dim': enc_kwargs_dict['edge_s_emb_dim'],
    'edge_v_emb_dim': enc_kwargs_dict['edge_v_emb_dim'],
    'r_max': enc_kwargs_dict['r_max'],
    'num_rbf': enc_kwargs_dict['num_rbf'],
    'activation': enc_kwargs_dict['activation'],
    'pool': enc_kwargs_dict['pool'],
    'module_cfg': dict_to_namespace(enc_kwargs_dict['module_cfg']),
    'model_cfg': dict_to_namespace(enc_kwargs_dict['model_cfg']),
    'layer_cfg': dict_to_namespace(enc_kwargs_dict['layer_cfg']),
}

# 初始化 GCPNet 编码器
full_gcpnet_encoder = GCPNetModel(
    **enc_kwargs
).eval()

print(f"✓ GCPNet 编码器初始化成功")
print(f"  - 层数: {enc_kwargs['num_layers']}")
print(f"  - 节点标量维度: {enc_kwargs['node_s_emb_dim']}")
print(f"  - 节点向量维度: {enc_kwargs['node_v_emb_dim']}")

# 创建一个简化的 featuriser（不使用需要 _slice_dict 的 vector node features）
featuriser_for_protein = ProteinFeaturiser(
    representation="CA",
    scalar_node_features=["amino_acid_one_hot"],
    vector_node_features=[],  # 移除 orientation 等需要 _slice_dict 的特征
    edge_types=["knn_16"],
    scalar_edge_features=["edge_distance"],
    vector_edge_features=["edge_vectors"],
)

print(f"✓ 蛋白质 Featuriser 初始化成功")


def _build_gcpnet_batch_from_featurised_batch(batch: Batch) -> Batch:
    """将 ProteinFeaturiser 产生的字段整理为 GCPNetModel 期望的命名。

    约定：
    - featuriser(batch) 后：
      * batch.x             -> 节点标量特征 h
      * batch.x_vector_attr -> 节点向量特征 chi（如果有）
      * batch.edge_attr     -> 边标量特征 e
      * batch.edge_vector_attr -> 边向量特征 xi
      * batch.pos / edge_index / batch 已经就绪
    """
    batch.h = batch.x
    # 如果没有向量特征，创建一个空的
    batch.chi = getattr(batch, "x_vector_attr", torch.zeros(batch.x.size(0), 0, 3, device=batch.x.device))
    batch.e = getattr(batch, "edge_attr", torch.zeros(batch.edge_index.size(1), 1, device=batch.x.device))
    batch.xi = getattr(batch, "edge_vector_attr", torch.zeros(batch.edge_index.size(1), 0, 3, device=batch.x.device))
    return batch


def _pool_by_graph(node_emb: torch.Tensor, batch_index: torch.Tensor, reduce: str = "mean") -> torch.Tensor:
    """简单的按图 pooling，使用 PyTorch 实现 mean/sum。"""
    num_graphs = int(batch_index.max().item()) + 1 if batch_index.numel() > 0 else 0
    if num_graphs == 0:
        return node_emb.new_zeros((0, node_emb.size(-1)))
    out = []
    for g in range(num_graphs):
        m = (batch_index == g)
        if not m.any():
            out.append(node_emb.new_zeros((node_emb.size(-1),)))
        else:
            if reduce == "sum":
                out.append(node_emb[m].sum(dim=0))
            else:
                out.append(node_emb[m].mean(dim=0))
    return torch.stack(out, dim=0)


def encode_protein_graph(batch: Batch) -> torch.Tensor:
    """对蛋白 Batch 进行编码：对 binding 残基节点做 mean pooling。"""
    # 使用简化的 featuriser
    batch = featuriser_for_protein(batch)
    batch = _build_gcpnet_batch_from_featurised_batch(batch)

    # 前向：完整版 GCPNetModel 返回 EncoderOutput，包含 "node_embedding" / "graph_embedding"
    with torch.no_grad():
        enc_out = full_gcpnet_encoder(batch)
    node_emb = enc_out["node_embedding"]  # [N_total, D]

    y = batch.y
    graph_idx = batch.batch

    # 先按 binding 掩码池化；无 binding 时退化为所有节点 mean
    num_graphs = int(graph_idx.max().item()) + 1
    h_list = []
    for g in range(num_graphs):
        mask_g = (graph_idx == g)
        x_g = node_emb[mask_g]
        y_g = y[mask_g]
        if x_g.numel() == 0:
            h_list.append(node_emb.new_zeros(enc_kwargs['node_s_emb_dim']))
            continue
        mask_binding = (y_g > 0)
        if mask_binding.any():
            h = x_g[mask_binding].mean(dim=0)
        else:
            h = x_g.mean(dim=0)
        h_list.append(h)

    return torch.stack(h_list, dim=0)

print(f"✓ 编码函数定义完成")


✓ GCPNet 编码器初始化成功
  - 层数: 6
  - 节点标量维度: 128
  - 节点向量维度: 16
✓ 蛋白质 Featuriser 初始化成功
✓ 编码函数定义完成


In [32]:
# 4.1 从 YAML 加载完整版 GCPNet encoder 配置并初始化模型

from omegaconf import OmegaConf
from types import SimpleNamespace

CFG_PATH = BASE_DIR / "config_gcpnet_encoder.yaml"

cfg = OmegaConf.load(str(CFG_PATH))
encoder_cfg = cfg.encoder

# 先转换为普通字典避免循环引用
enc_kwargs_dict = OmegaConf.to_container(encoder_cfg.kwargs, resolve=True)

print("原始配置文件中的输入维度：")
print(f"  - h_input_dim: {enc_kwargs_dict['model_cfg']['h_input_dim']}")
print(f"  - chi_input_dim: {enc_kwargs_dict['model_cfg']['chi_input_dim']}")
print(f"  - e_input_dim: {enc_kwargs_dict['model_cfg']['e_input_dim']}")
print(f"  - xi_input_dim: {enc_kwargs_dict['model_cfg']['xi_input_dim']}")


def dict_to_namespace(d):
    """递归地将字典转换为 SimpleNamespace，支持属性访问但避免 OmegaConf 的循环引用问题"""
    if isinstance(d, dict):
        return SimpleNamespace(**{k: dict_to_namespace(v) for k, v in d.items()})
    elif isinstance(d, list):
        return [dict_to_namespace(item) for item in d]
    else:
        return d


# 创建一个简化的 featuriser（不使用需要 _slice_dict 的 vector node features）
featuriser_for_protein = ProteinFeaturiser(
    representation="CA",
    scalar_node_features=["amino_acid_one_hot"],
    vector_node_features=[],  # 移除 orientation 等需要 _slice_dict 的特征
    edge_types=["knn_16"],
    scalar_edge_features=["edge_distance"],
    vector_edge_features=["edge_vectors"],
)

print(f"\n✓ 蛋白质 Featuriser 初始化成功")

# 测试 featuriser 输出维度
print("\n检测 featuriser 实际输出维度...")
test_data = Data()
test_data.pos = torch.randn(5, 3)
test_data.residue_type = torch.randint(0, 20, (5,))
test_data.edge_index = torch.tensor([[0, 1, 2], [1, 2, 3]], dtype=torch.long)
test_data.y = torch.zeros(5, dtype=torch.long)

test_batch = Batch.from_data_list([test_data])
test_batch.coords = torch.stack([torch.zeros_like(test_batch.pos), test_batch.pos], dim=1)
test_batch.seq_pos = torch.arange(test_batch.pos.size(0), dtype=torch.long)

test_batch = featuriser_for_protein(test_batch)
actual_h_dim = test_batch.x.size(-1)
actual_chi_dim = test_batch.x_vector_attr.size(1) if hasattr(test_batch, 'x_vector_attr') and test_batch.x_vector_attr is not None else 0

print(f"  - 实际节点标量特征维度 (h): {actual_h_dim}")
print(f"  - 实际节点向量特征维度 (chi): {actual_chi_dim}")
print(f"  - 实际边标量特征维度 (e): 1 → 9 (GCPNet内部会添加8维RBF)")
print(f"  - 实际边向量特征维度 (xi): 1")

# 更新 model_cfg 以匹配实际特征维度
print(f"\n更新配置以匹配 featuriser 输出...")
enc_kwargs_dict['model_cfg']['h_input_dim'] = actual_h_dim
enc_kwargs_dict['model_cfg']['chi_input_dim'] = actual_chi_dim
# e_input_dim 和 xi_input_dim 保持不变（GCPNet会自动处理）

print(f"✓ 已更新配置:")
print(f"  - h_input_dim: 49 → {actual_h_dim}")
print(f"  - chi_input_dim: 2 → {actual_chi_dim}")
print(f"  - e_input_dim: 9 (保持不变)")
print(f"  - xi_input_dim: 1 (保持不变)")

# 将嵌套配置转换为 SimpleNamespace
enc_kwargs = {
    'num_layers': enc_kwargs_dict['num_layers'],
    'emb_dim': enc_kwargs_dict['emb_dim'],
    'node_s_emb_dim': enc_kwargs_dict['node_s_emb_dim'],
    'node_v_emb_dim': enc_kwargs_dict['node_v_emb_dim'],
    'edge_s_emb_dim': enc_kwargs_dict['edge_s_emb_dim'],
    'edge_v_emb_dim': enc_kwargs_dict['edge_v_emb_dim'],
    'r_max': enc_kwargs_dict['r_max'],
    'num_rbf': enc_kwargs_dict['num_rbf'],
    'activation': enc_kwargs_dict['activation'],
    'pool': enc_kwargs_dict['pool'],
    'module_cfg': dict_to_namespace(enc_kwargs_dict['module_cfg']),
    'model_cfg': dict_to_namespace(enc_kwargs_dict['model_cfg']),
    'layer_cfg': dict_to_namespace(enc_kwargs_dict['layer_cfg']),
}

# 初始化 GCPNet 编码器（在更新配置之后）
print(f"\n初始化 GCPNet 编码器...")
full_gcpnet_encoder = GCPNetModel(
    **enc_kwargs
).eval()

print(f"✓ GCPNet 编码器初始化成功")
print(f"  - 层数: {enc_kwargs['num_layers']}")
print(f"  - 节点标量嵌入维度: {enc_kwargs['node_s_emb_dim']}")
print(f"  - 节点向量嵌入维度: {enc_kwargs['node_v_emb_dim']}")


def _ensure_batch_vector_attrs(batch: Batch) -> Batch:
    """确保 batch 有 GCPNetModel 期望的所有向量属性。"""
    if not hasattr(batch, "x_vector_attr") or batch.x_vector_attr is None:
        batch.x_vector_attr = torch.zeros(batch.x.size(0), 0, 3, device=batch.x.device)
    if not hasattr(batch, "edge_vector_attr") or batch.edge_vector_attr is None:
        batch.edge_vector_attr = torch.zeros(batch.edge_index.size(1), 0, 3, device=batch.x.device)
    return batch


def _pool_by_graph(node_emb: torch.Tensor, batch_index: torch.Tensor, reduce: str = "mean") -> torch.Tensor:
    """简单的按图 pooling，使用 PyTorch 实现 mean/sum。"""
    num_graphs = int(batch_index.max().item()) + 1 if batch_index.numel() > 0 else 0
    if num_graphs == 0:
        return node_emb.new_zeros((0, node_emb.size(-1)))
    out = []
    for g in range(num_graphs):
        m = (batch_index == g)
        if not m.any():
            out.append(node_emb.new_zeros((node_emb.size(-1),)))
        else:
            if reduce == "sum":
                out.append(node_emb[m].sum(dim=0))
            else:
                out.append(node_emb[m].mean(dim=0))
    return torch.stack(out, dim=0)


def encode_protein_graph(batch: Batch) -> torch.Tensor:
    """对蛋白 Batch 进行编码：对 binding 残基节点做 mean pooling。"""
    # 使用简化的 featuriser
    batch = featuriser_for_protein(batch)
    
    # 确保 batch 有所有必需的向量属性
    batch = _ensure_batch_vector_attrs(batch)

    # 前向：完整版 GCPNetModel 返回 EncoderOutput，包含 "node_embedding" / "graph_embedding"
    with torch.no_grad():
        enc_out = full_gcpnet_encoder(batch)
    node_emb = enc_out["node_embedding"]  # [N_total, D]

    y = batch.y
    graph_idx = batch.batch

    # 先按 binding 掩码池化；无 binding 时退化为所有节点 mean
    num_graphs = int(graph_idx.max().item()) + 1
    h_list = []
    for g in range(num_graphs):
        mask_g = (graph_idx == g)
        x_g = node_emb[mask_g]
        y_g = y[mask_g]
        if x_g.numel() == 0:
            h_list.append(node_emb.new_zeros(enc_kwargs['node_s_emb_dim']))
            continue
        mask_binding = (y_g > 0)
        if mask_binding.any():
            h = x_g[mask_binding].mean(dim=0)
        else:
            h = x_g.mean(dim=0)
        h_list.append(h)

    return torch.stack(h_list, dim=0)

print(f"\n✓ 编码函数定义完成")


原始配置文件中的输入维度：
  - h_input_dim: 49
  - chi_input_dim: 2
  - e_input_dim: 9
  - xi_input_dim: 1

✓ 蛋白质 Featuriser 初始化成功

检测 featuriser 实际输出维度...
  - 实际节点标量特征维度 (h): 23
  - 实际节点向量特征维度 (chi): 0
  - 实际边标量特征维度 (e): 1 → 9 (GCPNet内部会添加8维RBF)
  - 实际边向量特征维度 (xi): 1

更新配置以匹配 featuriser 输出...
✓ 已更新配置:
  - h_input_dim: 49 → 23
  - chi_input_dim: 2 → 0
  - e_input_dim: 9 (保持不变)
  - xi_input_dim: 1 (保持不变)

初始化 GCPNet 编码器...
✓ GCPNet 编码器初始化成功
  - 层数: 6
  - 节点标量嵌入维度: 128
  - 节点向量嵌入维度: 16

✓ 编码函数定义完成


In [33]:
# 4.2 定义 encode_ligand_graph：对配体原子图进行编码

def encode_ligand_graph(ligand_data_list: List[Data]) -> torch.Tensor:
    """对一组配体图编码，对每图的所有原子节点做 mean pooling。
    
    - 节点特征：对 `atom_type` 做 one-hot，然后 pad 到与蛋白特征相同的维度；
    - 边特征：用坐标差构造距离 + 单位向量。
    """
    batch = Batch.from_data_list(ligand_data_list)
    
    # 构造节点特征：atom_type one-hot
    num_atom_types = int(batch.atom_type.max().item()) + 1
    x_onehot = torch.nn.functional.one_hot(batch.atom_type, num_classes=num_atom_types).float()
    
    # Pad 到与蛋白特征相同的维度
    target_dim = enc_kwargs['model_cfg'].h_input_dim
    if x_onehot.size(-1) < target_dim:
        padding = torch.zeros(x_onehot.size(0), target_dim - x_onehot.size(-1), device=x_onehot.device)
        batch.x = torch.cat([x_onehot, padding], dim=-1)
    elif x_onehot.size(-1) > target_dim:
        # 如果维度过大，截断或报错
        batch.x = x_onehot[:, :target_dim]
    else:
        batch.x = x_onehot
    
    # 构造边特征：距离 + 方向向量
    pos = batch.pos  # [N, 3]
    row, col = batch.edge_index  # [2, E]
    diff = pos[row] - pos[col]  # [E, 3]
    dist = torch.norm(diff, dim=-1, keepdim=True)  # [E, 1]
    unit = diff / (dist + 1e-8)  # [E, 3] 单位向量
    
    # 设置 GCPNetModel 期望的属性
    batch.edge_attr = dist  # [E, 1] 标量边特征
    batch.edge_vector_attr = unit.unsqueeze(-2)  # [E, 1, 3] 向量边特征
    batch.x_vector_attr = torch.zeros(batch.x.size(0), 0, 3, device=batch.x.device)  # 配体无向量节点特征
    
    # 前向编码
    with torch.no_grad():
        enc_out = full_gcpnet_encoder(batch)
    node_emb = enc_out["node_embedding"]  # [N_total, D]
    
    # 按图 pooling
    return _pool_by_graph(node_emb, batch.batch, reduce="mean")


## Part 5：整合三类图并导出三重 embedding

这一部分将前面构建好的三类图（蛋白 binding 残基图、配体原子图、蛋白–配体相互作用图）通过完整版 GCPNet encoder 编码为三个向量，并在 `compute_triplet_embeddings` 中进行整合：

- **蛋白图编码**：先用 `to_batch_for_featuriser` + `encode_protein_graph` 得到蛋白口袋 embedding；
- **配体图编码**：用 `encode_ligand_graph` 对配体原子图做 mean pooling 得到配体 embedding；
- **相互作用图编码**：用 `encode_interaction_graph` 对蛋白–配体相互作用图做 pooling 得到复合物级 embedding；
- 最后将三者拼接，并写入 `binding_embeddings_triplet.csv`，作为下游模型/分析的输入特征。

In [35]:
run debug_gcpnet.py

步骤 1: 检测 ProteinFeaturiser 实际输出维度
✓ Featuriser 实际输出维度:
  - 节点标量特征 (h): 23
  - 节点向量特征 (chi): 0
  - 边标量特征 (e): 1
  - 边向量特征 (xi): 1

步骤 2: 检查 YAML 配置文件中的默认维度
配置文件中的默认维度:
  - h_input_dim: 49
  - chi_input_dim: 2
  - e_input_dim: 9
  - xi_input_dim: 1

步骤 3: 修正配置并初始化模型

更新配置维度以匹配 featuriser 输出...
⚠️  注意：GCPNet 会自动将边特征 (1维) + RBF展开 (8维) = 9维
   所以 e_input_dim 应保持为 9，不要修改为 1！
✓ 更新后的配置维度:
  - h_input_dim: 23 (修改: 49 → 23)
  - chi_input_dim: 0 (修改: 2 → 0)
  - e_input_dim: 9 (保持不变: 1+8 RBF)
  - xi_input_dim: 1 (保持不变)

初始化 GCPNet 模型...
✓ 模型初始化成功！

步骤 4: 测试前向传播
测试 batch 的属性:
  - x shape: torch.Size([10, 23])
  - x_vector_attr shape: torch.Size([10, 0, 3])
  - edge_attr shape: torch.Size([4, 1])
  - edge_vector_attr shape: torch.Size([4, 1, 3])
✓ 前向传播成功！
  - node_embedding shape: torch.Size([10, 128])
  - graph_embedding shape: torch.Size([1, 128])

调试完成！所有测试通过。


In [36]:
TRIPLET_EMBEDDINGS_CSV = BASE_DIR / "binding_embeddings_triplet.csv"


def compute_triplet_embeddings(
    pdb_dir: Path,
    binding_groups: dict,
    max_groups: int | None = 50,
    protein_chains=None,
):
    keys = list(binding_groups.keys())
    if max_groups is not None:
        keys = keys[:max_groups]

    protein_data_list = []
    ligand_data_list = []
    inter_data_list = []
    meta_list = []

    for key in keys:
        group_rows = binding_groups[key]
        pdb_id, lig_resname, lig_chain, lig_resnum = key
        pdb_path = pdb_dir / f"{pdb_id}.pdb"

        protein_data = build_pyg_data_for_group(pdb_dir, key, group_rows, protein_chains)
        if protein_data is None:
            continue

        ligand_data = build_ligand_graph_from_pdb(pdb_path, lig_resname, lig_chain, lig_resnum)
        if ligand_data is None:
            continue

        inter_data = build_interaction_graph(protein_data, ligand_data)

        protein_data_list.append(protein_data)
        ligand_data_list.append(ligand_data)
        inter_data_list.append(inter_data)
        meta_list.append(key)

    if not protein_data_list:
        print("没有成功构建的样本。")
        return None

    # 1) 蛋白图编码（先通过 featuriser，再走完整 GCPNet encoder）
    protein_batch = to_batch_for_featuriser(protein_data_list)
    h_protein = encode_protein_graph(protein_batch)  # [B, Dp]

    # 2) 配体图编码
    h_ligand = encode_ligand_graph(ligand_data_list)  # [B, Dl]

    # 3) 相互作用图编码
    h_inter = encode_interaction_graph(inter_data_list)  # [B, Di]

    # 拼接三者
    H_full = torch.cat([h_protein, h_ligand, h_inter], dim=1).detach().cpu().numpy()

    # 构建 DataFrame
    records = []
    for (pdb_id, lig_resname, lig_chain, lig_resnum) in meta_list:
        records.append({
            "pdb_id": pdb_id,
            "ligand_resname": lig_resname,
            "ligand_chain": lig_chain,
            "ligand_resnum": lig_resnum,
        })

    df_meta = pd.DataFrame(records)
    feat_cols = [f"feat_{i}" for i in range(H_full.shape[1])]
    df_feat = pd.DataFrame(H_full, columns=feat_cols)
    df_out = pd.concat([df_meta, df_feat], axis=1)
    df_out.to_csv(TRIPLET_EMBEDDINGS_CSV, index=False)
    print(f"已保存 {len(df_out)} 条三重 embedding 到 {TRIPLET_EMBEDDINGS_CSV}")

    return df_out


# 示例运行：先对前 50 个样本构建三图并编码
# （注意：这比之前的单图版本更重一点）
df_triplet = compute_triplet_embeddings(PDB_DIR, binding_groups, max_groups=50)
df_triplet.head()

已保存 50 条三重 embedding 到 c:\Users\Administrator\Desktop\IGEM\stage1\notebook-lab\binding_embeddings_triplet.csv


Unnamed: 0,pdb_id,ligand_resname,ligand_chain,ligand_resnum,feat_0,feat_1,feat_2,feat_3,feat_4,feat_5,...,feat_374,feat_375,feat_376,feat_377,feat_378,feat_379,feat_380,feat_381,feat_382,feat_383
0,0,FAD,B,1,-0.160192,0.037487,0.343853,0.107562,-0.045167,-0.109482,...,0.01529,0.039096,0.030261,0.059242,0.183268,-0.042948,0.263075,0.095507,-0.121052,0.07266
1,1,UNL,,1,-0.140156,0.044201,0.377216,0.274628,-0.066483,0.099652,...,0.044303,-0.060594,0.118736,-0.085967,0.095112,-0.050972,0.074142,-0.077221,0.015196,0.4545
2,100,UNL,,1,-0.174216,0.054912,0.565381,0.199195,0.056334,-0.095419,...,0.074319,-0.054754,0.080264,-0.049534,0.192698,-0.046946,0.171882,0.023477,-0.068335,0.349761
3,1000,UNL,,1,-0.210572,0.193983,0.098806,-0.067299,0.027583,-0.232041,...,0.047724,0.078645,-0.045543,0.01125,0.219708,0.067847,0.251348,0.196277,-0.032712,0.107715
4,1001,UNL,,1,-0.139574,0.238269,0.317292,0.02982,0.067032,0.024362,...,0.069117,0.075436,0.151512,-0.037952,0.214505,0.021167,0.288775,0.015469,-0.017705,0.316339
