In [1]:
#@title #### (Imports)

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

from graph_nets import blocks
from graph_nets import graphs
from graph_nets import modules
from graph_nets import utils_np
from graph_nets import utils_tf

import matplotlib.pyplot as plt
import networkx as nx
import numpy as np
import sonnet as snt
import tensorflow as tf


For more information, please see:
  * https://github.com/tensorflow/community/blob/master/rfcs/20180907-contrib-sunset.md
  * https://github.com/tensorflow/addons
If you depend on functionality not listed there, please file an issue.



In [7]:
NODES = "nodes"
EDGES = "edges"
RECEIVERS = "receivers"
SENDERS = "senders"
GLOBALS = "globals"
N_NODE = "n_node"
N_EDGE = "n_edge"
ALL_FIELDS = (NODES, EDGES, RECEIVERS, SENDERS, GLOBALS, N_NODE, N_EDGE)

In [2]:
# Global features for graph 0.
globals_0 = [1., 2., 3.]

# Node features for graph 0.
nodes_0 = [[10., 20., 30.],  # Node 0
           [11., 21., 31.],  # Node 1
           [12., 22., 32.],  # Node 2
           [13., 23., 33.],  # Node 3
           [14., 24., 34.]]  # Node 4

# Edge features for graph 0.
edges_0 = [[100., 200.],  # Edge 0
           [101., 201.],  # Edge 1
           [102., 202.],  # Edge 2
           [103., 203.],  # Edge 3
           [104., 204.],  # Edge 4
           [105., 205.]]  # Edge 5

# The sender and receiver nodes associated with each edge for graph 0.
senders_0 = [0,  # Index of the sender node for edge 0
             1,  # Index of the sender node for edge 1
             1,  # Index of the sender node for edge 2
             2,  # Index of the sender node for edge 3
             2,  # Index of the sender node for edge 4
             3]  # Index of the sender node for edge 5
receivers_0 = [1,  # Index of the receiver node for edge 0
               2,  # Index of the receiver node for edge 1
               3,  # Index of the receiver node for edge 2
               0,  # Index of the receiver node for edge 3
               3,  # Index of the receiver node for edge 4
               4]  # Index of the receiver node for edge 5

# Global features for graph 1.
globals_1 = [1001., 1002., 1003.]

# Node features for graph 1.
nodes_1 = [[1010., 1020., 1030.],  # Node 0
           [1011., 1021., 1031.]]  # Node 1

# Edge features for graph 1.
edges_1 = [[1100., 1200.],  # Edge 0
           [1101., 1201.],  # Edge 1
           [1102., 1202.],  # Edge 2
           [1103., 1203.]]  # Edge 3

# The sender and receiver nodes associated with each edge for graph 1.
senders_1 = [0,  # Index of the sender node for edge 0
             0,  # Index of the sender node for edge 1
             1,  # Index of the sender node for edge 2
             1]  # Index of the sender node for edge 3
receivers_1 = [0,  # Index of the receiver node for edge 0
               1,  # Index of the receiver node for edge 1
               0,  # Index of the receiver node for edge 2
               0]  # Index of the receiver node for edge 3

data_dict_0 = {
    "globals": globals_0,
    "nodes": nodes_0,
    "edges": edges_0,
    "senders": senders_0,
    "receivers": receivers_0
}

data_dict_1 = {
    "globals": globals_1,
    "nodes": nodes_1,
    "edges": edges_1,
    "senders": senders_1,
    "receivers": receivers_1
}

In [3]:
data_dict_list = [data_dict_0, data_dict_1]
graphs_tuple = utils_np.data_dicts_to_graphs_tuple(data_dict_list)

In [8]:
def print_graphs_tuple(graphs_tuple):
    graph_dtypes = graphs_tuple.map(
      lambda v: tf.as_dtype(v.dtype) if v is not None else None, ALL_FIELDS)
  
    graph_shapes = graphs_tuple.map(
      lambda v: list(v.shape) if v is not None else None, ALL_FIELDS)
    print("graph_dtypes:",graph_dtypes)
    print("graph_shapes:",graph_shapes)
    
    print("Shapes of `GraphsTuple`'s fields:")
    print(graphs_tuple.map(lambda x: x if x is None else x.shape, fields=graphs.ALL_FIELDS))
    print("\nData contained in `GraphsTuple`'s fields:")
    print("globals:\n{}".format(graphs_tuple.globals))
    print("nodes:\n{}".format(graphs_tuple.nodes))
    print("edges:\n{}".format(graphs_tuple.edges))
    print("senders:\n{}".format(graphs_tuple.senders))
    print("receivers:\n{}".format(graphs_tuple.receivers))
    print("n_node:\n{}".format(graphs_tuple.n_node))
    print("n_edge:\n{}".format(graphs_tuple.n_edge))

In [9]:
print_graphs_tuple(graphs_tuple)

graph_dtypes: GraphsTuple(nodes=tf.float64, edges=tf.float64, receivers=tf.int32, senders=tf.int32, globals=tf.float64, n_node=tf.int32, n_edge=tf.int32)
graph_shapes: GraphsTuple(nodes=[7, 3], edges=[10, 2], receivers=[10], senders=[10], globals=[2, 3], n_node=[2], n_edge=[2])
Shapes of `GraphsTuple`'s fields:
GraphsTuple(nodes=(7, 3), edges=(10, 2), receivers=(10,), senders=(10,), globals=(2, 3), n_node=(2,), n_edge=(2,))

Data contained in `GraphsTuple`'s fields:
globals:
[[1.000e+00 2.000e+00 3.000e+00]
 [1.001e+03 1.002e+03 1.003e+03]]
nodes:
[[  10.   20.   30.]
 [  11.   21.   31.]
 [  12.   22.   32.]
 [  13.   23.   33.]
 [  14.   24.   34.]
 [1010. 1020. 1030.]
 [1011. 1021. 1031.]]
edges:
[[ 100.  200.]
 [ 101.  201.]
 [ 102.  202.]
 [ 103.  203.]
 [ 104.  204.]
 [ 105.  205.]
 [1100. 1200.]
 [1101. 1201.]
 [1102. 1202.]
 [1103. 1203.]]
senders:
[0 1 1 2 2 3 5 5 6 6]
receivers:
[1 2 3 0 3 4 5 6 5 5]
n_node:
[5 2]
n_edge:
[6 4]
