In [1]:
import numpy as np
import pandas as pd

import tensorflow as tf
import nfp

print(f"tensorflow {tf.__version__}")
print(f"nfp {nfp.__version__}")

tensorflow 2.9.1
nfp 0.3.12+3.g93ba25b.dirty


In [2]:
smiles_list = ["CC", "CCC", "C1CC1", "C"]
preprocessor = nfp.preprocessing.mol_preprocessor.SmilesPreprocessor()

In [3]:
dataset = tf.data.Dataset.from_generator(
    lambda: (preprocessor(smiles, train=True) for smiles in smiles_list),
    output_signature=preprocessor.output_signature,
).apply(tf.data.experimental.dense_to_ragged_batch(batch_size=4))

2022-07-13 15:03:08.687436: I tensorflow/core/platform/cpu_feature_guard.cc:193] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations:  AVX2 FMA
To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.


In [4]:
inputs_ragged = list(dataset.take(1))[0]

In [5]:
inputs_ragged['atom'].shape

TensorShape([4, None])

In [6]:
inputs_ragged['bond'].shape

TensorShape([4, None])

In [7]:
inputs_ragged['connectivity'].shape

TensorShape([4, None, 2])

In [8]:
tf.shape(inputs_ragged['atom'])

<DynamicRaggedShape lengths=[4, (8, 11, 9, 5)] num_row_partitions=1>

In [11]:
layers = tf.keras.layers


atom_class = layers.Input(shape=[None], dtype=tf.int64, name="atom", ragged=True)
bond_class = layers.Input(shape=[None], dtype=tf.int64, name="bond", ragged=True)
connectivity = layers.Input(shape=[None, 2], dtype=tf.int64, name="connectivity", ragged=True)

atom_state = layers.Embedding(preprocessor.atom_classes, 16, mask_zero=True)(
    atom_class
)
bond_state = layers.Embedding(preprocessor.bond_classes, 16, mask_zero=True)(
    bond_class
)


# global_state = nfp.GlobalUpdate(units=8, num_heads=1)([atom_state, bond_state, connectivity])


source_atom = tf.gather(atom_state, connectivity[:, :, 1], batch_dims=1)
messages = nfp.ConcatDense()([source_atom, bond_state])

num_atoms = atom_class.row_lengths()
segment_ids = connectivity[:, :, 0] + tf.expand_dims(tf.math.cumsum(num_atoms, exclusive=True), 1)
summed_messages = tf.math.segment_sum(messages.merge_dims(0,1), segment_ids.merge_dims(0,1))
new_atom_state = tf.RaggedTensor.from_row_lengths(summed_messages, num_atoms)

model = tf.keras.Model(
    [atom_class, bond_class, connectivity], [source_atom]
)

In [12]:
source_atom = model(inputs_ragged)

In [18]:
source_atom[:, 0, :]

ValueError: Cannot index into an inner ragged dimension.

In [None]:
new_atom_state.ragged_rank

In [None]:
atom_class.ragged_rank

In [None]:
sum(num_atoms)

In [None]:
connectivity = inputs_ragged['connectivity']

In [None]:
connectivity[:, :, 0] + tf.expand_dims(tf.math.cumsum(num_atoms, exclusive=True), 1)

In [None]:
inputs_ragged['connectivity'][:, :, 0]

In [None]:
tf.math.segment_sum(messages[0], connectivity[0,:,0])

In [None]:
messages.shape

In [None]:
messages.row_lengths()

In [None]:
tf.shape(messages)