In [1]:
import tensorflow as tf

import graph_nets as gn
from graph_nets import utils_tf

import sonnet as snt
tf.enable_eager_execution()

  from ._conv import register_converters as _register_converters
  return _inspect.getargspec(target)


In [6]:
dims = 7
n_edges = 74
n_nodes = 5
rank = 2

edges = [[tf.Variable(tf.random_normal([dims, rank, 3], stddev=1.0)), 
                 tf.Variable(tf.random_normal([rank, rank, rank], stddev=1.0))]
               for _ in range(n_edges)]

graph = {
      "globals": None,
      "nodes": tf.random_normal([n_nodes, dims]),
      "edges": edges,
      "receivers": tf.random_uniform(minval=0, maxval=n_nodes, dtype=tf.int32, shape=[n_edges]),
      "senders": tf.random_uniform(minval=0, maxval=n_nodes, dtype=tf.int32, shape=[n_edges, 2]),
      "n_node": n_nodes,
      "n_edge": n_edges
}
input_graphs = gn.graphs.GraphsTuple(**graph)

In [7]:
class Tucker(snt.AbstractModule):
    def __init__(self,
                 model_fn,
                 name="tucker"):
        """Initializes a Tucker module.
        Args:
          model_fn: A callable
          name: The module name.
        """
        super(Tucker, self).__init__(name=name)
        # TODO extend to hypergraphs (not just binary relations)
        with self._enter_variable_scope():
            self._model_fn = model_fn()
        
    def _build(self, graph):
        # gather nodes vals from senders
        sender_nodes = gn.blocks.broadcast_sender_nodes_to_edges(graph)
        
        # apply edge to each relevant node
        node_messages = self._model_fn(graph.edges, sender_nodes)
        
        # aggregate according to the receivers
        nodes = tf.unsorted_segment_sum(node_messages, graph.receivers, tf.reduce_sum(graph.n_node))
        return graph.replace(nodes=nodes)
    
def tensor3_model_fn(edges, nodes):
    """
    Args:
        edges (list): a list of edges. each edge has 3 cores and a center defining
            a tucker decomposition of a 3-tensor.
        nodes (tf.tensor): the two inputs to the 3-tensor [A, B].
    """
    # treating the edges like linear functions applied to each node
    A = nodes[..., 0, :]
    B = nodes[..., 1, :]
    
    factors, core = zip(*edges)
    factors = tf.stack(factors, axis=0)
    core = tf.stack(core, axis=0)
    
    return tucker3_decomp(A, B, factors, core)

def tucker3_decomp(A, B, factors, core):
    """
    Tensordot with a decomposed tensor.
    """
    with tf.name_scope('tucker_tensor_dot'):
        U, V, W = factors[..., 0], factors[..., 1], factors[..., 2]

        A_ = tf.einsum('bij,bi->bj', U, A)/tf.cast(tf.shape(A)[-1], tf.float32)
        B_ = tf.einsum('bij,bi->bj', V, B)/tf.cast(tf.shape(B)[-1], tf.float32)

        C_ = tf.einsum('bijk,bi,bj->bk', core, A_, B_)
        return tf.einsum('bij,bj->bi', W, C_)

In [8]:
tucker = Tucker(lambda: tensor3_model_fn)
output_graphs = tucker(input_graphs)

In [9]:
output_graphs.nodes.shape

TensorShape([Dimension(5), Dimension(7)])

In [15]:
def step(model):
    with tf.GradientTape() as tape:
        x = tf.random_normal([1, dims])
        t = 2*x

        for _ in range(5):
            nodes = tf.concat([x + model.nodes[0, ...], 
                               model.nodes[1:, ...]], 
                              axis=0)
            model = model.replace(nodes=nodes)
            model = tucker(model)

        loss = tf.reduce_sum(tf.square(t - model.nodes[-1, ...]))
#         loss += 1e-8*tf.reduce_sum(tf.square(model.edges))

    variables = [v for e in edges for v in e ]
    g = tape.gradient(loss, variables)
    opt = tf.train.AdamOptimizer()
    train_step = opt.apply_gradients(zip(g, variables), global_step=tf.train.get_or_create_global_step())
    print('\rloss: {}'.format(loss), end='', flush=True)
    return model

In [16]:
output_graphs = input_graphs
for _ in range(100):
    output_graphs = step(output_graphs)

loss: nan

AttributeError: 'variable_scope' object has no attribute '_graph_context_manager'