Pre-Activation Wide-ResNet Network for digit classification.

References:
    Basic ResNet: https://arxiv.org/pdf/1512.03385.pdf
    Pre-Act. ResNet: https://arxiv.org/pdf/1603.05027.pdf
    Wide-ResNet: https://arxiv.org/pdf/1605.07146v1.pdf
    
In this post we build a Residual Network (ResNet) in TensorFlow with python to classify the handwritten digits of the popular MNIST database. We create a pre-activation ResNet that performs better than its original counterpart [2]. The network is created as a python class following some of the best-practice flavors that are starting to emerge.

First we define some utility functions that translate well across networks and tasks, helping future projects get off the ground that much quicker. Then we show the actual code with comments explaining the overall structure and most crucial parts.

Below we create a decorator that wraps around specific portions of our overall TensorFlow graph. The decorator makes sure that the code which creates the nodes and operations for the predictions, loss function, and optimizers is only ran once. This concept is known as lazy initialization, where some object or data structure is created only when it is needed. The code for the decorator is small yet handy. We can extend it to work with TensorFlow's variable scopes and arbitrary argument lists. Combined with helper functions for common layers and operations, the decorator helps us cut down on boilerplate code and make the network much more readable and understandable.

In [None]:
import functools

def lazy_init(function):
    '''
    Lazy initialization for inference, loss, and optimzier in a graph.
    '''
    attribute = '_cache_' + function.__name__

    @property
    @functools.wraps(function)
    def decorator(self):
        if not hasattr(self, attribute):
            setattr(self, attribute, function(self))
        return getattr(self, attribute)

    return decorator

Next we create a named tuple for our hyperparameters. This is cleaner and less error-prone than passing in a dictionary. Combined with `tf.FLAGS`, it is the ideal way of handling the multitude of parameters in a Deep Learning project.

In [None]:
from collections import namedtuple

# tuple for model hyperparameters
HyperParams = namedtuple('HyperParams',
                     'batch_size, num_cls, num_chans_in, height, width, filter_dims,'
                     'lr, kernel_dims, strides, n_hidden, relu_alpha')
# sample hyperparameters
h = HyperParams(128, 10, 1, 20, 512, [16, 16, 16, 32, 64], 1e-5, [3, 3],
                [1, 1, 1, 2, 2], 128, 0.01)

We also define some constants for Batch Normalization that have proven to be reliable.

In [None]:
# parameters for batch_norm
_BATCH_NORM_DECAY = 0.997
_BATCH_NORM_EPSILON = 1e-5

Now we are ready to build on pre-activation ResNet in TensorFlow:

In [5]:
import tensorflow as tf

