<a href="https://colab.research.google.com/github/CelikAbdullah/deep-learning-notebooks/blob/main/Computer%20Vision/models/SE-ResNet.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [34]:
from tensorflow import keras

# SE-ResNet

## The stem component

In [35]:
def stem(inputs):
  """ Build the stem component of SE-ResNet
      inputs : the input vector
  """
  # The 224x224 images are zero padded (black - no signal) to be 230x230 images prior to the first convolution
  x = keras.layers.ZeroPadding2D(padding=(3, 3))(inputs)

  # a 7x7 conv layer which uses a large (coarse) filter
  x = keras.layers.Conv2D(filters=64, kernel_size=7, strides=2, padding='valid', use_bias=False, kernel_initializer='he_normal')(x)
  x = keras.layers.BatchNormalization()(x)
  x = keras.layers.ReLU()(x)

  # a 3x3 max pooling to create pooled feature maps will be reduced by 75%
  x = keras.layers.ZeroPadding2D(padding=(1, 1))(x)
  x = keras.layers.MaxPooling2D((3, 3), strides=(2, 2))(x)

  return x

## The learner component

### The queeze & excite block

In [36]:
def squeeze_excite_block(x, ratio=16):
  """ Create a Squeeze and Excite block
      x    : input to the block
      ratio : amount of filter reduction during squeeze
  """
  # save the input
  shortcut = x

  # extract the number of filters on the input
  filters = x.shape[-1]

  # squeeze (dimensionality reduction)
  # apply global average pooling across the filters, which will output a 1D vector
  x = keras.layers.GlobalAveragePooling2D()(x)

  # reshape into 1x1 feature maps (1x1xC)
  x = keras.layers.Reshape((1, 1, filters))(x)

  # reduce the number of filters (1x1xC/r)
  n_filters = filters // ratio
  x = keras.layers.Dense(units=n_filters, activation='relu', kernel_initializer='he_normal', use_bias=False)(x)

  # excitation (dimensionality restoration)
  # restore the number of filters (1x1xC)
  x = keras.layers.Dense(units=filters, activation='sigmoid', kernel_initializer='he_normal', use_bias=False)(x)

  # scale - multiply the squeeze/excitation output with the input (WxHxC)
  x = keras.layers.Multiply()([shortcut, x])

  return x

### The Identity Block

In [37]:
def identity_block(x, n_filters, ratio=16):
  """ Create a Bottleneck Residual Block with Identity Link
      x        : input into the block
      n_filters: number of filters
      ratio    : amount of filter reduction during squeeze
  """
  # Save input vector (feature maps) for the identity link
  shortcut = x

  # a 1x1 conv layer for dimensionality reduction
  x = keras.layers.Conv2D(filters=n_filters, kernel_size=1, strides=1, use_bias=False, kernel_initializer='he_normal')(x)
  x = keras.layers.BatchNormalization()(x)
  x = keras.layers.ReLU()(x)

  # a 3x3 conv layer serving as a bottleneck layer
  x = keras.layers.Conv2D(filters=n_filters, kernel_size=3, strides=1, padding="same", use_bias=False, kernel_initializer='he_normal')(x)
  x = keras.layers.BatchNormalization()(x)
  x = keras.layers.ReLU()(x)

  # a 1x1 conv layer for dimensionality restoration - increase the number of output filters by 4X
  x = keras.layers.Conv2D(filters=n_filters * 4, kernel_size=1, strides=1, use_bias=False, kernel_initializer='he_normal')(x)
  x = keras.layers.BatchNormalization()(x)

  # pass the output through the squeeze and excitation block
  x = squeeze_excite_block(x, ratio)

  # add the identity link (input) to the output of the residual block
  x = keras.layers.Add()([shortcut, x])
  x = keras.layers.ReLU()(x)

  return x

### The Projection Block

