In [1]:
import numpy as np
import tensorflow as tf
from tensorflow.python.keras import layers
from tensorflow.python.keras import initializers

class InstanceNorm(tf.keras.layers.Layer):
    def __init__(self, hidden_channels):
        super().__init__()
        self.hidden_channels = hidden_channels
        self.omicron = tf.Variable(np.zeros(self.hidden_channels), dtype='float32')
        self.eta = tf.Variable(np.random.rand(self.hidden_channels), dtype='float32')
        self.delta = tf.Variable(np.zeros(self.hidden_channels)+0.1, dtype='float32')

    def call(self, r):
        '''
        Param: r, a 4D tensor, b x h x w x c, where b = 1
        Return: a tensor normalized with the same size as r.
        '''                
        return tf.convert_to_tensor([self.omicron + self.delta * (r[0] - tf.math.reduce_mean(r[0], axis=(0, 1)))\
                         /(tf.math.sqrt(tf.math.reduce_variance(r[0], axis=(0, 1))+self.eta))])

class fGRU(tf.keras.layers.Layer):
    '''
    Generates an fGRUCell
    params:
    hidden_channels: the number of channels which is constant throughout the
                     processing of each unit
    '''
    def __init__(self, input_shape, kernel_size=3, padding='same', use_attention=0, channel_sym = False):
        # channel_sym assigned False for speed. Saves 30 seconds.

        super().__init__()
        self.hidden_channels = input_shape[-1]
        self.kernel_size = kernel_size
        self.padding = padding
        self.channel_sym = channel_sym
        self.use_attention = use_attention
        self.input_shape_ = input_shape

        if self.use_attention:
            # TODO: implement attention
            pass
        else:
            # Initialize convolutional kernels
            self.U_a = layers.Conv2D(
                filters=self.hidden_channels,
                kernel_size=1, 
                strides=1, 
                padding=self.padding,
                kernel_initializer=initializers.Orthogonal(),
                )
            
            self.U_m = layers.Conv2D(
                filters=1,
                kernel_size=self.kernel_size, 
                strides=1, 
                padding=self.padding,
                kernel_initializer=initializers.Orthogonal(),
                )
            
            self.W_s = layers.Conv2D(
                filters=self.hidden_channels,
                kernel_size=self.kernel_size, 
                strides=1, 
                padding=self.padding,
                kernel_initializer=initializers.Orthogonal(),
                )
            
            self.U_f = layers.Conv2D(
                filters=self.hidden_channels,
                kernel_size=self.kernel_size, 
                strides=1, 
                padding=self.padding,
                kernel_initializer=initializers.Orthogonal(),
                )
            
            self.W_f = layers.Conv2D(
                filters=self.hidden_channels,
                kernel_size=self.kernel_size, 
                strides=1, 
                padding=self.padding,
                kernel_initializer=initializers.Orthogonal(),
                )
        self.build(self.input_shape_)

        # initiate other weights
        self.alpha = tf.Variable(0.1, dtype='float32')
        self.mu = tf.Variable(0, dtype='float32')
        self.nu = tf.Variable(0, dtype='float32')
        self.omega = tf.Variable(0.1, dtype='float32')

    def channel_symmetrize(self):
        '''
        symmetrize the kernels channel-wise
        Somehow, if I write it in init, there will be the following error:
        'Conv2D' does not have attribute 'kernel'.
        '''
        if self.channel_sym: 
            for i in range(self.hidden_channels):
                for j in range(i, self.hidden_channels):
                    self.U_a.kernel[:,:,i,j].assign(self.U_a.kernel[:,:,j,i])
                    self.U_f.kernel[:,:,i,j].assign(self.U_f.kernel[:,:,j,i])
                    self.W_s.kernel[:,:,i,j].assign(self.W_s.kernel[:,:,j,i])
                    self.W_f.kernel[:,:,i,j].assign(self.W_f.kernel[:,:,j,i])

    def build(self, input_shape):
        self.U_a.build(input_shape)
        self.U_m.build(input_shape)
        self.U_f.build(input_shape)
        self.W_s.build(input_shape)
        self.W_f.build(input_shape)
        if self.channel_sym:
            self.channel_symmetrize()
        
        # initialize instance norm layers
        self.iN1 = InstanceNorm(self.hidden_channels)
        self.iN2 = InstanceNorm(self.hidden_channels)
        self.iN3 = InstanceNorm(self.hidden_channels)
        self.iN4 = InstanceNorm(self.hidden_channels)


    def call(self, z, h):
        '''
        Params: 
        Z: output from the last layer if fGRU-horizontal, hidden state of the
        current layer at t if fGRU-feedback.
        H: hidden state of the current layer at t-1 if fGRU-horizontal, output
        from the next layer if fGRU-feedback.
        '''

        # Stage 1: suppression
        a_s = self.U_a(h) # Compute channel-wise selection
        m_s = self.U_m(h) # Compute spatial selection
        # (note that U_a and U_m are kernels of different sizes and therefore
        # have different functions)

        m_s_expanded = tf.transpose(tf.convert_to_tensor([tf.transpose(m_s)[0]]*self.hidden_channels))
        g_s = tf.sigmoid(self.iN1(a_s * m_s_expanded))
        # Compute suppression gate
        c_s = self.iN2(self.W_s(h * g_s))
        # compute suppression interactions
        S = tf.keras.activations.relu(z - tf.keras.activations.relu((self.alpha * h + self.mu)*c_s))
        # Additive and multiplicative suppression of Z

        # Stage 2: facilitation
        g_f = tf.sigmoid(self.iN3(self.U_f(S)))
        # Compute channel-wise recurrent updates
        c_f = self.iN4(self.W_f(S))
        # Compute facilitation interactions
        h_tilda = tf.keras.activations.relu(self.nu*(c_f + S) + self.omega*(c_f * S))
        # Additive and multiplicative facilitation of S
        print('successfully went to h_tilda')
        ht = (1 - g_f) * h + g_f * h_tilda
        # Update recurrent state
        return ht

