In [1]:
import tensorflow as tf

In [2]:
from tensorflow.keras import layers

In [3]:
from tensorflow.keras import datasets
(x_train, y_train), (x_test, y_test) = datasets.mnist.load_data()
x_train = tf.pad(x_train, [[0, 0], [2, 2], [2, 2]])/255
x_test = tf.pad(x_test, [[0, 0], [2, 2], [2, 2]])/255
x_train = tf.expand_dims(x_train, axis=3, name=None)
x_test = tf.expand_dims(x_test, axis=3, name=None)
x_train = tf.repeat(x_train, 3, axis=3)
x_test = tf.repeat(x_test, 3, axis=3)
x_val = x_train[-2000:, :, :, :]
y_val = y_train[-2000:]
x_train = x_train[:-2000, :, :, :]
y_train = y_train[:-2000]

In [4]:
class Residual(layers.Layer):
    def __init__(self, num_channels, strides=1):
        super(Residual, self).__init__()
        self.conv1 = layers.Conv2D(num_channels, kernel_size=3, padding='same', strides=strides)
        self.conv2 = layers.Conv2D(num_channels, kernel_size=3, padding='same')
        self.bn1 = layers.BatchNormalization()
        self.bn2 = layers.BatchNormalization()

    def call(self, inp):
        x = self.conv1(inp)
        x = self.bn1(x)
        x = tf.nn.relu(x)
        x = self.conv2(x)
        x = self.bn2(x)
        x = tf.nn.relu(x)
        if x.shape == inp.shape:
            x += inp
        return x

In [12]:
parameters = [
    tf.keras.layers.Input(shape=(32, 32, 3)),
    tf.keras.layers.experimental.preprocessing.Resizing(96, 96, interpolation="bilinear", input_shape=x_train.shape[1:]),
    tf.keras.layers.Conv2D(64, 7, strides=2, padding='same', activation='relu'),
    tf.keras.layers.MaxPooling2D(3, strides=2),
    Residual(64),
    Residual(64),
    Residual(128),
    Residual(128),
    Residual(256),
    Residual(256),
    Residual(512),
    Residual(512),
    tf.keras.layers.GlobalAveragePooling2D(),
    tf.keras.layers.Dense(10, activation='softmax')
]

In [13]:
model = tf.keras.Sequential(parameters)

In [14]:
model.compile(optimizer='adam', loss='sparse_categorical_crossentropy', metrics=['accuracy'])

In [15]:
model.summary()

Model: "sequential_1"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 resizing_1 (Resizing)       (None, 96, 96, 3)         0         
                                                                 
 conv2d_17 (Conv2D)          (None, 48, 48, 64)        9472      
                                                                 
 max_pooling2d_1 (MaxPooling  (None, 23, 23, 64)       0         
 2D)                                                             
                                                                 
 residual_8 (Residual)       (None, 23, 23, 64)        74368     
                                                                 
 residual_9 (Residual)       (None, 23, 23, 64)        74368     
                                                                 
 residual_10 (Residual)      (None, 23, 23, 128)       222464    
                                                      

In [16]:
model.fit(x_train, y_train, validation_data=(x_val, y_val), batch_size=8, epochs=10)

Epoch 1/10
Epoch 2/10
Epoch 3/10
  50/7250 [..............................] - ETA: 9:25 - loss: 0.0667 - accuracy: 0.9775


KeyboardInterrupt

