In [2]:
import keras  
from keras.layers import Input, Conv2D,GlobalAveragePooling2D, Dense, BatchNormalization, Activation
from keras.models import Model
from keras.layers import DepthwiseConv2D 

In [7]:
gb_depthwise_no = 0

def depthwise_separable(x, feature_maps, stride=1, depthwise_no=None):  
    global gb_depthwise_no
    if depthwise_no is None: 
        gb_depthwise_no += 1
        depthwise_no = gb_depthwise_no  
    depthwise_name="depthwise"+str(depthwise_no)
    
    x = DepthwiseConv2D((3,3),strides=(stride, stride), padding='same', name=depthwise_name+"conv1")(x) 
    x = BatchNormalization()(x) 
    x = Activation('relu')(x) 
    x = Conv2D(int(feature_maps), (1,1), strides=(1,1), padding='same', name=depthwise_name+"conv2")(x) 
    x = BatchNormalization()(x) 
    x = Activation('relu')(x)
    return x 

In [8]:
def MobileNet(img_input, shallow=False, classes=10): 
     
    x = Conv2D(int(32), (3,3), strides=(2,2), padding='same', name="conv1")(img_input) 
    x = BatchNormalization()(x) 
    x = Activation('relu')(x)

    x = depthwise_separable(x,feature_maps=64,  stride=1, depthwise_no=1)
    x = depthwise_separable(x,feature_maps=128, stride=2, depthwise_no=2)
    x = depthwise_separable(x,feature_maps=128, stride=1, depthwise_no=3)
    x = depthwise_separable(x,feature_maps=256, stride=2, depthwise_no=4)
    x = depthwise_separable(x,feature_maps=256, stride=1, depthwise_no=5)
    x = depthwise_separable(x,feature_maps=5112, stride=2, depthwise_no=6) 
    
    if not shallow:
        for i in range(5):
            x = depthwise_separable(x,feature_maps=5112, stride=1, depthwise_no=7+i)  
    depthwise_no = 7+ (5 if not shallow else 0)
    
    x = depthwise_separable(x,feature_maps=1024, stride=2, depthwise_no=depthwise_no+1)  
    x = depthwise_separable(x,feature_maps=1024, stride=1, depthwise_no=depthwise_no+2)  
     
    x = GlobalAveragePooling2D()(x)
    out = Dense(classes, activation='softmax', name="Dense1")(x)
    return out
img_input=Input(shape=(32,32,3), name="Input")
output = MobileNet(img_input)
model=Model(img_input,output)
model.summary()

# 把模型保存为图片
from keras.utils import plot_model
plot_model(model,to_file='model_png/209MobileNetV2.png', show_layer_names=True, show_shapes=True) 

_________________________________________________________________
Layer (type)                 Output Shape              Param #   
Input (InputLayer)           (None, 32, 32, 3)         0         
_________________________________________________________________
conv1 (Conv2D)               (None, 16, 16, 32)        896       
_________________________________________________________________
batch_normalization_4 (Batch (None, 16, 16, 32)        128       
_________________________________________________________________
activation_4 (Activation)    (None, 16, 16, 32)        0         
_________________________________________________________________
depthwise1conv1 (DepthwiseCo (None, 16, 16, 32)        320       
_________________________________________________________________
batch_normalization_5 (Batch (None, 16, 16, 32)        128       
_________________________________________________________________
activation_5 (Activation)    (None, 16, 16, 32)        0         
__________