In [None]:
import rdkit
from rdkit import Chem

from utils import *

from graphenvironments import mol2vecstupidsimple

In [None]:
two = 'COc1cc(C2Oc3c(OC)cc(C=CCO)cc3C2CO)ccc1[O]'
m = Chem.MolFromSmiles(two)
m = Chem.AddHs(m)
Chem.AllChem.EmbedMultipleConfs(m, numConfs=200, numThreads=-1)
Chem.AllChem.MMFFOptimizeMoleculeConfs(x, numThreads=-1)


In [None]:
import py3Dmol
p = py3Dmol.view(width=800,height=800)
drawit(m, p)

In [None]:
from torch_geometric.data import Data, Batch
from torch_geometric.transforms import Distance, NormalizeScale, RadiusGraph

import glob
import json

def bond_features(bond, use_chirality=False, use_basic_feats=True, connectivity=False):
    bt = bond.GetBondType()
    bond_feats = []
    if use_basic_feats:
        bond_feats = bond_feats + [
            bt == Chem.rdchem.BondType.SINGLE, bt == Chem.rdchem.BondType.DOUBLE,
            bt == Chem.rdchem.BondType.TRIPLE, bt == Chem.rdchem.BondType.AROMATIC,
            bond.GetIsConjugated(),
            bond.IsInRing()
        ]
    if use_chirality:
        bond_feats = bond_feats + one_of_k_encoding_unk(
            str(bond.GetStereo()),
            ["STEREONONE", "STEREOANY", "STEREOZ", "STEREOE"])
    if connectivity:
        bond_feats = bond_feats + [1]
    return np.array(bond_feats)

def get_bond_pair(mol):
    bonds = mol.GetBonds()
    res = [[],[]]
    for bond in bonds:
        res[0] += [bond.GetBeginAtomIdx(), bond.GetEndAtomIdx()]
        res[1] += [bond.GetEndAtomIdx(), bond.GetBeginAtomIdx()]
    return res

def atom_features_simple(atom, conf):
    p = conf.GetAtomPosition(atom.GetIdx())
    return np.array([p.x, p.y, p.z])

def mol2vecsimple(mol):
    conf = mol.GetConformer(id=-1)
    atoms = mol.GetAtoms()
    bonds = mol.GetBonds()
    node_f= [atom_features_simple(atom, conf) for atom in atoms]
    edge_index = get_bond_pair(mol)
    edge_attr = [bond_features(bond, use_chirality=False) for bond in bonds]
    for bond in bonds:
        edge_attr.append(bond_features(bond))
    data = Data(
                x=torch.tensor(node_f, dtype=torch.float),
                edge_index=torch.tensor(edge_index, dtype=torch.long),
                edge_attr=torch.tensor(edge_attr,dtype=torch.float),
                pos=torch.Tensor(conf.GetPositions())
            )
    data = Distance()(data)
    return data

def mol2vecstupidsimple(mol):
    conf = mol.GetConformer(id=-1)
    atoms = mol.GetAtoms()
    bonds = mol.GetBonds()
    node_f= [[] for atom in atoms]
    edge_index = get_bond_pair(mol)
    edge_attr = [bond_features(bond, use_chirality=False, use_basic_feats=False) for bond in bonds]
    for bond in bonds:
        edge_attr.append(bond_features(bond, use_chirality=False, use_basic_feats=False))
    
    data = Data(
                x=torch.tensor(node_f, dtype=torch.float),
                edge_index=torch.tensor(edge_index, dtype=torch.long),
                edge_attr=torch.tensor(edge_attr,dtype=torch.float),
                pos=torch.Tensor(conf.GetPositions())
            )
    
    data = NormalizeScale()(data)
    data = Distance(norm=False)(data)
    data.x = data.pos
    
    e = data.edge_attr
    new_e = -1 + ((e - e.min())*2)/(e.max() - e.min())
    data.edge_attr = new_e
    
    return data

def mol2vecbasic(mol):
    mol = Chem.rdmolops.RemoveHs(mol)
    conf = mol.GetConformer(id=-1)
    atoms = mol.GetAtoms()
    bonds = mol.GetBonds()
    node_f= [[] for atom in atoms]
    edge_index = get_bond_pair(mol)
    edge_attr = [bond_features(bond, use_chirality=False, use_basic_feats=False) for bond in bonds]
    for bond in bonds:
        edge_attr.append(bond_features(bond, use_chirality=False, use_basic_feats=False))
    
    data = Data(
                x=torch.tensor(node_f, dtype=torch.float),
                edge_index=torch.tensor(edge_index, dtype=torch.long),
                edge_attr=torch.tensor(edge_attr,dtype=torch.float),
                pos=torch.Tensor(conf.GetPositions())
            )
    
    data = NormalizeScale()(data)
    data = Distance(norm=False)(data)
    data.x = data.pos
    
    e = data.edge_attr
    new_e = -1 + ((e - e.min())*2)/(e.max() - e.min())
    data.edge_attr = new_e
    
    return data

