In [1]:
import tensorflow as tf
from tensorflow import keras

In [9]:
class InceptionModule(keras.layers.Layer):
    def __init__(self,filter_1x1, reduce_3x3, filter_3x3, reduce_5x5, filter_5x5, pool_proj,**kwargs):
        super().__init__(**kwargs)
        self.conv1x1 = keras.layers.Conv2D(filter_1x1,1,strides=1,padding='same')
        self.reduce3x3 = keras.layers.Conv2D(reduce_3x3,1,strides=1,padding='same')
        self.conv3x3 = keras.layers.Conv2D(filter_3x3,3,strides=1,padding='same')
        self.reduce5x5 = keras.layers.Conv2D(reduce_5x5,1,strides=1,padding='same')
        self.conv5x5 = keras.layers.Conv2D(filter_5x5,5,strides=1,padding='same')
        self.maxpool = keras.layers.MaxPooling2D(3,strides=1,padding='same')
        self.proj = keras.layers.Conv2D(pool_proj,1,strides=1,padding='same')
        self.relu = keras.layers.Activation('relu')
        self.bn_1 = keras.layers.BatchNormalization()
        self.bn_r3 = keras.layers.BatchNormalization()
        self.bn_3 = keras.layers.BatchNormalization()
        self.bn_r5 = keras.layers.BatchNormalization()
        self.bn_5 = keras.layers.BatchNormalization()
        self.bn_pr = keras.layers.BatchNormalization()
    def call(self,inputs):
        conv1 = self.conv1x1(inputs)
        conv1 = self.bn_1(conv1)
        conv1 = self.relu(conv1)
        
        conv3 = self.reduce3x3(inputs)
        conv3 = self.bn_r3(conv3)
        conv3 = self.relu(conv3)
        conv3 = self.conv3x3(conv3)
        conv3 = self.bn_3(conv3)
        conv3 = self.relu(conv3)
        
        conv5 = self.reduce5x5(inputs)
        conv5 = self.bn_r5(conv5)
        conv5 = self.relu(conv5)
        conv5 = self.conv5x5(conv5)
        conv5 = self.bn_5(conv5)
        conv5 = self.relu(conv5)
        
        pool = self.maxpool(inputs)
        pool = self.proj(pool)
        pool = self.bn_pr(pool)
        pool = self.relu(pool)
        
        output = tf.concat([conv1,conv3,conv5,pool],axis=3)
        
        return output

In [16]:
model = keras.models.Sequential()
model.add(keras.layers.Conv2D(64,7,activation='relu',strides=2,padding='same',input_shape=[32,32,3],use_bias=False))
model.add(keras.layers.MaxPooling2D(pool_size=3,strides=2,padding='same'))
model.add(keras.layers.Lambda(lambda X : tf.nn.local_response_normalization(X,alpha=0.00002,beta=0.75,bias=1)))
model.add(keras.layers.Conv2D(64,1,strides=1,activation='relu',padding='same',use_bias=False))
model.add(keras.layers.Conv2D(192,3,strides=1,activation='relu',padding='same',use_bias=False))
model.add(keras.layers.Lambda(lambda X : tf.nn.local_response_normalization(X,alpha=0.00002,beta=0.75,bias=1)))
model.add(keras.layers.MaxPooling2D(pool_size=3,strides=2,padding='same'))
model.add(InceptionModule(64,96,128,16,32,32))
model.add(InceptionModule(128,128,192,32,96,64))
model.add(keras.layers.MaxPooling2D(pool_size=3,strides=2,padding='same'))
model.add(InceptionModule(192,96,208,16,48,64))
model.add(InceptionModule(160,112,224,24,64,64))
model.add(InceptionModule(128,128,256,24,64,64))
model.add(InceptionModule(112,144,288,32,64,64))
model.add(InceptionModule(256,160,320,32,128,128))
model.add(keras.layers.MaxPooling2D(pool_size=3,strides=2,padding='same'))
model.add(InceptionModule(256,160,320,32,128,128))
model.add(InceptionModule(384,192,384,48,128,128))
model.add(keras.layers.GlobalAveragePooling2D())
model.add(keras.layers.Dropout(0.4))
model.add(keras.layers.Dense(10,activation='softmax'))

In [17]:
model.compile(loss='sparse_categorical_crossentropy',optimizer='nadam',metrics=['accuracy'])

In [18]:
model.summary()

Model: "sequential_7"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 conv2d_130 (Conv2D)         (None, 16, 16, 64)        9408      
                                                                 
 max_pooling2d_33 (MaxPoolin  (None, 8, 8, 64)         0         
 g2D)                                                            
                                                                 
 lambda_12 (Lambda)          (None, 8, 8, 64)          0         
                                                                 
 conv2d_131 (Conv2D)         (None, 8, 8, 64)          4096      
                                                                 
 conv2d_132 (Conv2D)         (None, 8, 8, 192)         110592    
                                                                 
 lambda_13 (Lambda)          (None, 8, 8, 192)         0         
                                                      