# Graph tensor

In [23]:
import tensorflow as tf
tf.config.set_visible_devices([], 'GPU')

import sys
sys.path.append('../../../../')

from molgraph.chemistry import MolecularGraphEncoder
from molgraph.chemistry import Featurizer
from molgraph.chemistry import features

from molgraph.layers import GCNConv
from molgraph.layers import Readout

### Construct **GraphTensor** 

Construct `GraphTensor` from a `MolecularGraphEncoder`

In [24]:
atom_encoder = Featurizer([
    features.Symbol({'C', 'N', 'O'}, oov_size=1),
    features.Hybridization({'SP', 'SP2', 'SP3'}, oov_size=1),
    features.HydrogenDonor(),
    features.HydrogenAcceptor(),
    features.Hetero()
])

bond_encoder = Featurizer([
    features.BondType({'SINGLE', 'DOUBLE', 'TRIPLE', 'AROMATIC'}),
    features.Rotatable()
])

mol_encoder = MolecularGraphEncoder(atom_encoder, bond_encoder)

molecules = [
    'OCC1OC(C(C1O)O)n1cnc2c1ncnc2N',
    'C(C(=O)O)N',
    'C1=CC(=CC=C1CC(C(=O)O)N)O'
]

graph_tensor = mol_encoder(molecules)

print(graph_tensor, end='\n\n')
print('node_feature shape:', graph_tensor.node_feature.shape)
print('edge_dst shape:    ', graph_tensor.edge_dst.shape)
print('edge_src shape:    ', graph_tensor.edge_src.shape)
print('edge_feature shape:', graph_tensor.edge_feature.shape)

GraphTensor(
  node_feature=<tf.RaggedTensor: shape=(3, None, 11), dtype=float32>,
  positional_encoding=<tf.RaggedTensor: shape=(3, None, 16), dtype=float32>,
  edge_feature=<tf.RaggedTensor: shape=(3, None, 5), dtype=float32>,
  edge_dst=<tf.RaggedTensor: shape=(3, None), dtype=int32>,
  edge_src=<tf.RaggedTensor: shape=(3, None), dtype=int32>)

node_feature shape: (3, None, 11)
edge_dst shape:     (3, None)
edge_src shape:     (3, None)
edge_feature shape: (3, None, 5)


### Merge subgraphs of **GraphTensor**

By converting nested ragged tensors to tensors, via the `merge()` method, a more efficient representation of `GraphTensor` is obtained. I.e., in this exmaple, the `GraphTensor` now encodes the three molecules as *a single disjoint graph* instead of *three separate graphs*. 

In [25]:
graph_tensor = graph_tensor.merge()

print(graph_tensor, end='\n\n')
print('node_feature shape:', graph_tensor.node_feature.shape)
print('edge_dst shape:    ', graph_tensor.edge_dst.shape)
print('edge_src shape:    ', graph_tensor.edge_src.shape)
print('edge_feature shape:', graph_tensor.edge_feature.shape)
print('graph_indicator:   ', graph_tensor.graph_indicator.numpy())

GraphTensor(
  node_feature=<tf.Tensor: shape=(37, 11), dtype=float32>,
  positional_encoding=<tf.Tensor: shape=(37, 16), dtype=float32>,
  edge_feature=<tf.Tensor: shape=(76, 5), dtype=float32>,
  edge_dst=<tf.Tensor: shape=(76,), dtype=int32>,
  edge_src=<tf.Tensor: shape=(76,), dtype=int32>,
  graph_indicator=<tf.Tensor: shape=(37,), dtype=int32>)

node_feature shape: (37, 11)
edge_dst shape:     (76,)
edge_src shape:     (76,)
edge_feature shape: (76, 5)
graph_indicator:    [0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 2 2 2 2 2 2 2 2 2 2 2 2 2]


### Separate subgraphs of **GraphTensor**

By converting nested tensors to ragged tensors, via the `separate()` method, a batchable representation of `GraphTensor` is obtained (see later).

In [26]:
graph_tensor = graph_tensor.separate()

print(graph_tensor, end='\n\n')
print('node_feature shape:', graph_tensor.node_feature.shape)
print('edge_dst shape:    ', graph_tensor.edge_dst.shape)
print('edge_src shape:    ', graph_tensor.edge_src.shape)
print('edge_feature shape:', graph_tensor.edge_feature.shape)

