In [1]:
import networkx as nx
import matplotlib.pyplot as plt
import tensorflow_gnn as tfgnn
import numpy as np
import tensorflow as tf

In [None]:
graphs = []

for i in range(100):
    g = nx.random_tree(np.random.poisson(10))

    graphs.append((1, g))
for i in range(100):
    g = nx.erdos_renyi_graph(np.random.poisson(10),0.2)
    graphs.append((0,g))

In [None]:
def nx_graph_to_tfgnn(graph, label) -> tfgnn.GraphTensor : 
    n_nodes = len(graph.nodes)
    n_edges = len(graph.edges)

    tf_g = tfgnn.GraphTensor.from_pieces(
    node_sets={
        "normal": tfgnn.NodeSet.from_fields(
            sizes=tf.constant([n_nodes]),
            features={
                "size":tf.random.normal((n_nodes,3))
                # "size": tf.ones((n_nodes, 3))
                # tfgnn.HIDDEN_STATE:tf.TensorSpec((None,7), tf.Float32)
            }
        )
    },

    edge_sets={
        "connects": tfgnn.EdgeSet.from_fields(
            sizes=tf.constant([n_edges]),
            adjacency=tfgnn.Adjacency.from_indices(
                target=("normal",[i[1] for i in list(graph.edges)]),
                source=("normal",[i[0] for i in list(graph.edges)]),
            )
        )
    },
    context=tfgnn.Context.from_fields(
        features={"label":[label]}
    )
    )
    
    return tf_g

In [None]:
g = nx_graph_to_tfgnn(graphs[0][1], graphs[0][0])

In [None]:
g = g.merge_batch_to_components()

In [None]:
import networkx as nx
import matplotlib.pyplot as plt
import tensorflow_gnn as tfgnn
from tensorflow_gnn import runner

In [None]:
graphs = []

for i in range(1):
    g = nx.random_tree(np.random.poisson(10))
    graphs.append(("tree", g))

for i in range(1):
    g = nx.erdos_renyi_graph(np.random.poisson(10),0.2)
    graphs.append(("erdos",g))

In [None]:
fig, ax = plt.subplots(2,1)
nx.draw(graphs[0][1], ax=ax[0])
nx.draw(graphs[1][1], ax=ax[1])

In [None]:
node_set_lengths = [[len(graph[1].nodes)] for graph in graphs]
edge_set_lengths = [[len(graph[1].edges)] for graph in graphs]

In [None]:
sources = [[node[0] for node in list(graph[1].edges)] for graph in graphs]
targets = [[node[1] for node in list(graph[1].edges)] for graph in graphs]
means = [3,10]

In [None]:
means

[3, 10]

In [None]:
tf.random.normal([10,], 0,1)

<tf.Tensor: shape=(10,), dtype=float32, numpy=
array([ 0.20341861,  1.1126    ,  0.48681453, -0.02784214, -1.9175309 ,
       -0.98004955,  2.384279  , -0.6575413 , -1.2647614 , -0.522205  ],
      dtype=float32)>

In [None]:
tf.ragged.constant(
                    [
                        tf.random.normal([10,], 0,1),
                        tf.random.normal([10,], 0,1)
                    ]
)

In [None]:
tf_g = tfgnn.GraphTensor.from_pieces(
    node_sets={
        "normal": tfgnn.NodeSet.from_fields(
            sizes=tf.constant(node_set_lengths),
            features={
                tfgnn.HIDDEN_STATE:tf.ragged.constant(
                    [
                        np.random.normal(means[i],1,node_set_lengths[i]) for i in range(len(node_set_lengths))
                    ], dtype=tf.float32
                ),

                "occupation_id":tf.ragged.constant(
                    [
                        np.random.randint(0,means[i], node_set_lengths[i]) for i in range(len(node_set_lengths))
                    ]
                )
                # "size": tf.ones((n_nodes, 3))
                # tfgnn.HIDDEN_STATE:tf.TensorSpec((None,7), tf.Float32)
            }
        )
    },

    edge_sets={
        "connect": tfgnn.EdgeSet.from_fields(
            sizes=tf.constant(edge_set_lengths),
            adjacency=tfgnn.Adjacency.from_indices(
                target=("normal",tf.ragged.constant(targets)),
                source=("normal",tf.ragged.constant(sources)),
            )
        )
    },
    context=tfgnn.Context.from_fields(
        features={"label":[["tree"], ["erdos"]]}
    )
)

