In [91]:
import tensorflow as tf
import numpy as np
from datetime import datetime
from resnet import SimpleResNetModule

tf.executing_eagerly = True

from matplotlib import pyplot as plt
%matplotlib inline

import scipy.stats as stats

# Overview

This notebook illustrates the implementation of CNNs with skip connections (i.e. a ResNet) using the `keras.Model` API. Skip connections are useful for avoiding vanishing gradients.

I also compare against a simple sequential CNN using image classification on `fashion_mnist`. Test accuracy of the Resnet is 87%, vs 10% for the sequential model.

In [123]:
class ResNetModel(tf.keras.Model):
    def __init__(self,
                 hparams = [
                    {"filters": 4, "kernel_size": 3, "activation": "relu"},
                    {"filters": 4, "kernel_size": 3, "activation": "relu"},
                 ],
                 name = None):
        super().__init__(name = name)
        self.hparams = hparams
        self.dense = tf.keras.layers.Dense(10, activation = "softmax")
        self.flatten = tf.keras.layers.Flatten()
        self.blocks = [SimpleResNetModule(hparams = block_hparams) for block_hparams in self.hparams]
    
    def call(self, x):
        for block in self.blocks:
            x = block(x)
        return self.dense(self.flatten(x))
    
class SimpleCNNModel(tf.keras.Model):
    def __init__(self,
                 name = None):
        super().__init__(name = name)
        self.convs = [tf.keras.layers.Conv2D(
            4,
            kernel_size = 3,
            activation = "relu",
            padding = "same",
            strides = 1 + l % 2) for l in range(4)]
        self.dense = tf.keras.layers.Dense(10, activation = "softmax")
        self.flatten = tf.keras.layers.Flatten()
    
    def call(self, x):
        for layer in self.convs:
            x = layer(x)
        return self.dense(self.flatten(x))

In [124]:
model = ResNetModel()
model.compile(
    # By default, fit() uses tf.function().  You can
    # turn that off for debugging, but it is on now.
    run_eagerly=False,

    # Using a built-in optimizer, configuring as an object
    optimizer=tf.keras.optimizers.Adagrad(learning_rate=1e-1),

    # Keras comes with built-in MSE error
    # However, you could use the loss function
    # defined above
    loss=tf.keras.losses.SparseCategoricalCrossentropy(),
    metrics=['accuracy']
)

simpleCNNModel = SimpleDenseModel()
simpleCNNModel.compile(
    # By default, fit() uses tf.function().  You can
    # turn that off for debugging, but it is on now.
    run_eagerly=False,

    # Using a built-in optimizer, configuring as an object
    optimizer=tf.keras.optimizers.Adagrad(learning_rate=1e-1),

    # Keras comes with built-in MSE error
    # However, you could use the loss function
    # defined above
    loss=tf.keras.losses.SparseCategoricalCrossentropy(),
    metrics=['accuracy']
)

In [125]:
(x_train, y_train), (x_test, y_test) = tf.keras.datasets.fashion_mnist.load_data()

x_train, x_test = x_train / 255.0, x_test / 255.0
x_train = x_train[..., tf.newaxis].astype("float32")
x_test = x_test[..., tf.newaxis].astype("float32")

print("Training Resnet...")
model.fit(
    x = x_train,
    y = y_train,
    epochs = 3,
    batch_size = 32,
) 

print("Training Simple Dense...")
simpleCNNModel.fit(
    x = x_train,
    y = y_train,
    epochs = 3,
    batch_size = 32,
) 

Training Resnet...
Epoch 1/3
Epoch 2/3
Epoch 3/3
Training Simple Dense...
Epoch 1/3
Epoch 2/3
Epoch 3/3


<keras.callbacks.History at 0xb56b9cac8>

In [127]:
print("[Resnet] Test accuracy: %.2f%%"%(tf.keras.metrics.Accuracy()(np.argmax(model(x_test), axis = 1), y_test).numpy()*100))
print("[SimpleCNN] Test accuracy: %.2f%%"%(tf.keras.metrics.Accuracy()(np.argmax(simpleCNNModel(x_test), axis = 1), y_test).numpy()*100))


[Resnet] Test accuracy: 86.81%
[SimpleCNN] Test accuracy: 10.00%


In [129]:
## Comparing model summary (similar # of params, very different accuracies.)

model.summary(), simpleCNNModel.summary()

Model: "res_net_model_31"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
dense_3241 (Dense)           multiple                  1970      
_________________________________________________________________
flatten_1628 (Flatten)       multiple                  0         
_________________________________________________________________
simple_res_net_module_20 (Si multiple                  224       
_________________________________________________________________
simple_res_net_module_21 (Si multiple                  440       
Total params: 2,634
Trainable params: 2,634
Non-trainable params: 0
_________________________________________________________________
Model: "simple_dense_model_6"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
conv2d_26 (Conv2D)           multiple                  40        
__________________

(None, None)