In [1]:
import tensorflow as tf
tfk = tf.keras
tfkl = tfk.layers

import tensorflow_probability as tfp
tfd = tfp.distributions
tfpl = tfp.layers
tfb = tfp.bijectors

In [26]:
class GraphNet(tfk.Model):
    
    def __init__(self,
        f_inp,
        f_pool,
        f_v_up,
        f_e_up,
        f_adj_up,
        **kwargs):
        """
        """
        
        super(GraphNet, self).__init__(kwargs)
        
        self.f_inp = f_inp
        self.f_pool = f_pool
        self.f_v_up = f_v_up
        self.f_e_up = f_e_up
        self.f_adj_up = f_adj_up
        
        self.f_V_src_loc = tfkl.Lambda(lambda V_src, A: 
            tf.einsum('...sv,...sd->...sdv', V_src, A))
        self.f_V_dst_loc = tfkl.Lambda(lambda V_dst, A: 
            tf.einsum('...dv,...sd->...dsv', V_src_loc))
        self.f_perm = tfkl.Lambda(lambda x: tf.einsum('...sdv->...dsv', x))
        
    def call(self, inputs, training=False):
        """
        """
        
        # unpack inputs
        V_src, V_dst, E, A = inputs
        
        # provide vert-localized copies of src and dst verts
        V_src_loc = self._f_V_src_loc(V_src, A)
        V_dst_loc = self._f_V_dst_loc(V_dst, A)
        
        # get src-dst pair-specific inputs to dst verts
        inp = self.f_inp([V_src_loc, E], training=training)
        inp = self.f_perm(inp)
        
        # pool src-dst pair-specific inputs
        V_dst_new = self.f_pool([V_dst, inp], training=training)    
        
        # update dst verts
        V_dst = self.f_v_up([V_dst, V_dst_new], training=training)
        
        # update edges
        E = self.f_e_up([V_src_loc, V_dst_loc, E], training=training)
        
        # update adjacency matrix
        A = self.f_a_up(E, training=training)
        
        return V_dst, E, A
    
        
    @staticmethod
    def f_pool_sum():
        return tfkl.Lambda(lambda V_dst, inp: tf.reduce_sum(inp, axis=-2))
    @staticmethod
    def f_pool_ave():
        return tfkl.Lambda(lambda V_dst, inp: tf.reduce_mean(inp, axis=-2))
    @staticmethod
    def f_pool_prod():
        return tfkl.Lambda(lambda V_dst, inp: tf.reduce_prod(inp, axis=-2))
    class f_pool_attn(tfkl.Layer):
        
        def __init__(self, d_key=8, d_val=None, N_heads=8, pre_layer_normalization=True, **kwargs):
            """
            pre-LN (https://arxiv.org/abs/2004.08249)
            """
            super(f_pool_attn, self).__init__(**kwargs)
            
            self.pre_layer_normalization = pre_layer_normalization                
            self.d_key = d_key
            self.d_val = self.d_key if d_val is None else d_val
            self.N_heads = N_heads
            
            if self.pre_layer_normalization:
                self.V_dst_LN = tfkl.LayerNormalization()
                self.inp_LN = tfkl.LayerNormalization()
            
        def build(self, input_shape):
            V_dst_shape, inp_shape = input_shape
            
            self.f_val = tfkl.Dense(N_heads * self.d_val, 'relu')
            self.f_key = tfkl.Dense(N_heads * self.d_key, 'relu')
            self.f_query = tfkl.Dense(N_heads * self.d_key, 'relu')
            
            self.reshape_q = tfkl.Reshape(V_dst_shape[:-2] +
                (self.N_heads, self.d_key))
            self.reshape_k = tfkl.Reshape(inp_shape[:-2] +
                (self.N_heads, self.d_key))
            self.reshape_v = tfkl.Reshape(inp_shape[:-2] +
                (self.N_heads, self.d_val))
            
            def _f_MHA(queries, keys, values):
                score = tf.einsum('...dhq,...dshq->dsh', queries, keys)
                score = score / tf.sqrt(self.d_key)
                score = tf.nn.softmax(score, axis=-1)
                return tf.einsum('...dsh,...dshv->...dhv', score, values)
            self.f_MHA = tfkl.Lambda(lambda q,k,v: _f_MHA(q,k,v))
            
            self.f_cat = tfkl.Reshape(V_dst_shape[:-1]+(-1,))
            self.f_emb_cat = tfkl.Dense(V_dst_shape[-1], 'relu')
        
        def call(self, inputs, training=False):
            # unpack inputs
            V_dst, inp = inputs

            # pre-LN
            if pre_layer_normalization:
                V_dst = self.V_dst_LN(V_dst, training=training)
                inp = self.inp_LN(inp, training=training)
            
            # generate queries, keys, and values for all heads
            queries = self.f_query(V_dst, training=training)  # [..., N_dst, N_heads*d_key]
            keys = self.f_key(inp, training=training) # [..., N_dst, N_src, N_heads*d_key]
            values = self.f_val(inp, training=training) # [..., N_dst, N_src, N_heads*d_val]
            
            # reshape into separate heads
            queries = self.reshape_q(queries) # [..., N_dst, N_heads, d_key]
            keys = self.reshape_k(keys) # [..., N_dst, N_heads, d_key]
            values = self.reshape_v(values) # [..., N_dst, N_heads, d_key]
            
            # perform multi-head attention
            mha_lookup = self.f_MHA([queries, keys, values], training=training)
            # [..., N_dst, N_heads, d_val]
            
            # concatenate heads
            mha_cat = self.f_cat(mha_lookup, training=training)
            
            # embed in output space
            return self.f_emb_cat(mha_cat, training=training)
        
    @staticmethod
    def f_v_up_add():
        return tfkl.Add()
    
    @staticmethod
    def f_v_up_direct():
        return tfkl.Lambda(lambda V_dst, V_dst_new: V_dst_new)
    
    class f_v_up_beta(tfkl.Layer):
        def __init__(self, **kwargs):
            super(f_v_up_beta, self).__init__(**kwargs)
            self.f_beta = tfkl.Dense(1, 'softmax')
        def call(self, inputs, training=False):
            V_dst, V_dst_new = inputs
            beta = self.f_beta(V_dst_new)
            return beta*V_dst + (1-beta)*V_dst_new
        
    class f_v_up_alphabeta(tfkl.Layer):
        def __init__(self, **kwargs):
            super(f_v_up_beta, self).__init__(**kwargs)
            self.f_beta = tfkl.Dense(1, 'softmax')
            self.f_alpha = tfkl.Dense(1, 'softmax')
        def call(self, inputs, training=False):
            V_dst, V_dst_new = inputs
            alpha = self.f_alpha(V_dst)
            beta = self.f_beta(V_dst_new)
            return alpha*V_dst + beta*V_dst_new
        
    @staticmethod
    def f_inp_concat():
        return tfkl.Concatenate()
    
    @staticmethod
    def f_inp_edges():
        return tfkl.Lambda(lambda V_src_loc, E: E)
    
    @staticmethod
    def f_inp_verts():
        return tfkl.Lambda(lambda V_src_loc, E: V_src_loc)
    
    @staticmethod
    def f_a_up():
        def f(x):
            y=tfkl.Dense(1, 'softmax')(x)
            y=tf.squeeze(y)
            
        return tfkl.Lambda(lambda E: f(E))
    
    @staticmethod
    def f_e_up_const():
        return tfkl.Lambda(lambda V_src_loc, V_dst_loc, E: E)
    
    class f_e_up_dense(tfkl.Layer):
        def __init__(self, **kwargs):
            super(f_v_up_beta, self).__init__(**kwargs)
        def build(self, input_shape):
            V_src_loc_shape, V_dst_loc_shape, E_shape = input_shape
            self.f_E_new = tfkl.Dense(tf.shape(E_shape)[-1], 'relu')
            self.V_dst_perm = tfkl.Lambda(
                lambda x: tf.einsum('...dsv->...sdv', x))
        def call(self, inputs, training=False):
            V_src_loc, V_dst_loc, E = inputs
            V_dst_loc_perm = self.V_dst_perm(V_dst_loc)
            return self.f_E_new(tfkl.concatenate([
                V_src_loc, V_dst_loc_perm, E]))
        
    class f_e_up_dense_oneway(tfkl.Layer):
        def __init__(self, **kwargs):
            super(f_v_up_beta, self).__init__(**kwargs)
        def build(self, input_shape):
            V_src_loc_shape, V_dst_loc_shape, E_shape = input_shape
            self.V_dst_perm = tfkl.Lambda(
                lambda x: tf.einsum('...dsv->...sdv', x))
            self.f_E_new = tfkl.Dense(tf.shape(E_shape)[-1], 'relu')
        def call(self, inputs, training=False):
            V_src_loc, V_dst_loc, E = inputs
            return self.f_E_new(tfkl.concatenate([V_src_loc, E]))
        
    class f_e_up_beta(tfkl.Layer):
        def __init__(self, **kwargs):
            super(f_v_up_beta, self).__init__(**kwargs)
        def build(self, input_shape):
            V_src_loc_shape, V_dst_loc_shape, E_shape = input_shape
            self.V_dst_perm = tfkl.Lambda(
                lambda x: tf.einsum('...dsv->...sdv', x))
            self.f_beta = tfkl.Dense(1, 'softmax')
            self.f_E_new = tfkl.Dense(tf.shape(E_shape)[-1], 'relu')
        def call(self, inputs, training=False):
            V_src_loc, V_dst_loc, E = inputs
            V_dst_loc_perm = self.V_dst_perm(V_dst_loc)
            E_new = self.f_E_new(tfkl.concatenate([
                V_src_loc, V_dst_loc_perm, E]))
            beta = self.f_beta(tfkl.concatenate([V_src_loc, V_dst_loc_perm]))
            return beta*V_dst + (1-beta)*E_new

    class f_e_up_attn(tfkl.Layer):
        
        def __init__(self, d_key=8, d_val=None, N_heads=8, pre_layer_normalization=True, **kwargs):
            """
            pre-LN (https://arxiv.org/abs/2004.08249)
            """
            super(f_pool_attn, self).__init__(**kwargs)
            
            self.pre_layer_normalization = pre_layer_normalization                
            self.d_key = d_key
            self.d_val = self.d_key if d_val is None else d_val
            self.N_heads = N_heads
            
            if self.pre_layer_normalization:
                self.V_dst_LN = tfkl.LayerNormalization()
                self.inp_LN = tfkl.LayerNormalization()
            
        def build(self, input_shape):
            V_src_loc_shape, V_dst_loc_shape, E_shape = input_shape
            
            self.V_dst_perm = tfkl.Lambda(
                lambda x: tf.einsum('...dsv->...sdv', x))
            
            self.cat_q_data = tfkl.Concatenate()
            self.cat_kv_data = tkfl.Concatenate()
            
            self.f_val = tfkl.Dense(N_heads * self.d_val, 'relu')
            self.f_key = tfkl.Dense(N_heads * self.d_key, 'relu')
            self.f_query = tfkl.Dense(N_heads * self.d_key, 'relu')
            
            self.reshape_q = tfkl.Reshape(E_shape[:-1] +
                (self.N_heads, self.d_key))
            self.reshape_k = tfkl.Reshape(E_shape[:-1] +
                (self.N_heads, self.d_key))
            self.reshape_v = tfkl.Reshape(E_shape[:-1] +
                (self.N_heads, self.d_val))
            
            def _f_MHA(queries, keys, values):
                score = tf.einsum('...sdhq,...sdhq->sdh', queries, keys)
                score = score / tf.sqrt(self.d_key)
                score = tf.nn.softmax(score, axis=-1)
                return tf.einsum('...sdh,...sdhv->...dhv', score, values)
            self.f_MHA = tfkl.Lambda(lambda q,k,v: _f_MHA(q,k,v))
            
            self.f_cat = tfkl.Reshape(E_shape[:-1]+(-1,))
            self.f_emb_cat = tfkl.Dense(E_shape[-1], 'relu')
        
        def call(self, inputs, training=False):
            # unpack inputs
            V_src_loc, V_dst_loc, E = inputs

            # pre-LN
            if pre_layer_normalization:
                V_dst = self.V_dst_LN(V_dst, training=training)
                inp = self.inp_LN(inp, training=training)
            
            V_dst_loc_perm = self.V_dst_perm(V_dst_loc)
            
            q_data = self.cat_q_data([V_dst_loc, E])
            kv_data = self.cat_kv_data([V_src_loc, E])
            
            # generate queries, keys, and values for all heads
            queries = self.f_query(q_data, training=training)  # [..., N_src, N_dst, N_heads*d_key]
            keys = self.f_key(kv_data, training=training) # [..., N_src, N_dst, N_heads*d_key]
            values = self.f_val(kv_data, training=training) # [..., N_src, N_dst, N_heads*d_val]
            
            # reshape into separate heads
            queries = self.reshape_q(queries) # [..., N_src, N_dst, N_heads, d_key]
            keys = self.reshape_k(keys) # [..., N_src, N_dst, N_heads, d_key]
            values = self.reshape_v(values) # [..., N_src, N_dst, N_heads, d_key]
            
            # perform multi-head attention
            mha_lookup = self.f_MHA([queries, keys, values], training=training)
            # [..., N_src, N_dst, N_heads, d_val]
            
            # concatenate heads
            mha_cat = self.f_cat(mha_lookup, training=training)
            # [..., N_src, N_dst, N_heads*d_val]
            
            # embed in output space
            return self.f_emb_cat(mha_cat, training=training)
            # [..., N_src, N_dst, d_E]

