In [1]:
import tensorflow as tf

from functools import partial


# ResNet-34 CNN

Implementing the ResNet-34 CNN using Keras.

**This is done only for teaching purposes to understand CNN better. If you need to build a CNN on top of ResNet, load it
from keras and use transfer learning instead**

In [2]:
DefaultConv2D = partial(tf.keras.layers.Conv2D, kernel_size=3, strides=1, padding="same", 
                        kernel_initializer="he_normal", use_bias=False)

class ResidualUnit(tf.keras.layers.Layer):
  def __init__(self, filters, strides=1, activation="relu", **kwargs):
    super().__init__(**kwargs)
    self.activation = tf.keras.activations.get(activation)
    self.main_layers = [
      DefaultConv2D(filters, strides=strides),
      tf.keras.layers.BatchNormalization(),
      self.activation,
      DefaultConv2D(filters),
      tf.keras.layers.BatchNormalization()
    ]

    self.skip_layers = []

    # When the stride is greater than 1, we need to account for a reduction in spatial dimension of the output
    # so that it matches the input. For example, if we set strides=2, the output dimensions will be cut in half.
    # When we'll need to add them to the input, there will be a mismatch. To fix it, we need to reduce the dimensions
    # of the inputs as well. We use a single conv layer with the same stride and a kernel size of 1.
    if strides > 1:
      self.skip_layers = [
        # Adjusts the number of channels while maintaining the same spatial dimensions
        DefaultConv2D(filters, kernel_size=1, strides=strides),
        tf.keras.layers.BatchNormalization()
      ]

  def call(self, inputs):
    # Forward prop thru the block's main layers
    Z = inputs
    for layer in self.main_layers:
      Z = layer(Z)
    
    # Forward prop thru the block's skip layers (it'll have layers only if the stride is > 1).
    # This will ensure compatibility of the inputs and outputs spatial dimensions.
    skip_Z = inputs
    for layer in self.skip_layers:
      skip_Z = layer(skip_Z)
    
    # Activate the output with inputsxw
    return self.activation(Z + skip_Z)

In [5]:
model = tf.keras.Sequential([
  DefaultConv2D(64, kernel_size=7, strides=2, input_shape=[224, 224, 3]),
  tf.keras.layers.BatchNormalization(),
  tf.keras.layers.Activation("relu"),
  tf.keras.layers.MaxPool2D(pool_size=3, strides=2, padding="same")
])

prev_filters = 64
# [64, 64, 64, 128, 128, 128, 128, 256, 256, 256, 256, 256, 256, 512, 512, 512]
filters_list = [64] * 3 + [128] * 4 + [256] * 6 + [512] * 3
for filters in filters_list:
  # When moving to a new filters size, we use strides=2, else strides=1
  strides = 1 if filters == prev_filters else 2
  model.add(ResidualUnit(filters, strides=strides))
  prev_filters = filters

model.add(tf.keras.layers.GlobalAvgPool2D())
model.add(tf.keras.layers.Flatten())
model.add(tf.keras.layers.Dense(10, activation="relu"))
