In [1]:
import sys

## Adds the module to path
sys.path.append("..")

# Model example

## 1. Setup
Imports needed for this tutorial

In [2]:
import graphtrack as gt 
import numpy as np

import logging
logging.disable(logging.WARNING) 

## 2. Define the neural network model

The neural network architecture used is a message passing graph neural network. We create this model by calling the function `mpGraphNet()`.

In [3]:
model = gt.gnns.mpGraphNet(
    dense_layer_dimensions=(32, 64, 96),
    base_layer_dimensions=(96, 96),
    number_of_outputs=2,
    output_activation="sigmoid"
)

## 3. Test model

Now, we evaluate the model on randomly-generated node features, edge features, and connection matrix. 

In [4]:
_printer = "{0} shape: {1}"

# We evaluate the model on randomly-generated node features,
# edge features, and connection matrix.

# Define nodes. Random matrix with 5 nodes,64 node features,
# and batch size of 8.
nodes = np.random.rand(8, 5, 64)
print(_printer.format("Nodes", nodes.shape))

# Define edges. Random connection matrix with 10 edges, and batch size
# of 8. First column is the sender, and
# second column is the receiver.
edges = np.random.randint(low=0, high=5, size=(8, 10, 2))
print(_printer.format("Edges", edges.shape))

# Define edge features. Random matrix with 10 edges, 32 edge features,
# and batch size of 8.
edge_f = np.random.rand(8, 10, 32)
print(_printer.format("Edge features", edge_f.shape))

Nodes shape: (8, 5, 64)
Edges shape: (8, 10, 2)
Edge features shape: (8, 10, 32)


In [5]:
# Evaluate the model
output_nodes, output_edge_f = model([nodes, edge_f, edges])

print(_printer.format("Output nodes", output_nodes.shape))
print(_printer.format("Output edge features", output_edge_f.shape))

Output nodes shape: (8, 5, 2)
Output edge features shape: (8, 10, 2)
