Skip to content

Commit

Permalink
Merge pull request #13 from NREL/tf-2.6
Browse files Browse the repository at this point in the history
new preprocessor inheritance model
  • Loading branch information
pstjohn committed Oct 19, 2021
2 parents e4ed176 + 1ea10ef commit 32cbb54
Show file tree
Hide file tree
Showing 22 changed files with 762 additions and 412 deletions.
1 change: 1 addition & 0 deletions .gitignore
Expand Up @@ -111,3 +111,4 @@ examples/*.tfrecord
examples/*.p
examples/*.h5
.idea/
.vscode/settings.json
5 changes: 5 additions & 0 deletions .vscode/extensions.json
@@ -0,0 +1,5 @@
{
"recommendations": [
"ms-python.python"
]
}
6 changes: 4 additions & 2 deletions etc/environment.yml
Expand Up @@ -8,7 +8,9 @@ dependencies:
- rdkit
- pytest
- tqdm
- numpy
- numpy=1.19
- networkx
- pymatgen
- pip
- pip:
- tensorflow
- tensorflow
14 changes: 7 additions & 7 deletions examples/creating_and_training_a_model.ipynb
Expand Up @@ -232,7 +232,7 @@
"print(preprocessor.atom_tokenizer._data)\n",
"\n",
"for smiles in train:\n",
" preprocessor.construct_feature_matrices(smiles, train=True)\n",
" preprocessor(smiles, train=True)\n",
" \n",
"print()\n",
"print(\"after pre-allocating\")\n",
Expand Down Expand Up @@ -260,7 +260,7 @@
"smiles = 'CCO'\n",
"\n",
"# Atom types, as integer classes\n",
"preprocessor.construct_feature_matrices(smiles, train=True)['atom']"
"preprocessor(smiles, train=True)['atom']"
]
},
{
Expand All @@ -281,7 +281,7 @@
],
"source": [
"# Bond types, as integer classes\n",
"preprocessor.construct_feature_matrices(smiles, train=True)['bond']"
"preprocessor(smiles, train=True)['bond']"
]
},
{
Expand All @@ -305,7 +305,7 @@
],
"source": [
"# A connectivity array, where row i indicates bond i connects atom j to atom k\n",
"preprocessor.construct_feature_matrices(smiles, train=True)['connectivity']"
"preprocessor(smiles, train=True)['connectivity']"
]
},
{
Expand All @@ -322,7 +322,7 @@
"# hence why the atom and bond types above start with 1 as the unknown class)\n",
"\n",
"train_dataset = tf.data.Dataset.from_generator(\n",
" lambda: ((preprocessor.construct_feature_matrices(row.SMILES, train=False), row.YSI)\n",
" lambda: ((preprocessor(row.SMILES, train=False), row.YSI)\n",
" for i, row in ysi[ysi.SMILES.isin(train)].iterrows()),\n",
" output_types=(preprocessor.output_types, tf.float32),\n",
" output_shapes=(preprocessor.output_shapes, []))\\\n",
Expand All @@ -334,7 +334,7 @@
"\n",
"\n",
"valid_dataset = tf.data.Dataset.from_generator(\n",
" lambda: ((preprocessor.construct_feature_matrices(row.SMILES, train=False), row.YSI)\n",
" lambda: ((preprocessor(row.SMILES, train=False), row.YSI)\n",
" for i, row in ysi[ysi.SMILES.isin(valid)].iterrows()),\n",
" output_types=(preprocessor.output_types, tf.float32),\n",
" output_shapes=(preprocessor.output_shapes, []))\\\n",
Expand Down Expand Up @@ -498,7 +498,7 @@
"# Here, we create a test dataset that doesn't assume we know the values for the YSI\n",
"\n",
"test_dataset = tf.data.Dataset.from_generator(\n",
" lambda: (preprocessor.construct_feature_matrices(smiles, train=False)\n",
" lambda: (preprocessor(smiles, train=False)\n",
" for smiles in test),\n",
" output_types=preprocessor.output_types,\n",
" output_shapes=preprocessor.output_shapes)\\\n",
Expand Down
1 change: 0 additions & 1 deletion nfp/__init__.py
Expand Up @@ -5,7 +5,6 @@
from .preprocessing import *

custom_objects = {
'Slice': Slice,
'Gather': Gather,
'Reduce': Reduce,
'masked_mean_squared_error': masked_mean_squared_error,
Expand Down
53 changes: 27 additions & 26 deletions nfp/layers/graph_layers.py
@@ -1,4 +1,3 @@
import numpy as np
import tensorflow as tf
from tensorflow.keras import layers

Expand Down Expand Up @@ -39,13 +38,12 @@ def build(self, input_shape):
super().build(input_shape)

self.gather = nfp.Gather()
self.slice1 = nfp.Slice(np.s_[:, :, 1])
self.slice0 = nfp.Slice(np.s_[:, :, 0])
self.concat = nfp.ConcatDense()

def call(self, inputs, mask=None):
def call(self, inputs, mask=None, **kwargs):
""" Inputs: [atom_state, bond_state, connectivity]
Outputs: bond_state
"""
if not self.use_global:
atom_state, bond_state, connectivity = inputs
Expand All @@ -54,13 +52,15 @@ def call(self, inputs, mask=None):
global_state = self.tile([global_state, bond_state])

# Get nodes at start and end of edge
source_atom = self.gather([atom_state, self.slice1(connectivity)])
target_atom = self.gather([atom_state, self.slice0(connectivity)])
source_atom = self.gather([atom_state, connectivity[:, :, 0]])
target_atom = self.gather([atom_state, connectivity[:, :, 1]])

if not self.use_global:
new_bond_state = self.concat([bond_state, source_atom, target_atom])
new_bond_state = self.concat(
[bond_state, source_atom, target_atom])
else:
new_bond_state = self.concat([bond_state, source_atom, target_atom, global_state])
new_bond_state = self.concat(
[bond_state, source_atom, target_atom, global_state])

if self.dropout > 0.:
new_bond_state = self.dropout_layer(new_bond_state)
Expand All @@ -84,26 +84,25 @@ def build(self, input_shape):
num_features = input_shape[1][-1]

self.gather = nfp.Gather()
self.slice0 = nfp.Slice(np.s_[:, :, 0])
self.slice1 = nfp.Slice(np.s_[:, :, 1])

self.concat = nfp.ConcatDense()
self.reduce = nfp.Reduce(reduction='sum')

self.dense1 = layers.Dense(2 * num_features, activation='relu')
self.dense2 = layers.Dense(num_features)

def call(self, inputs, mask=None):
def call(self, inputs, mask=None, **kwargs):
""" Inputs: [atom_state, bond_state, connectivity]
Outputs: atom_state
"""
if not self.use_global:
atom_state, bond_state, connectivity = inputs
else:
atom_state, bond_state, connectivity, global_state = inputs
global_state = self.tile([global_state, bond_state])

source_atom = self.gather([atom_state, self.slice1(connectivity)])
source_atom = self.gather([atom_state, connectivity[:, :, 1]])

if not self.use_global:
messages = self.concat([source_atom, bond_state])
Expand All @@ -112,10 +111,11 @@ def call(self, inputs, mask=None):

if mask is not None:
# Only works for sum, max
messages = tf.where(tf.expand_dims(mask[1], axis=-1),
messages, tf.zeros_like(messages))
messages = tf.where(tf.expand_dims(mask[1], axis=-1), messages,
tf.zeros_like(messages))

new_atom_state = self.reduce([messages, self.slice0(connectivity), atom_state])
new_atom_state = self.reduce(
[messages, connectivity[:, :, 0], atom_state])

# Dense net after message reduction
new_atom_state = self.dense1(new_atom_state)
Expand Down Expand Up @@ -151,11 +151,13 @@ def build(self, input_shape):

def transpose_scores(self, input_tensor):
input_shape = tf.shape(input_tensor)
output_shape = [input_shape[0], input_shape[1], self.num_heads, self.units]
output_shape = [
input_shape[0], input_shape[1], self.num_heads, self.units
]
output_tensor = tf.reshape(input_tensor, output_shape)
return tf.transpose(a=output_tensor, perm=[0, 2, 1, 3]) # [B,N,S,H]

def call(self, inputs, mask=None):
def call(self, inputs, mask=None, **kwargs):
if not self.use_global:
atom_state, bond_state, connectivity = inputs
else:
Expand All @@ -168,17 +170,18 @@ def call(self, inputs, mask=None):

if mask is not None:
graph_element_mask = tf.concat([mask[0], mask[1]], axis=1)
query = tf.where(
tf.expand_dims(graph_element_mask, axis=-1),
query,
tf.ones_like(query) * query.dtype.min)
query = tf.where(tf.expand_dims(graph_element_mask, axis=-1),
query,
tf.ones_like(query) * query.dtype.min)

query = tf.transpose(query, perm=[0, 2, 1])
value = self.transpose_scores(self.value_layer(graph_elements)) # [B,N,S,H]
value = self.transpose_scores(
self.value_layer(graph_elements)) # [B,N,S,H]

attention_probs = tf.nn.softmax(query)
context = tf.matmul(tf.expand_dims(attention_probs, 2), value)
context = tf.reshape(context, [batch_size, self.num_heads * self.units])
context = tf.reshape(context,
[batch_size, self.num_heads * self.units])

if self.dropout > 0.:
context = self.dropout_layer(context)
Expand All @@ -187,7 +190,5 @@ def call(self, inputs, mask=None):

def get_config(self):
config = super(GlobalUpdate, self).get_config()
config.update(
{"units": self.units,
"num_heads": self.num_heads})
config.update({"units": self.units, "num_heads": self.num_heads})
return config
109 changes: 84 additions & 25 deletions nfp/layers/layers.py
@@ -1,9 +1,66 @@
import logging

import tensorflow as tf
from tensorflow.keras import layers


class RBFExpansion(layers.Layer):
def __init__(self,
dimension=128,
init_gap=10,
init_max_distance=7,
trainable=False):
""" Layer to calculate radial basis function 'embeddings' for a continuous input variable. The width and
location of each bin can be optionally trained. Essentially equivalent to a 1-hot embedding for a continous
variable.
Parameters
----------
dimension: The total number of distance bins
init_gap: The initial width of each gaussian distribution
init_max_distance: the initial maximum value of the continuous variable
trainable: Whether the centers and gap parameters should be added as trainable NN parameters.
"""
super(RBFExpansion, self).__init__()
self.init_gap = init_gap
self.init_max_distance = init_max_distance
self.dimension = dimension
self.trainable = trainable

def build(self, input_shape):
self.centers = tf.Variable(
name='centers',
initial_value=tf.range(0,
self.init_max_distance,
delta=self.init_max_distance /
self.dimension),
trainable=self.trainable,
dtype=tf.float32)

self.gap = tf.Variable(name='gap',
initial_value=tf.constant(self.init_gap,
dtype=tf.float32),
trainable=self.trainable,
dtype=tf.float32)

def call(self, inputs, **kwargs):
distances = tf.where(tf.math.is_nan(inputs),
tf.zeros_like(inputs, dtype=inputs.dtype), inputs)
offset = tf.expand_dims(distances, -1) - tf.cast(
self.centers, inputs.dtype)
logits = -self.gap * offset**2
return tf.exp(logits)

def compute_mask(self, inputs, mask=None):
return tf.logical_not(tf.math.is_nan(inputs))

def get_config(self):
return {
'init_gap': self.init_gap,
'init_max_distance': self.init_max_distance,
'dimension': self.dimension,
'trainable': self.trainable
}


def batched_segment_op(data,
segment_ids,
num_segments,
Expand Down Expand Up @@ -42,22 +99,22 @@ def batched_segment_op(data,
return tf.reshape(reduced_data, [batch_size, num_segments, data.shape[-1]])


class Slice(layers.Layer):
def __init__(self, slice_obj, *args, **kwargs):
super(Slice, self).__init__(*args, **kwargs)
self.slice_obj = slice_obj
self.supports_masking = True

def call(self, inputs, mask=None):
return inputs[self.slice_obj]

def get_config(self):
return {'slice_obj': str(self.slice_obj)}

@classmethod
def from_config(cls, config):
config['slice_obj'] = eval(config['slice_obj'])
return cls(**config)
# class Slice(layers.Layer):
# def __init__(self, slice_obj, *args, **kwargs):
# super(Slice, self).__init__(*args, **kwargs)
# self.slice_obj = slice_obj
# self.supports_masking = True
#
# def call(self, inputs, mask=None):
# return inputs[self.slice_obj]
#
# def get_config(self):
# return {'slice_obj': str(self.slice_obj)}
#
# @classmethod
# def from_config(cls, config):
# config['slice_obj'] = eval(config['slice_obj'])
# return cls(**config)


class Gather(layers.Layer):
Expand All @@ -82,16 +139,19 @@ def _parse_inputs_and_mask(self, inputs, mask=None):

return data, segment_ids, target, data_mask

def compute_output_shape(self, input_shape):
data_shape, _, target_shape = input_shape
return [data_shape[0], target_shape[1], data_shape[-1]]

def call(self, inputs, mask=None):
data, segment_ids, target, data_mask = self._parse_inputs_and_mask(
inputs, mask)
num_segments = tf.shape(target, out_type=segment_ids.dtype)[1]
return batched_segment_op(
data,
segment_ids,
num_segments,
data_mask=data_mask,
reduction=self.reduction)
return batched_segment_op(data,
segment_ids,
num_segments,
data_mask=data_mask,
reduction=self.reduction)

def get_config(self):
return {'reduction': self.reduction}
Expand All @@ -100,7 +160,6 @@ def get_config(self):
class ConcatDense(layers.Layer):
""" Layer to combine the concatenation and two dense layers. Just useful as a common operation in the graph
layers """

def __init__(self, **kwargs):
super(ConcatDense, self).__init__(**kwargs)
self.supports_masking = True
Expand Down
3 changes: 0 additions & 3 deletions nfp/preprocessing/__init__.py
@@ -1,3 +0,0 @@
from .features import *
from .preprocessor import *
from .tfrecord import *

0 comments on commit 32cbb54

Please sign in to comment.