<h1>MobileNet 2017 paper implementation using Tensorflow.</h1>

arxiv link to the paper: https://arxiv.org/pdf/1704.04861.pdf

In [1]:
from tensorflow.keras.layers import Conv2D, DepthwiseConv2D, ReLU, BatchNormalization, add,Softmax, AveragePooling2D, Dense, Input, GlobalAveragePooling2D
from tensorflow.keras.models import Model

In [2]:
def depth_block(x, strides):
    x = DepthwiseConv2D(3,strides = strides,padding = 'same', use_bias=False)(x)
    x = BatchNormalization()(x)
    x = ReLU()(x)
    #print(x.shape)
    return x

def single_conv_block(x,filters):
    x = Conv2D(filters, 1,use_bias=False)(x)
    x= BatchNormalization()(x)
    x = ReLU()(x)
    #print(x.shape)
    return x

In [4]:
def combo_layer(x,repetition):
    if repetition == 512:
        x = single_conv_block(x, repetition)
        i=0
        while i<5:
            x = depth_block(x,strides =(1,1))
            x = single_conv_block(x,repetition)
            i+=1
        x = depth_block(x,strides = (2,2))
        

    else:
        x = single_conv_block(x, repetition)
        x = depth_block(x,strides = (1,1))
        
        x = single_conv_block(x, repetition)
        x = depth_block(x,strides = (2,2))

    return x

In [5]:
def MobileNet(input_shape=(224,224,3),n_classes = 1000):
    input = Input ( input_shape)
    x = Conv2D(32,3,strides=(2,2),padding = 'same', use_bias=False)(input)

    x =  BatchNormalization()(x)
    x = ReLU()(x)
    #print(x.shape)
    filters = [128,256,512]

    x = depth_block(x,strides= (1,1))
    #print(x.shape)

    x = single_conv_block(x,64)

    x = depth_block(x,strides= (2,2))

    for repetition in filters:
        #print(repetition)
        #print(x.shape)
        x = combo_layer(x,repetition)

    x = single_conv_block(x,1024)
    x = depth_block(x,strides = (2,2))
    x = single_conv_block(x,1024)

    x = GlobalAveragePooling2D()(x)
    #print(x.shape)
    output = Dense(n_classes,activation='softmax')(x)

    model = Model(input, output)
    return model

In [6]:
n_classes = 1000
input_shape = (224,224,3)

model = MobileNet(input_shape,n_classes)
model.summary()

Model: "model"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
input_1 (InputLayer)         [(None, 224, 224, 3)]     0         
_________________________________________________________________
conv2d (Conv2D)              (None, 112, 112, 32)      864       
_________________________________________________________________
batch_normalization (BatchNo (None, 112, 112, 32)      128       
_________________________________________________________________
re_lu (ReLU)                 (None, 112, 112, 32)      0         
_________________________________________________________________
depthwise_conv2d (DepthwiseC (None, 112, 112, 32)      288       
_________________________________________________________________
batch_normalization_1 (Batch (None, 112, 112, 32)      128       
_________________________________________________________________
re_lu_1 (ReLU)               (None, 112, 112, 32)      0     