class ResNet(object):
    def __init__(self, inputs, labels, is_training, hps):
        '''
        Builds a 5-layer pre-activation ResNet.
        Using Xavier initializer for conv and dense weights.
        Leaky-relu as activation function.
        '''
        # dimensions and sizes for input and label shapes
        self.batch_size = hps.batch_size
        self.num_cls = hps.num_cls

        # store placeholders for inputs, labels, and is_training flag
        self.inputs = inputs
        self.labels = labels
        self.is_training = is_training

        # check data format for convolutions and batch norm
        # makes a difference when code is run with a GPU.
        self.data_format = 'channels_first' if tf.test.is_built_with_cuda() \
                                            else 'channels_last'

        # network hyperparameters
        self.n_filts = hps.filter_dims
        self.kernel_dim = hps.kernel_dims
        self.strides = hps.strides
        self.n_hidden = hps.n_hidden
        self.relu_alpha = hps.relu_alpha
        # learning rate
        self.lr = hps.lr
        # initializer for weight layers
        self.weight_init = tf.variance_scaling_initializer()

        # subgraphs that will be created by lazy_init
        self.logits
        self.loss
        self.train_op

    @lazy_init
    def logits(self):
        x = self._build_first_layer()
        x = self._build_resnet(x)
        return x

    def _build_first_layer(self):
        '''
        Return the first convolutional layer common to all ResNet models.

        Note: Slightly different from layer + pooling in original resnet, but
            this performs better.
        '''
        with tf.variable_scope('first_layer'):
            if self.data_format == 'channels_first':
                self.inputs = tf.transpose(self.inputs, [0, 3, 2, 1])
            x = self._conv_layer(self.inputs, 16, 3, 1)
            return x

    def _build_resnet(self, x):
        '''Builds the residual blocks.'''

        with tf.variable_scope('first_enc_block'):
            x = self._resnet_block(x, 16, self.n_filts[0], self.strides[0])

        with tf.variable_scope('second_enc_block'):
            x = self._resnet_block(x, self.n_filts[0], self.n_filts[1],
                self.strides[1])

        with tf.variable_scope('third_enc_block'):
            x = self._resnet_block(x, self.n_filts[1], self.n_filts[2],
                self.strides[2])

        with tf.variable_scope('fourth_enc_block'):
            x = self._resnet_block(x, self.n_filts[2], self.n_filts[3],
                self.strides[3])

        with tf.variable_scope('GAP_block'):
            gap_dims = [2, 3] if self.data_format == 'channels_first' \
                                else [1, 2]
            x = tf.reduce_mean(x, gap_dims)

        with tf.variable_scope('dense'):
            x = tf.reshape(x, [self.batch_size, -1])
            x = tf.layers.dense(x, self.n_hidden, activation=self._leaky_relu,
                    kernel_initializer=self.weight_init)

        with tf.variable_scope('logits'):
            x = tf.layers.dense(x, self.num_cls,
                kernel_initializer=self.weight_init)

        return x

    @lazy_init
    def loss(self):
        with tf.variable_scope('loss'):
            x_ent = tf.nn.softmax_cross_entropy_with_logits(
                labels=self.labels,
                logits=self.logits)
            loss = tf.reduce_mean(x_ent)
            tf.summary.scalar('loss', loss)
        return loss

    @lazy_init
    def train_op(self):
        with tf.variable_scope('optimizer'):
            # learning rate schedule
            global_step = tf.Variable(0, trainable=False)
            learning_rate = tf.train.exponential_decay(
                learning_rate=self.lr, global_step=global_step,
                decay_steps=10000, decay_rate=0.97, staircase=True)
            tf.summary.scalar('learn_rate', learning_rate)
            # Nesterov momentum optimizer
            optim = tf.train.MomentumOptimizer(learning_rate=learning_rate,
                momentum=.9, use_nesterov=True)
            # add optimizers for batch_norm
            update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
            with tf.control_dependencies(update_ops):
                train_op = optim.minimize(self.loss, global_step=global_step)
        self.summaries = tf.summary.merge_all()
        return train_op

    def _resnet_block(self, inputs, in_filts, out_filts, strides):
        '''
        ResNet block with two convolutions.
        Projects original inputs through 1x1 conv if strides downsample.
        '''
        # store original x for skip-connection
        orig_in = inputs

        # project input if we are downsampling
        if in_filts != out_filts:
            # 1x1 conv, downsample by 2, output number of filters
            orig_in = self._conv_layer(orig_in, out_filts, 1, 2)
            # strides to downsample in encoder conv layers
            first_stride, second_stride = 2, 1
        else:
            first_stride, second_stride = strides, strides

        # first pre-activation block
        inputs = self._batch_norm_relu(inputs)
        inputs = self._conv_layer(inputs, num_filts, self.kernel_dim,
            first_stride)

        # second pre-activation block
        inputs = self._batch_norm_relu(inputs)
        inputs = self._conv_layer(inputs, num_filts, self.kernel_dim,
            second_stride)

        # add back original (possibly projected) inputs
        inputs += orig_in

        return inputs

    def _batch_norm_relu(self, inputs):
        '''
        Passes inputs through batch normalization and leaky-relu.
        '''
        inputs = tf.layers.batch_normalization(inputs,
            axis=1 if self.data_format == 'channels_first' else 3,
            momentum=_BN_DECAY, epsilon=_BN_EPSILON, center=True, scale=True,
            training=self.is_training, fused=True)
        inputs = self._leaky_relu(inputs)
        return inputs

    def _conv_layer(self, inputs, n_filts, kernel, strides, padding='same'):
        return tf.layers.conv2d(inputs, n_filts, kernel, strides, padding=padding,
            data_format=self.data_format, kernel_initializer=self.weight_init)

    def _leaky_relu(self, inputs):
        return tf.maximum(self.relu_alpha * inputs, inputs)