def mol2vecskeleton(mol):
    mol = Chem.rdmolops.RemoveHs(mol)
    conf = mol.GetConformer(id=-1)
    atoms = mol.GetAtoms()
    bonds = mol.GetBonds()
    node_f= [[] for atom in atoms]
    edge_index = get_bond_pair(mol)
    edge_attr = [bond_features(bond, use_chirality=False, use_basic_feats=True) for bond in bonds]
    for bond in bonds:
        edge_attr.append(bond_features(bond, use_chirality=False, use_basic_feats=True))
    
    data = Data(
                x=torch.tensor(node_f, dtype=torch.float),
                edge_index=torch.tensor(edge_index, dtype=torch.long),
                edge_attr=torch.tensor(edge_attr,dtype=torch.float),
                pos=torch.Tensor(conf.GetPositions())
            )
    
    data = NormalizeScale()(data)
    data = Distance(norm=False)(data)
    data.x = data.pos
    
    return data


import time
def mol2vecdense(mol):
    conf = mol.GetConformer(id=-1)
    atoms = mol.GetAtoms()
    bonds = mol.GetBonds()

    adj = Chem.rdmolops.GetAdjacencyMatrix(mol)
    n = len(atoms)
    
    edge_index = []
    edge_attr = []
    
    
    for i in range(n):
        for j in range(n):
            if i == j:
                continue
            edge_index.append([i, j])
            edge_attr.append(adj[i][j])            
            
            
    node_f= [[] for atom in atoms]
    
    data = Data(
                x=torch.tensor(node_f, dtype=torch.float),
                edge_index=torch.tensor(edge_index, dtype=torch.long).T,
                edge_attr=torch.tensor(edge_attr,dtype=torch.float),
                pos=torch.Tensor(conf.GetPositions())
            )
    
    data = NormalizeScale()(data)
    data = Distance(norm=False)(data)
    
    return data


In [None]:
import time

s = time.time()
data = mol2vecdense(m)
dt = time.time() - s 
print(dt)
s = time.time()
data = mol2vecstupidsimple(m)
dt = time.time() - s 
print(dt)
s = time.time()
data = mol2vecbasic(m)
dt = time.time() - s 
print(dt)
s = time.time()
data = mol2vecskeleton(m)
dt = time.time() - s 
print(dt)

In [21]:
help(data)

Help on Data in module torch_geometric.data.data object:

class Data(builtins.object)
 |  A plain old python object modeling a single graph with various
 |  (optional) attributes:
 |  
 |  Args:
 |      x (Tensor, optional): Node feature matrix with shape :obj:`[num_nodes,
 |          num_node_features]`. (default: :obj:`None`)
 |      edge_index (LongTensor, optional): Graph connectivity in COO format
 |          with shape :obj:`[2, num_edges]`. (default: :obj:`None`)
 |      edge_attr (Tensor, optional): Edge feature matrix with shape
 |          :obj:`[num_edges, num_edge_features]`. (default: :obj:`None`)
 |      y (Tensor, optional): Graph or node targets with arbitrary shape.
 |          (default: :obj:`None`)
 |      pos (Tensor, optional): Node position matrix with shape
 |          :obj:`[num_nodes, num_dimensions]`. (default: :obj:`None`)
 |      norm (Tensor, optional): Normal vector matrix with shape
 |          :obj:`[num_nodes, num_dimensions]`. (default: :obj:`None`)
 |

In [6]:
conf = m.GetConformer(id=-1)
conf.GetPositions()[0:5]

array([[ 5.36624129,  1.53623619,  0.5008547 ],
       [ 4.89901223,  0.64599879, -0.65641336],
       [ 3.37139128,  0.53296038, -0.4621809 ],
       [ 2.76025379, -0.30969669, -1.5183659 ],
       [ 1.22672118, -0.39979683, -1.33461131]])

In [29]:
data

Data(edge_attr=[26, 1], edge_index=[2, 26], pos=[14, 3], x=[14, 3])