In [2]:
import tensorflow as tf
tfk = tf.keras
tfkl = tfk.layers

import tensorflow_probability as tfp
tfd = tfp.distributions
tfpl = tfp.layers
tfb = tfp.bijectors

In [3]:
tf.config.list_logical_devices()

[LogicalDevice(name='/device:CPU:0', device_type='CPU'),
 LogicalDevice(name='/device:XLA_CPU:0', device_type='XLA_CPU'),
 LogicalDevice(name='/device:GPU:0', device_type='GPU'),
 LogicalDevice(name='/device:XLA_GPU:0', device_type='XLA_GPU')]

In [None]:
class GraphNet(tfk.Model):
    
    def __init__(self,
        f_inp,
        f_pool,
        f_v_up,
        f_e_up,
        f_adj_up,
        **kwargs):
        """
        """
        
        super(GraphNet, self).__init__(kwargs)
        
        self.f_inp = f_inp
        self.f_pool = f_pool
        self.f_v_up = f_v_up
        self.f_e_up = f_e_up
        self.f_adj_up = f_adj_up
        
    def call(self, inputs, training=False):
        """
        """
        
        # unpack inputs
        V_src, V_dst, E, A = inputs
        
        # provide vert-localized copies of src and dst verts
        V_src_loc = tf.einsum('...sv,...sd->...sdv', V_src, A)
        V_dst_loc = tf.einsum('...sdv->...dsv', V_src_loc)
        
        # get src-dst pair-specific inputs to dst verts
        inp = self.f_inp([V_src_loc, E], training=training)
        inp = tf.einsum('...sdv->...dsv', inp)
        
        # pool src-dst pair-specific inputs
        V_dst_new = self.f_pool([V_dst, inp], training=training)    
        
        # update dst verts
        V_dst = self.f_v_up([V_dst, V_dst_new], training=training)
        
        # update edges
        E = self.f_e_up([V_src, V_dst, E], training=training)
        
        # update adjacency matrix
        A = self.f_a_up(E, training=training)
        
        return V_dst, E, A
    
        
    @staticmethod
    def f_pool_sum():
        return tfkl.Lambda(lambda V_dst, inp: tf.reduce_sum(inp, axis=-2))
    @staticmethod
    def f_pool_ave():
        return tfkl.Lambda(lambda V_dst, inp: tf.reduce_mean(inp, axis=-2))
    @staticmethod
    def f_pool_prod():
        return tfkl.Lambda(lambda V_dst, inp: tf.reduce_prod(inp, axis=-2))
    class f_pool_attn(tfkl.Layer):
        
        def __init__(self, d_val, d_key=8, pre_layer_normalization=True, **kwargs):
            """
            pre-LN (https://arxiv.org/abs/2004.08249)
            """
            super(f_pool_attn, self).__init__(**kwargs)
            
            self.pre_layer_normalization = pre_layer_normalization
            if self.pre_layer_normalization:
                self.V_dst_LN = tfkl.LayerNormalization()
                self.inp_LN = tfkl.LayerNormalization()
                
            self.d_val = d_val
            self.d_key = d_key
            self.f_val = tfkl.Dense(d_val, 'relu')
            self.f_key = tfkl.Dense(d_key, 'relu')
            self.f_query = tfkl.Dense(d_key, 'relu')
        
        def call(self, inputs, training=False)
            # unpacking
            V_dst, inp = inputs

            # pre-LN
            if pre_layer_normalization:
                V_dst = self.V_dst_LN(V_dst)
                inp = self.inp_LN(inp)
            
            # generate queries, keys, and values
            queries = self.f_query(V_dst)  # [..., N_dst, d_key]
            keys = self.f_key(inp) # [..., N_dst, N_src, d_key]
            vals = self.f_val(inp) # [..., N_dst, N_src, d_val]
            
            # attention
            score = tf.einsum('...dq,...dsq->ds', queries, keys)
            score = score / tf.sqrt(self.d_key)
            score = tf.nn.softmax(score, axis=-1)
            
            # weighted sum
            pooled = tf.einsum('...ds,...dsv->...dv', score, vals)
            
            return pooled               
            
    @staticmethod
    def f_v_up_sum():
        return tfkl.Add()
    
    @staticmethod
    def f_v_up_direct():
        return tfkl.Lambda(lambda V_dst, V_dst_new: return V_dst_new)
    
    class f_v_up_beta(tfkl.Layer):
        def __init__(self, **kwargs):
            super(f_v_up_beta, self).__init__(**kwargs)
            self.f_beta = tfkl.Dense(1, 'softmax')
        def call(self, inputs, training=False):
            V_dst, V_dst_new = inputs
            beta = self.f_beta(V_dst_new)
            return beta*V_dst + (1-beta)*V_dst_new
        
    class f_v_up_alphabeta(tfkl.Layer):
        def __init__(self, **kwargs):
            super(f_v_up_beta, self).__init__(**kwargs)
            self.f_beta = tfkl.Dense(1, 'softmax')
            self.f_alpha = tfkl.Dense(1, 'softmax')
        def call(self, inputs, training=False):
            V_dst, V_dst_new = inputs
            alpha = self.f_alpha(V_dst)
            beta = self.f_beta(V_dst_new)
            return alpha*V_dst + beta*V_dst_new
        
    @staticmethod
    def f_inp_concat():
        return tfkl.Concatenate()
    
    @staticmethod
    def f_inp_edges():
        return tfkl.Lambda(lambda V_src_loc, E: E)
    
    @staticmethod
    def f_inp_verts():
        return tfkl.Lambda(lambda V_src_loc, E: V_src_loc)
    
    @staticmethod
    def f_a_up():
        def f(x):
            y=tfkl.Dense(1, 'softmax')(x)
            y=tf.squeeze(y)
            
        return tfkl.Lambda(lambda E: f(E))
    
    @staticmethod
    def f_e_up_const():
        return tfkl.Lambda(lambda V_src, V_dst, E: E)
    
    class f_e_up_dense(tfkl.Layer):
        def __init__(self, d_E **kwargs):
            super(f_v_up_beta, self).__init__(**kwargs)
            self.f_E_new = tfkl.Dense(d_E, 'relu')
        def call(self, inputs, training=False):
            # V_src, V_dst, E = inputs
            return = self.f_E_new(tfkl.concatenate(input))
        
    class f_e_up_dense_oneway(tfkl.Layer):
        def __init__(self, d_E **kwargs):
            super(f_v_up_beta, self).__init__(**kwargs)
            self.f_E_new = tfkl.Dense(d_E, 'relu')
        def call(self, inputs, training=False):
            V_src, V_dst, E = inputs
            return = self.f_E_new(tfkl.concatenate([V_src, E]))
        
    class f_e_up_beta(tfkl.Layer):
        def __init__(self, d_E **kwargs):
            super(f_v_up_beta, self).__init__(**kwargs)
            self.f_beta = tfkl.Dense(1, 'softmax')
            self.f_E_new = tfkl.Dense(d_E, 'relu')
        def call(self, inputs, training=False):
            V_src, V_dst, E = inputs
            beta = self.f_beta(tfkl.concatenate([V_src, V_dst]))
            E_new = self.f_E_new(tfkl.concatenate(V_src, E))
            return beta*V_dst + (1-beta)*E_new
        
    class f_e_up_attn(tfkl.Layer):
        def __init__(self, d_val, d_key=8, pre_layer_normalization=True, **kwargs):
            """
            pre-LN (https://arxiv.org/abs/2004.08249)
            """
            super(f_pool_attn, self).__init__(**kwargs)
            
            self.pre_layer_normalization = pre_layer_normalization
            if self.pre_layer_normalization:
                self.V_dst_LN = tfkl.LayerNormalization()
                self.inp_LN = tfkl.LayerNormalization()
                
            self.d_val = d_val
            self.d_key = d_key
            self.f_val = tfkl.Dense(d_val, 'relu')
            self.f_key = tfkl.Dense(d_key, 'relu')
            self.f_query = tfkl.Dense(d_key, 'relu')
        def call(self, inputs, training=False):
            # unpacking
            V_src, V_dst, E = inputs

            # pre-LN
            if pre_layer_normalization:
                V_dst = self.V_dst_LN(V_dst)
                inp = self.inp_LN(inp)
            
            # generate queries, keys, and values
            queries = self.f_query(V_dst)  # [..., N, d_key]
            keys = self.f_key(inp) # [..., N, d_key]
            vals = self.f_val(inp) # [..., N, d_val]
            
            # attention
            score = tf.einsum('...qx,...kx->...qk', queries, keys)
            score = score / tf.sqrt(self.d_key)
            score = tf.nn.softmax(score, axis=-1)
            
            # weighted sum
            pooled = tf.einsum('...qv,...vx->...qx', score, vals)
            
            return pooled            