Pre-Activation Wide-ResNet Network for Spectrum Mining and Modulation Detection.

References:
    Basic ResNet: https://arxiv.org/pdf/1512.03385.pdf
    Pre-Act. ResNet: https://arxiv.org/pdf/1603.05027.pdf
    Wide-ResNet: https://arxiv.org/pdf/1605.07146v1.pdf

Below we create a decorator that wraps the subgraphs for our inference, loss, and optimizer.
This makes sure we only call the graph creation code once. Small, but handy. Can extend decorator to work with `tf.variable_scope` and arbitrary arguments. This combined with `partial` can really cut down on boilerplate code and make the network much more readable/understandable.

In [None]:
import functools

def lazy_init(function):
    '''
    Lazy initialization for inference, loss, and optimzier in a graph.
    '''
    attribute = '_cache_' + function.__name__

    @property
    @functools.wraps(function)
    def decorator(self):
        if not hasattr(self, attribute):
            setattr(self, attribute, function(self))
        return getattr(self, attribute)

    return decorator

Next we create a named tuple for our hyperparameters. This is cleaner and less error-prone than passing in a dictionary. Combined with `FLAGS`, it is the ideal way of handling the multitude of parameters in a Deep Learning project.

We also define some constants for Batch Normalization that have proven to be reliable.

In [None]:
from collections import namedtuple

# tuple for model hyperparameters
ResNetParams = namedtuple('ResNetParams',
                     'batch_size, num_cls, num_chans_in, height, width, lr, filter_dims,'
                     'kernel_dims, strides, n_hidden, relu_alpha')
# sample hyperparameters
h = HyperParams(128, 10, 1, 20, 512, 1e-5, [16, 16, 32, 64], [3, 3],
                [1, 1, 2, 2], 128, 0.01)

# parameters for batch_norm
_BATCH_NORM_DECAY = 0.997
_BATCH_NORM_EPSILON = 1e-5

Now we are ready to build on Pre-Activation ResNet in tensorflow:

In [6]:
import tensorflow as tf

class ResNet(object):
    def __init__(self, inputs, labels, is_training, hps):
        '''
        Builds a 5-layer pre-activation ResNet.
        Using Xavier initializer for conv and dense weights.
        Leaky-relu as activation function.
        '''
        # parameters and inputs
        self.batch_size = hps.batch_size
        self.num_cls = hps.num_cls
        num_chans = hps.num_chans
        height = hps.height
        width = hps.width

        # placeholder for inputs, labels, and is_training flag
        self.inputs = inputs
        self.labels = labels
        self.is_training = is_training

        # network parameters
        self.n_filts = hps.filter_dims
        # self.n_filts = [16, 16, 80, 160, 320] # params for wide-resnet(5,4)
        self.kernel_dim = hps.kernel_dims
        self.strides = hps.strides
        self.n_hidden = hps.n_hidden
        self.relu_alpha = hps.relu_alpha
        # learning rate
        self.lr = hps.lr
        # initializers for weight layers
        self.conv_init = tf.contrib.layers.xavier_initializer_conv2d()
        self.dense_init = tf.contrib.layers.xavier_initializer()

        # subgraphs
        self.inference
        self.loss
        self.train_op

    @lazy_init
    def inference(self):
        x = self._build_first_layer()
        x = self._build_resnet(x)
        return x

    def _build_first_layer(self):
        '''
        Return the first convolutional layer common to all ResNet models.

        Note: Slightly different from layer + pooling in original resnet, but
            this performs better.
        '''
        with tf.variable_scope('first_layer'):
            x = self._conv_layer(self.inputs, 16, 3, 1)
            return x

    def _build_resnet(self, x):
        '''Builds the residual blocks.'''

        with tf.variable_scope('first_block'):
            x = self._resnet_block(x, self.n_filts[0], self.strides[0],
                    self.is_training)

        with tf.variable_scope('second_block'):
            x = self._resnet_block(x, self.n_filts[1], self.strides[1],
                    self.is_training)

        with tf.variable_scope('third_block'):
            x = self._resnet_block(x, self.n_filts[2], self.strides[2],
                    self.is_training)

        with tf.variable_scope('fourth_block'):
            x = self._resnet_block(x, self.n_filts[3], self.strides[3],
                    self.is_training)

        with tf.variable_scope('GAP_block'):
            x = tf.reduce_mean(x, [1, 2])

        with tf.variable_scope('dense'):
            x = tf.reshape(x, [self.batch_size, -1])
            x = tf.layers.dense(x, self.n_hidden, activation=self._leaky_relu,
                    kernel_initializer=self.dense_init)

        with tf.variable_scope('logits'):
            x = tf.layers.dense(x, self.num_cls,
                kernel_initializer=self.dense_init)

        return x

    @lazy_init
    def loss(self):
        with tf.variable_scope('loss'):
            x_ent = tf.nn.softmax_cross_entropy_with_logits(
                labels=self.labels,
                logits=self.inference)
            loss = tf.reduce_mean(x_ent)
            tf.summary.scalar('loss', loss)
        return loss

    @lazy_init
    def train_op(self):
        with tf.variable_scope('optimizer'):
            # learning rate schedule
            global_step = tf.Variable(0, trainable=False)
            learning_rate = tf.train.exponential_decay(
                learning_rate=self.lr, global_step=global_step,
                decay_steps=10000, decay_rate=0.97, staircase=True)
            tf.summary.scalar('learn_rate', learning_rate)
            # Nesterov momentum optimizer
            optim = tf.train.MomentumOptimizer(learning_rate=learning_rate,
                momentum=.9, use_nesterov=True)
            # add optimizers for batch_norm
            update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
            with tf.control_dependencies(update_ops):
                train_op = optim.minimize(self.loss, global_step=global_step)
        self.summaries = tf.summary.merge_all()
        return train_op

    def _resnet_block(self, inputs, num_filts, strides, is_training):
        '''
        ResNet block with two convolutions.
        Projects original inputs through 1x1 conv if strides downsample.
        '''
        # store original x for skip-connection
        orig_in = inputs

        # if downsampling, then double number of filters and project original x
        if strides == 2:
            num_filts *= 2
            orig_in = self._conv_layer(orig_in, num_filts, 1, 2*strides)

        # first pre-activation block
        inputs = self._batch_norm_relu(inputs, is_training)
        inputs = self._conv_layer(inputs, num_filts, self.kernel_dim, strides)

        # second pre-activation block
        inputs = self._batch_norm_relu(inputs, is_training)
        inputs = self._conv_layer(inputs, num_filts, self.kernel_dim, strides)

        # add back original (possibly projected) inputs
        inputs += orig_in

        return inputs

    def _batch_norm_relu(self, inputs, is_training):
        '''
        Passes inputs through batch normalization and leaky-relu.
        '''
        inputs = tf.layers.batch_normalization(inputs, axis=3,
            momentum=_BN_DECAY, epsilon=_BN_EPSILON, center=True, scale=True,
            training=is_training, fused=True)
        inputs = self._leaky_relu(inputs)
        return inputs

    def _conv_layer(self, inputs, n_filts, kernel, strides, padding='same',
                    data_format='channels_last'):
        return tf.layers.conv2d(inputs, n_filts, kernel, strides, padding=padding,
            data_format=data_format, kernel_initializer=self.conv_init)

    def _leaky_relu(self, inputs):
        return tf.maximum(self.relu_alpha * inputs, inputs)