In [27]:
import random

class MultiGraphNet(tfk.Model):
    
    def __init__(self,
        multigraph,
        f_update_seq,
        f_rel_update=None,
        f_rel_update_model=None,
        randomized_update_seq=False,
        f_ret=(lambda x: x)):
        """
        f_rel_update (dict<(str,str): GraphNet): update functions
            for each source-destination graph pairs. If `None`, specify
            an `f_rel_update_model` that will be applied to all edges in
            the multigraph.
        f_rel_update_model (GraphNet): updating function to be copied
            for all source-destination graph relations in the case that
            `f_rel_update` is `None`.
        """
        self.MG = multigraph
        self.update_seq = update_seq
        self.f_rel_update = self.f_rel_update
        if self.f_rel_update is None:
            self.f_rel_update = {}
            for (src, dst) in list(self.MG.Es.keys()):
                self.f_rel_update[(src, dst)] = f_rel_update_model
        self.randomized_update_seq = randomized_update_seq
        self.f_ret = f_ret
    
    @staticmethod
    def f_update_seq_reg(multigraph):
        """just go through all defined relations"""
        seq = list(self.mg.Vs.keys())
        if self.randomized_update_seq:
            seq = random.shuffle(seq)
        return seq
    
    @staticmethod
    def f_update_seq_egocentric(multigraph):
        """first perform intragraph update, then intergraph update"""
        
        all_names = all_names2 = list(self.mg.Vs.keys())
        if self.randomized_update_seq:
            all_names = random.shuffle(all_names)
        if self.randomized_update_seq:
            all_names2 = random.shuffle(all_names2)
            
        intragraph_pairs = [(src_name, src_name) for src_name in all_names]
        
        intergraph_pairs = []
        for src_name in all_names:
            for dst_name in all_names:
                if src_name != dst_name:
                    intergraph_pairs.append((src_name, dst_name))
                    
        return intragraph_pairs + intergraph_pairs
    
    @staticmethod
    class f_ret_just_graph:
        def __init__(self, graph_name):
            self.graph_name = graph_name
        
        def __call__(self, multigraph):
            return (multigraph.Vs[self.graph_name],
                    multigraph.Es[(self.graph_name, self.graph_name)],
                    multigraph.As[(self.graph_name, self.graph_name)])
        
    @staticmethod
    class f_ret_just_root:
        def __init__(self, root_name):
            self.root_name = root_name
        
        def __call__(self, multigraph):
            return tf.reduce_mean(
                multigraph.Vs[self.root_name],
                axis=-2)
    
    def cell(self, inputs, training):
        for rel in self.update_seq(self.mg):
            src, dst = rel
            self.MG.V[dst], self.MG.E[rel], self.MG.A[rel] = \
                self.f_rel_update[rel](
                    self.MG.V[src], self.MG.V[dst],
                    self.MG.E[rel], self.MG.A[rel])
        return self.f_ret(self.mg)

