<h1>MobileNet 2017 paper implementation using Tensorflow.</h1>

arxiv link to the paper: https://arxiv.org/pdf/1704.04861.pdf

In [1]:
from tensorflow.keras.layers import Conv2D, DepthwiseConv2D, ReLU, BatchNormalization, add,Softmax, AveragePooling2D, Dense, Input, GlobalAveragePooling2D
from tensorflow.keras.models import Model

In [2]:
def depth_block(x, strides):
    x = DepthwiseConv2D(3,strides = strides,padding = 'same', use_bias=False)(x)
    x = BatchNormalization()(x)
    x = ReLU()(x)
    #print(x.shape)
    return x

def single_conv_block(x,filters):
    x = Conv2D(filters, 1,use_bias=False)(x)
    x= BatchNormalization()(x)
    x = ReLU()(x)
    #print(x.shape)
    return x

In [15]:
def combo_layer(x,filters, strides):
    x = depth_block(x,strides)  
    x = single_conv_block(x, filters)
    return x

In [16]:
def MobileNet(input_shape=(224,224,3),n_classes = 1000):
    input = Input ( input_shape)
    x = Conv2D(32,3,strides=(2,2),padding = 'same', use_bias=False)(input)

    x =  BatchNormalization()(x)
    x = ReLU()(x)

    x = combo_layer(x,64, strides=(1,1))

    x = combo_layer(x,128,strides=(2,2))
    x = combo_layer(x,128,strides=(1,1))

    x = combo_layer(x,256,strides=(2,2))
    x = combo_layer(x,256,strides=(1,1))

    x = combo_layer(x,512,strides=(2,2))
    for _ in range(5):
        x = combo_layer(x,512,strides=(1,1))
    
    x = combo_layer(x,1024,strides=(2,2))
    x = combo_layer(x,1024,strides=(1,1))

    x = GlobalAveragePooling2D()(x)
    
    output = Dense(n_classes,activation='softmax')(x)

    model = Model(input, output)
    return model

In [17]:
n_classes = 1000
input_shape = (224,224,3)

model = MobileNet(input_shape,n_classes)
model.summary()

Model: "model_5"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
input_6 (InputLayer)         [(None, 224, 224, 3)]     0         
_________________________________________________________________
conv2d_70 (Conv2D)           (None, 112, 112, 32)      864       
_________________________________________________________________
batch_normalization_135 (Bat (None, 112, 112, 32)      128       
_________________________________________________________________
re_lu_135 (ReLU)             (None, 112, 112, 32)      0         
_________________________________________________________________
depthwise_conv2d_65 (Depthwi (None, 112, 112, 32)      288       
_________________________________________________________________
batch_normalization_136 (Bat (None, 112, 112, 32)      128       
_________________________________________________________________
re_lu_136 (ReLU)             (None, 112, 112, 32)      0   