<a href="https://colab.research.google.com/github/CelikAbdullah/deep-learning-notebooks/blob/main/Computer%20Vision/models/ResNeXt.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [37]:
from tensorflow import keras

# ResNeXt

## The stem component

In [38]:
def stem(inputs):
  """ Build the stem component of ResNeXt
      inputs : represents the input vector
  """
  # a strided 7x7 conv layer
  x = keras.layers.Conv2D(filters=64, kernel_size=7, strides=2, padding='same', kernel_initializer='he_normal', use_bias=False)(inputs)
  # apply batch normalization
  x = keras.layers.BatchNormalization()(x)
  # the ReLU activation layer
  x = keras.layers.ReLU()(x)
  # apply a strided 3x3 max pooling
  x = keras.layers.MaxPooling2D(pool_size=(3, 3), strides=(2, 2), padding='same')(x)

  return x

## The learner component

### The Identity Block

In [39]:
def identity_block(x, filters_in, filters_out, cardinality=32):
  """ Build a ResNeXT block with an identity link
      x          : input to block
      filters_in : number of filters  (channels) at the input convolution
      filters_out: number of filters (channels) at the output convolution
      cardinality: width of group convolution
  """

  # Save the input
  shortcut = x

  # a 1x1 conv layer for dimensionality reduction followed by a batch normalization and ReLU activation layer
  x = keras.layers.Conv2D(filters=filters_in, kernel_size=1, strides=1, padding='same', kernel_initializer='he_normal', use_bias=False)(shortcut)
  x = keras.layers.BatchNormalization()(x)
  x = keras.layers.ReLU()(x)

  # compute the cardinality (wide) layer (split-transform)
  filters_card = filters_in // cardinality
  groups = []
  for i in range(cardinality):
    group = keras.layers.Lambda(lambda z: z[:, :, :, i * filters_card:i * filters_card + filters_card])(x)
    groups.append(keras.layers.Conv2D(filters=filters_card, kernel_size=3, strides=1, padding='same', kernel_initializer='he_normal', use_bias=False)(group))

  # concatenate the outputs of the cardinality layer together (merge)
  x = keras.layers.Concatenate()(groups)
  # apply BN + ReLU activation
  x = keras.layers.BatchNormalization()(x)
  x = keras.layers.ReLU()(x)

  # a 1x1 conv layer for dimensionality restoration followed by batch normalization
  x = keras.layers.Conv2D(filters=filters_out, kernel_size=1, strides=1, padding='same', kernel_initializer='he_normal', use_bias=False)(x)
  x = keras.layers.BatchNormalization()(x)

  # Identity Link: Add the shortcut (input) to the output of the block
  x = keras.layers.Add()([shortcut, x])
  # after adding, apply ReLU activation
  x = keras.layers.ReLU()(x)

  return x

### The Projection Block

In [40]:
def projection_block(x, filters_in, filters_out, cardinality=32, strides=(2, 2)):
  """ Build a ResNeXT block with projection shortcut
      x          : input to the block
      filters_in : number of filters  (channels) at the input convolution
      filters_out: number of filters (channels) at the output convolution
      cardinality: width of group convolution
      strides    : whether entry convolution is strided (i.e., (2, 2) vs (1, 1))
  """

  # a 1x1 conv layer to build the projection shortcut
  # we increase filters by 2X to match shape when added to output of block
  shortcut = keras.layers.Conv2D(filters=filters_out, kernel_size=1, strides=strides, padding='same', kernel_initializer='he_normal')(x)
  shortcut = keras.layers.BatchNormalization()(shortcut)

  # a 1x1 conv layer for dimensionality reduction
  x = keras.layers.Conv2D(filters=filters_in, kernel_size=1, strides=1, padding='same', kernel_initializer='he_normal', use_bias=False)(x)
  x = keras.layers.BatchNormalization()(x)
  x = keras.layers.ReLU()(x)

  # compute the cardinality (wide) layer (split-transform)
  filters_card = filters_in // cardinality
  groups = []
  for i in range(cardinality):
    group = keras.layers.Lambda(lambda z: z[:, :, :, i * filters_card:i * filters_card + filters_card])(x)
    groups.append(keras.layers.Conv2D(filters=filters_card, kernel_size=3, strides=strides, padding='same', kernel_initializer='he_normal', use_bias=False)(group))

  # concatenate the outputs of the cardinality layer together (merge)
  x = keras.layers.Concatenate()(groups)
  x = keras.layers.BatchNormalization()(x)
  x = keras.layers.ReLU()(x)

  # a 1x1 conv layer for dimensionality restoration
  x = keras.layers.Conv2D(filters_out, (1, 1), strides=(1, 1), padding='same', kernel_initializer='he_normal', use_bias=False)(x)
  x = keras.layers.BatchNormalization()(x)

  # use the iddentity link to add the shortcut (input) to the output of the block
  x = keras.layers.Add()([shortcut, x])
  x = keras.layers.ReLU()(x)

  return x

