In [1]:
import tensorflow as tf

In [2]:
from functools import partial

In [3]:

MyConv2D=partial(tf.keras.layers.Conv2D,kernel_size=(3,3),padding='same',dilation_rate=(1,1),strides=(1,1),kernel_initializer='he_normal',use_bias=False,activation=None)

In [11]:
class ResU2(tf.keras.layers.Layer):
    def __init__(self,filters,strides,activation='relu',**kwargs):
        super().__init__(**kwargs)
        self.activation=tf.keras.activations.get(activation)
        self.main_layers=[MyConv2D(filters,strides=strides),tf.keras.layers.BatchNormalization(),self.activation,MyConv2D(filters),tf.keras.layers.BatchNormalization()]
        self.skip_layers=[MyConv2D(filters,kernel_size=(1,1),strides=strides),tf.keras.layers.BatchNormalization()]
    def call(self,inputs):
        Z=inputs
        for layer in self.main_layers:
            Z=layer(Z)
        skip_Z=inputs
        for layer in self.skip_layers:
            skip_Z=layer(skip_Z)
        Z=Z+skip_Z
        Z=self.activation(Z)
        return Z

In [12]:
model=tf.keras.Sequential()
model.add(MyConv2D(filters=64,kernel_size=(7,7),strides=2,input_shape=[224,224,3]))
model.add(tf.keras.layers.BatchNormalization())
model.add(tf.keras.layers.Activation('relu'))
model.add(tf.keras.layers.MaxPool2D(pool_size=(3,3),strides=(2,2),padding='same'))

#Residual Blocks

prev_filters=64
for filters in [64]*3+[128]*4+[256]*6+[512]*3:
    strides=1 if filters==prev_filters else 2
    model.add(ResU2(filters,strides))
    prev_filters=filters

#Last Layers
model.add(tf.keras.layers.GlobalAvgPool2D())
model.add(tf.keras.layers.Flatten())
model.add(tf.keras.layers.Dense(10,activation='softmax'))

In [13]:
model.summary()