In [1]:
import tensorflow as tf

from tensorflow import keras

In [50]:
class ResidualUnit(keras.layers.Layer):
    def __init__(self,strides,size_1,size_2,size_3,size_4,filters_1,filters_2,filters_3,filters_4,activation='relu',**kwargs):
        super().__init__(**kwargs)
        self.activation = keras.activations.get(activation)
        self.main_layers = [
            
            keras.layers.Conv2D(filters_1,size_1,strides=strides,padding='same',use_bias=False),
            keras.layers.BatchNormalization(),
            self.activation,
            keras.layers.Conv2D(filters_1,size_2,strides=1,padding='same',use_bias=False),
            keras.layers.BatchNormalization(),
            self.activation,

            keras.layers.Conv2D(filters_3,size_3,strides=1,padding='same',use_bias= False),
            keras.layers.BatchNormalization(),
            self.activation


        ]
        self.skip_layers = [
            
            keras.layers.Conv2D(filters_4,size_4,strides=strides,padding='same',use_bias = False),
            keras.layers.BatchNormalization(),
            self.activation
            
            
        ]
        
        
    def call(self,inputs):
        Z = inputs
        for layer in self.main_layers:
            Z = layer(Z)
        skip_Z = inputs
        for layer in self.skip_layers:
            skip_Z = layer(skip_Z)
        return self.activation(skip_Z+Z)
            
        

In [60]:
model = keras.models.Sequential()
model.add(keras.layers.Conv2D(64,7,strides=2,input_shape=[224,224,3],padding='same',use_bias=False))
model.add(keras.layers.BatchNormalization())
model.add(keras.layers.Activation('relu'))
model.add(keras.layers.MaxPool2D(pool_size=3,strides=2,padding='same'))

for i in range(3):

    model.add(ResidualUnit(strides = 1, size_1=1,size_2=3,size_3=1,size_4=1 ,filters_1=64,filters_2=64,filters_3=256,filters_4=256))

for i in range(4):
    
    if i == 0 : 

        model.add(ResidualUnit(strides = 2, size_1=1,size_2=3,size_3=1,size_4=1 ,filters_1=128,filters_2=128,filters_3=512,filters_4=512))

    else:
        
        model.add(ResidualUnit(strides = 1, size_1=1,size_2=3,size_3=1,size_4=1 ,filters_1=128,filters_2=128,filters_3=512,filters_4=512))
for i in range(6):
    
    if i == 0 : 

        model.add(ResidualUnit(strides = 2, size_1=1,size_2=3,size_3=1,size_4=1 ,filters_1=256,filters_2=256,filters_3=1024,filters_4=1024))

    else:
        
        model.add(ResidualUnit(strides = 1, size_1=1,size_2=3,size_3=1,size_4=1 ,filters_1=256,filters_2=256,filters_3=1024,filters_4=1024))
for i in range(3):
    
    if i == 0 : 

        model.add(ResidualUnit(strides = 2, size_1=1,size_2=3,size_3=1,size_4=1 ,filters_1=512,filters_2=512,filters_3=2048,filters_4=2048))

    else:
        model.add(ResidualUnit(strides = 1, size_1=1,size_2=3,size_3=1,size_4=1 ,filters_1=512,filters_2=512,filters_3=2048,filters_4=2048))

        
        
model.add(keras.layers.GlobalAvgPool2D())
model.add(keras.layers.Flatten())
model.add(keras.layers.Dense(1000,activation = 'softmax'))


In [61]:
model.build()
model.summary()

Model: "sequential_23"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
conv2d_362 (Conv2D)          (None, 112, 112, 64)      9408      
_________________________________________________________________
batch_normalization_362 (Bat (None, 112, 112, 64)      256       
_________________________________________________________________
activation_23 (Activation)   (None, 112, 112, 64)      0         
_________________________________________________________________
max_pooling2d_23 (MaxPooling (None, 56, 56, 64)        0         
_________________________________________________________________
residual_unit_86 (ResidualUn (None, 56, 56, 256)       76288     
_________________________________________________________________
residual_unit_87 (ResidualUn (None, 56, 56, 256)       137728    
_________________________________________________________________
residual_unit_88 (ResidualUn (None, 56, 56, 256)     