In [1]:
import numpy as np
import tensorflow as tf
from tensorflow.python.keras import layers
from tensorflow.python.keras import initializers

class InstanceNorm(tf.keras.layers.Layer):
    def __init__(self, hidden_channels):
        super().__init__()
        self.hidden_channels = hidden_channels
        self.omicron = tf.Variable(np.zeros(self.hidden_channels), dtype='float32')
        self.eta = tf.Variable(np.random.rand(self.hidden_channels), dtype='float32')
        self.delta = tf.Variable(np.zeros(self.hidden_channels)+0.1, dtype='float32')

    def call(self, r):
        '''
        Param: r, a 4D tensor, b x h x w x c, where b = 1
        Return: a tensor normalized with the same size as r.
        '''                
        return tf.convert_to_tensor([self.omicron + self.delta * (r[0] - tf.math.reduce_mean(r[0], axis=(0, 1)))\
                         /(tf.math.sqrt(tf.math.reduce_variance(r[0], axis=(0, 1))+self.eta))])

class fGRU(tf.keras.layers.Layer):
    '''
    Generates an fGRUCell
    params:
    hidden_channels: the number of channels which is constant throughout the
                     processing of each unit
    '''
    def __init__(self, input_shape, kernel_size=3, padding='same', use_attention=0, channel_sym = False):
        # channel_sym assigned False for speed. Saves 30 seconds.

        super().__init__()
        self.hidden_channels = input_shape[-1]
        self.kernel_size = kernel_size
        self.padding = padding
        self.channel_sym = channel_sym
        self.use_attention = use_attention
        self.input_shape_ = input_shape

        if self.use_attention:
            # TODO: implement attention
            pass
        else:
            # Initialize convolutional kernels
            self.U_a = layers.Conv2D(
                filters=self.hidden_channels,
                kernel_size=1, 
                strides=1, 
                padding=self.padding,
                kernel_initializer=initializers.Orthogonal(),
                )
            
            self.U_m = layers.Conv2D(
                filters=1,
                kernel_size=self.kernel_size, 
                strides=1, 
                padding=self.padding,
                kernel_initializer=initializers.Orthogonal(),
                )
            
            self.W_s = layers.Conv2D(
                filters=self.hidden_channels,
                kernel_size=self.kernel_size, 
                strides=1, 
                padding=self.padding,
                kernel_initializer=initializers.Orthogonal(),
                )
            
            self.U_f = layers.Conv2D(
                filters=self.hidden_channels,
                kernel_size=self.kernel_size, 
                strides=1, 
                padding=self.padding,
                kernel_initializer=initializers.Orthogonal(),
                )
            
            self.W_f = layers.Conv2D(
                filters=self.hidden_channels,
                kernel_size=self.kernel_size, 
                strides=1, 
                padding=self.padding,
                kernel_initializer=initializers.Orthogonal(),
                )
        self.build(self.input_shape_)

        # initiate other weights
        self.alpha = tf.Variable(0.1, dtype='float32')
        self.mu = tf.Variable(0, dtype='float32')
        self.nu = tf.Variable(0, dtype='float32')
        self.omega = tf.Variable(0.1, dtype='float32')

    def channel_symmetrize(self):
        '''
        symmetrize the kernels channel-wise
        Somehow, if I write it in init, there will be the following error:
        'Conv2D' does not have attribute 'kernel'.
        '''
        if self.channel_sym: 
            for i in range(self.hidden_channels):
                for j in range(i, self.hidden_channels):
                    self.U_a.kernel[:,:,i,j].assign(self.U_a.kernel[:,:,j,i])
                    self.U_f.kernel[:,:,i,j].assign(self.U_f.kernel[:,:,j,i])
                    self.W_s.kernel[:,:,i,j].assign(self.W_s.kernel[:,:,j,i])
                    self.W_f.kernel[:,:,i,j].assign(self.W_f.kernel[:,:,j,i])

    def build(self, input_shape):
        self.U_a.build(input_shape)
        self.U_m.build(input_shape)
        self.U_f.build(input_shape)
        self.W_s.build(input_shape)
        self.W_f.build(input_shape)
        if self.channel_sym:
            self.channel_symmetrize()
        
        # initialize instance norm layers
        self.iN1 = InstanceNorm(self.hidden_channels)
        self.iN2 = InstanceNorm(self.hidden_channels)
        self.iN3 = InstanceNorm(self.hidden_channels)
        self.iN4 = InstanceNorm(self.hidden_channels)


    def call(self, z, h):
        '''
        Params: 
        Z: output from the last layer if fGRU-horizontal, hidden state of the
        current layer at t if fGRU-feedback.
        H: hidden state of the current layer at t-1 if fGRU-horizontal, output
        from the next layer if fGRU-feedback.
        '''

        # Stage 1: suppression
        a_s = self.U_a(h) # Compute channel-wise selection
        m_s = self.U_m(h) # Compute spatial selection
        # (note that U_a and U_m are kernels of different sizes and therefore
        # have different functions)

        m_s_expanded = tf.transpose(tf.convert_to_tensor([tf.transpose(m_s)[0]]*self.hidden_channels))
        g_s = tf.sigmoid(self.iN1(a_s * m_s_expanded))
        # Compute suppression gate
        c_s = self.iN2(self.W_s(h * g_s))
        # compute suppression interactions
        S = tf.keras.activations.relu(z - tf.keras.activations.relu((self.alpha * h + self.mu)*c_s))
        # Additive and multiplicative suppression of Z

        # Stage 2: facilitation
        g_f = tf.sigmoid(self.iN3(self.U_f(S)))
        # Compute channel-wise recurrent updates
        c_f = self.iN4(self.W_f(S))
        # Compute facilitation interactions
        h_tilda = tf.keras.activations.relu(self.nu*(c_f + S) + self.omega*(c_f * S))
        # Additive and multiplicative facilitation of S
        ht = (1 - g_f) * h + g_f * h_tilda
        # Update recurrent state
        return ht