### The ResNeXt Group

In [41]:
def group(x, filters_in, filters_out, n_blocks, cardinality=32, strides=(2, 2)):
  """ Build a Residual group
      x          : input to the group
      filters_in : number of filters  (channels) at the input convolution
      filters_out: number of filters (channels) at the output convolution
      cardinality: width of group convolution
      strides    : whether its a strided convolution
  """
  # create the projection block
  # we double the size of filters to fit the first residual group
  # we also reduce feature maps by 75% (strides=2, 2) to fit the next Residual Group
  x = projection_block(x, filters_in, filters_out, strides=strides, cardinality=cardinality)

  # create the remaining blocks
  for _ in range(n_blocks):
    x = identity_block(x, filters_in, filters_out, cardinality=cardinality)

  return x

### Putting it all together

In [42]:
def learner(x, groups, cardinality=32):
  """ Build the Learner
      x          : input to the learner
      groups     : list of groups: filters in, filters out, number of blocks
      cardinality: width of group convolution
  """
  # create the first ResNeXt group (not-strided)
  filters_in, filters_out, n_blocks = groups.pop(0)
  x = group(x, filters_in, filters_out, n_blocks, strides=(1, 1), cardinality=cardinality)

  # create the remaining ResNeXt groups
  for filters_in, filters_out, n_blocks in groups:
    x = group(x, filters_in, filters_out, n_blocks, cardinality=cardinality)
  return x

## The task component

In [43]:
def task(x, classes):
  """ Construct the Classifier
      x         : input to the classifier
      classes   : number of output classes
  """
  # Final Dense Outputting Layer
  x = keras.layers.GlobalAveragePooling2D()(x)
  outputs = keras.layers.Dense(units=classes, activation='softmax', kernel_initializer='he_normal')(x)

  return outputs

## ResNeXt model

In [44]:
def build_model(groups, shape=(224,224,3), cardinality=32, classes=1000):
  # the input tensor
  inputs = keras.Input(shape=(224, 224, 3))

  # the stem component
  x = stem(inputs)

  # the learner component
  x = learner(x, groups[50], cardinality)

  # the task component
  outputs = task(x, classes)

  # Instantiate the Model
  return keras.Model(inputs, outputs)

In [45]:
# Meta-parameter: number of filters in, out and number of blocks
groups = { 50 : [ (128, 256, 3), (256, 512, 4), (512, 1024, 6),  (1024, 2048, 3)], # ResNeXt 50
           101: [ (128, 256, 3), (256, 512, 4), (512, 1024, 23), (1024, 2048, 3)], # ResNeXt 101
           152: [ (128, 256, 3), (256, 512, 8), (512, 1024, 36), (1024, 2048, 3)]  # ResNeXt 152
         }

In [46]:
# create the ResNeXt model
resnext_model = build_model(groups=groups)

# print a summary
resnext_model.summary()


Model: "model"
__________________________________________________________________________________________________
 Layer (type)                Output Shape                 Param #   Connected to                  
 input_3 (InputLayer)        [(None, 224, 224, 3)]        0         []                            
                                                                                                  
 conv2d_38 (Conv2D)          (None, 112, 112, 64)         9408      ['input_3[0][0]']             
                                                                                                  
 batch_normalization_6 (Bat  (None, 112, 112, 64)         256       ['conv2d_38[0][0]']           
 chNormalization)                                                                                 
                                                                                                  
 re_lu_4 (ReLU)              (None, 112, 112, 64)         0         ['batch_normalization_6[0]