2023-05-07 23:13:53.884854: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 AVX512F AVX512_VNNI FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.


In [2]:
testCell = fGRU([2, 24, 24, 64])

image = np.random.rand(2, 24, 24, 64)
H = np.random.rand(2, 24, 24, 64)

out = testCell(image, H)

print(out)

successfully went to h_tilda
tf.Tensor(
[[[[0.17670536 0.05386674 0.08195007 ... 0.19052775 0.50975573
    0.41383624]
   [0.23866518 0.07530871 0.23960882 ... 0.4768745  0.32390657
    0.33147368]
   [0.45453075 0.19438295 0.25265265 ... 0.17036861 0.06937771
    0.34412903]
   ...
   [0.43183276 0.41736105 0.48195222 ... 0.16671437 0.11967142
    0.01121795]
   [0.36706245 0.04296055 0.03496937 ... 0.11565843 0.01695942
    0.25038895]
   [0.36113074 0.1747229  0.5105558  ... 0.46054152 0.37540382
    0.13217406]]

  [[0.34135628 0.4513697  0.16396527 ... 0.05932774 0.42069256
    0.04274375]
   [0.37440076 0.15505584 0.4341617  ... 0.08102697 0.27423388
    0.21334124]
   [0.22877148 0.47010973 0.03322793 ... 0.41577843 0.03750823
    0.14296333]
   ...
   [0.45461667 0.18557462 0.02443766 ... 0.06333341 0.04295073
    0.46256605]
   [0.40788773 0.50568837 0.12842602 ... 0.28507867 0.02897747
    0.09211819]
   [0.41338804 0.12329229 0.26920077 ... 0.29553512 0.35680038
    0.315370

In [6]:
import numpy as np
import tensorflow as tf
from tensorflow.python.keras import layers
import fGRU

class GammaNetBlock(layers.Layer):
    '''
    Generate a block in gamma-net
    '''
    def __init__(self, batch_size, input_shape, layers_config):
        '''
        params: 
        input_channels: int, the number of input channels
        hidden_channels: int, the number of channels within the block
        layers: a list of tuples, specifying what kind of layers it contains:
                Conv2D: ('c', [kernel_size, strides])
                TransposedConv2D: ('t', [kernel_size, strides])
                fGru: ('f', [input_shape, use_attention])
                maxPool: ('m', [kernel_size, strides])
                instanceNorm: ('i')
        '''
        super().__init__()
        self.batch_size = batch_size

        self.input_shape_ = [batch_size, input_shape[0], input_shape[1], input_shape[2]]
        # here, input_shape_ is in [batch_size, height, width, channel_size]

        self.hidden_channels = self.input_shape_[-1]
        self.fgru = None
        self.hidden_state = None
        self.layers_config = layers_config
        self.layers = []

        for layer in layers_config:
        # populate the blocks with layers
            if layer[0] == 'c':
                kernel_size = layer[1][0]
                strides = layer[1][1]
                self.layers.append(layers.Conv2D(
                    filters=self.hidden_channels, 
                    kernel_size=kernel_size, 
                    strides=strides,
                    padding='same',
                    activation='ReLU'
                    ))
                
            elif layer[0] == 't':
                kernel_size = layer[1][0]
                strides = layer[1][1]
                self.layers.append(layers.Conv2DTranspose(
                    filters=self.hidden_channels, 
                    kernel_size=kernel_size, 
                    strides=strides,
                    padding='same' 
                    ))
                
            elif layer[0] == 'f':
                kernel_size = layer[1][0]
                use_attention = layer[1][1]
                self.fgru=fGRU.fGRU(
                    input_shape=self.input_shape_, 
                    kernel_size=kernel_size,
                    use_attention = use_attention,
                    )
                self.layers.append(self.fgru)

            elif layer[0] == 'm':
                pool_size = layer[1][0]
                strides = layer[1][1]
                self.layers.append(layers.MaxPool2D(
                    pool_size=pool_size, 
                    strides=strides,
                    padding='valid' 
                    ))
                
            elif layer[0] == 'i':
                self.layers.append(fGRU.InstanceNorm(self.hidden_channels))

    def call(self, x, h):
        z = x
        for layer in self.layers:
            if layer == self.fgru:
                print('successfully went to fgru') # for debugging
                z = layer(z, h)
                print('successfully got the output of fgru') # for debugging
                self.hidden_state = z
            else:
                z = layer(z)
        return z

# in each layer, there are three lists: 
# input shape (without batch_size), bottom-up unit, top-down unit (if any)
default_config = [
    [[384, 384, 24], 
     [('c', [3, 1]), ('c', [3, 1]), ('f', [9, False]), ('m', [2, 2])],
     [('t', [4, 2]),('c', [3, 1]),('i'),('f', [1, False])]], # first layer
    [[192, 192, 28], 
     [('c', [3, 1]), ('f', [7, False]), ('m', [2, 2])],
     [('t', [4, 2]),('c', [3, 1]),('i'),('f', [1, False])]], # second layer
    [[96, 96, 36], 
     [('c', [3, 1]), ('f', [5, False]), ('m', [2, 2])],
     [('t', [4, 2]),('c', [3, 1]),('i'),('f', [1, False])]], # third layer
    [[48, 48, 48], 
     [('c', [3, 1]), ('f', [3, False]), ('m', [2, 2])],
     [('t', [4, 2]),('c', [3, 1]),('i'),('f', [1, False])]], # forth layer
    [[24, 24, 64], 
     [('c', [3, 1]), ('f', [3, False])]], # fifth layer
    [[384, 384, 24], [('i'), ('c', [5, 1])]] # readout layer
    ]

class GammaNet(tf.keras.Model):
    '''
    Gamma-net class
    '''
    def __init__(self, batch_size=1, steps=1, blocks_config = default_config):
        super().__init__()
        self.batch_size = batch_size
        self.n_layers = len(blocks_config) - 1
        self.steps = steps
        self.blocks_config = blocks_config
        self.blocks = [] # stores gammanetblocks, number of items equals number 
                         # of layers, each layer contains one or two blocks.

        for i in range(self.n_layers + 1):
        # for all layers:
            block_config = self.blocks_config[i]
            input_shape = block_config[0]
            block = []
            for j in range(1, len(block_config)):
                block.append(GammaNetBlock(self.batch_size, input_shape, block_config[j]))
            self.blocks.append(block)

    def call(self, x):
        for _ in range(self.steps):
            z = x 
            # In the paper, this assignment appears before the time loop,
            # and updates z on the first layer with ReLU and Conv every time 
            # step. 
            # This doesn't make much sense, because at time t, the input
            # to the first layer would already gone through t-1 ReLU and Convs,
            # but when you consider human brain, every second comes a fresh image
            # from the very bottom of the visual path.
            for l in range(self.n_layers):
            # bottom-up
                if l == self.n_layers-1:
                    h = self.blocks[l][0].hidden_state
                else: h = self.blocks[l][1].hidden_state
                if h == None:
                # if no initial hidden_state, assign h as 0.
                # note that the input_shape_ of gammaNetBlock objects contains
                # batch_size aat the begginning already.
                    h = tf.zeros(self.blocks[l][0].input_shape_)
                z = self.blocks[l][0](z, h)
            
            for l in range(self.n_layers-2, -1, -1):
            # top-down
                h = self.blocks[l][0].hidden_state
                z = self.blocks[l][1](z, h)
        
        out = self.blocks[-1][0](z, None)
        print('went to final output') # for debugging
        return out

In [7]:
testNet = GammaNet()
print(testNet.blocks[0])

ListWrapper([<__main__.GammaNetBlock object at 0x7fd5d8ecfd00>, <__main__.GammaNetBlock object at 0x7fd5b9292d30>])


In [8]:
image = np.random.rand(1, 384, 384, 24)
out = testNet(image)
print(out)

successfully went to fgru
successfully went to h_tilda
successfully got the output of fgru
successfully went to fgru
successfully went to h_tilda
successfully got the output of fgru
successfully went to fgru
successfully went to h_tilda
successfully got the output of fgru
successfully went to fgru
successfully went to h_tilda
successfully got the output of fgru
successfully went to fgru
successfully went to h_tilda
successfully got the output of fgru
successfully went to fgru
successfully went to h_tilda
successfully got the output of fgru
successfully went to fgru
successfully went to h_tilda
successfully got the output of fgru
successfully went to fgru
successfully went to h_tilda
successfully got the output of fgru
successfully went to fgru
successfully went to h_tilda
successfully got the output of fgru
went to final output
tf.Tensor(
[[[[2.55350574e-06 6.92129902e-07 6.59590205e-06 ... 0.00000000e+00
    9.53527888e-06 0.00000000e+00]
   [8.18234548e-06 8.07233209e-06 0.00000000e+

In [9]:
testNet.build([1, 384, 384, 24])
testNet.compile(optimizer=tf.optimizers.Adam(), loss='mse')
testNet.summary()

successfully went to fgru
successfully went to h_tilda
successfully got the output of fgru
successfully went to fgru
successfully went to h_tilda
successfully got the output of fgru
successfully went to fgru
successfully went to h_tilda
successfully got the output of fgru
successfully went to fgru
successfully went to h_tilda
successfully got the output of fgru
successfully went to fgru
successfully went to h_tilda
successfully got the output of fgru
successfully went to fgru
successfully went to h_tilda
successfully got the output of fgru
successfully went to fgru
successfully went to h_tilda
successfully got the output of fgru
successfully went to fgru
successfully went to h_tilda
successfully got the output of fgru
successfully went to fgru
successfully went to h_tilda
successfully got the output of fgru
went to final output
Model: "gamma_net"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
Total par

In [10]:
# "tensor cannot be accessed" in gradient descent.
testNet.fit(np.random.rand(1, 384, 384, 24), np.random.rand(1, 384, 384, 24), epochs=2, batch_size=1)

Epoch 1/2
successfully went to fgru
successfully went to h_tilda
successfully got the output of fgru
successfully went to fgru
successfully went to h_tilda
successfully got the output of fgru
successfully went to fgru
successfully went to h_tilda
successfully got the output of fgru
successfully went to fgru
successfully went to h_tilda
successfully got the output of fgru
successfully went to fgru
successfully went to h_tilda
successfully got the output of fgru
successfully went to fgru
successfully went to h_tilda
successfully got the output of fgru
successfully went to fgru
successfully went to h_tilda
successfully got the output of fgru
successfully went to fgru
successfully went to h_tilda
successfully got the output of fgru
successfully went to fgru
successfully went to h_tilda
successfully got the output of fgru
went to final output
successfully went to fgru
successfully went to h_tilda
successfully got the output of fgru
successfully went to fgru
successfully went to h_tilda
succ

TypeError: <tf.Tensor 'gamma_net/gamma_net_block_1/f_gru_2/add_3:0' shape=(1, 384, 384, 24) dtype=float32> is out of scope and cannot be used here. Use return values, explicit Python locals or TensorFlow collections to access it.
Please see https://www.tensorflow.org/guide/function#all_outputs_of_a_tffunction_must_be_return_values for more information.

<tf.Tensor 'gamma_net/gamma_net_block_1/f_gru_2/add_3:0' shape=(1, 384, 384, 24) dtype=float32> was defined here:
    File "/Users/yixiangsun/opt/anaconda3/lib/python3.8/runpy.py", line 194, in _run_module_as_main
      return _run_code(code, main_globals, None,
    File "/Users/yixiangsun/opt/anaconda3/lib/python3.8/runpy.py", line 87, in _run_code
      exec(code, run_globals)
    File "/Users/yixiangsun/opt/anaconda3/lib/python3.8/site-packages/ipykernel_launcher.py", line 16, in <module>
      app.launch_new_instance()
    File "/Users/yixiangsun/opt/anaconda3/lib/python3.8/site-packages/traitlets/config/application.py", line 846, in launch_instance
      app.start()
    File "/Users/yixiangsun/opt/anaconda3/lib/python3.8/site-packages/ipykernel/kernelapp.py", line 677, in start
      self.io_loop.start()
    File "/Users/yixiangsun/opt/anaconda3/lib/python3.8/site-packages/tornado/platform/asyncio.py", line 199, in start
      self.asyncio_loop.run_forever()
    File "/Users/yixiangsun/opt/anaconda3/lib/python3.8/asyncio/base_events.py", line 570, in run_forever
      self._run_once()
    File "/Users/yixiangsun/opt/anaconda3/lib/python3.8/asyncio/base_events.py", line 1859, in _run_once
      handle._run()
    File "/Users/yixiangsun/opt/anaconda3/lib/python3.8/asyncio/events.py", line 81, in _run
      self._context.run(self._callback, *self._args)
    File "/Users/yixiangsun/opt/anaconda3/lib/python3.8/site-packages/ipykernel/kernelbase.py", line 471, in dispatch_queue
      await self.process_one()
    File "/Users/yixiangsun/opt/anaconda3/lib/python3.8/site-packages/ipykernel/kernelbase.py", line 460, in process_one
      await dispatch(*args)
    File "/Users/yixiangsun/opt/anaconda3/lib/python3.8/site-packages/ipykernel/kernelbase.py", line 367, in dispatch_shell
      await result
    File "/Users/yixiangsun/opt/anaconda3/lib/python3.8/site-packages/ipykernel/kernelbase.py", line 662, in execute_request
      reply_content = await reply_content
    File "/Users/yixiangsun/opt/anaconda3/lib/python3.8/site-packages/ipykernel/ipkernel.py", line 360, in do_execute
      res = shell.run_cell(code, store_history=store_history, silent=silent)
    File "/Users/yixiangsun/opt/anaconda3/lib/python3.8/site-packages/ipykernel/zmqshell.py", line 532, in run_cell
      return super().run_cell(*args, **kwargs)
    File "/Users/yixiangsun/opt/anaconda3/lib/python3.8/site-packages/IPython/core/interactiveshell.py", line 2863, in run_cell
      result = self._run_cell(
    File "/Users/yixiangsun/opt/anaconda3/lib/python3.8/site-packages/IPython/core/interactiveshell.py", line 2909, in _run_cell
      return runner(coro)
    File "/Users/yixiangsun/opt/anaconda3/lib/python3.8/site-packages/IPython/core/async_helpers.py", line 129, in _pseudo_sync_runner
      coro.send(None)
    File "/Users/yixiangsun/opt/anaconda3/lib/python3.8/site-packages/IPython/core/interactiveshell.py", line 3106, in run_cell_async
      has_raised = await self.run_ast_nodes(code_ast.body, cell_name,
    File "/Users/yixiangsun/opt/anaconda3/lib/python3.8/site-packages/IPython/core/interactiveshell.py", line 3309, in run_ast_nodes
      if await self.run_code(code, result, async_=asy):
    File "/Users/yixiangsun/opt/anaconda3/lib/python3.8/site-packages/IPython/core/interactiveshell.py", line 3369, in run_code
      exec(code_obj, self.user_global_ns, self.user_ns)
    File "/var/folders/t4/7j6s3sk94lg03x70m0lzjrdm0000gn/T/ipykernel_59327/1794143762.py", line 1, in <cell line: 1>
      testNet.fit(np.random.rand(1, 384, 384, 24), np.random.rand(1, 384, 384, 24), epochs=2, batch_size=1)
    File "/Users/yixiangsun/opt/anaconda3/lib/python3.8/site-packages/keras/src/utils/traceback_utils.py", line 65, in error_handler
      return fn(*args, **kwargs)
    File "/Users/yixiangsun/opt/anaconda3/lib/python3.8/site-packages/keras/src/engine/training.py", line 1742, in fit
      tmp_logs = self.train_function(iterator)
    File "/Users/yixiangsun/opt/anaconda3/lib/python3.8/site-packages/tensorflow/python/util/traceback_utils.py", line 150, in error_handler
      return fn(*args, **kwargs)
    File "/Users/yixiangsun/opt/anaconda3/lib/python3.8/site-packages/tensorflow/python/eager/polymorphic_function/polymorphic_function.py", line 840, in __call__
      result = self._call(*args, **kwds)
    File "/Users/yixiangsun/opt/anaconda3/lib/python3.8/site-packages/tensorflow/python/eager/polymorphic_function/polymorphic_function.py", line 888, in _call
      self._initialize(args, kwds, add_initializers_to=initializers)
    File "/Users/yixiangsun/opt/anaconda3/lib/python3.8/site-packages/tensorflow/python/eager/polymorphic_function/polymorphic_function.py", line 709, in _initialize
      self._variable_creation_fn    # pylint: disable=protected-access
    File "/Users/yixiangsun/opt/anaconda3/lib/python3.8/site-packages/tensorflow/python/eager/polymorphic_function/tracing_compiler.py", line 176, in _get_concrete_function_internal_garbage_collected
      concrete_function, _ = self._maybe_define_concrete_function(args, kwargs)
    File "/Users/yixiangsun/opt/anaconda3/lib/python3.8/site-packages/tensorflow/python/eager/polymorphic_function/tracing_compiler.py", line 171, in _maybe_define_concrete_function
      return self._maybe_define_function(args, kwargs)
    File "/Users/yixiangsun/opt/anaconda3/lib/python3.8/site-packages/tensorflow/python/eager/polymorphic_function/tracing_compiler.py", line 401, in _maybe_define_function
      concrete_function = self._create_concrete_function(
    File "/Users/yixiangsun/opt/anaconda3/lib/python3.8/site-packages/tensorflow/python/eager/polymorphic_function/tracing_compiler.py", line 305, in _create_concrete_function
      func_graph_module.func_graph_from_py_func(
    File "/Users/yixiangsun/opt/anaconda3/lib/python3.8/site-packages/tensorflow/python/framework/func_graph.py", line 1094, in func_graph_from_py_func
      func_outputs = python_func(*func_args, **func_kwargs)
    File "/Users/yixiangsun/opt/anaconda3/lib/python3.8/site-packages/tensorflow/python/eager/polymorphic_function/polymorphic_function.py", line 613, in wrapped_fn
      out = weak_wrapped_fn().__wrapped__(*args, **kwds)
    File "/Users/yixiangsun/opt/anaconda3/lib/python3.8/site-packages/tensorflow/python/framework/func_graph.py", line 1069, in autograph_handler
      return autograph.converted_call(
    File "/Users/yixiangsun/opt/anaconda3/lib/python3.8/site-packages/keras/src/engine/training.py", line 1338, in train_function
      return step_function(self, iterator)
    File "/Users/yixiangsun/opt/anaconda3/lib/python3.8/site-packages/keras/src/engine/training.py", line 1322, in step_function
      outputs = model.distribute_strategy.run(run_step, args=(data,))
    File "/Users/yixiangsun/opt/anaconda3/lib/python3.8/site-packages/tensorflow/python/distribute/distribute_lib.py", line 1671, in run
      return self._extended.call_for_each_replica(fn, args=args, kwargs=kwargs)
    File "/Users/yixiangsun/opt/anaconda3/lib/python3.8/site-packages/tensorflow/python/distribute/distribute_lib.py", line 3248, in call_for_each_replica
      return self._call_for_each_replica(fn, args, kwargs)
    File "/Users/yixiangsun/opt/anaconda3/lib/python3.8/site-packages/tensorflow/python/distribute/distribute_lib.py", line 4046, in _call_for_each_replica
      return fn(*args, **kwargs)
    File "/Users/yixiangsun/opt/anaconda3/lib/python3.8/site-packages/keras/src/engine/training.py", line 1303, in run_step
      outputs = model.train_step(data)
    File "/Users/yixiangsun/opt/anaconda3/lib/python3.8/site-packages/keras/src/engine/training.py", line 1080, in train_step
      y_pred = self(x, training=True)
    File "/Users/yixiangsun/opt/anaconda3/lib/python3.8/site-packages/keras/src/utils/traceback_utils.py", line 65, in error_handler
      return fn(*args, **kwargs)
    File "/Users/yixiangsun/opt/anaconda3/lib/python3.8/site-packages/keras/src/engine/training.py", line 569, in __call__
      return super().__call__(*args, **kwargs)
    File "/Users/yixiangsun/opt/anaconda3/lib/python3.8/site-packages/keras/src/utils/traceback_utils.py", line 65, in error_handler
      return fn(*args, **kwargs)
    File "/Users/yixiangsun/opt/anaconda3/lib/python3.8/site-packages/keras/src/engine/base_layer.py", line 1150, in __call__
      outputs = call_fn(inputs, *args, **kwargs)
    File "/Users/yixiangsun/opt/anaconda3/lib/python3.8/site-packages/keras/src/utils/traceback_utils.py", line 96, in error_handler
      return fn(*args, **kwargs)
    File "/var/folders/t4/7j6s3sk94lg03x70m0lzjrdm0000gn/T/ipykernel_59327/2511607014.py", line 141, in call
      for _ in range(self.steps):
    File "/Users/yixiangsun/opt/anaconda3/lib/python3.8/site-packages/tensorflow/python/autograph/operators/control_flow.py", line 454, in for_stmt
      for_fn(iter_, extra_test, body, get_state, set_state, symbol_names, opts)
    File "/Users/yixiangsun/opt/anaconda3/lib/python3.8/site-packages/tensorflow/python/autograph/operators/control_flow.py", line 505, in _py_for_stmt
      body(target)
    File "/Users/yixiangsun/opt/anaconda3/lib/python3.8/site-packages/tensorflow/python/autograph/operators/control_flow.py", line 471, in protected_body
      original_body(protected_iter)
    File "/var/folders/t4/7j6s3sk94lg03x70m0lzjrdm0000gn/T/ipykernel_59327/2511607014.py", line 162, in call
      for l in range(self.n_layers-2, -1, -1):
    File "/Users/yixiangsun/opt/anaconda3/lib/python3.8/site-packages/tensorflow/python/autograph/operators/control_flow.py", line 454, in for_stmt
      for_fn(iter_, extra_test, body, get_state, set_state, symbol_names, opts)
    File "/Users/yixiangsun/opt/anaconda3/lib/python3.8/site-packages/tensorflow/python/autograph/operators/control_flow.py", line 505, in _py_for_stmt
      body(target)
    File "/Users/yixiangsun/opt/anaconda3/lib/python3.8/site-packages/tensorflow/python/autograph/operators/control_flow.py", line 471, in protected_body
      original_body(protected_iter)
    File "/var/folders/t4/7j6s3sk94lg03x70m0lzjrdm0000gn/T/ipykernel_59327/2511607014.py", line 165, in call
      z = self.blocks[l][1](z, h)
    File "/Users/yixiangsun/opt/anaconda3/lib/python3.8/site-packages/tensorflow/python/keras/engine/base_layer.py", line 1045, in __call__
      outputs = call_fn(inputs, *args, **kwargs)
    File "/var/folders/t4/7j6s3sk94lg03x70m0lzjrdm0000gn/T/ipykernel_59327/2511607014.py", line 81, in call
      for layer in self.layers:
    File "/Users/yixiangsun/opt/anaconda3/lib/python3.8/site-packages/tensorflow/python/autograph/operators/control_flow.py", line 454, in for_stmt
      for_fn(iter_, extra_test, body, get_state, set_state, symbol_names, opts)
    File "/Users/yixiangsun/opt/anaconda3/lib/python3.8/site-packages/tensorflow/python/autograph/operators/control_flow.py", line 505, in _py_for_stmt
      body(target)
    File "/Users/yixiangsun/opt/anaconda3/lib/python3.8/site-packages/tensorflow/python/autograph/operators/control_flow.py", line 471, in protected_body
      original_body(protected_iter)
    File "/var/folders/t4/7j6s3sk94lg03x70m0lzjrdm0000gn/T/ipykernel_59327/2511607014.py", line 82, in call
      if layer == self.fgru:
    File "/Users/yixiangsun/opt/anaconda3/lib/python3.8/site-packages/tensorflow/python/autograph/operators/control_flow.py", line 1269, in if_stmt
      _py_if_stmt(cond, body, orelse)
    File "/Users/yixiangsun/opt/anaconda3/lib/python3.8/site-packages/tensorflow/python/autograph/operators/control_flow.py", line 1322, in _py_if_stmt
      return body() if cond else orelse()
    File "/var/folders/t4/7j6s3sk94lg03x70m0lzjrdm0000gn/T/ipykernel_59327/2511607014.py", line 84, in call
      z = layer(z, h)
    File "/Users/yixiangsun/opt/anaconda3/lib/python3.8/site-packages/keras/src/utils/traceback_utils.py", line 65, in error_handler
      return fn(*args, **kwargs)
    File "/Users/yixiangsun/opt/anaconda3/lib/python3.8/site-packages/keras/src/engine/base_layer.py", line 1150, in __call__
      outputs = call_fn(inputs, *args, **kwargs)
    File "/Users/yixiangsun/opt/anaconda3/lib/python3.8/site-packages/keras/src/utils/traceback_utils.py", line 96, in error_handler
      return fn(*args, **kwargs)
    File "/Users/yixiangsun/Brown/2023spring/CSCI1952Q/RecurrentFeedbackCNN/fGRU.py", line 151, in call
      ht = (1 - g_f) * h + g_f * h_tilda
    File "/Users/yixiangsun/opt/anaconda3/lib/python3.8/site-packages/tensorflow/python/util/traceback_utils.py", line 150, in error_handler
      return fn(*args, **kwargs)
    File "/Users/yixiangsun/opt/anaconda3/lib/python3.8/site-packages/tensorflow/python/ops/math_ops.py", line 1466, in binary_op_wrapper
      return func(x, y, name=name)
    File "/Users/yixiangsun/opt/anaconda3/lib/python3.8/site-packages/tensorflow/python/util/traceback_utils.py", line 150, in error_handler
      return fn(*args, **kwargs)
    File "/Users/yixiangsun/opt/anaconda3/lib/python3.8/site-packages/tensorflow/python/util/dispatch.py", line 1176, in op_dispatch_handler
      return dispatch_target(*args, **kwargs)
    File "/Users/yixiangsun/opt/anaconda3/lib/python3.8/site-packages/tensorflow/python/ops/math_ops.py", line 1837, in _add_dispatch
      return gen_math_ops.add_v2(x, y, name=name)
    File "/Users/yixiangsun/opt/anaconda3/lib/python3.8/site-packages/tensorflow/python/ops/gen_math_ops.py", line 475, in add_v2
      _, _, _op, _outputs = _op_def_library._apply_op_helper(
    File "/Users/yixiangsun/opt/anaconda3/lib/python3.8/site-packages/tensorflow/python/framework/op_def_library.py", line 795, in _apply_op_helper
      op = g._create_op_internal(op_type_name, inputs, dtypes=None,
    File "/Users/yixiangsun/opt/anaconda3/lib/python3.8/site-packages/tensorflow/python/framework/func_graph.py", line 670, in _create_op_internal
      return super()._create_op_internal(  # pylint: disable=protected-access
    File "/Users/yixiangsun/opt/anaconda3/lib/python3.8/site-packages/tensorflow/python/framework/ops.py", line 3383, in _create_op_internal
      ret = Operation.from_node_def(

The tensor <tf.Tensor 'gamma_net/gamma_net_block_1/f_gru_2/add_3:0' shape=(1, 384, 384, 24) dtype=float32> cannot be accessed from here, because it was defined in FuncGraph(name=train_function, id=140554869249648), which is out of scope.