In [38]:
def projection_block(x, n_filters, strides=(2,2), ratio=16):
  """ Create Bottleneck Residual Block with Projection Shortcut
      Increase the number of filters by 4X
      x        : input into the block
      n_filters: number of filters
      strides  : whether entry convolution is strided (i.e., (2, 2) vs (1, 1))
      ratio    : amount of filter reduction during squeeze
  """

  # Construct the projection shortcut
  # Increase filters by 4X to match shape when added to output of block
  shortcut = keras.layers.Conv2D(filters=4 * n_filters, kernel_size=1, strides=strides, use_bias=False, kernel_initializer='he_normal')(x)
  shortcut = keras.layers.BatchNormalization()(shortcut)

  # a 1x1 conv layer for dimensionality reduction
  # note that a feature pooling is done when strides=(2, 2)
  x = keras.layers.Conv2D(filters=n_filters, kernel_size=1, strides=strides, use_bias=False, kernel_initializer='he_normal')(x)
  x = keras.layers.BatchNormalization()(x)
  x = keras.layers.ReLU()(x)

  # a 3x3 conv layer acting as a bottleneck layer
  x = keras.layers.Conv2D(filters=n_filters, kernel_size=3, strides=1, padding='same', use_bias=False, kernel_initializer='he_normal')(x)
  x = keras.layers.BatchNormalization()(x)
  x = keras.layers.ReLU()(x)

  # a conv layer for dimensionality restoration - increase the number of filters by 4X
  x = keras.layers.Conv2D(filters=4 * n_filters, kernel_size=1, strides=1, use_bias=False, kernel_initializer='he_normal')(x)
  x = keras.layers.BatchNormalization()(x)

  # pass the output through the squeeze and excitation block
  x = squeeze_excite_block(x, ratio)

  # add the projection shortcut link to the output of the residual block
  x = keras.layers.Add()([x, shortcut])
  x = keras.layers.ReLU()(x)

  return x

### The SE-ResNet Group

In [39]:
def group(x, n_filters, n_blocks, ratio, strides=(2, 2)):
  """ Build the Squeeze-Excite Group
	    x        : input to the group
      n_blocks : number of blocks
	    n_filters: number of filters
      ratio    : amount of filter reduction during squeeze
      strides  : whether projection block is strided
  """
  # the first block uses linear projection to match the doubling of filters between groups
  x = projection_block(x, n_filters, strides=strides, ratio=ratio)

  # the remaining blocks use identity link
  for _ in range(n_blocks-1):
    x = identity_block(x, n_filters, ratio=ratio)

  return x

### Putting it all together

In [40]:
def learner(x, groups, ratio):
  """ Build the learner component of the SE-ResNet.
      x     : input to the learner
	    groups: list of groups: number of filters and blocks
      ratio : amount of filter reduction in squeeze
  """
  # the first residual block group is not strided
  n_filters, n_blocks = groups.pop(0)
  x = group(x, n_filters, n_blocks, ratio, strides=(1, 1))

  # the remaining residual block groups are strided
  for n_filters, n_blocks in groups:
    x = group(x, n_filters, n_blocks, ratio)

  return x

## The task component

In [41]:
def classifier(x, n_classes):
  """ Create the Classifier Group
      x         : input to the classifier
      n_classes : number of output classes
  """
  # global average pooling at the end of all the convolutional residual blocks
  x = keras.layers.GlobalAveragePooling2D()(x)

  # a final softmax layer
  outputs = keras.layers.Dense(units=n_classes, activation='softmax', kernel_initializer='he_normal')(x)

  return outputs

## The SE-ResNet model

In [42]:
def build_se_resnet(groups, shape=(224, 224, 3), ratio=16, classes=1000):

  # The input tensor
  inputs = keras.Input(shape=(224, 224, 3))
  # The stem component
  x = stem(inputs)
  # The learner component
  x = learner(x, groups[50], ratio)
  # The task component
  outputs = classifier(x, classes)

  # return the SE-ResNet model
  return keras.Model(inputs=inputs, outputs=outputs)


In [43]:
# Meta-parameter: # Meta-parameter: list of groups: filter size and number of blocks
groups = { 50 : [ (64, 3), (128, 4), (256, 6),  (512, 3) ],		# SE-ResNet50
           101: [ (64, 3), (128, 4), (256, 23), (512, 3) ],		# SE-ResNet101
           152: [ (64, 3), (128, 8), (256, 36), (512, 3) ]		# SE-ResNet152
         }

# Meta-parameter: Amount of filter reduction in squeeze operation
ratio = 16

In [44]:
# create the SE-ResNet model
se_resnet_model = build_se_resnet(groups)

# print a summary
se_resnet_model.summary()

Model: "model"
__________________________________________________________________________________________________
 Layer (type)                Output Shape                 Param #   Connected to                  
 input_4 (InputLayer)        [(None, 224, 224, 3)]        0         []                            
                                                                                                  
 zero_padding2d_6 (ZeroPadd  (None, 230, 230, 3)          0         ['input_4[0][0]']             
 ing2D)                                                                                           
                                                                                                  
 conv2d_13 (Conv2D)          (None, 112, 112, 64)         9408      ['zero_padding2d_6[0][0]']    
                                                                                                  
 batch_normalization_13 (Ba  (None, 112, 112, 64)         256       ['conv2d_13[0][0]']       