In [186]:
import numpy as np
import tensorflow as tf
from tensorflow.python.keras import layers
from tensorflow.python.keras import initializers

class fGRUCell(tf.keras.layers.Layer):
    '''
    Generates an fGRUCell
    params:
    input_size: n x n
    hidden_channels: the number of channels which is constant throughout the
                     processing of each unit
    '''
    def __init__(self, input_size, hidden_channels, kernel_size=3, padding='same',
                 normtype='batchnorm', channel_sym=True, use_attention=0):
        super().__init__()
        self.input_size = input_size
        self.hidden_channels = hidden_channels
        self.kernel_size = kernel_size
        self.padding = padding
        self.normtype = normtype
        self.channel_sym = channel_sym

        if use_attention:
            # TODO: implement attention
            pass
        else:
            # Initialize convolutional kernels
            self.U_a = layers.Conv2D(
                filters=self.hidden_channels,
                kernel_size=1, 
                strides=1, 
                padding=self.padding,
                kernel_initializer=initializers.Orthogonal(),
                )
            
            self.U_m = layers.Conv2D(
                filters=1,
                kernel_size=self.kernel_size, 
                strides=1, 
                padding=self.padding,
                kernel_initializer=initializers.Orthogonal(),
                )
            
            self.W_s = layers.Conv2D(
                filters=self.hidden_channels,
                kernel_size=self.kernel_size, 
                strides=1, 
                padding=self.padding,
                kernel_initializer=initializers.Orthogonal(),
                )
            
            self.U_f = layers.Conv2D(
                filters=self.hidden_channels,
                kernel_size=self.kernel_size, 
                strides=1, 
                padding=self.padding,
                kernel_initializer=initializers.Orthogonal(),
                )
            
            self.W_f = layers.Conv2D(
                filters=self.hidden_channels,
                kernel_size=self.kernel_size, 
                strides=1, 
                padding=self.padding,
                kernel_initializer=initializers.Orthogonal(),
                )

        # initiate other weights
        self.alpha = tf.Variable(0.1, dtype='float32')
        self.mu = tf.Variable(0, dtype='float32')
        self.nu = tf.Variable(0, dtype='float32')
        self.omega = tf.Variable(0.1, dtype='float32')
        self.delta = tf.Variable(np.zeros(self.hidden_channels)+0.1, dtype='float32')
        self.omicron = tf.Variable(np.zeros(self.hidden_channels), dtype='float32')
        self.eta = tf.Variable(np.random.rand(self.hidden_channels), dtype='float32')

    def channel_symmetrize(self):
        '''
        symmetrize the kernels channel-wise
        Somehow, if I write it in init, there will be the following error:
        'Conv2D' does not have attribute 'kernel'.
        '''
        if self.channel_sym: 
            for i in range(self.hidden_channels):
                for j in range(self.hidden_channels):
                    self.U_a.kernel[:,:,i,j].assign(self.U_a.kernel[:,:,j,i])
                    self.U_f.kernel[:,:,i,j].assign(self.U_f.kernel[:,:,j,i])
                    self.W_s.kernel[:,:,i,j].assign(self.W_s.kernel[:,:,j,i])
                    self.W_f.kernel[:,:,i,j].assign(self.W_f.kernel[:,:,j,i])
    
    def instance_norm(self, r):
        '''
        Param: r, a 4D tensor, b x h x w x c, where b = 1
        Return: a tensor normalized with the same size as r.
        '''                
        return np.array([self.omicron + self.delta * (r[0] - np.mean(r[0], axis=(0, 1)))\
                         /(np.sqrt(np.var(r[0], axis=(0, 1))+self.eta))])

    def call(self, Z, H):
        '''
        Params: 
        Z: output from the last layer if fGRU-horizontal, hidden state of the
        current layer at t if fGRU-feedback.
        H: hidden state of the current layer at t-1 if fGRU-horizontal, output
        from the next layer if fGRU-feedback.
        '''

        # Stage 1: suppression
        A_s = self.U_a(H) # Compute channel-wise selection
        M_s = self.U_m(H) # Compute spatial selection
        # (note that U_a and U_m are kernels of different sizes and therefore
        # have different functions)
        M_s_resized = np.array([tf.transpose(M_s)[0]]*self.hidden_channels).T
        G_s = tf.sigmoid(self.instance_norm(A_s * M_s_resized))
        # Compute suppression gate
        C_s = self.instance_norm(self.W_s(H * G_s))
        # compute suppression interactions
        S = tf.keras.activations.relu(Z - tf.keras.activations.relu((self.alpha * H + self.mu)*C_s))
        # Additive and multiplicative suppression of Z

        # Stage 2: facilitation
        G_f = tf.sigmoid(self.instance_norm(self.U_f(S)))
        # Compute channel-wise recurrent updates
        C_f = self.instance_norm(self.W_f(S))
        # Compute facilitation interactions
        H_tilda = tf.keras.activations.relu(self.nu*(C_f + S) + self.omega*(C_f * S))
        # Additive and multiplicative facilitation of S
        Ht = (1 - G_f) * H + G_f * H_tilda
        # Update recurrent state
        return Ht

