In [131]:
from rdkit import Chem
from torch_geometric.data import Data
import torch
import math
import numpy as np
from tqdm import tqdm
import pickle
import warnings
 
# 忽略未使用变量的警告
warnings.filterwarnings("ignore")


In [132]:
acceptor_patterns = Chem.MolFromSmarts('[$([O;H1;v2]),'
                                       '$([O;H0;v2;!$(O=N-*),'
                                       '$([O;-;!$(*-N=O)]),'
                                       '$([o;+0])]),'
                                       '$([n;+0;!X3;!$([n;H1](cc)cc),'
                                       '$([$([N;H0]#[C&v4])]),'
                                       '$([N&v3;H0;$(Nc)])]),'
                                       '$([F;$(F-[#6]);!$(FC[F,Cl,Br,I])])]')
donor_patterns = Chem.MolFromSmarts('[$([N&!H0&v3,N&!H0&+1&v4,n&H1&+0,$([$([Nv3](-C)(-C)-C)]),'
                                    '$([$(n[n;H1]),'
                                    '$(nc[n;H1])])]),'
                                    '$([NX3,NX2]([!O,!S])!@C(!@[NX3,NX2]([!O,!S]))!@[NX3,NX2]([!O,!S])),'
                                    '$([O,S;H1;+0])]')
basic_patterns = Chem.MolFromSmarts('[$([N;H2&+0][$([C,a]);!$([C,a](=O))]),'
                                    '$([N;H1&+0]([$([C,a]);!$([C,a](=O))])[$([C,a]);!$([C,a](=O))]),'
                                    '$([N;H0&+0]([C;!$(C(=O))])([C;!$(C(=O))])[C;!$(C(=O))]),'
                                    '$([N,n;X2;+0])]')
acidic_patterns = Chem.MolFromSmarts('[CX3](=O)[OX1H0-,OX2H1]')

metal_anums = [3, 4, 11, 12, 13, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29,
               30, 31, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49,
               50, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68,
               69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83,
               87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 100, 101,
               102, 103]
halogens_anums = [9, 17, 35, 53]

In [133]:
def one_hot_encoding_unk(x, allowable_set):
    if x not in allowable_set:
        x = allowable_set[-1]
    return list(map(lambda s: x == s, allowable_set))

In [134]:
def smiles_to_graph(smiles):
    mol = Chem.MolFromSmiles(smiles, removeHs=False)
    donor_idxs = np.array(mol.GetSubstructMatches(donor_patterns, maxMatches=100_000)).flatten()
    acceptor_idxs = np.array(mol.GetSubstructMatches(acceptor_patterns, maxMatches=100_000)).flatten()
    basic_idxs = np.array(mol.GetSubstructMatches(basic_patterns, maxMatches=100_000)).flatten()
    acidic_idxs = np.array(mol.GetSubstructMatches(acidic_patterns, maxMatches=100_000)).flatten()

    edge_index = []
    for bond in mol.GetBonds():
        edge_index.append(np.array([bond.GetBeginAtom().GetIdx(), bond.GetEndAtom().GetIdx()  ]))
        edge_index.append(np.array([bond.GetEndAtom().GetIdx(),   bond.GetBeginAtom().GetIdx()]))         
                          
    features = []
    for atom in mol.GetAtoms():
        idx = atom.GetIdx()
        anum = atom.GetAtomicNum()
        edge_index.append(np.array([idx,   idx]))
        feature = np.array(
            one_hot_encoding_unk(atom.GetSymbol(),['C', 'N', 'O',  'F','S','CI', 'Br', 'P', 'I' ]) +
            [atom.GetMass(), atom.GetDegree(), atom.GetTotalNumHs()]+
             one_hot_encoding_unk(atom.GetHybridization(),  [Chem.rdchem.HybridizationType.SP2, Chem.rdchem.HybridizationType.SP3])+
             [atom.IsInRing(), atom.GetIsAromatic()]+
             [(anum == 6 and np.in1d([aneighbor.GetAtomicNum() for aneighbor in atom.GetNeighbors()], [6, 1, 0]).all()),
             (anum in metal_anums),
             (anum in halogens_anums),
              (idx in donor_idxs and anum != 6),
             (idx in acceptor_idxs and anum != 6)] +
             [int((idx in acidic_idxs and anum != 6) or (atom.GetFormalCharge() < 0))]+
              [int((idx in basic_idxs and anum != 6) or (atom.GetFormalCharge() > 0))]
        )
        features.append(feature)                      
                    
    features = np.array(features)
    edge_index = np.array(edge_index)
    
    g = Data(x = torch.Tensor(features).float(), 
             edge_index=torch.Tensor(edge_index).long()
            )
    return g

In [154]:
def sdf_to_graph(path, name):
    mol =  Chem.SDMolSupplier(f'{path}/{name}', removeHs=False)[0]
    donor_idxs = np.array(mol.GetSubstructMatches(donor_patterns, maxMatches=100_000)).flatten()
    acceptor_idxs = np.array(mol.GetSubstructMatches(acceptor_patterns, maxMatches=100_000)).flatten()
    basic_idxs = np.array(mol.GetSubstructMatches(basic_patterns, maxMatches=100_000)).flatten()
    acidic_idxs = np.array(mol.GetSubstructMatches(acidic_patterns, maxMatches=100_000)).flatten()

    edge_index = []
    for bond in mol.GetBonds():
        edge_index.append(np.array([bond.GetBeginAtom().GetIdx(), bond.GetEndAtom().GetIdx()  ]))
        edge_index.append(np.array([bond.GetEndAtom().GetIdx(),   bond.GetBeginAtom().GetIdx()]))         
                          
    features = []
    for atom in mol.GetAtoms():
        idx = atom.GetIdx()
        anum = atom.GetAtomicNum()
        edge_index.append(np.array([idx,   idx]))
#         print(int((idx in acidic_idxs and anum != 6) or atom.GetFormalCharge() < 0))
        feature = np.array(
            one_hot_encoding_unk(atom.GetSymbol(),['C', 'N', 'O',  'F','S','CI', 'Br', 'P', 'I' ]) +
            [atom.GetMass(), atom.GetDegree(), atom.GetTotalNumHs()]+
             one_hot_encoding_unk(atom.GetHybridization(),  [Chem.rdchem.HybridizationType.SP2, Chem.rdchem.HybridizationType.SP3])+
             [atom.IsInRing(), atom.GetIsAromatic()]+
             [(anum == 6 and np.in1d([aneighbor.GetAtomicNum() for aneighbor in atom.GetNeighbors()], [6, 1, 0]).all()),
             (anum in metal_anums),
             (anum in halogens_anums),
              (idx in donor_idxs and anum != 6),
             (idx in acceptor_idxs and anum != 6)] +
             [int((idx in acidic_idxs and anum != 6) or (atom.GetFormalCharge() < 0))]+
              [int((idx in basic_idxs and anum != 6) or (atom.GetFormalCharge() > 0))]
        )
        features.append(feature)                      
                    
    features = np.array(features)
    edge_index = np.array(edge_index)
    
    g = Data(x = torch.Tensor(features).float(), 
             edge_index=torch.Tensor(edge_index).long()
            )
    return g