In [11]:
class MultiGraph:
    
    def __init__(self, Vs=dict(), Es=dict(), As=dict()):
        """
        Vs: dict<str, Tensor>
        rels: dict<(str,str), Tensor>
        
        """
        self.Vs = Vs
        self.Es = Es
        self.As = As
    
    @property
    def N_v(self, name):
        return tf.shape(self.Vs[name])[-2]
    @property
    def d_v(self, name):
        return tf.shape(self.Vs[name])[-1]
    @property
    def d_e(self, src, dst):
        return tf.shape(self.Es[(src, dst)])[-1]
    
    def connect_graphs(self, src, dst, e_emb=[1.0], density=1.0):
        leading_dims = tf.shape(self.Vs[src])[:-2]
        N_src = tf.shape(self.Vs[src])[-2:-1]
        N_dst = tf.shape(self.Vs[dst])[-2:-1]
        self.As[src,dst] = tf.cast(tf.random.uniform(
            shape=leading_dims+N_src+N_dst,
            lowvalue=0, highvalue=1) < density,
            tfk.backend.floatx)
        self.Es[src,dst] = tf.einsum('...sd,v->...sdv',
            self.As[src,dst], e_emb)
    
    def add_root_network(self,
        root_name,
        intragraph_density=1.0,
        intergraph_density=1.0,
        neighbors=[],
        connection_direction=["src", "dst"],
        N_v=1,
        emb_v=[1.0]):
        """convenience function to make root node network
        and connect to other graphs. Root networks provide an
        information highway for intragraph vert updates and
        can be used to connect heterogenous graphs.
        
        WARNING: the multigraph must have at least one other
        set of verts so we can detirmine batch size and time
        steps (or any other leading dimensions).
        
        neighbors (list<str>): neighboring graphs (if any) to
            connect new root network to
        """
        
        # create root graph verts
        leading_dims = tf.shape(list(self.Vs.values())[0])[:-2]
        d_v = tf.shape(emb_v)
        self.Vs[root_name] = tf.fill(
            dims=leading_dims+tf.TensorSpec((N_v,))+d_v,
            value=emb_v)
        
        # connect graph internally
        self.connect_graphs(root_name, root_name, 
            density=intragraph_density)
        
        # connect with neighbors
        for neighbor in neighbors:
            if "src" in connection_direction:
                self.connect_graphs(root_name, neighbor,
                    density=intergraph_density)
            if "dst" in connection_direction:
                self.connect_graphs(neighbor, root_name,
                    density=intergraph_density)