# Models

Some example data to show model inputs.

In [22]:
%%capture
import keras_core

## Functional API

Like most models in `kgcnn.literature` the models can be set up with the `keras` functional API. Here an example for a simple message passing GNN. The layers are taken from `kgcnn.layers` . See documentation of layers for further details.

In [81]:
from kgcnn.layers.casting import CastBatchedIndicesToDisjoint
from kgcnn.layers.gather import GatherNodes
from kgcnn.layers.pooling import PoolingNodes
from kgcnn.layers.aggr import AggregateLocalEdges

ns = ks.layers.Input(shape=(None, 1), dtype="float32")
e_idx = ks.layers.Input(shape=(None, 2), dtype="int64")
total_n = ks.layers.Input(shape=(), dtype="int64")  # Or mask
total_e = ks.layers.Input(shape=(), dtype="int64")  # Or mask

# Model is build with padded input.
n, idx, batch_id, _, _, _, _, _ = CastBatchedIndicesToDisjoint()([ns, e_idx, total_n, total_e])
n_in_out = GatherNodes()([n, idx])
node_messages = ks.layers.Dense(64, activation='relu')(n_in_out)
node_updates = AggregateLocalEdges()([n, node_messages, idx])
n_node_updates = ks.layers.Concatenate()([n, node_updates])
n_embedding = ks.layers.Dense(1)(n_node_updates)
g_embedding = PoolingNodes()([total_n, n_embedding, batch_id])

message_passing = ks.models.Model(inputs=[ns, e_idx, total_n, total_e], outputs=g_embedding)

## Subclassing Model

A model can be constructed by subclassing from `keras.models.Model` where the call method must be implemented. 

In [82]:
class MessagePassingModel(ks.models.Model):

    def __init__(self):
        super().__init__()
        self._layer_casting = CastBatchedIndicesToDisjoint()
        self._layer_gather_nodes = GatherNodes()
        self._layer_dense = ks.layers.Dense(64, activation='relu')
        self._layer_aggregate_edges = AggregateLocalEdges()
        self._layer_cat = ks.layers.Concatenate(axis=-1)
        self._layer_dense_last = ks.layers.Dense(1)
        self._layer_pool_nodes = PoolingNodes()

    def build(self, input_shape):
        super().build(input_shape)
        
    def call(self, inputs, **kwargs):
        n, idx, batch_id, _, _, _, total_n, _ = self._layer_casting(inputs)
        n_in_out = self._layer_gather_nodes([n, idx])
        node_messages = self._layer_dense(n_in_out)
        node_updates = self._layer_aggregate_edges([n, node_messages, idx])
        n_node_updates = self._layer_cat([n, node_updates])
        n_embedding = self._layer_dense_last(n_node_updates)
        g_embedding = self._layer_pool_nodes([total_n, n_embedding, batch_id])
        return g_embedding

message_passing_2 = MessagePassingModel()

## Templates

Also layers can be further subclassed to create a GNN, for example of the message passing base layer. Where only `message_function` and `update_nodes` must be implemented.

## Loading options

There are many options to load data to a keras model, which depend on the size and location of the data to pass to the model. There may differences in speed and utility depending on the loading method. For more examples, please find https://github.com/aimat-lab/gcnn_keras/blob/master/notebooks/tutorial_model_loading_options.ipynb .

##### 1. Padded Tensor

The most simple way to pass tensors to the model is to simply pad to same size tensor. For the model input further information is required on the padding. Either a length tensor or a mask.

In [83]:
from keras_core import ops
example_nodes = ops.convert_to_tensor([[[1.], [2.]], [[1.0], [0.0]], [[2.0], [0.0]]])
example_indices = ops.convert_to_tensor([[[0, 1], [1, 0], [1,1]], [[0, 0], [0, 0], [0, 0]], [[0, 0], [0, 0], [0, 0]]], dtype="int64")
example_total_nodes = ops.convert_to_tensor([2, 1, 1], dtype="int64")
example_total_edges = ops.convert_to_tensor([3, 1, 1], dtype="int64")

In [84]:
message_passing.predict([example_nodes, example_indices, example_total_nodes, example_total_edges])

[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 224ms/step


array([[-0.10487814],
       [-0.1265333 ],
       [-0.2530666 ]], dtype=float32)

##### 2. Ragged input

> **NOTE**: You can find this page as jupyter notebook in https://github.com/aimat-lab/gcnn_keras/tree/master/docs/source