Skip to content

Commit

Permalink
Merge pull request #12 from NREL/new_dataset_and_preprocessor_edits
Browse files Browse the repository at this point in the history
New dataset and preprocessor edits
  • Loading branch information
pstjohn committed Sep 8, 2021
2 parents ff38498 + b512f7c commit ac60063
Show file tree
Hide file tree
Showing 3 changed files with 37 additions and 15 deletions.
48 changes: 33 additions & 15 deletions nfp/preprocessing/preprocessor.py
Expand Up @@ -37,7 +37,9 @@ class MolPreprocessor(object):

def __init__(self,
atom_features: Optional[Callable[[rdkit.Chem.Atom], Hashable]] = None,
bond_features: Optional[Callable[[rdkit.Chem.Bond], Hashable]] = None) -> None:
bond_features: Optional[Callable[[rdkit.Chem.Bond], Hashable]] = None,
output_dtype: str = 'int32',
) -> None:

self.atom_tokenizer = Tokenizer()
self.bond_tokenizer = Tokenizer()
Expand All @@ -50,6 +52,7 @@ def __init__(self,

self.atom_features = atom_features
self.bond_features = bond_features
self.output_dtype = output_dtype

# Keep track of biggest molecules seen in training
self.max_atoms = 0
Expand All @@ -75,11 +78,22 @@ def bond_classes(self):
""" The number of bond types found (includes the 0 null-bond type) """
return self.bond_tokenizer.num_classes + 1

def construct_feature_matrices(self, mol: rdkit.Chem.Mol, train: bool = False) -> {}:
def construct_feature_matrices(self,
mol: rdkit.Chem.Mol,
train: bool = False,
max_num_atoms: Optional[int] = None,
max_num_bonds: Optional[int] = None,
) -> {str: np.ndarray}:
""" Convert an rdkit Mol to a list of tensors
'atom' : (n_atom,) length list of atom classes
'bond' : (n_bond,) list of bond classes
'connectivity' : (n_bond, 2) array of source atom, target atom pairs.
Parameters
----------
mol : rdkit.Chem.Mol
train : bool
max_num_atoms : int, optional
Specify the size of the output arrays with a maximum number of atoms
max_num_bonds : int, optional
Maximum number of bonds in the output array
"""

self.atom_tokenizer.train = train
Expand All @@ -92,9 +106,12 @@ def construct_feature_matrices(self, mol: rdkit.Chem.Mol, train: bool = False) -
if n_bond == 0:
n_bond = 1

atom_feature_matrix = np.zeros(n_atom, dtype='int32')
bond_feature_matrix = np.zeros(n_bond, dtype='int32')
connectivity = np.zeros((n_bond, 2), dtype='int32')
max_num_atoms = mol.GetNumAtoms() if max_num_atoms is None else max_num_atoms
max_num_bonds = n_bond if max_num_bonds is None else max_num_bonds

atom_feature_matrix = np.zeros(max_num_atoms, dtype=self.output_dtype)
bond_feature_matrix = np.zeros(max_num_bonds, dtype=self.output_dtype)
connectivity = np.zeros((max_num_bonds, 2), dtype=self.output_dtype)

if n_bond == 1:
bond_feature_matrix[0] = self.bond_tokenizer('self-link')
Expand Down Expand Up @@ -131,14 +148,16 @@ def construct_feature_matrices(self, mol: rdkit.Chem.Mol, train: bool = False) -
'connectivity': connectivity,
}

output_signature = {'atom': tf.TensorSpec(shape=(None,), dtype=tf.int32),
'bond': tf.TensorSpec(shape=(None,), dtype=tf.int32),
'connectivity': tf.TensorSpec(shape=(None, 2), dtype=tf.int32)}
@property
def output_signature(self):
return {'atom': tf.TensorSpec(shape=(None,), dtype=self.output_dtype),
'bond': tf.TensorSpec(shape=(None,), dtype=self.output_dtype),
'connectivity': tf.TensorSpec(shape=(None, 2), dtype=self.output_dtype)}


def load_from_json(obj, data):
for key, val in obj.__dict__.items():
if type(val) == type(data[key]):
if isinstance(val, type(data[key])):
obj.__dict__[key] = data[key]
elif hasattr(val, '__dict__'):
load_from_json(val, data[key])
Expand All @@ -150,12 +169,11 @@ def __init__(self, *args, explicit_hs: bool = True, **kwargs):
super(SmilesPreprocessor, self).__init__(*args, **kwargs)
self.explicit_hs = explicit_hs

def construct_feature_matrices(self, smiles: str, train: bool = False) -> {}:
def construct_feature_matrices(self, smiles: str, train: bool = False, **kwargs) -> {}:
mol = rdkit.Chem.MolFromSmiles(smiles)
if self.explicit_hs:
mol = rdkit.Chem.AddHs(mol)
return super(SmilesPreprocessor, self).construct_feature_matrices(mol, train=train)

return super(SmilesPreprocessor, self).construct_feature_matrices(mol, train=train, **kwargs)


def get_max_atom_bond_size(smiles_iterator, explicit_hs=True):
Expand Down
2 changes: 2 additions & 0 deletions tests/layers/test_graph_layers.py
Expand Up @@ -125,6 +125,8 @@ def test_masking_message(inputs_no_padding, inputs_with_padding, smiles_inputs):


def test_no_residual(inputs_no_padding, inputs_with_padding, smiles_inputs):
""" This model might not work when saved and loaded, see
https://github.com/tensorflow/tensorflow/issues/38620 """
preprocessor, inputs = smiles_inputs

atom_class = layers.Input(shape=[None], dtype=tf.int64, name='atom')
Expand Down
2 changes: 2 additions & 0 deletions tests/models/test_save_and_load.py
Expand Up @@ -8,6 +8,8 @@


def test_save_and_load_message(inputs_no_padding, inputs_with_padding, smiles_inputs, tmpdir: 'py.path.local'):
""" mainly to do with https://github.com/tensorflow/tensorflow/issues/38620 """

preprocessor, inputs = smiles_inputs

atom_class = layers.Input(shape=[None], dtype=tf.int64, name='atom')
Expand Down

0 comments on commit ac60063

Please sign in to comment.