In [2]:
testCell = fGRU([2, 24, 24, 64])

image = np.random.rand(2, 24, 24, 64)
H = np.random.rand(2, 24, 24, 64)

out = testCell(image, H)

print(out)

2023-05-09 00:51:43.846934: I tensorflow/core/platform/cpu_feature_guard.cc:151] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations:  AVX2 FMA
To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.


tf.Tensor(
[[[[0.07984334 0.11455605 0.163245   ... 0.17957237 0.0634892
    0.4242353 ]
   [0.01165117 0.3060048  0.45316622 ... 0.0485071  0.32089683
    0.4126587 ]
   [0.3190008  0.18126743 0.1159981  ... 0.08572371 0.24778576
    0.3925361 ]
   ...
   [0.4006103  0.29275915 0.37584186 ... 0.4716049  0.462499
    0.35205117]
   [0.02333889 0.20960535 0.43413654 ... 0.4586685  0.16885658
    0.3135927 ]
   [0.33670387 0.33334872 0.18522047 ... 0.2495065  0.07325691
    0.3886185 ]]

  [[0.05987698 0.35893238 0.18363668 ... 0.04022368 0.12781082
    0.13562593]
   [0.07062237 0.3909683  0.48119456 ... 0.1932039  0.02582544
    0.15836316]
   [0.00128846 0.27886367 0.21333492 ... 0.446795   0.17821662
    0.4759597 ]
   ...
   [0.28442866 0.17637363 0.49467507 ... 0.39762995 0.22101052
    0.1840567 ]
   [0.35709444 0.37761354 0.04195167 ... 0.23620974 0.21628556
    0.17500621]
   [0.21945629 0.07619277 0.3555263  ... 0.24796298 0.16033842
    0.21940245]]

  [[0.41433257 0.48755026 

In [64]:
import numpy as np
import tensorflow as tf
from tensorflow.python.keras import layers
import fGRU

class GammaNetBlock(layers.Layer):
    '''
    Generate a block in gamma-net
    '''
    def __init__(self, batch_size, input_shape, layers_config):
        '''
        params: 
        input_channels: int, the number of input channels
        hidden_channels: int, the number of channels within the block
        layers: a list of tuples, specifying what kind of layers it contains:
                Conv2D: ('c', [kernel_size, strides])
                TransposedConv2D: ('t', [kernel_size, strides])
                fGru: ('f', [input_shape, use_attention])
                maxPool: ('m', [kernel_size, strides])
                instanceNorm: ('i')
                denselayer: ('d', [unit])
        '''
        super().__init__()
        self.batch_size = batch_size

        self.input_shape_ = [batch_size, input_shape[0], input_shape[1], input_shape[2]]
        # here, input_shape_ is in [batch_size, height, width, channel_size]

        self.hidden_channels = self.input_shape_[-1]
        self.fgru = None
        self.hidden_state = tf.Variable(tf.zeros(self.input_shape_), trainable=False)
        self.layers_config = layers_config
        self.layers = []

        for layer in layers_config:
        # populate the blocks with layers
            if layer[0] == 'c':
                kernel_size = layer[1][0]
                strides = layer[1][1]
                self.layers.append(layers.Conv2D(
                    filters=self.hidden_channels, 
                    kernel_size=kernel_size, 
                    strides=strides,
                    padding='same',
                    activation='ReLU'
                    ))
                
            elif layer[0] == 't':
                kernel_size = layer[1][0]
                strides = layer[1][1]
                self.layers.append(layers.Conv2DTranspose(
                    filters=self.hidden_channels, 
                    kernel_size=kernel_size, 
                    strides=strides,
                    padding='same' 
                    ))
                
            elif layer[0] == 'f':
                kernel_size = layer[1][0]
                use_attention = layer[1][1]
                self.fgru=fGRU.fGRU(
                    input_shape=self.input_shape_, 
                    kernel_size=kernel_size,
                    use_attention = use_attention,
                    )
                self.layers.append(self.fgru)

            elif layer[0] == 'm':
                pool_size = layer[1][0]
                strides = layer[1][1]
                self.layers.append(layers.MaxPool2D(
                    pool_size=pool_size, 
                    strides=strides,
                    padding='valid' 
                    ))
                
            elif layer[0] == 'i':
                self.layers.append(fGRU.InstanceNorm(self.hidden_channels))

            elif layer[0] == 'd':
                unit = layer[1][0]
                if layer[1][1] == 's':
                    self.layers.append(layers.Dense(unit, activation='softmax'))
                elif layer[1][1] == 'l':
                    self.layers.append(layers.Dense(unit, activation='leaky_relu'))

            elif layer[0] == 'l':
                self.layers.append(layers.Flatten())

    def call(self, x, h):
        z = x
        for layer in self.layers:
            if layer == self.fgru:
                z = layer(z, h)
                self.hidden_state.assign(z)
            else:
                z = layer(z)
        return z

# in each layer, there are three lists: 
# input shape (without batch_size), bottom-up unit, top-down unit (if any)
default_config = [
    [[384, 384, 24], 
     [('c', [3, 1]), ('c', [3, 1]), ('f', [9, False]), ('m', [2, 2])],
     [('t', [4, 2]),('c', [3, 1]),('i'),('f', [1, False])]], # first layer
    [[192, 192, 28], 
     [('c', [3, 1]), ('f', [7, False]), ('m', [2, 2])],
     [('t', [4, 2]),('c', [3, 1]),('i'),('f', [1, False])]], # second layer
    [[96, 96, 36], 
     [('c', [3, 1]), ('f', [5, False]), ('m', [2, 2])],
     [('t', [4, 2]),('c', [3, 1]),('i'),('f', [1, False])]], # third layer
    [[48, 48, 48], 
     [('c', [3, 1]), ('f', [3, False]), ('m', [2, 2])],
     [('t', [4, 2]),('c', [3, 1]),('i'),('f', [1, False])]], # forth layer
    [[24, 24, 64], 
     [('c', [3, 1]), ('f', [3, False])]], # fifth layer
    [[384, 384, 24], [('i'), ('c', [5, 1])]] # readout layer
    ]

class GammaNet(tf.keras.Model):
    '''
    Gamma-net class
    '''
    def __init__(self, batch_size=1, steps=1, blocks_config = default_config, mode='segmentation'):
        super().__init__()
        self.batch_size = batch_size
        self.n_layers = len(blocks_config) - 1
        self.steps = steps
        self.blocks_config = blocks_config
        self.blocks = [] # stores gammanetblocks, number of items equals number 
                         # of layers, each layer contains one or two blocks.
        self.mode = mode

        for i in range(self.n_layers + 1):
        # for all layers:
            block_config = self.blocks_config[i]
            input_shape = block_config[0]
            block = []
            for j in range(1, len(block_config)):
                block.append(GammaNetBlock(self.batch_size, input_shape, block_config[j]))
            self.blocks.append(block)
            
    def get_output(self, x):
        for _ in range(self.steps):
            z = x 
            # In the paper, this assignment appears before the time loop,
            # and updates z on the first layer with ReLU and Conv every time 
            # step. 
            # This doesn't make much sense, because at time t, the input
            # to the first layer would already gone through t-1 ReLU and Convs,
            # but when you consider human brain, every second comes a fresh image
            # from the very bottom of the visual path.
            for l in range(self.n_layers):
            # bottom-up
                if l == self.n_layers-1:
                    h = self.blocks[l][0].hidden_state
                else: h = self.blocks[l][1].hidden_state
                if h == None:
                # if no initial hidden_state, assign h as 0.
                # note that the input_shape_ of gammaNetBlock objects contains
                # batch_size aat the begginning already.
                    h = tf.zeros(self.blocks[l][0].input_shape_)
                z = self.blocks[l][0](z, h)
            
            for l in range(self.n_layers-2, -1, -1):
            # top-down
                h = self.blocks[l][0].hidden_state
                z = self.blocks[l][1](z, h)

    def call(self, x):
        for _ in range(self.steps):
            z = x 
            # In the paper, this assignment appears before the time loop,
            # and updates z on the first layer with ReLU and Conv every time 
            # step. 
            # This doesn't make much sense, because at time t, the input
            # to the first layer would already gone through t-1 ReLU and Convs,
            # but when you consider human brain, every second comes a fresh image
            # from the very bottom of the visual path.
            for l in range(self.n_layers):
            # bottom-up
                if l == self.n_layers-1:
                    pos = 0
                else: pos = 1
                h = tf.zeros(self.blocks[l][pos].input_shape_)
                z = self.blocks[l][0](z, h)
            
            for l in range(self.n_layers-2, -1, -1):
            # top-down
                h = self.blocks[l][0].hidden_state
                z = self.blocks[l][1](z, h)
        if self.mode == 'segmentation':
            out = self.blocks[-1][0](z, None)
        elif self.mode == 'classification': 
            out = self.blocks[-1][0](self.blocks[self.n_layers-1][0].hidden_state, None)
        return out

In [39]:
testNet = GammaNet()

In [40]:
image = np.random.rand(1, 384, 384, 24)
out = testNet(image)
print(out)

tf.Tensor(
[[[[0.00000000e+00 4.32764318e-05 8.10029633e-07 ... 0.00000000e+00
    0.00000000e+00 4.34857648e-05]
   [0.00000000e+00 5.87533323e-05 0.00000000e+00 ... 2.43774048e-05
    0.00000000e+00 4.58322975e-05]
   [0.00000000e+00 1.41478358e-05 9.41691724e-07 ... 0.00000000e+00
    1.75382102e-05 4.51710512e-05]
   ...
   [0.00000000e+00 0.00000000e+00 1.10112251e-06 ... 0.00000000e+00
    2.31475042e-05 1.80980605e-05]
   [0.00000000e+00 1.63596487e-05 0.00000000e+00 ... 0.00000000e+00
    9.18468140e-06 7.59661179e-06]
   [7.10298173e-06 1.40269894e-05 0.00000000e+00 ... 0.00000000e+00
    0.00000000e+00 0.00000000e+00]]

  [[0.00000000e+00 1.18805656e-05 0.00000000e+00 ... 5.79135594e-05
    0.00000000e+00 5.49850993e-05]
   [0.00000000e+00 0.00000000e+00 0.00000000e+00 ... 5.22956398e-05
    0.00000000e+00 2.55932973e-05]
   [0.00000000e+00 4.80036288e-05 0.00000000e+00 ... 4.30816144e-05
    0.00000000e+00 9.49998503e-05]
   ...
   [3.36376002e-06 0.00000000e+00 0.00000000e+

In [41]:
testCell = fGRU.fGRU([2, 24, 24, 64])

image = np.random.rand(2, 24, 24, 64)
H = np.random.rand(2, 24, 24, 64)

out = testCell(image, H)

print(out)

tf.Tensor(
[[[[3.66100252e-01 2.87641823e-01 2.52678208e-02 ... 4.29280251e-01
    3.48325878e-01 1.40824050e-01]
   [1.43614843e-01 3.41969699e-01 6.24147616e-02 ... 3.15974355e-01
    4.14131135e-01 7.29575008e-02]
   [2.32048586e-01 2.33579293e-01 4.46975529e-01 ... 3.33756924e-01
    4.10192937e-01 2.35022187e-01]
   ...
   [7.10494742e-02 1.25166088e-01 3.13693613e-01 ... 3.57219875e-02
    4.25390750e-01 4.12459195e-01]
   [3.81193757e-01 4.34275091e-01 3.60676289e-01 ... 1.59671694e-01
    1.57082990e-01 1.70408547e-01]
   [3.07462156e-01 3.64185236e-02 4.73077059e-01 ... 4.07286823e-01
    2.90528893e-01 4.81497020e-01]]

  [[4.41718400e-01 4.30781811e-01 1.14224508e-01 ... 9.48649645e-03
    2.93744475e-01 2.43767917e-01]
   [1.74163237e-01 3.92020971e-01 6.54795468e-02 ... 3.65524381e-01
    7.87539110e-02 4.44264084e-01]
   [2.95827091e-01 4.03627694e-01 8.12440887e-02 ... 1.24102861e-01
    4.53640610e-01 4.17754203e-01]
   ...
   [4.13201094e-01 2.87497520e-01 3.99560213e-

In [42]:
testNet.build([1, 384, 384, 24])
testNet.compile(optimizer=tf.optimizers.Adam(), loss='mse')
testNet.summary()

Model: "gamma_net_12"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
Total params: 14,400,117
Trainable params: 797,301
Non-trainable params: 13,602,816
_________________________________________________________________


In [69]:
from types import SimpleNamespace

import numpy as np
import tensorflow as tf

## Run functions eagerly to allow numpy conversions.
## Enable experimental debug mode to suppress warning (feel free to remove second line)
tf.config.run_functions_eagerly(True)
tf.data.experimental.enable_debug_mode()


###############################################################################################


def get_data():
    """
    Loads CIFAR10 training and testing datasets

    :return X0: training images,
            Y0: training labels,
            X1: testing images,
            Y1: testing labels
            D0: TF Dataset training subset
            D1: TF Dataset testing subset
        D_info: TF Dataset metadata
    """

    ## This process may take a bit to load the first time; should get much faster
    import tensorflow_datasets as tfds

    ## Overview of dataset downloading: https://www.tensorflow.org/datasets/catalog/overview
    ## CIFAR-10 Dataset https://www.tensorflow.org/datasets/catalog/cifar10
    (D0, D1), D_info = tfds.load(
        "cifar10", as_supervised=True, split=["train[:50%]", "test"], with_info=True
    )

    X0, X1 = [np.array([r[0] for r in tfds.as_numpy(D)]) for D in (D0, D1)]
    Y0, Y1 = [np.array([r[1] for r in tfds.as_numpy(D)]) for D in (D0, D1)]

    return X0, Y0, X1, Y1, D0, D1, D_info

X0, Y0, X1, Y1, D0, D1, D_info = get_data()

In [70]:
classification_config = [
    [[32, 32, 3], 
     [('c', [3, 1]), ('c', [3, 1]), ('f', [9, False]), ('m', [2, 2])],
     [('t', [4, 2]),('c', [3, 1]),('i'),('f', [1, False])]], # first layer
    [[16, 16, 5], 
     [('c', [3, 1]), ('f', [7, False]), ('m', [2, 2])],
     [('t', [4, 2]),('c', [3, 1]),('i'),('f', [1, False])]], # second layer
    [[8, 8, 10], 
     [('c', [3, 1]), ('f', [5, False]), ('m', [2, 2])],
     [('t', [4, 2]),('c', [3, 1]),('i'),('f', [1, False])]], # third layer
    [[4, 4, 20], 
     [('c', [3, 1]), ('f', [3, False])],], # forth layer # fifth layer
    [[4, 4, 20], [('i'), ('l'), ('d', [100, 'l']), ('d', [10, 's'])]] # readout layer
    ]

In [71]:
model = GammaNet(steps = 4, blocks_config = classification_config, mode='classification')
model.build([1, 32, 32, 3])
model.compile(optimizer=tf.optimizers.Adam(), loss=tf.keras.losses.CategoricalCrossentropy)

In [74]:
model.summary()

Model: "gamma_net_8"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
Total params: 79,017
Trainable params: 68,393
Non-trainable params: 10,624
_________________________________________________________________


In [72]:
Y0 = np.array(tf.cast(Y0, 'float32'))
Y1 = np.array(tf.cast(Y1, 'float32'))
X0 = np.array(tf.cast(X0, 'float32'))
X1 = np.array(tf.cast(X1, 'float32'))

In [73]:
print(X0.shape)
output_prep_fn = tf.keras.layers.CategoryEncoding(
        num_tokens=10, output_mode="one_hot"
    )

Y0 = output_prep_fn(Y0)
print(Y0.shape)
history = model.fit(X0, Y0, batch_size=1, epochs=2, validation_data = (X1, output_prep_fn(Y1)))

(25000, 32, 32, 3)
(25000, 10)
Epoch 1/2


ValueError: Attempt to convert a value (<keras.losses.CategoricalCrossentropy object at 0x16c55c760>) with an unsupported type (<class 'keras.losses.CategoricalCrossentropy'>) to a Tensor.