In [187]:
testCell = fGRUCell([4, 4], 9)

image = np.random.rand(1, 4, 4, 9)
H = np.zeros([1, 4, 4, 9])

out = testCell(image, H)

print(out)

tf.Tensor(
[[[[0.00000000e+00 0.00000000e+00 0.00000000e+00 2.15177788e-04
    0.00000000e+00 0.00000000e+00 0.00000000e+00 0.00000000e+00
    3.59515652e-05]
   [6.59738085e-04 0.00000000e+00 0.00000000e+00 8.62883462e-04
    0.00000000e+00 0.00000000e+00 0.00000000e+00 0.00000000e+00
    3.83849227e-04]
   [1.41518831e-03 0.00000000e+00 9.08419315e-04 1.93786621e-03
    0.00000000e+00 3.72788600e-05 0.00000000e+00 0.00000000e+00
    0.00000000e+00]
   [2.34267602e-04 0.00000000e+00 0.00000000e+00 3.25311703e-04
    9.83635400e-05 0.00000000e+00 0.00000000e+00 0.00000000e+00
    7.22595723e-05]]

  [[0.00000000e+00 2.20990842e-04 0.00000000e+00 9.99127369e-05
    0.00000000e+00 0.00000000e+00 1.47421297e-03 0.00000000e+00
    1.66495647e-05]
   [4.35309834e-04 8.06536351e-04 0.00000000e+00 0.00000000e+00
    7.03845406e-04 2.06682889e-04 4.72256419e-04 7.47388258e-05
    3.60420454e-05]
   [3.49085400e-04 0.00000000e+00 1.07597247e-04 1.51329592e-03
    0.00000000e+00 1.34279559e-04 0

In [188]:
testCell.channel_symmetrize()

In [189]:
testCell.W_f.kernel

<tf.Variable 'f_gru_cell_72/conv2d_345/kernel:0' shape=(3, 3, 9, 9) dtype=float32, numpy=
array([[[[ 1.46000743e-01, -1.56464670e-02,  2.11676043e-02,
           3.49971503e-02, -2.57256497e-02,  2.10961010e-02,
          -1.86283916e-01, -1.57420233e-01,  1.85161065e-02],
         [-1.56464670e-02, -6.94890227e-03,  1.78442687e-01,
           1.24136321e-01, -6.77111745e-02, -2.03279823e-01,
          -1.48666367e-01,  5.17323017e-02, -1.13463990e-01],
         [ 2.11676043e-02,  1.78442687e-01,  2.26456434e-01,
           2.51750618e-01,  1.30727559e-01,  1.06898747e-01,
          -1.01504643e-02,  5.55311516e-02, -6.97149113e-02],
         [ 3.49971503e-02,  1.24136321e-01,  2.51750618e-01,
           7.76124969e-02, -9.70551223e-02, -5.28101027e-02,
          -8.79670400e-03, -2.14207560e-01, -4.06872071e-02],
         [-2.57256497e-02, -6.77111745e-02,  1.30727559e-01,
          -9.70551223e-02,  2.42406409e-02, -4.65196855e-02,
          -4.25187349e-02,  1.77684814e-01,  3.93419