# Graph tensor

### Import modules

In [1]:
from molgraph import chemistry
from molgraph import layers
from molgraph import GraphTensor #####

import tensorflow as tf
from tensorflow import keras

### Construct a `GraphTensor`

Although a `GraphTensor` can be constructed directly from its constructor, here we construct a `GraphTensor` from a `MolecularGraphEncoder`.

In [2]:
atom_encoder = chemistry.Featurizer([
    chemistry.features.Symbol({'C', 'N', 'O'}, oov_size=1),
    chemistry.features.Hybridization({'SP', 'SP2', 'SP3'}, oov_size=1),
    chemistry.features.HydrogenDonor(),
    chemistry.features.HydrogenAcceptor(),
    chemistry.features.Hetero()
])

bond_encoder = chemistry.Featurizer([
    chemistry.features.BondType({'SINGLE', 'DOUBLE', 'TRIPLE', 'AROMATIC'}),
    chemistry.features.Rotatable()
])

mol_encoder = chemistry.MolecularGraphEncoder(
    atom_encoder, bond_encoder, positional_encoding_dim=None)

smiles_list = [
    'OCC1OC(C(C1O)O)n1cnc2c1ncnc2N',
    'C(C(=O)O)N',
    'C1=CC(=CC=C1CC(C(=O)O)N)O'
]

graph_tensor = mol_encoder(smiles_list)

print(graph_tensor)

GraphTensor(
  sizes=<tf.Tensor: shape=(3,), dtype=int64>,
  node_feature=<tf.Tensor: shape=(37, 11), dtype=float32>,
  edge_src=<tf.Tensor: shape=(76,), dtype=int32>,
  edge_dst=<tf.Tensor: shape=(76,), dtype=int32>,
  edge_feature=<tf.Tensor: shape=(76, 5), dtype=float32>)


### `.separate()` &ndash; Separate subgraphs of `GraphTensor`

In [3]:
graph_tensor = graph_tensor.separate()
print(graph_tensor)

GraphTensor(
  sizes=<tf.Tensor: shape=(3,), dtype=int64>,
  node_feature=<tf.RaggedTensor: shape=(3, None, 11), dtype=float32, ragged_rank=1>,
  edge_src=<tf.RaggedTensor: shape=(3, None), dtype=int32, ragged_rank=1>,
  edge_dst=<tf.RaggedTensor: shape=(3, None), dtype=int32, ragged_rank=1>,
  edge_feature=<tf.RaggedTensor: shape=(3, None, 5), dtype=float32, ragged_rank=1>)


### `.merge()` &ndash; Merge subgraphs of `GraphTensor`

In [4]:
graph_tensor = graph_tensor.merge()
print(graph_tensor)

GraphTensor(
  sizes=<tf.Tensor: shape=(3,), dtype=int64>,
  node_feature=<tf.Tensor: shape=(37, 11), dtype=float32>,
  edge_src=<tf.Tensor: shape=(76,), dtype=int32>,
  edge_dst=<tf.Tensor: shape=(76,), dtype=int32>,
  edge_feature=<tf.Tensor: shape=(76, 5), dtype=float32>)


### `.propagate()` &ndash; Propagate node informaton with the `GraphTensor`

In [5]:
print('Node features before:\n', graph_tensor.node_feature, end='\n\n')
graph_tensor = graph_tensor.propagate()
print('Node features after:\n', graph_tensor.node_feature)

