## Imports

In [1]:
import tensorflow as tf
from tensorflow.keras.datasets import mnist
(X_orig, Y_train), (X_test, Y_test) = mnist.load_data()
X_orig = X_orig / 255.0
X_test = X_test / 255.0

## Model

In [6]:
inputs = tf.keras.Input(shape=(28, 28, 1))
x = tf.keras.layers.RandomContrast(0.4)(inputs)
x = tf.keras.layers.BatchNormalization()(x)
x = tf.keras.layers.RandomRotation(0.02)(x)

x_skip = tf.keras.layers.BatchNormalization()(x)
x = tf.keras.layers.Conv2D(filters=16, kernel_size=3, padding='same')(x_skip)
x = tf.keras.layers.BatchNormalization()(x)
x = tf.keras.layers.Conv2D(filters=16, kernel_size=3, padding='same')(x)
x = tf.keras.layers.BatchNormalization()(x)

x = tf.keras.layers.Add()([x, x_skip])
x = tf.keras.layers.MaxPooling2D((2,2))(x)
x = tf.keras.layers.BatchNormalization()(x)

x = tf.keras.layers.BatchNormalization()(x)
x_skip = tf.keras.layers.Conv2D(filters=32, kernel_size=1, padding='same')(x)
x = tf.keras.layers.Conv2D(filters=32, kernel_size=3, padding='same')(x_skip)
x = tf.keras.layers.BatchNormalization()(x)
x = tf.keras.layers.Conv2D(filters=32, kernel_size=3, padding='same')(x)
x = tf.keras.layers.BatchNormalization()(x)

x = tf.keras.layers.Add()([x, x_skip])
x = tf.keras.layers.MaxPooling2D((2,2))(x)
x = tf.keras.layers.BatchNormalization()(x)

x = tf.keras.layers.Flatten()(x)
x = tf.keras.layers.Dense(64, activation='relu')(x)
outputs = tf.keras.layers.Dense(10)(x)

model = tf.keras.Model(inputs=inputs, outputs=outputs)

In [7]:
loss_fn = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)
model.compile(optimizer='adam', loss=loss_fn, metrics=['accuracy'])
history = model.fit(X_orig, Y_train, validation_split=0.1, epochs=4, batch_size=2048, verbose=2)
model.evaluate(X_test, Y_test, verbose=2)

Epoch 1/4
27/27 - 2s - loss: 0.7121 - accuracy: 0.7839 - val_loss: 0.7733 - val_accuracy: 0.8302 - 2s/epoch - 73ms/step
Epoch 2/4
27/27 - 1s - loss: 0.1745 - accuracy: 0.9484 - val_loss: 0.5998 - val_accuracy: 0.9205 - 971ms/epoch - 36ms/step
Epoch 3/4
27/27 - 1s - loss: 0.1120 - accuracy: 0.9672 - val_loss: 0.5166 - val_accuracy: 0.9232 - 965ms/epoch - 36ms/step
Epoch 4/4
27/27 - 1s - loss: 0.0831 - accuracy: 0.9759 - val_loss: 0.4862 - val_accuracy: 0.8992 - 959ms/epoch - 36ms/step
313/313 - 1s - loss: 0.5251 - accuracy: 0.8795 - 818ms/epoch - 3ms/step


[0.5250759124755859, 0.8794999718666077]