In [2]:
import tensorflow as tf
import keras

In [19]:
class ResidualUnit(keras.layers.Layer):
    def __init__(self, filters, strides=1, activation='relu'):
        super().__init__()
        self.activation_layer = keras.layers.Activation(activation)
        self.main_layers = [
            keras.layers.Conv2D(filters, kernel_size=3, strides=strides,
                                padding='same', use_bias=False),
            keras.layers.BatchNormalization(),
            self.activation_layer,
            keras.layers.Conv2D(filters, kernel_size=3, strides=1,
                                padding='same', use_bias=False),
            keras.layers.BatchNormalization(),
        ]
        self.skip_layers = []
        if strides > 1:
            self.skip_layers = [
                keras.layers.Conv2D(filters, strides=strides,
                                    kernel_size=1, padding='same', use_bias=False),
                keras.layers.BatchNormalization()
            ]
    
    def call(self, inputs):
        z = inputs
        for layer in self.main_layers:
            z = layer(z)
        skip_z = inputs
        for layer in self.skip_layers:
            skip_z = layer(skip_z)
        return self.activation_layer(z + skip_z)


In [20]:
model = keras.models.Sequential()
model.add(keras.layers.Conv2D(filters=64, kernel_size=7, strides=2,
                              input_shape=[224, 224, 3]))
model.add(keras.layers.BatchNormalization())
model.add(keras.layers.Activation("relu"))
model.add(keras.layers.MaxPool2D(pool_size=3, strides=2, padding='same'))

prev_filters = 64
for filters in [64]*3 + [128]*4 + [256]*6 + [512]*3:
    strides = 1 if filters == prev_filters else 2
    model.add(ResidualUnit(filters=filters, strides=strides))
    prev_filters = filters

model.add(keras.layers.GlobalAveragePooling2D())
model.add(keras.layers.Flatten())
model.add(keras.layers.Dense(10, activation='softmax'))


In [21]:
model.summary()