GraphTensor(
  node_feature=<tf.RaggedTensor: shape=(3, None, 11), dtype=float32>,
  positional_encoding=<tf.RaggedTensor: shape=(3, None, 16), dtype=float32>,
  edge_feature=<tf.RaggedTensor: shape=(3, None, 5), dtype=float32>,
  edge_dst=<tf.RaggedTensor: shape=(3, None), dtype=int32>,
  edge_src=<tf.RaggedTensor: shape=(3, None), dtype=int32>)

node_feature shape: (3, None, 11)
edge_dst shape:     (3, None)
edge_src shape:     (3, None)
edge_feature shape: (3, None, 5)


### Update the **GraphTensor**

The `GraphTensor` can conveniently be updated, via the `update()` and `remove()` method. Make sure the new fields have either `node` or `edge` in their names, depending on if they are associated with nodes or edges of the graph tensor.

In [27]:
graph_tensor = graph_tensor.merge()

random_features = tf.random.uniform(shape=graph_tensor['node_feature'].shape)
graph_tensor = graph_tensor.update({'node_random_features': random_features})

graph_tensor = graph_tensor.remove(['edge_feature'])

graph_tensor

GraphTensor(
  node_feature=<tf.Tensor: shape=(37, 11), dtype=float32>,
  positional_encoding=<tf.Tensor: shape=(37, 16), dtype=float32>,
  edge_dst=<tf.Tensor: shape=(76,), dtype=int32>,
  edge_src=<tf.Tensor: shape=(76,), dtype=int32>,
  graph_indicator=<tf.Tensor: shape=(37,), dtype=int32>,
  node_random_features=<tf.Tensor: shape=(37, 11), dtype=float32>)

### Index the **GraphTensor**

The `GraphTensor` can be indexed either by passing a `str` (to obtain a specific field of `GraphTensor`) or `int`, `list[int]` or `slice` (to extract specific subgraphs (molecules) from `GraphTensor`). (Alternatively, the `GraphTensor` can be passed to `tf.gather` to extract specific subgraphs.)

In [28]:
print(graph_tensor, end='\n\n')
print(graph_tensor[[1, 2]], end='\n\n')
print(graph_tensor[1:3], end='\n\n')
print(graph_tensor.separate()[[1, 2]], end='\n\n')

GraphTensor(
  node_feature=<tf.Tensor: shape=(37, 11), dtype=float32>,
  positional_encoding=<tf.Tensor: shape=(37, 16), dtype=float32>,
  edge_dst=<tf.Tensor: shape=(76,), dtype=int32>,
  edge_src=<tf.Tensor: shape=(76,), dtype=int32>,
  graph_indicator=<tf.Tensor: shape=(37,), dtype=int32>,
  node_random_features=<tf.Tensor: shape=(37, 11), dtype=float32>)

GraphTensor(
  node_feature=<tf.Tensor: shape=(18, 11), dtype=float32>,
  positional_encoding=<tf.Tensor: shape=(18, 16), dtype=float32>,
  node_random_features=<tf.Tensor: shape=(18, 11), dtype=float32>,
  edge_dst=<tf.Tensor: shape=(34,), dtype=int32>,
  edge_src=<tf.Tensor: shape=(34,), dtype=int32>,
  graph_indicator=<tf.Tensor: shape=(18,), dtype=int32>)

