In [None]:
import tensorflow as tf

In [None]:
def convnext(inputs, filters):
    x = tf.keras.layers.Conv2D(filters=filters, kernel_size=7, strides=1, 
                               groups=filters, padding="same", kernel_initializer='he_normal')(inputs)
    x = tf.keras.layers.LayerNormalization(epsilon=1e-6)(x)
    x = tf.keras.layers.Dense(filters * 4)(x)
    x = tf.keras.layers.Activation("gelu")(x)
    x = tf.keras.layers.Dense(filters)(x)
    x = tf.keras.layers.Lambda(lambda g: g * 1e-6)(x)
    x = inputs + x
    return x

In [None]:
def model(num_classes):
    inputs = tf.keras.layers.Input(shape=(None,None,3))
    
    #stem
    stem = tf.keras.layers.Conv2D(filters=96, kernel_size=4, strides=4, 
                                  padding='same', kernel_initializer='he_normal')(inputs)
    
    #stage1
    x1 = convnext(inputs=stem, filters=96)
    x2 = convnext(inputs=x1, filters=96)    
    x3 = convnext(inputs=x2, filters=96)
    down = tf.keras.layers.Conv2D(filters=192, kernel_size=2, strides=2, 
                                  padding='same', kernel_initializer='he_normal')(x3)
    down = tf.keras.layers.LayerNormalization(epsilon=1e-6)(down)
    
    #stage2
    x4 = convnext(inputs=down, filters=192)
    x5 = convnext(inputs=x4, filters=192)    
    x6 = convnext(inputs=x5, filters=192)
    down = tf.keras.layers.Conv2D(filters=384, kernel_size=2, strides=2, 
                                  padding='same', kernel_initializer='he_normal')(x6)
    down = tf.keras.layers.LayerNormalization(epsilon=1e-6)(down)
    
    #stage3
    x7 = convnext(inputs=down, filters=384)
    x8 = convnext(inputs=x7, filters=384)    
    x9 = convnext(inputs=x8, filters=384)
    x10 = convnext(inputs=x9, filters=384)
    x11 = convnext(inputs=x10, filters=384)    
    x12 = convnext(inputs=x11, filters=384)
    x13 = convnext(inputs=x12, filters=384)
    x14 = convnext(inputs=x13, filters=384)    
    x15 = convnext(inputs=x14, filters=384)
    down = tf.keras.layers.Conv2D(filters=768, kernel_size=2, strides=2, 
                                  padding='same', kernel_initializer='he_normal')(x15)
    down = tf.keras.layers.LayerNormalization(epsilon=1e-6)(down)
    
    #stage4
    x16 = convnext(inputs=down, filters=768)
    x17 = convnext(inputs=x16, filters=768)    
    x18 = convnext(inputs=x17, filters=768)
    
    x = tf.keras.layers.GlobalAvgPool2D()(x18)
    outputs = tf.keras.layers.Dense(num_classes)(x)
    
    return tf.keras.Model(inputs=inputs, outputs=outputs)