In [1]:
import os
import tensorflow as tf
import numpy as np
from tensorflow import keras

In [2]:
tf.random.set_seed(22)
np.random.seed(22)
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "2"
assert tf.__version__.startswith("2.")

In [3]:
(x_train, y_train), (x_test, y_test) = keras.datasets.mnist.load_data()
x_train, x_test = x_train.astype(np.float32) / 255., x_test.astype(np.float32) / 255.
x_train, x_test = np.expand_dims(x_train, axis = 3), np.expand_dims(x_test, axis = 3)

In [4]:
db_train = tf.data.Dataset.from_tensor_slices((x_train, y_train)).batch(256)
db_test = tf.data.Dataset.from_tensor_slices((x_test, y_test)).batch(256)

In [5]:
print(x_train.shape, y_train.shape)
print(x_test.shape, y_test.shape)

(60000, 28, 28, 1) (60000,)
(10000, 28, 28, 1) (10000,)


In [6]:
class ConvBNReLU(keras.Model):
    def __init__(self, ch, kernelsz = 3, strides = 1, padding = "same"):
        super(ConvBNReLU, self).__init__()
        
        self.model = keras.models.Sequential([
            keras.layers.Conv2D(ch, kernelsz, strides = strides, padding = padding),
            keras.layers.BatchNormalization(),
            keras.layers.ReLU()
        ])
    
    def call(self, x, training = None):
        x = self.model(x, training = training)
        return x

In [7]:
class InceptionBlk(keras.Model):
    def __init__(self, ch, strides = 1):
        super(InceptionBlk, self).__init__()
        self.ch = ch
        self.strides = strides
        
        self.conv1 = ConvBNReLU(ch, strides = strides)
        self.conv2 = ConvBNReLU(ch, kernelsz = 3, strides = strides)
        self.conv3_1 = ConvBNReLU(ch, kernelsz = 3, strides = strides)
        self.conv3_2 = ConvBNReLU(ch, kernelsz = 3, strides = 1)
        
        self.pool = keras.layers.MaxPooling2D(3, strides = 1, padding = "same")
        self.pool_conv = ConvBNReLU(ch, strides = strides)
        
    def call(self, x, training = None):
        x1 = self.conv1(x, training = training)
        x2 = self.conv2(x, training = training)
        x3_1 = self.conv3_1(x, training = training)
        x3_2 = self.conv3_2(x3_1, training = training)
        x4 = self.pool(x)
        x4 = self.pool_conv(x4, training = training)
        x = tf.concat([x1, x2, x3_2, x4], axis = 3)
        return x

In [8]:
class Inception(keras.Model):
    def __init__(self, num_layers, num_classes, init_ch = 16, **kwargs):
        super(Inception, self).__init__(**kwargs)
        self.in_channels = init_ch
        self.out_channels = init_ch
        self.num_layers = num_layers
        self.init_ch = init_ch
        
        self.conv1 = ConvBNReLU(init_ch)
        self.blocks = keras.models.Sequential(name = "DYNAMIC-BLOCKS")
        for block_id in range(num_layers):
            for layer_id in range(2):
                if layer_id == 0:
                    block = InceptionBlk(self.out_channels, strides = 2)
                else:
                    block = InceptionBlk(self.out_channels, strides = 1)
                    
                self.blocks.add(block)
            self.out_channels *= 2
        self.avg_pool = keras.layers.GlobalAveragePooling2D()
        self.fc = keras.layers.Dense(num_classes)
    
    def call(self, x, training = None):
        out = self.conv1(x, training = training)
        out = self.blocks(out, training = training)
        out = self.avg_pool(out)
        out = self.fc(out)
        return out

In [9]:
BATCH_SIZE = 128
EPOCHS = 10
model = Inception(2, 10)
model.build(input_shape = (None, 28, 28, 1))
model.summary()

Model: "inception"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
conv_bn_re_lu (ConvBNReLU)   multiple                  224       
_________________________________________________________________
DYNAMIC-BLOCKS (Sequential)  multiple                  292704    
_________________________________________________________________
global_average_pooling2d (Gl multiple                  0         
_________________________________________________________________
dense (Dense)                multiple                  1290      
Total params: 294,218
Trainable params: 293,226
Non-trainable params: 992
_________________________________________________________________


In [10]:
optimizer = keras.optimizers.Adam(learning_rate = 1e-3)
criterion = keras.losses.CategoricalCrossentropy(from_logits = True)
acc_meter = keras.metrics.Accuracy()

In [11]:
for epoch in range(EPOCHS):
    for step, (x, y) in enumerate(db_train):
        with tf.GradientTape() as tape:
            logits = model(x)
            loss = criterion(tf.one_hot(y, depth = 10), logits)
        
        grads = tape.gradient(loss, model.trainable_variables)
        optimizer.apply_gradients(zip(grads, model.trainable_variables))
        
        if step % 10 == 0:
            print("Step: ", step, "Loss: ", loss.numpy())
    
    acc_meter.reset_states()
    for x, y in db_test:
        logits = model(x, training = False)
        pred = tf.argmax(logits, axis = 1)
        acc_meter.update_state(y, pred)
    
    print("Epochs: ", epoch, "Evaluation Accuracy: ", acc_meter.result().numpy())

Step:  0 Loss:  2.3067923
Step:  10 Loss:  2.104604
Step:  20 Loss:  1.5719802
Step:  30 Loss:  1.3793477
Step:  40 Loss:  0.9467964
Step:  50 Loss:  0.8574427
Step:  60 Loss:  0.58914363
Step:  70 Loss:  0.35742673
Step:  80 Loss:  0.5373102
Step:  90 Loss:  0.4555395
Step:  100 Loss:  0.38353968
Step:  110 Loss:  0.34024346
Step:  120 Loss:  0.34489682
Step:  130 Loss:  0.20710012
Step:  140 Loss:  0.25235802
Step:  150 Loss:  0.21698198
Step:  160 Loss:  0.2321302
Step:  170 Loss:  0.18356465
Step:  180 Loss:  0.3122157
Step:  190 Loss:  0.13903624
Step:  200 Loss:  0.2680899
Step:  210 Loss:  0.18700425
Step:  220 Loss:  0.1676037
Step:  230 Loss:  0.03278955
Epochs:  0 Evaluation Accuracy:  0.9546
Step:  0 Loss:  0.17778103
Step:  10 Loss:  0.15723905
Step:  20 Loss:  0.21589167
Step:  30 Loss:  0.2675207
Step:  40 Loss:  0.3387717
Step:  50 Loss:  0.3005675
Step:  60 Loss:  0.2221539
Step:  70 Loss:  0.12473531
Step:  80 Loss:  0.14233962
Step:  90 Loss:  0.12571849
Step:  100 Lo