In [None]:
import tensorflow as tf
from tensorflow.keras import layers
class RecurrentConvNextTiny(tf.keras.Model):
  def __init__(self):
    super().__init__()
    self.backbone  = tf.keras.applications.convnext.ConvNeXtTiny(
                    model_name='convnext_tiny',
                    include_top=False,
                    include_preprocessing=True,
                    weights='imagenet',
                    input_tensor=None,
                    input_shape=(224, 224, 3)
                )


    self.gru = layers.GRU(128, activation='tanh')
    self.dense1 = layers.Dense(64, activation="relu")
    self.dense2 = layers.Dense(64, activation="relu")

    self.output_layer = layers.Dense(5, activation="softmax")

  def __call__(self, inputs, training=False):
    recurrent_inputs = []
    for i in range(inputs.shape[1]):
      image_features = self.backbone(inputs[:,i])
      image_features = tf.keras.layers.Flatten()(image_features)
      recurrent_inputs.append(image_features)

    rnn_input = tf.stack(recurrent_inputs, axis=1)

    x = self.gru(rnn_input)
    x = self.dense1(x)
    x = self.dense2(x)
    return self.output_layer(x)

In [None]:
model = RecurrentConvNextTiny()

In [None]:
sample_input = tf.random.normal((7,5,224,224,3))

In [None]:
model.predict(sample_input)



array([[0.271167  , 0.33411783, 0.13445558, 0.1783922 , 0.08186743],
       [0.26734263, 0.31150568, 0.14871182, 0.18566348, 0.0867764 ],
       [0.2779697 , 0.31962883, 0.12939364, 0.18865162, 0.08435615],
       [0.2533247 , 0.3211682 , 0.15052849, 0.18983278, 0.08514585],
       [0.25616467, 0.32577506, 0.14166562, 0.19265145, 0.08374324],
       [0.26658738, 0.32319444, 0.13510163, 0.191401  , 0.08371557],
       [0.268578  , 0.31550312, 0.14008923, 0.18895923, 0.08687049]],
      dtype=float32)

# V2

In [None]:
class RecurrentConvNextTiny(tf.keras.Model):
  def __init__(self):
    super().__init__()
    backbone  = tf.keras.applications.convnext.ConvNeXtTiny(
                    model_name='convnext_tiny',
                    include_top=False,
                    include_preprocessing=True,
                    weights='imagenet',
                    input_tensor=None,
                    input_shape=(224, 224, 3)
                )
    flatten = tf.keras.layers.Flatten()

    self.bottom = tf.keras.layers.TimeDistributed(tf.keras.Sequential([
        backbone,
        flatten
    ]))

    self.gru = layers.GRU(128, activation='tanh')
    self.dense1 = layers.Dense(64, activation="relu")
    self.dense2 = layers.Dense(64, activation="relu")

    self.output_layer = layers.Dense(5, activation="softmax")

  def __call__(self, inputs, training=False):

    x = self.bottom(inputs)
    x = self.gru(x)
    x = self.dense1(x)
    x = self.dense2(x)
    return self.output_layer(x)

In [None]:
model = RecurrentConvNextTiny()

In [None]:
model.predict(sample_input)



array([[0.19394262, 0.12502167, 0.1794809 , 0.2600652 , 0.24148965],
       [0.1759656 , 0.12841065, 0.17844352, 0.2809659 , 0.23621438],
       [0.180843  , 0.13342293, 0.16407986, 0.28569695, 0.23595725],
       [0.18915747, 0.12935854, 0.17224367, 0.26526812, 0.24397215],
       [0.16939059, 0.1370537 , 0.1690742 , 0.28534546, 0.23913603],
       [0.17971508, 0.13906918, 0.15674274, 0.28176472, 0.24270831],
       [0.18937613, 0.12763949, 0.17623691, 0.27167746, 0.23507008]],
      dtype=float32)