/
mol_graph.py
94 lines (76 loc) · 3.41 KB
/
mol_graph.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
import numpy as np
from rdkit.Chem import MolFromSmiles
from features import atom_features, bond_features
degrees = [0, 1, 2, 3, 4, 5]
class MolGraph(object):
def __init__(self):
self.nodes = {} # dict of lists of nodes, keyed by node type
def new_node(self, ntype, features=None, rdkit_ix=None):
new_node = Node(ntype, features, rdkit_ix)
self.nodes.setdefault(ntype, []).append(new_node)
return new_node
def add_subgraph(self, subgraph):
old_nodes = self.nodes
new_nodes = subgraph.nodes
for ntype in set(old_nodes.keys()) | set(new_nodes.keys()):
old_nodes.setdefault(ntype, []).extend(new_nodes.get(ntype, []))
def sort_nodes_by_degree(self, ntype):
nodes_by_degree = {i : [] for i in degrees}
for node in self.nodes[ntype]:
nodes_by_degree[len(node.get_neighbors(ntype))].append(node)
new_nodes = []
for degree in degrees:
cur_nodes = nodes_by_degree[degree]
self.nodes[(ntype, degree)] = cur_nodes
new_nodes.extend(cur_nodes)
self.nodes[ntype] = new_nodes
def feature_array(self, ntype):
assert ntype in self.nodes
return np.array([node.features for node in self.nodes[ntype]])
def rdkit_ix_array(self):
return np.array([node.rdkit_ix for node in self.nodes['atom']])
def neighbor_list(self, self_ntype, neighbor_ntype):
assert self_ntype in self.nodes and neighbor_ntype in self.nodes
neighbor_idxs = {n : i for i, n in enumerate(self.nodes[neighbor_ntype])}
return [[neighbor_idxs[neighbor]
for neighbor in self_node.get_neighbors(neighbor_ntype)]
for self_node in self.nodes[self_ntype]]
class Node(object):
__slots__ = ['ntype', 'features', '_neighbors', 'rdkit_ix']
def __init__(self, ntype, features, rdkit_ix):
self.ntype = ntype
self.features = features
self._neighbors = []
self.rdkit_ix = rdkit_ix
def add_neighbors(self, neighbor_list):
for neighbor in neighbor_list:
self._neighbors.append(neighbor)
neighbor._neighbors.append(self)
def get_neighbors(self, ntype):
return [n for n in self._neighbors if n.ntype == ntype]
def graph_from_smiles_tuple(smiles_tuple):
graph_list = [graph_from_smiles(s) for s in smiles_tuple]
big_graph = MolGraph()
for subgraph in graph_list:
big_graph.add_subgraph(subgraph)
# This sorting allows an efficient (but brittle!) indexing later on.
big_graph.sort_nodes_by_degree('atom')
return big_graph
def graph_from_smiles(smiles):
graph = MolGraph()
mol = MolFromSmiles(smiles)
if not mol:
raise ValueError("Could not parse SMILES string:", smiles)
atoms_by_rd_idx = {}
for atom in mol.GetAtoms():
new_atom_node = graph.new_node('atom', features=atom_features(atom), rdkit_ix=atom.GetIdx())
atoms_by_rd_idx[atom.GetIdx()] = new_atom_node
for bond in mol.GetBonds():
atom1_node = atoms_by_rd_idx[bond.GetBeginAtom().GetIdx()]
atom2_node = atoms_by_rd_idx[bond.GetEndAtom().GetIdx()]
new_bond_node = graph.new_node('bond', features=bond_features(bond))
new_bond_node.add_neighbors((atom1_node, atom2_node))
atom1_node.add_neighbors((atom2_node,))
mol_node = graph.new_node('molecule')
mol_node.add_neighbors(graph.nodes['atom'])
return graph