In [1]:
try:
    %tensorflow_version 2.x
except Exception as e:
    print(e)

In [2]:
import tensorflow as tf
import datetime
from tensorflow.keras.layers import ZeroPadding2D, AveragePooling2D, MaxPool2D, Conv2D, DepthwiseConv2D, Activation, BatchNormalization, MaxPool2D, Flatten
from tensorflow.keras.models import Model
import tensorflow_datasets as tfds

In [3]:
class DepthwiseConvolution(Model):
    def __init__(self, filters=None, kernel_size=None, strides=None, alpha=1, activation=None):
        super(DepthwiseConvolution, self).__init__()
        self.pad = ZeroPadding2D()
        if strides == (1,1):
            self.dwconv = DepthwiseConv2D(kernel_size, padding='same')
        else:
            self.dwconv = DepthwiseConv2D(kernel_size, strides=strides)
        self.pwconv = Conv2D(alpha*filters, 1, strides=(1,1))
        self.bn = BatchNormalization()
        self.act = Activation(activation)
    def call(self, input):
        x = self.pad(input)
        x = self.dwconv(x)
        x = self.bn(x)
        x = self.act(x)
        x = self.pwconv(x)
        x = self.bn(x)
        x = self.act(x)
        return x

In [4]:
class MobileNet(Model):
    def __init__(self, kernel_size=3, num_classes=None, alpha=1, pho=1):
        super(MobileNet, self).__init__()
        self.pad = ZeroPadding2D()
        self.conv1 = Conv2D(alpha*32, kernel_size, strides=(2,2))
        self.convdw1 = DepthwiseConvolution(64, 3, (1,1), 1, 'relu')
        self.convdw2 = DepthwiseConvolution(128, 3, (2,2), 1, 'relu')
        self.convdw3 = DepthwiseConvolution(128, 3, (1,1), 1, 'relu')
        self.convdw4 = DepthwiseConvolution(256, 3, (2,2), 1, 'relu')
        self.convdw5 = DepthwiseConvolution(256, 3, (1,1), 1, 'relu')
        self.convdw6 = DepthwiseConvolution(512, 3, (2,2), 1, 'relu')
        self.convdw7 = DepthwiseConvolution(512, 3, (1,1), 1, 'relu')
        self.convdw8 = DepthwiseConvolution(1024, 3, (2,2), 1, 'relu')
        self.convdw9 = DepthwiseConvolution(1024, 3, (2,2), 1, 'relu')
        self.avgpool = AveragePooling2D(pool_size=(7,7), strides=(1,1))
        self.dense1 = Flatten()
        self.dense2 = Dense(1000, activation=activation)
        self.output = Dense(num_classes, activation=activation)
    def call(self, input):
        x = self.pad(input)
        x = self.conv1(x)
        x = self.convdw1(x)
        x = self.convdw2(x)
        x = self.convdw3(x)
        x = self.convdw4(x)
        x = self.convdw5(x)
        x = self.convdw6(x)
        for _ in range(5):
            x = self.convdw7(x)
        x = self.convdw8(x)
        x = self.convdw9(x)
        x = self.pooling(x)
        x = self.dense1(x)
        x = self.dense2(x)
        x = self.output(x)
        return x