In [1]:
import tensorflow as tf
import tensorflow.keras as keras

In [2]:
class ResidualUnit(keras.layers.Layer):
    def __init__(self, filters, activation='relu', strides=1, **kwargs):
        super().__init__(**kwargs)
        self.activation = keras.activations.get(activation)
        self.main_layers = [
            keras.layers.Conv2D(filters, 3, strides=strides, padding='SAME', use_bias=False),
            keras.layers.BatchNormalization(),
            self.activation,
            keras.layers.Conv2D(filters, 3, strides=1, padding='SAME', use_bias=False),
            keras.layers.BatchNormalization(),
        ]
        self.skip_layers = [keras.layers.Conv2D(filters, 1, strides=1, padding='SAME', use_bias=False)]
        if strides > 1:
            self.skip_layers = [
                keras.layers.Conv2D(filters, 1, strides=strides, padding='SAME', use_bias=False),
                keras.layers.BatchNormalization()
            ]
    
    def call(self, inputs):
        main_output = inputs
        for layer in self.main_layers:
            main_output = layer(main_output)
        skip_output = inputs
        for layer in self.skip_layers:
            skip_output = layer(skip_output)
        return self.activation(main_output + skip_output)

In [4]:
model = keras.models.Sequential([
    keras.layers.Conv2D(64, 7, strides=2, input_shape=[224, 224, 3], padding='SAME', use_bias=False),
    keras.layers.BatchNormalization(),
    keras.layers.Activation('relu'),
])
prev_filters = 64    
for filters in 4*[64] + 4*[128] + 4*[256] + 4*[512]:
    strides = 1 if filters == prev_filters else 2
    model.add(ResidualUnit(filters, strides=strides))
    prev_filters = filters
model.add(keras.layers.GlobalAvgPool2D())
model.add(keras.layers.Dense(10, activation='softmax'))

In [5]:
model.summary()

Model: "sequential"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
conv2d (Conv2D)              (None, 112, 112, 64)      9408      
_________________________________________________________________
batch_normalization (BatchNo (None, 112, 112, 64)      256       
_________________________________________________________________
activation (Activation)      (None, 112, 112, 64)      0         
_________________________________________________________________
residual_unit (ResidualUnit) (None, 112, 112, 64)      78336     
_________________________________________________________________
residual_unit_1 (ResidualUni (None, 112, 112, 64)      78336     
_________________________________________________________________
residual_unit_2 (ResidualUni (None, 112, 112, 64)      78336     
_________________________________________________________________
residual_unit_3 (ResidualUni (None, 112, 112, 64)      7