In [6]:
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
from tensorflow.keras.datasets import mnist

In [7]:
(x_train, y_train), (x_test, y_test) = mnist.load_data()

In [8]:
x_train = x_train.reshape(-1,28,28,1).astype("float32") / 255.0
x_test = x_test.reshape(-1,28,28,1).astype("float32") / 255.0

In [13]:
class CNNBlock(layers.Layer):
    def __init__(self, out_channels, kernel_size=3):
        super(CNNBlock, self).__init__()
        self.conv = layers.Conv2D(out_channels, kernel_size, padding='same')
        self.bn = layers.BatchNormalization()
        
    def call(self, input_tensor, training=False):
        x = self.conv(input_tensor)
        x = self.bn(x, training=training)
        x = tf.nn.relu(x)
        return x

In [14]:
model = keras.Sequential(
[
    CNNBlock(32),
    CNNBlock(64),
    CNNBlock(128),
    layers.Flatten(),
    layers.Dense(10),
])

In [17]:
model.compile(
    optimizer = keras.optimizers.Adam(),
    loss = keras.losses.SparseCategoricalCrossentropy(from_logits=True),
    metrics=["accuracy"],
)

In [18]:
model.fit(x_train,y_train, batch_size=64, epochs=3, verbose=2)
model.evaluate(x_test,y_test, batch_size=64, verbose=2)

Epoch 1/3
938/938 - 335s - loss: 0.5471 - accuracy: 0.9483 - 335s/epoch - 358ms/step
Epoch 2/3
938/938 - 341s - loss: 0.0868 - accuracy: 0.9827 - 341s/epoch - 364ms/step
Epoch 3/3
938/938 - 370s - loss: 0.0328 - accuracy: 0.9904 - 370s/epoch - 395ms/step
157/157 - 8s - loss: 0.0534 - accuracy: 0.9857 - 8s/epoch - 50ms/step


[0.0534060075879097, 0.9857000112533569]

In [19]:
#ResBlock
class ResBlock(layers.Layer):
    def __init__(self, channels):
        
        super(ResBlock, self).__init__()
        self.ccn1 = CNNBlock(channels[0])
        self.ccn2 = CNNBlock(channels[1])
        self.ccn3 = CNNBlock(channels[2])
        self.pooling = layers.MaxPooling2D()
        self.identity_mapping = layers.Conv2D(channels[1], 1, padding='same')
        
    def call(self, input_tensor, training=False):
        x = self.ccn1(input_tensor, training=training)
        x = self.ccn2(x, training=training)
        x = self.ccn3(x + self.identity_mapping(input_tensor), training=training)
        return self.pooling(x)

In [21]:
class ResNet_like(keras.Model):
    def __init__(self, num_classes=10):
        
        super(ResNet_like, self).__init__()
        self.block1 = ResBlock([32,32,64])
        self.block2 = ResBlock([128,128,256])
        self.block3 = ResBlock([128,256,512])
        self.pool = layers.GlobalAveragePooling2D()
        self.classifier = layers.Dense(num_classes)
        
    def call(self, input_tensor, training=False):
        x = self.block1(input_tensor, training=training)
        x = self.block2(x, training=training)
        x = self.block3(x, training=training)
        x = self.pool(x)
        return self.classifier(x)
        

In [24]:
model = ResNet_like(num_classes=10)
model.compile(
    optimizer = keras.optimizers.Adam(),
    loss = keras.losses.SparseCategoricalCrossentropy(from_logits=True),
    metrics=["accuracy"],
)

In [26]:
model.fit(x_train,y_train, batch_size=64, epochs=20, verbose=2)
model.evaluate(x_test,y_test, batch_size=64, verbose=2)

Epoch 1/20
938/938 - 854s - loss: 0.0519 - accuracy: 0.9842 - 854s/epoch - 910ms/step
Epoch 2/20
938/938 - 878s - loss: 0.0319 - accuracy: 0.9897 - 878s/epoch - 936ms/step
Epoch 3/20
938/938 - 852s - loss: 0.0264 - accuracy: 0.9917 - 852s/epoch - 908ms/step
Epoch 4/20
938/938 - 859s - loss: 0.0234 - accuracy: 0.9925 - 859s/epoch - 915ms/step
Epoch 5/20
938/938 - 805s - loss: 0.0188 - accuracy: 0.9938 - 805s/epoch - 858ms/step
Epoch 6/20
938/938 - 809s - loss: 0.0183 - accuracy: 0.9941 - 809s/epoch - 862ms/step
Epoch 7/20
938/938 - 810s - loss: 0.0157 - accuracy: 0.9949 - 810s/epoch - 864ms/step
Epoch 8/20
938/938 - 798s - loss: 0.0173 - accuracy: 0.9945 - 798s/epoch - 851ms/step
Epoch 9/20
938/938 - 850s - loss: 0.0119 - accuracy: 0.9962 - 850s/epoch - 906ms/step
Epoch 10/20
938/938 - 803s - loss: 0.0119 - accuracy: 0.9961 - 803s/epoch - 856ms/step
Epoch 11/20
938/938 - 789s - loss: 0.0094 - accuracy: 0.9969 - 789s/epoch - 841ms/step
Epoch 12/20
938/938 - 813s - loss: 0.0094 - accuracy

[0.07835865020751953, 0.9824000000953674]