GraphNetLib is a library for making graph networks with Tensorflow 2.x.
Graph networks operates on graphs. They are fed with a graph and output a graph as well. A graph is a structure having node (V), edge(E) and global(u) features. To learn more about graph networks, see paper: Relational inductive biases, deep learning, and graph networks.
GraphNetLib can be installed using pip
. This installation is compatible with Linux/Mac OS X, and Python 3.6+.
Please note that package is not currently registered in PyPI and require installation from this repository directly.
python3 -m pip install git+https://github.com/Rufaim/graph_net_lib
The package supports both CPU and GPU versions of Tensorflow.
Let's consider a small usage example of creating a graph data and processing it with a graph network.
import tensorflow as tf
import graphnetlib as gnl
# generate your node features
nodes_ = get_graph_nodes()
# generate edge features.
# receivers and senders are 1d tensors of integers representing indexes of
# corresponding outputting and receiving nodes.
edges_, senders_, receivers_ = get_graph_edges()
# generate global features
globals_ = get_graph_features()
# create the Graph structure
graph_data = gnl.Graph(nodes_,
edges_,
globals_,
receivers_,
senders_,
tf.shape(nodes_)[:1], # number of nodes in each graph in the batch. In the example we consider only one graph
tf.shape(edges_)[:1]) # number of edges in each graph in the batch. In the example we consider only one graph
# Create the graph network
np = gnl.NodeProcessor(tf.keras.Sequential([
tf.keras.layers.Dense(32,activation=tf.nn.relu),
tf.keras.layers.Dense(32,activation=tf.nn.relu),
tf.keras.layers.LayerNormalization()
]))
# Setup of all processors is very explicit to give user more of control
ep = gnl.EdgeProcessor(tf.keras.Sequential([
tf.keras.layers.Dense(32,activation=tf.nn.relu),
tf.keras.layers.Dense(32,activation=tf.nn.relu),
tf.keras.layers.LayerNormalization()
]))
gp = gnl.GlobalProcessor(tf.keras.Sequential([
tf.keras.layers.Dense(32,activation=tf.nn.relu),
tf.keras.layers.Dense(32,activation=tf.nn.relu),
tf.keras.layers.LayerNormalization()
]))
graph_network = gnl.GraphNetwork(node_processor=np,
edge_processor=ep,
global_processor=gp)
# Process the graph with the graph network
processed_graph_data = graph_network(graph_data)
The repository includes three demo examples of how to use the package. Those demos are similar to DeepMind's Graph Nets library
The "shortest path demo" shows how to train graph network to label nodes and edges on the shortest path between two nodes. Data are generated randomly, but it is always ensured that final graph is connected.
Over a sequence of message-passing steps, the model refines its prediction of the shortest path.
The "physics demo" predicts a physics of a randomly generated mass-spring systems. A graph network is trained to predict the evolution of the system after a fixed timestep. The network predicitons are fed to the network to rollout the whole dynamics of the system.
The "sort demo" graph network is trained to sort a list of random numbers.
The network is trained to classify edges if a sender node (columns in the figure) is standing before a receiver node (rows) in the sorted list.
True conntections | Predicted conntections |
---|---|
This implementation is based on adapted code of the DeepMind's Graph Nets library. All rights for the original implementation belong to DeepMind.