In [5]:
import tensorflow as tf
import numpy as np
from tensorflow import keras
from tensorflow.keras.layers import Conv2D, DepthwiseConv2D, BatchNormalization, GlobalAveragePooling2D, Activation, Dense
from tensorflow.keras.datasets import cifar10
import os


def mobile(x):
    def dw(x, pad, f, st):
        x = DepthwiseConv2D(kernel_size = (3, 3), padding = 'same')(x)
        x = BatchNormalization()(x)
        x = Activation('relu')(x) 
        x = Conv2D(filters = f, kernel_size = (1, 1), strides = st, padding = 'same')(x)
        x = BatchNormalization()(x)
        x = Activation('relu')(x)
        return x

    x = Conv2D(filters = 32, kernel_size = (3, 3), strides = 2, padding = 'same')(x)
    x = BatchNormalization()(x)
    x = Activation('relu')(x)

    x = dw(x, 'same', 64, 1)
    x = dw(x, 'valid', 128, 2)
    x = dw(x, 'same', 128, 1)
    x = dw(x, 'same', 256, 2)
    x = dw(x, 'same', 256, 1)
    x = dw(x, 'valid', 512, 2)
    x = dw(x, 'same', 512, 1)
    x = dw(x, 'same', 512, 1)
    x = dw(x, 'same', 512, 1)
    x = dw(x, 'same', 512, 1)
    x = dw(x, 'same', 512, 1)
    x = dw(x, 'valid', 1024, 2)
    x = dw(x, 'same', 1024, 1)

    x = GlobalAveragePooling2D()(x)

    return x

inputs = keras.Input(shape = (224, 224, 3))
x = mobile(inputs)
outputs = Dense(10, activation="softmax")(x)

model = keras.Model(inputs=inputs, outputs=outputs)

In [6]:
model.summary()

Model: "model"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 input_2 (InputLayer)        [(None, 224, 224, 3)]     0         
                                                                 
 conv2d_1 (Conv2D)           (None, 112, 112, 32)      896       
                                                                 
 batch_normalization_1 (Batc  (None, 112, 112, 32)     128       
 hNormalization)                                                 
                                                                 
 activation (Activation)     (None, 112, 112, 32)      0         
                                                                 
 depthwise_conv2d (Depthwise  (None, 112, 112, 32)     320       
 Conv2D)                                                         
                                                                 
 batch_normalization_2 (Batc  (None, 112, 112, 32)     128   