In [1]:
import torch
from torch_geometric.data import Data
from rdkit import Chem
from rdkit.Chem import AllChem
import numpy as np

def sdf_to_pyg_data(sdf_file):
    # 读取SDF文件
    mol = Chem.SDMolSupplier(sdf_file, removeHs=False)[0]
    if mol is None:
        raise ValueError(f"无法读取SDF文件: {sdf_file}")

    # 提取原子特征
    atom_features = []
    for atom in mol.GetAtoms():
        features = [
            atom.GetAtomicNum(),
            atom.GetChiralTag(),
            atom.GetDegree(),
            atom.GetFormalCharge(),
            atom.GetNumExplicitHs(),
            atom.GetNumRadicalElectrons(),
            atom.GetHybridization(),
            atom.GetIsAromatic(),
            atom.IsInRing()
        ]
        atom_features.append(features)

    # 提取化学键信息和构建边索引
    edge_indices = []
    edge_attributes = []
    for bond in mol.GetBonds():
        i = bond.GetBeginAtomIdx()
        j = bond.GetEndAtomIdx()
        edge_indices += [[i, j], [j, i]]
        
        bond_type = bond.GetBondType()
        bond_stereo = bond.GetStereo()
        is_conjugated = bond.GetIsConjugated()
        
        edge_attributes += [[bond_type, bond_stereo, is_conjugated]] * 2

    # 获取3D坐标
    conf = mol.GetConformer()
    positions = conf.GetPositions()

    # 转换为PyTorch张量
    x = torch.tensor(atom_features, dtype=torch.float)
    edge_index = torch.tensor(edge_indices, dtype=torch.long).t().contiguous()
    edge_attr = torch.tensor(edge_attributes, dtype=torch.float)
    pos = torch.tensor(positions, dtype=torch.float)

    # 创建Data对象
    data = Data(x=x, edge_index=edge_index, edge_attr=edge_attr, pos=pos)

    return data

# 使用示例
if __name__ == "__main__":
    sdf_file = './../Testset/4ytc/4ytc_ligand.sdf'
    try:
        ligand_data = sdf_to_pyg_data(sdf_file)
        print(f"成功创建PyG Data对象:")
        print(f"节点特征: {ligand_data.x.shape}")
        print(f"边索引: {ligand_data.edge_index.shape}")
        print(f"边属性: {ligand_data.edge_attr.shape}")
        print(f"3D坐标: {ligand_data.pos.shape}")
    except Exception as e:
        print(f"处理SDF文件时出错: {str(e)}")

成功创建PyG Data对象:
节点特征: torch.Size([42, 9])
边索引: torch.Size([2, 90])
边属性: torch.Size([90, 3])
3D坐标: torch.Size([42, 3])