Node features before:
 tf.Tensor(
[[0. 0. 0. 1. 0. 0. 0. 1. 1. 1. 1.]
 [0. 1. 0. 0. 0. 0. 0. 1. 0. 0. 0.]
 [0. 1. 0. 0. 0. 0. 0. 1. 0. 0. 0.]
 [0. 0. 0. 1. 0. 0. 0. 1. 0. 1. 1.]
 [0. 1. 0. 0. 0. 0. 0. 1. 0. 0. 0.]
 [0. 1. 0. 0. 0. 0. 0. 1. 0. 0. 0.]
 [0. 1. 0. 0. 0. 0. 0. 1. 0. 0. 0.]
 [0. 0. 0. 1. 0. 0. 0. 1. 1. 1. 1.]
 [0. 0. 0. 1. 0. 0. 0. 1. 1. 1. 1.]
 [0. 0. 1. 0. 0. 0. 1. 0. 0. 1. 1.]
 [0. 1. 0. 0. 0. 0. 1. 0. 0. 0. 0.]
 [0. 0. 1. 0. 0. 0. 1. 0. 0. 1. 1.]
 [0. 1. 0. 0. 0. 0. 1. 0. 0. 0. 0.]
 [0. 1. 0. 0. 0. 0. 1. 0. 0. 0. 0.]
 [0. 0. 1. 0. 0. 0. 1. 0. 0. 1. 1.]
 [0. 1. 0. 0. 0. 0. 1. 0. 0. 0. 0.]
 [0. 0. 1. 0. 0. 0. 1. 0. 0. 1. 1.]
 [0. 1. 0. 0. 0. 0. 1. 0. 0. 0. 0.]
 [0. 0. 1. 0. 0. 0. 1. 0. 1. 1. 1.]
 [0. 1. 0. 0. 0. 0. 0. 1. 0. 0. 0.]
 [0. 1. 0. 0. 0. 0. 1. 0. 0. 0. 0.]
 [0. 0. 0. 1. 0. 0. 1. 0. 0. 1. 1.]
 [0. 0. 0. 1. 0. 0. 1. 0. 1. 0. 1.]
 [0. 0. 1. 0. 0. 0. 0. 1. 1. 1. 1.]
 [0. 1. 0. 0. 0. 0. 1. 0. 0. 0. 0.]
 [0. 1. 0. 0. 0. 0. 1. 0. 0. 0. 0.]
 [0. 1. 0. 0. 0. 0. 1. 0. 0. 0

### `.update()` &ndash; Update data of the `GraphTensor`


In [6]:
node_supplementary_data = tf.random.uniform(
    shape=graph_tensor.node_feature.shape[:-1] + [1])

node_feature_updated = tf.random.uniform(
    shape=graph_tensor.node_feature.shape[:-1] + [128])

# Both add new data and update existing data of the GraphTensor:
graph_tensor = graph_tensor.update({
    'node_supplementary_data': node_supplementary_data, 
    'node_feature': node_feature_updated
})
print(graph_tensor)

GraphTensor(
  sizes=<tf.Tensor: shape=(3,), dtype=int64>,
  node_feature=<tf.Tensor: shape=(37, 128), dtype=float32>,
  edge_src=<tf.Tensor: shape=(76,), dtype=int32>,
  edge_dst=<tf.Tensor: shape=(76,), dtype=int32>,
  edge_feature=<tf.Tensor: shape=(76, 5), dtype=float32>,
  node_supplementary_data=<tf.Tensor: shape=(37, 1), dtype=float32>)


### `.remove()` &ndash; Remove data from `GraphTensor`

In [7]:
graph_tensor = graph_tensor.remove([
    'node_supplementary_data', 'edge_feature'
])
print(graph_tensor)

GraphTensor(
  sizes=<tf.Tensor: shape=(3,), dtype=int64>,
  node_feature=<tf.Tensor: shape=(37, 128), dtype=float32>,
  edge_src=<tf.Tensor: shape=(76,), dtype=int32>,
  edge_dst=<tf.Tensor: shape=(76,), dtype=int32>)


### `__getitem__` &ndash; Index lookup on the `GraphTensor`

The `GraphTensor` can be indexed either by passing a `str` (to obtain a specific field of `GraphTensor`) or `int`, `list[int]` or `slice` (to extract specific subgraphs (molecules) from `GraphTensor`). (Alternatively, the `GraphTensor` can be passed to `tf.gather` to extract specific subgraphs.)

In [8]:
print("Complete graph:\n")
print("---" * 20)
print(graph_tensor, end='\n\n')

print("---" * 20)
print("Subgraph (2) and (3) of graph:\n")
print(graph_tensor[[1, 2]], end='\n\n')

print("---" * 20)
print("Subgraph (2) and (3) of graph:\n")
print(graph_tensor[:2], end='\n\n')

Complete graph:

------------------------------------------------------------
GraphTensor(
  sizes=<tf.Tensor: shape=(3,), dtype=int64>,
  node_feature=<tf.Tensor: shape=(37, 128), dtype=float32>,
  edge_src=<tf.Tensor: shape=(76,), dtype=int32>,
  edge_dst=<tf.Tensor: shape=(76,), dtype=int32>)

------------------------------------------------------------
Subgraph (2) and (3) of graph:

GraphTensor(
  sizes=<tf.Tensor: shape=(2,), dtype=int64>,
  node_feature=<tf.Tensor: shape=(18, 128), dtype=float32>,
  edge_src=<tf.Tensor: shape=(34,), dtype=int32>,
  edge_dst=<tf.Tensor: shape=(34,), dtype=int32>)

------------------------------------------------------------
Subgraph (2) and (3) of graph:

GraphTensor(
  sizes=<tf.Tensor: shape=(2,), dtype=int64>,
  node_feature=<tf.Tensor: shape=(24, 128), dtype=float32>,
  edge_src=<tf.Tensor: shape=(50,), dtype=int32>,
  edge_dst=<tf.Tensor: shape=(50,), dtype=int32>)



### `__getattr__` &ndash; Attribute lookup on the `GraphTensor`

In [9]:
print("Access `node_feature` field:\n")
print("---" * 20)
print(graph_tensor.node_feature, end='\n\n')

print("---" * 20)
print("Access `edge_src` field:\n")
print(graph_tensor.edge_src, end='\n\n')

print("---" * 20)
print("Access `graph_indicator` field:\n")
print(graph_tensor.graph_indicator, end='\n\n')

Access `node_feature` field:

------------------------------------------------------------
tf.Tensor(
[[0.30606592 0.01332998 0.28550065 ... 0.30522108 0.43709052 0.2496804 ]
 [0.47505558 0.6802629  0.12628877 ... 0.54731417 0.85908985 0.01080072]
 [0.32505012 0.16541815 0.9268564  ... 0.19977057 0.6975106  0.63107324]
 ...
 [0.06981373 0.0497787  0.7329197  ... 0.72168195 0.992267   0.4002931 ]
 [0.6254629  0.77454865 0.4750824  ... 0.21217322 0.10769343 0.71567035]
 [0.29524624 0.7836231  0.7198993  ... 0.94255567 0.926514   0.62505746]], shape=(37, 128), dtype=float32)

------------------------------------------------------------
Access `edge_src` field:

tf.Tensor(
[ 0  1  1  2  2  2  3  3  4  4  4  5  5  5  6  6  6  7  8  9  9  9 10 10
 11 11 12 12 12 13 13 13 14 14 15 15 16 16 17 17 17 18 19 19 20 20 20 21
 22 23 24 24 25 25 26 26 26 27 27 28 28 29 29 29 30 30 31 31 31 32 32 32
 33 34 35 36], shape=(76,), dtype=int32)

------------------------------------------------------------


### `tf.concat` &ndash; Concatenating multiple `GraphTensor` instances

In [10]:
print("Concatenating two graphs in non-ragged states:\n")
graph_tensor_concat = tf.concat([
    graph_tensor, 
    graph_tensor], axis=0)
print(graph_tensor_concat, end='\n\n')
print("Inspect `graph_indicator`:\n")
print(graph_tensor_concat.graph_indicator, end='\n\n')

print('---' * 20)
print("Concatenating two graphs in ragged states")
graph_tensor_concat = tf.concat([
    graph_tensor.separate(), 
    graph_tensor.separate()], axis=0)
print(graph_tensor_concat, end='\n\n')

Concatenating two graphs in non-ragged states:

GraphTensor(
  sizes=<tf.Tensor: shape=(6,), dtype=int64>,
  node_feature=<tf.Tensor: shape=(74, 128), dtype=float32>,
  edge_src=<tf.Tensor: shape=(152,), dtype=int32>,
  edge_dst=<tf.Tensor: shape=(152,), dtype=int32>)

Inspect `graph_indicator`:

tf.Tensor(
[0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 2 2 2 2 2 2 2 2 2 2 2 2 2
 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 4 4 4 4 4 5 5 5 5 5 5 5 5 5 5 5 5 5], shape=(74,), dtype=int64)

------------------------------------------------------------
Concatenating two graphs in ragged states
GraphTensor(
  sizes=<tf.Tensor: shape=(6,), dtype=int64>,
  node_feature=<tf.RaggedTensor: shape=(6, None, 128), dtype=float32, ragged_rank=1>,
  edge_src=<tf.RaggedTensor: shape=(6, None), dtype=int32, ragged_rank=1>,
  edge_dst=<tf.RaggedTensor: shape=(6, None), dtype=int32, ragged_rank=1>)



### `tf.split` &ndash; Splits a `GraphTensor` into multiple `GraphTensor` instances

In [11]:
tf.split(graph_tensor_concat.merge(), num_or_size_splits=6)

[GraphTensor(
   sizes=<tf.Tensor: shape=(1,), dtype=int64>,
   node_feature=<tf.Tensor: shape=(19, 128), dtype=float32>,
   edge_src=<tf.Tensor: shape=(42,), dtype=int32>,
   edge_dst=<tf.Tensor: shape=(42,), dtype=int32>),
 GraphTensor(
   sizes=<tf.Tensor: shape=(1,), dtype=int64>,
   node_feature=<tf.Tensor: shape=(5, 128), dtype=float32>,
   edge_src=<tf.Tensor: shape=(8,), dtype=int32>,
   edge_dst=<tf.Tensor: shape=(8,), dtype=int32>),
 GraphTensor(
   sizes=<tf.Tensor: shape=(1,), dtype=int64>,
   node_feature=<tf.Tensor: shape=(13, 128), dtype=float32>,
   edge_src=<tf.Tensor: shape=(26,), dtype=int32>,
   edge_dst=<tf.Tensor: shape=(26,), dtype=int32>),
 GraphTensor(
   sizes=<tf.Tensor: shape=(1,), dtype=int64>,
   node_feature=<tf.Tensor: shape=(19, 128), dtype=float32>,
   edge_src=<tf.Tensor: shape=(42,), dtype=int32>,
   edge_dst=<tf.Tensor: shape=(42,), dtype=int32>),
 GraphTensor(
   sizes=<tf.Tensor: shape=(1,), dtype=int64>,
   node_feature=<tf.Tensor: shape=(5, 128)

### `.spec` &ndash; The spec of the `GraphTensor`

In [12]:
print(graph_tensor.spec)

GraphTensor.Spec(sizes=TensorSpec(shape=(None,), dtype=tf.int64, name=None), node_feature=TensorSpec(shape=(None, 128), dtype=tf.float32, name=None), edge_src=TensorSpec(shape=(None,), dtype=tf.int32, name=None), edge_dst=TensorSpec(shape=(None,), dtype=tf.int32, name=None), edge_feature=None, edge_weight=None, node_position=None, auxiliary={})


### `.shape` &ndash; Partial shape of the `GraphTensor`

In [13]:
print('(partial) shape:', graph_tensor.shape)

(partial) shape: (3, None, 128)


### `.dtype` &ndash; Partial dtype of the `GraphTensor`

In [14]:
print('(partial) dtype:', graph_tensor.dtype.name)

(partial) dtype: float32


### `.rank` &ndash; Partial rank of the `GraphTensor` 

In [15]:
print('(partial) rank: ', graph_tensor.rank)

(partial) rank:  3


### `tf.data.Dataset` &ndash; Creating a TF dataset from a `GraphTensor`

In [16]:
ds = tf.data.Dataset.from_tensor_slices(graph_tensor)
print('Dataset object:\n', ds)

print('\n' + '---' * 20)
# Loop over dataset
for i, x in enumerate(ds.batch(2).map(lambda x: x)):
    print(f"\nbatch {i + 1}:\n")
    print(x)
    print('\n' + '---' * 20)

Dataset object:
 <_TensorSliceDataset element_spec=GraphTensor.Spec(sizes=TensorSpec(shape=(), dtype=tf.int64, name=None), node_feature=TensorSpec(shape=(None, 128), dtype=tf.float32, name=None), edge_src=TensorSpec(shape=(None,), dtype=tf.int32, name=None), edge_dst=TensorSpec(shape=(None,), dtype=tf.int32, name=None), edge_feature=None, edge_weight=None, node_position=None, auxiliary={})>

------------------------------------------------------------

batch 1:

GraphTensor(
  sizes=<tf.Tensor: shape=(2,), dtype=int64>,
  node_feature=<tf.Tensor: shape=(24, 128), dtype=float32>,
  edge_src=<tf.Tensor: shape=(50,), dtype=int32>,
  edge_dst=<tf.Tensor: shape=(50,), dtype=int32>)

------------------------------------------------------------

batch 2:

GraphTensor(
  sizes=<tf.Tensor: shape=(1,), dtype=int64>,
  node_feature=<tf.Tensor: shape=(13, 128), dtype=float32>,
  edge_src=<tf.Tensor: shape=(26,), dtype=int32>,
  edge_dst=<tf.Tensor: shape=(26,), dtype=int32>)

---------------------

### `layers` &ndash; Passing a `GraphTensor`  to a layer

The `GraphTensor` can be passed to GNN layers either as a single disjoint graph or subgraphs.

In [17]:
gin_conv = layers.GINConv(128)

print("Pass GraphTensor in non-ragged state:\n")
print(gin_conv(graph_tensor), end='\n\n')
print('---' * 20)
print('\nPass GraphTensor in ragged state:\n')
print(gin_conv(graph_tensor.separate()), end='\n\n')

Pass GraphTensor in non-ragged state:

GraphTensor(
  sizes=<tf.Tensor: shape=(3,), dtype=int64>,
  node_feature=<tf.Tensor: shape=(37, 128), dtype=float32>,
  edge_src=<tf.Tensor: shape=(76,), dtype=int32>,
  edge_dst=<tf.Tensor: shape=(76,), dtype=int32>)

------------------------------------------------------------

Pass GraphTensor in ragged state:

GraphTensor(
  sizes=<tf.Tensor: shape=(3,), dtype=int64>,
  node_feature=<tf.RaggedTensor: shape=(3, None, 128), dtype=float32, ragged_rank=1>,
  edge_src=<tf.RaggedTensor: shape=(3, None), dtype=int32, ragged_rank=1>,
  edge_dst=<tf.RaggedTensor: shape=(3, None), dtype=int32, ragged_rank=1>)



### `Model` &ndash; Passing a `GraphTensor` to a model

In [18]:
model = tf.keras.Sequential([
    layers.GCNConv(),
    layers.GCNConv(),
    layers.Readout(),
    keras.layers.Dense(1)
])

y_dummy = tf.constant([[1.], [2.], [3.]])


model.compile('adam', 'huber')
print("Using (graph_tensor, label) pair as input:\n")
model.fit(graph_tensor, y_dummy, batch_size=2, epochs=5)

print('\n------------------------------\n')
print("Using tf.data.Dataset as input:\n")
dataset = tf.data.Dataset.from_tensor_slices((graph_tensor, y_dummy))
model.fit(dataset.batch(2), epochs=5);

Using (graph_tensor, label) pair as input:

Epoch 1/5
Epoch 2/5
Epoch 3/5
Epoch 4/5
Epoch 5/5

------------------------------

Using tf.data.Dataset as input:

Epoch 1/5
Epoch 2/5
Epoch 3/5
Epoch 4/5
Epoch 5/5
