In [1]:
import tensorflow as tf
from tensorflow.keras.layers import (
    Input, Conv2D, DepthwiseConv2D,
    BatchNormalization, ReLU,
    GlobalAveragePooling2D, Dense, Add
)
from tensorflow.keras import Model

In [2]:
def inverted_residual_block(x, out_channels, stride, expansion):
    residual = x
    in_channels = x.shape[-1]

    # Expantion phase (1 x 1)
    if expansion != 1:
        x = Conv2D(
            filters=in_channels * expansion,
            kernel_size=1,
            padding='same',
            use_bias=False
        )(x)
        x = BatchNormalization()(x)
        x = ReLU(6.)(x)
        
    # Depthwise convolution (3 x 3)
    x = DepthwiseConv2D(
        kernel_size = 3,
        strides = stride,
        padding = 'same',
        use_bias = False
    )(x)

    x = BatchNormalization()(x)

    x = ReLU(6.)(x)

    # Projection phase (1 x 1)
    x = Conv2D(
        filters = out_channels,
        kernel_size = 1,
        padding = 'same',
        use_bias = False
    )(x)

    x = BatchNormalization()(x)

    if stride == 1 and in_channels == out_channels:
        x = Add()([residual, x])

    return x

In [3]:
def MobileNet(input_shape=(224, 224, 3), num_classes=100):
    inputs = Input(shape=input_shape)

    # Initial convolution (stem)
    x = Conv2D(
        filters = 32,
        kernel_size = 3,
        strides = 2,
        padding = 'same',
        use_bias = False
    )(inputs)

    x = BatchNormalization()(x)

    x = ReLU(6.)(x)


    # Block configuration
    config = [
        (1,  16, 1, 1),            # (expansion, output_channels, repeats, stride)
        (6,  24, 2, 2),
        (6,  32, 3, 2),
        (6,  64, 4, 2),
        (6,  96, 3, 1),
        (6, 160, 3, 2),
        (6, 320, 1, 1),
    ]

    # Build inverted residual blocks
    for expansion, channels, repeats, stride in config:
        for i in range(repeats):
            s = stride if i == 0 else 1
            x = inverted_residual_block(x, channels, s, expansion)

    # Final convolution
    x = Conv2D(
        filters = 1280,
        kernel_size = 1,
        padding = 'same',
        use_bias = False
    )(x)

    x = BatchNormalization()(x)

    x = ReLU(6.)(x)


    # Classification head
    x = GlobalAveragePooling2D()(x)
    outputs = Dense(num_classes, activation='softmax')(x)

    return Model(inputs, outputs)

In [4]:
model = MobileNet()
model.summary()

In [5]:
# tf.keras.utils.plot_model(
#     model,
#     to_file="mobilenetv2_architecture.png",
#     show_shapes=True,
#     show_layer_names=True
# )