In [2]:
import tensorflow_gnn as tfgnn
import tensorflow as tf
# Model hyper-parameters:
h_dims = {'user': 256, 'movie': 64, 'genre': 128}
    
    # Model builder initialization:
gnn = tfgnn.keras.ConvGNNBuilder(
    lambda edge_set_name: WeightedSumConvolution(),
    lambda node_set_name: tfgnn.keras.layers.NextStateFromConcat(
    tf.keras.layers.Dense(h_dims[node_set_name]))
    )
    
    # Two rounds of message passing to target node sets:
model = tf.keras.models.Sequential([
        gnn.Convolve({'genre'}),  # sends messages from movie to genre
        gnn.Convolve({'user'}),  # sends messages from movie and genre to users
        tfgnn.keras.layers.Readout(node_set_name="user"),
        tf.keras.layers.Dense(1)
    ])

In [10]:
class WeightedSumConvolution(tf.keras.layers.Layer):
        def call(self, graph: tfgnn.GraphTensor,
               edge_set_name: tfgnn.EdgeSetName) -> tfgnn.Field:
            messages = tfgnn.broadcast_node_to_edges(
                graph,
                edge_set_name,
                tfgnn.SOURCE,
                feature_name=tfgnn.DEFAULT_STATE_NAME)
            weights = graph.edge_sets[edge_set_name]['weight']
            weighted_messages = tf.expand_dims(weights, -1) * messages
            pooled_messages = tfgnn.pool_edges_to_node(
                graph,
                edge_set_name,
                tfgnn.TARGET,
                reduce_type='sum',
                feature_value=weighted_messages)
            return pooled_messages