<h2>Imports</h2>

In [2]:
##Imports##
import tensorflow as tf
from tensorflow import keras
from keras.layers import Conv2D, MaxPooling2D, AveragePooling2D, concatenate, Dense, Flatten

<h2>Conv Module</h2>

In [3]:
class ConvModule(keras.layers.Layer):
    def __init__(
    self, kernel_num, kernel_size, strides, padding='same'
):
        super().__init__()
        self.conv = keras.layers.Conv2D(
                    kernel_num, 
                    kernel_size=kernel_size, 
                    strides=strides, 
                    padding=padding
                )
        self.bn   = keras.layers.BatchNormalization()


    def call(self, input_tensor, training=False):
        x = self.conv(input_tensor)
        x = self.bn(x, training=training)
        x = tf.nn.relu(x)
        return x

<h2>Inception Module</h2>

In [4]:
class InceptionModule(keras.layers.Layer):
    def __init__(self, kernel_size1x1, kernel_size3x3):
        super().__init__()
        self.conv1 = ConvModule(
            kernel_size1x1, kernel_size=(1,1), strides=(1,1)
        )
        self.conv2 = ConvModule(
            kernel_size3x3, kernel_size=(3,3), strides=(1,1)
        )
        self.cat   = keras.layers.Concatenate()

    def call(self, input_tensor, training=False):
        x_1x1 = self.conv1(input_tensor)
        x_3x3 = self.conv2(input_tensor)
        x = self.cat([x_1x1, x_3x3])
        return x

<h2>Downsample Module</h2>

In [5]:
class DownsampleModule(keras.layers.Layer):
    def __init__(self, kernel_size):
        super().__init__()
        self.conv3 = ConvModule(
            kernel_size, 
            kernel_size=(3,3), 
            strides=(2,2), 
            padding="valid"
        ) 

        self.pool  = keras.layers.MaxPooling2D(
            pool_size=(3, 3), 
            strides=(2,2)
        )
        self.cat   = keras.layers.Concatenate()

    def call(self, input_tensor, training=False):
        conv_x = self.conv3(input_tensor, training=training)
        pool_x = self.pool(input_tensor)
        return self.cat([conv_x, pool_x])

<h2>Inception (Small)</h2>

In [6]:
class MiniInception(keras.Model):
    def __init__(self, num_classes=10):
        super().__init__()

        # the first conv module
        self.conv_block = ConvModule(96, (3,3), (1,1))

        # 2 inception module and 1 downsample module
        self.inception_block1  = InceptionModule(32, 32)
        self.inception_block2  = InceptionModule(32, 48)
        self.downsample_block1 = DownsampleModule(80)

        # 4 inception module and 1 downsample module
        self.inception_block3  = InceptionModule(112, 48)
        self.inception_block4  = InceptionModule(96, 64)
        self.inception_block5  = InceptionModule(80, 80)
        self.inception_block6  = InceptionModule(48, 96)
        self.downsample_block2 = DownsampleModule(96)

        # 2 inception module 
        self.inception_block7 = InceptionModule(176, 160)
        self.inception_block8 = InceptionModule(176, 160)

        # average pooling
        self.avg_pool = keras.layers.AveragePooling2D((7,7))

        # model tail
        self.flat      = keras.layers.Flatten()
        self.classfier = keras.layers.Dense(
            num_classes, activation='softmax'
        )
        
    def call(self, input_tensor, training=False, **kwargs):
        # forward pass 
        x = self.conv_block(input_tensor)
        x = self.inception_block1(x)
        x = self.inception_block2(x)
        x = self.downsample_block1(x)

        x = self.inception_block3(x)
        x = self.inception_block4(x)
        x = self.inception_block5(x)
        x = self.inception_block6(x)
        x = self.downsample_block2(x)

        x = self.inception_block7(x)
        x = self.inception_block8(x)
        x = self.avg_pool(x)

        x = self.flat(x)
        return self.classfier(x)

    def build_graph(self, raw_shape):
        x = keras.Input(shape=raw_shape)
        return keras.Model(inputs=[x], outputs=self.call(x))

In [7]:
raw_input = (32, 32, 3)

# init model object
cm = MiniInception()

# The first call to the `cm` will create the weights
y = cm(tf.ones(shape=(0,*raw_input))) 

# print summary
cm.build_graph(raw_input).summary()

Model: "model"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 input_1 (InputLayer)        [(None, 32, 32, 3)]       0         
                                                                 
 conv_module (ConvModule)    (None, 32, 32, 96)        3072      
                                                                 
 inception_module (Inceptio  (None, 32, 32, 64)        31040     
 nModule)                                                        
                                                                 
 inception_module_1 (Incept  (None, 32, 32, 80)        30096     
 ionModule)                                                      
                                                                 
 downsample_module (Downsam  (None, 15, 15, 160)       58000     
 pleModule)                                                      
                                                             