In [6]:
import pickle
import torch
import numpy as np
import json
from tqdm import tqdm

import dgl
from rdkit import Chem

from tqdm import tqdm

In [2]:
# one_hot
def one_hot(char, dict):  #
    
    h = torch.zeros(len(dict)+1)
    for i in range(len(dict)):
        if dict[i] == char:
            h[i] = 1
            break
            
        elif i == len(dict)-1:
            h[-1] = 1
    
    return h


def smi_2_graph(smi):
    
    mol = Chem.MolFromSmiles(smi)
    nums_atom = mol.GetNumAtoms()
    # 边
    u, v = [], []
    for bond in mol.GetBonds():
        begin = bond.GetBeginAtom()
        end = bond.GetEndAtom()
        u.append(begin.GetIdx())
        v.append(end.GetIdx())
        
    u, v = torch.tensor(u), torch.tensor(v)
    g = dgl.graph((u, v))
    
    #节点特征
    feats = []
    for atom in mol.GetAtoms():
        
        feat_0 = one_hot(atom.GetSymbol(), ['C', 'N', 'O', 'S', 'F', 'Si', 'P', 'Cl', 'Br', 'Mg', 'Na','Ca',
                                             'Fe', 'As', 'Al', 'I', 'B', 'V', 'K', 'Tl', 'Yb','Sb', 'Sn',
                                             'Ag', 'Pd', 'Co', 'Se', 'Ti', 'Zn', 'H','Li', 'Ge', 'Cu', 'Au',
                                             'Ni', 'Cd', 'In', 'Mn', 'Zr','Cr', 'Pt', 'Hg', 'Pb'])
        feat_1 = one_hot(atom.GetFormalCharge(), [1,2,3,4,5,6,7,8])
        feat_2 = one_hot(atom.GetDegree(), [1,2,3,4,5,6,7,8])
        feat_3 = torch.tensor([atom.GetIsAromatic()])
        feat_4 = torch.tensor([atom.IsInRing()])
        
        feat = torch.cat([feat_0,feat_1,feat_2,feat_3,feat_4])
        
        feats.append(feat)
        
        
    feats_1 = torch.cat([feats[idx].unsqueeze(0) for idx in range(len(feats))], 0) #转为tensor
    nums_n = min(max(max(u), max(v))+1, nums_atom)  #对齐
    feats_2 = feats_1[:nums_n]
    bg = dgl.to_bidirected(g) #无向图
    bg.ndata['feat'] = feats_2
    
    return bg

In [3]:
def get_G(ligands):
    
    XD = {}
    for d in tqdm(ligands.keys(), ncols=0):
        try:
            XD[str(d)] = smi_2_graph(ligands[d])
        except:
            print(ligands[d])
            
    return XD

# BIOSNAP

In [5]:
with open("BIOSNAP/drug_smi_raw.pkl", "rb") as f:
    ligands = pickle.load(f)
print(len(ligands))

4510


In [None]:
XD = get_G(ligands)

In [9]:
list(XD.values())[0]

Graph(num_nodes=18, num_edges=40,
      ndata_schemes={'feat': Scheme(shape=(64,), dtype=torch.float32)}
      edata_schemes={})

In [10]:
with open("BIOSNAP/drug_graph.pkl", 'wb') as f:
    pickle.dump(XD, f)

In [4]:
with open("BDB/drug_smi_raw.pkl", "rb") as f:
    ligands = pickle.load(f)
print(len(ligands))

10636


In [5]:
XD = get_G(ligands)

In [7]:
list(XD.values())[0]

Graph(num_nodes=21, num_edges=44,
      ndata_schemes={'feat': Scheme(shape=(64,), dtype=torch.float32)}
      edata_schemes={})

In [8]:
len(XD)

10636

In [9]:
with open("BDB/drug_graph.pkl", 'wb') as f:
    pickle.dump(XD, f)