In [2]:
import tensorflow as tf

In [16]:
class inception(tf.keras.layers.Layer):
    def __init__(self, filter1, filter3r, filter3, filter5r, filter5, maxPool):
        super().__init__()
        self.c1 = tf.keras.layers.Conv2D(filter1, (1, 1), padding='same', activation='relu')
        self.c2 = tf.keras.Sequential([
            tf.keras.layers.Conv2D(filter3r, (1, 1), padding='same', activation='relu'),
            tf.keras.layers.Conv2D(filter3, (1, 1), padding='same', activation='relu')
            ])
        self.c3 = tf.keras.Sequential([
            tf.keras.layers.Conv2D(filter5r, (1, 1), padding='same', activation='relu'),
            tf.keras.layers.Conv2D(filter5, (1, 1), padding='same', activation='relu')
            ])
        self.c4 = tf.keras.Sequential([
            tf.keras.layers.MaxPool2D((3, 3), strides=1, padding='same'),
            tf.keras.layers.Conv2D(maxPool, (1, 1), padding='same', activation='relu')
            ])
    
    def call(self, inp):
        out1 = self.c1(inp)
        out2 = self.c2(inp)
        out3 = self.c3(inp)
        out4 = self.c4(inp)
        return tf.concat([out1, out2, out3, out4], axis=3)

In [19]:
from tensorflow.keras import datasets
(x_train, y_train), (x_test, y_test) = datasets.mnist.load_data()
x_train = tf.pad(x_train, [[0, 0], [2, 2], [2, 2]])/255
x_test = tf.pad(x_test, [[0, 0], [2, 2], [2, 2]])/255
x_train = tf.expand_dims(x_train, axis=3, name=None)
x_test = tf.expand_dims(x_test, axis=3, name=None)
x_train = tf.repeat(x_train, 3, axis=3)
x_test = tf.repeat(x_test, 3, axis=3)
x_val = x_train[-2000:, :, :, :]
y_val = y_train[-2000:]
x_train = x_train[:-2000, :, :, :]
y_train = y_train[:-2000]

In [21]:
x_train.shape

TensorShape([58000, 32, 32, 3])

In [38]:
parameters = [
    tf.keras.layers.Input(shape=(32, 32, 3)),
    tf.keras.layers.experimental.preprocessing.Resizing(224, 224, interpolation="bilinear", input_shape=x_train.shape[1:]),
    tf.keras.layers.Conv2D(64, 7, strides=2, padding='same', activation='relu'),
    tf.keras.layers.MaxPooling2D(3, strides=2),
    tf.keras.layers.Conv2D(64, 1, strides=1, padding='same', activation='relu'),
    tf.keras.layers.Conv2D(192, 3, strides=1, padding='same', activation='relu'),
    tf.keras.layers.MaxPooling2D(3, strides=2),
    inception(64, 96, 128, 16, 32, 32),
    inception(128, 128, 192, 32, 96, 64),
    tf.keras.layers.MaxPooling2D(3, strides= (2, 2)),
    inception(192, 96, 208, 16, 48, 64),
    inception(160, 112, 224, 24, 64, 64),
    inception(128, 128, 256, 24, 64, 64),
    inception(112, 144, 288, 32, 64, 64),
    inception(256, 160, 320, 32, 128, 128),
    tf.keras.layers.MaxPooling2D(3, strides=2),
    inception(256, 160, 320, 32, 128, 128),
    inception(384, 192, 384, 48, 128, 128),
    tf.keras.layers.GlobalAveragePooling2D(),
    tf.keras.layers.Dense(10, activation='softmax')
]

In [39]:
model = tf.keras.Sequential(parameters)

In [40]:
model.compile(optimizer='adam', loss='sparse_categorical_crossentropy', metrics=['accuracy'])

In [41]:
model.summary()

Model: "sequential_250"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 resizing (Resizing)         (None, 224, 224, 3)       0         
                                                                 
 conv2d_460 (Conv2D)         (None, 112, 112, 64)      9472      
                                                                 
 max_pooling2d_106 (MaxPooli  (None, 55, 55, 64)       0         
 ng2D)                                                           
                                                                 
 conv2d_461 (Conv2D)         (None, 55, 55, 64)        4160      
                                                                 
 conv2d_462 (Conv2D)         (None, 55, 55, 192)       110784    
                                                                 
 max_pooling2d_107 (MaxPooli  (None, 27, 27, 192)      0         
 ng2D)                                              

In [43]:
model.fit(x_train, y_train, validation_data=(x_val, y_val), batch_size=64, epochs=10)

Epoch 1/10
Epoch 2/10
Epoch 3/10
Epoch 4/10
Epoch 5/10
Epoch 6/10
Epoch 7/10
Epoch 8/10
Epoch 9/10
Epoch 10/10


<keras.callbacks.History at 0x18871a174f0>