# The inverted residual block - Layer - Model subclassing 

https://www.tensorflow.org/guide/keras/custom_layers_and_models

In [1]:
import tensorflow as tf
from tensorflow.keras.layers import Layer, Conv2D, DepthwiseConv2D, BatchNormalization, ReLU, GlobalAveragePooling2D
import numpy as np
from tensorflow.python.keras import backend

class InvertedResidual(Layer):
  def __init__(self, filters, strides, activation=ReLU(), expansion_factor=6, 
               regularizer=None, trainable=True, name=None, **kwargs):
    super(InvertedResidual, self).__init__(trainable=trainable, name=name, **kwargs)
    self.filters = filters
    self.strides = strides
    self.expansion_factor = expansion_factor  
    self.activation = activation
    self.regularizer = regularizer
    self.channel_axis = 1 if backend.image_data_format() == 'channels_first' else -1

  def build(self, input_shape):
    input_channels = int(input_shape[self.channel_axis]) # C
    self.ptwise_conv1 = Conv2D(filters=int(input_channels*self.expansion_factor),
    	                       kernel_size=1, kernel_regularizer=self.regularizer, use_bias=False)
    self.dwise = DepthwiseConv2D(kernel_size=3, strides=self.strides,
    	                         kernel_regularizer=self.regularizer, padding='same', use_bias=False)
    self.ptwise_conv2 = Conv2D(filters=self.filters, kernel_size=1, 
                               kernel_regularizer=self.regularizer, use_bias=False) 
    self.bn1 = BatchNormalization()
    self.bn2 = BatchNormalization()
    self.bn3 = BatchNormalization()
    
  def call(self, input_x, training=False):
    # Expansion 
    x = self.ptwise_conv1(input_x)
    x = self.bn1(x, training=training)
    x = self.activation(x)  
    # Spatial filtering
    x = self.dwise(x)
    x = self.bn2(x, training=training)
    x = self.activation(x)  
    # back to low-dim w/o activation
    x = self.ptwise_conv2(x)
    x = self.bn3(x, training=training) 
    # Residual connection only if i/o have same spatial and depth dims
    if input_x.shape[1:] == x.shape[1:]:
        x += input_x
    return x  

  def get_config(self):
    cfg = super(InvertedResidual, self).get_config()
    cfg.update({'filters': self.filters,
    	        'strides': self.strides,
              'regularizer': self.strides,
    	        'expansion_factor': self.expansion_factor,
              'activation': self.activation})
    return cfg

In [8]:
layer = InvertedResidual(filters=23,strides=1, expansion_factor=4)

In [10]:
efficientnet_mini = tf.keras.Sequential([
                    # stem
                    Conv2D(filters=32, kernel_size=3, strides=2, use_bias=False, 
                           padding='same', input_shape=(224, 224, 3)),
                    BatchNormalization(),
                    ReLU(),
                    # Blocks 
                    InvertedResidual(filters=32, strides=1), #
                    InvertedResidual(filters=64, strides=2),
                    InvertedResidual(filters=64, strides=1), #
                    InvertedResidual(filters=64, strides=1), #
                    InvertedResidual(filters=128, strides=2),
                    InvertedResidual(filters=128, strides=1), # 
                    InvertedResidual(filters=128, strides=1), # 
                    GlobalAveragePooling2D()
                    # Dense layer later
                    ])

In [11]:
efficientnet_mini(np.ones((1,224, 224, 3)))

<tf.Tensor: shape=(1, 128), dtype=float32, numpy=
array([[-1.12157075e-04,  2.13146141e-05, -5.97271610e-05,
         9.77436139e-05, -2.05491917e-04, -8.11218852e-05,
         1.01773607e-04,  4.35018534e-04, -8.48429409e-06,
        -2.62034191e-05, -1.81655021e-04,  9.19364975e-05,
         7.27250153e-05, -1.06163010e-04,  1.00529011e-04,
         1.55344984e-04,  3.57772689e-04,  6.38489073e-05,
        -3.06343019e-04,  1.14284136e-04,  1.77486654e-04,
         2.53496866e-04,  3.02614899e-05, -2.59563167e-05,
         3.85745720e-04,  1.96714071e-04, -4.16272724e-06,
         2.37652232e-04,  2.88321986e-04, -1.48028761e-04,
        -1.46496939e-04, -5.48170428e-05, -8.28062184e-05,
        -7.75079316e-05,  6.60948572e-05,  1.06293710e-05,
         1.90470979e-04,  2.06762197e-05, -9.20272068e-05,
         3.79743396e-05, -5.39096363e-05, -5.14145286e-05,
         5.19152763e-06, -1.66228492e-05,  3.37430858e-04,
        -1.14104623e-04,  3.28906244e-05, -4.88774567e-05,
      