In [None]:
tf_g_merged = tf_g.merge_batch_to_components()

In [None]:
out = tfgnn.keras.layers.GraphUpdate(
    node_sets={
        "normal":tfgnn.keras.layers.NodeSetUpdate(
            {
                "connect":tfgnn.keras.layers.SimpleConv(
                    tf.keras.layers.Dense(64, "relu"), "sum", receiver_tag=tfgnn.SOURCE
                )
            },
            tfgnn.keras.layers.NextStateFromConcat(tf.keras.layers.Dense(128))
        )
    }
)

In [None]:
out(tf_g_merged)

In [None]:
tfgnn.keras.layers.SimpleConv(
    tf.keras.layers.Dense(64, "relu")
)(tf_g_merged)

In [None]:
from tensorflow_gnn.models import graph_sage

In [None]:
def set_initial_node_state(node_set, node_set_name):
    if node_set_name == "normal":
        occupation_embedding = tf.keras.layers.Embedding(10, 32)
        return tf.keras.layers.Concatenate()(
            [occupation_embedding(node_set["occupation_id"])]
        )

def set_context(context):
    print(context)
    if context_name == "label":
        context_embedding = tf.keras.layers.TextVectorization

In [None]:
tf_g_merged.context["label"]

<tf.Tensor: shape=(2,), dtype=string, numpy=array([b'tree', b'erdos'], dtype=object)>

In [None]:
graph = tfgnn.keras.layers.MapFeatures(
    node_sets_fn=set_initial_node_state,
    # context_fn=set_context
)(tf_g_merged)

In [None]:
tf_g_merged.node_sets["normal"]["hidden_state"]

<tf.Tensor: shape=(24,), dtype=float64, numpy=
array([ 3.01497136,  4.47924201,  1.72604073,  3.15120397,  2.9806009 ,
        3.05342666,  1.87367667,  2.87844245,  2.53267698,  4.35745556,
        2.24356763,  9.09085008, 10.41487504,  9.42272877, 10.92901437,
       11.68119551,  9.57105215, 10.28672148,  8.88603723,  9.23901135,
       10.37924045, 10.22259462, 10.75000173, 11.138661  ])>

In [None]:
graph_test.node_sets["nodes"]["hidden_state"]

In [None]:
tf_g_merged.edge_sets["edges"].adjacency[0]

<tf.Tensor: shape=(25,), dtype=int32, numpy=
array([ 0,  0,  1,  1,  2,  3,  5,  5,  5,  6, 11, 12, 12, 12, 12, 13, 15,
       15, 15, 15, 16, 16, 17, 19, 20], dtype=int32)>

In [None]:
import tensorflow as tf
import tensorflow_gnn as tfgnn
from tensorflow_gnn.models.gcn import gcn_conv
graph_test = tfgnn.GraphTensor.from_pieces(
   node_sets={
       tfgnn.NODES: tfgnn.NodeSet.from_fields(
           sizes=[2, 2],
           features={tfgnn.HIDDEN_STATE: tf.constant(
                         [[1., 0, 0], [0, 1, 0]]*2)
                    })},
   edge_sets={
       tfgnn.EDGES: tfgnn.EdgeSet.from_fields(
           sizes=[2, 2],
           adjacency=tfgnn.Adjacency.from_indices(
               source=(tfgnn.NODES, tf.constant([0, 1, 2, 3],
                                                dtype=tf.int64)),
               target=(tfgnn.NODES, tf.constant([1, 0, 3, 2],
                                                dtype=tf.int64))))})
gcnconv = gcn_conv.GCNConv(3)
gcnconv(tf_g_merged, edge_set_name=tfgnn.EDGES)   # Has shape=(4, 3).

In [None]:
graph.node_sets

{'normal': NodeSet(features={'hidden_state': <tf.Tensor: shape=(24, 32), dtype=tf.float32>}, sizes=[11 13])}

In [None]:
graph_test.node_sets

{'nodes': NodeSet(features={'hidden_state': <tf.Tensor: shape=(4, 3), dtype=tf.float32>}, sizes=[2 2])}

In [None]:
from tensorflow_gnn.models.gcn import gcn_conv

gcnconv = gcn_conv.GCNConv(32)
gcnconv(graph, edge_set_name=tfgnn.EDGES)