GraphTensor(
  node_feature=<tf.Tensor: shape=(18, 11), dtype=float32>,
  positional_encoding=<tf.Tensor: shape=(18, 16), dtype=float32>,
  node_random_features=<tf.Tensor: shape=(18, 11), dtype=float32>,
  edge_dst=<tf.Tensor: shape=(34,), dtype=int32>,
  edge_src=<tf.Tens

### Concatenating **GraphTensor**s

In [29]:
print(tf.concat([
    graph_tensor, 
    graph_tensor], axis=0), end='\n\n')

print(tf.concat([
    graph_tensor.separate(), 
    graph_tensor.separate()], axis=0), end='\n\n')

GraphTensor(
  node_feature=<tf.Tensor: shape=(74, 11), dtype=float32>,
  positional_encoding=<tf.Tensor: shape=(74, 16), dtype=float32>,
  edge_dst=<tf.Tensor: shape=(152,), dtype=int32>,
  edge_src=<tf.Tensor: shape=(152,), dtype=int32>,
  graph_indicator=<tf.Tensor: shape=(74,), dtype=int32>,
  node_random_features=<tf.Tensor: shape=(74, 11), dtype=float32>)

GraphTensor(
  node_feature=<tf.RaggedTensor: shape=(2, None, 11), dtype=float32>,
  positional_encoding=<tf.RaggedTensor: shape=(2, None, 16), dtype=float32>,
  node_random_features=<tf.RaggedTensor: shape=(2, None, 11), dtype=float32>,
  edge_dst=<tf.RaggedTensor: shape=(2, None), dtype=int32>,
  edge_src=<tf.RaggedTensor: shape=(2, None), dtype=int32>)



### Spec of **GraphTensor**

With the `GraphTensor` an associated `GraphTensorSpec` exist. The `GraphTensorSpec` can be obtained via the `.spec` or `.unspecific_spec` property. The latter is recommended, namely, to leave the outermost dimension "unknown" (None), as the number of nodes and edges vary from input to input. 

In [30]:
print('spec:', graph_tensor.spec, end='\n\n')
print('unspecific spec:', graph_tensor.unspecific_spec, end='\n\n')

spec: GraphTensorSpec({'node_feature': TensorSpec(shape=(37, 11), dtype=tf.float32, name=None), 'positional_encoding': TensorSpec(shape=(37, 16), dtype=tf.float32, name=None), 'edge_dst': TensorSpec(shape=(76,), dtype=tf.int32, name=None), 'edge_src': TensorSpec(shape=(76,), dtype=tf.int32, name=None), 'graph_indicator': TensorSpec(shape=(37,), dtype=tf.int32, name=None), 'node_random_features': TensorSpec(shape=(37, 11), dtype=tf.float32, name=None)}, TensorShape([37, 11]), tf.float32)

unspecific spec: GraphTensorSpec({'node_feature': TensorSpec(shape=(None, 11), dtype=tf.float32, name=None), 'positional_encoding': TensorSpec(shape=(None, 16), dtype=tf.float32, name=None), 'edge_dst': TensorSpec(shape=(None,), dtype=tf.int32, name=None), 'edge_src': TensorSpec(shape=(None,), dtype=tf.int32, name=None), 'graph_indicator': TensorSpec(shape=(None,), dtype=tf.int32, name=None), 'node_random_features': TensorSpec(shape=(None, 11), dtype=tf.float32, name=None)}, TensorShape([None, 11]), tf

### Properties of **GraphTensor**

In [31]:
print('(partial) shape:', graph_tensor.shape)
print('(partial) dtype:', graph_tensor.dtype.name)
print('(partial) rank: ', graph_tensor.rank)

(partial) shape: (37, 11)
(partial) dtype: float32
(partial) rank:  2


### Passing **GraphTensor** to **tf.data.Dataset** 

The "separated" ("ragged") `GraphTensor` can be passed to a TF dataset, and subsequently batched (for modeling).

In [32]:
graph_tensor = graph_tensor.separate()
ds = tf.data.Dataset.from_tensor_slices(graph_tensor)
ds

<TensorSliceDataset element_spec=GraphTensorSpec({'node_feature': RaggedTensorSpec(TensorShape([None, 11]), tf.float32, 0, tf.int32), 'positional_encoding': RaggedTensorSpec(TensorShape([None, 16]), tf.float32, 0, tf.int32), 'node_random_features': RaggedTensorSpec(TensorShape([None, 11]), tf.float32, 0, tf.int32), 'edge_dst': RaggedTensorSpec(TensorShape([None]), tf.int32, 0, tf.int32), 'edge_src': RaggedTensorSpec(TensorShape([None]), tf.int32, 0, tf.int32)}, TensorShape([None, 11]), tf.float32)>

In [33]:
for x in ds.batch(2).map(lambda x: x.merge()).take(1):
    pass
x

GraphTensor(
  node_feature=<tf.Tensor: shape=(None, 11), dtype=float32>,
  positional_encoding=<tf.Tensor: shape=(None, 16), dtype=float32>,
  node_random_features=<tf.Tensor: shape=(None, 11), dtype=float32>,
  edge_dst=<tf.Tensor: shape=(None,), dtype=int32>,
  edge_src=<tf.Tensor: shape=(None,), dtype=int32>,
  graph_indicator=<tf.Tensor: shape=(None,), dtype=int32>)

### Passing **GraphTensor**  to GNN layers

The `GraphTensor` can be passed to GNN layers either as a single disjoint graph or subgraphs.

In [34]:
gcn_layer = GCNConv(128)

print(gcn_layer(x), end='\n\n')
print(gcn_layer(x.separate()), end='\n\n')

GraphTensor(
  node_feature=<tf.Tensor: shape=(24, 128), dtype=float32>,
  positional_encoding=<tf.Tensor: shape=(24, 16), dtype=float32>,
  node_random_features=<tf.Tensor: shape=(24, 11), dtype=float32>,
  edge_dst=<tf.Tensor: shape=(50,), dtype=int32>,
  edge_src=<tf.Tensor: shape=(50,), dtype=int32>,
  graph_indicator=<tf.Tensor: shape=(24,), dtype=int32>)

GraphTensor(
  node_feature=<tf.RaggedTensor: shape=(2, None, 128), dtype=float32>,
  positional_encoding=<tf.RaggedTensor: shape=(2, None, 16), dtype=float32>,
  node_random_features=<tf.RaggedTensor: shape=(2, None, 11), dtype=float32>,
  edge_dst=<tf.RaggedTensor: shape=(2, None), dtype=int32>,
  edge_src=<tf.RaggedTensor: shape=(2, None), dtype=int32>)



### Additional features of **GraphTensor**

In [35]:
# Convert extracted `node_feature` from tf.RaggedTensor to tf.Tensor
node_feature = graph_tensor.node_feature.merge_dims(outer_axis=0, inner_axis=1)
print('extracted node_feature shape =', node_feature.shape, end='\n\n')
# The `GraphTensor` automatically converts it back to tf.RaggedTensor
graph_tensor = graph_tensor.update({'node_feature': node_feature})
print(graph_tensor, end='\n')

print('\n--------------------------\n')

# Keep extracted `node_feature` as tf.RaggedTensor
node_feature = graph_tensor.node_feature
print('extracted node_feature shape =', node_feature.shape, end='\n\n')
# Merge `GraphTensor` (causing nested tensors to be tf.Tensors)
graph_tensor = graph_tensor.merge()
graph_tensor = graph_tensor.update({'node_feature': node_feature})
print(graph_tensor)

extracted node_feature shape = (37, 11)

GraphTensor(
  node_feature=<tf.RaggedTensor: shape=(3, None, 11), dtype=float32>,
  positional_encoding=<tf.RaggedTensor: shape=(3, None, 16), dtype=float32>,
  node_random_features=<tf.RaggedTensor: shape=(3, None, 11), dtype=float32>,
  edge_dst=<tf.RaggedTensor: shape=(3, None), dtype=int32>,
  edge_src=<tf.RaggedTensor: shape=(3, None), dtype=int32>)

--------------------------

extracted node_feature shape = (3, None, 11)

GraphTensor(
  node_feature=<tf.Tensor: shape=(37, 11), dtype=float32>,
  positional_encoding=<tf.Tensor: shape=(37, 16), dtype=float32>,
  node_random_features=<tf.Tensor: shape=(37, 11), dtype=float32>,
  edge_dst=<tf.Tensor: shape=(76,), dtype=int32>,
  edge_src=<tf.Tensor: shape=(76,), dtype=int32>,
  graph_indicator=<tf.Tensor: shape=(37,), dtype=int32>)


### Passing **GraphTensor** to GNN models

In [37]:
model = tf.keras.Sequential([
    tf.keras.layers.Input(type_spec=graph_tensor.unspecific_spec),
    GCNConv(),
    GCNConv(),
    Readout(),
    tf.keras.layers.Dense(1)
])

y_dummy = tf.constant([5.1, 2.3, -5.1])


model.compile('sgd', 'mse')
model.fit(graph_tensor.separate(), y_dummy, epochs=5)

print('\n------------------------------\n')

dataset = tf.data.Dataset.from_tensor_slices((graph_tensor.separate(), y_dummy)).batch(3)
model.fit(dataset, epochs=5);

Epoch 1/5
Epoch 2/5
Epoch 3/5
Epoch 4/5
Epoch 5/5

------------------------------

Epoch 1/5
Epoch 2/5
Epoch 3/5
Epoch 4/5
Epoch 5/5
