In [1]:
import tensorflow as tf

In [9]:
'''
Pre-Activation Wide-ResNet Network for Spectrum Mining and Modulation Detection.

References:
    Basic ResNet: https://arxiv.org/pdf/1512.03385.pdf
    Pre-Act. ResNet: https://arxiv.org/pdf/1603.05027.pdf
    Wide-ResNet: https://arxiv.org/pdf/1605.07146v1.pdf
'''
import tensorflow as tf
from collections import namedtuple

# tuple for model hyperparameters
HyperParams = namedtuple('HyperParams',
                     'batch_size, num_chans_in, height, width, filter_dims,'
                     'lr, kernel_dims, strides, n_hidden, relu_alpha')
# sample hyperparameters
h = HyperParams(128, 1, 20, 512, [16, 16, 16, 32, 64], 1e-5, [3, 3],
                [1, 1, 1, 2, 2], 128, 0.01)

# parameters for batch_norm
_BATCH_NORM_DECAY = 0.997
_BATCH_NORM_EPSILON = 1e-5


def ResNet(object):
    def __init__(self, hps=h):

        # parameters and inputs
        self.batch_size = 128 # hps.batch_size
        num_chans = 1 # hps.num_chans_in
        height = 20 # hps.height
        width = 512 # hps.width

        # placeholder for inputs and training flag
        self.inputs = tf.placeholder(tf.float32,
            shape=(None, num_chans, height, width), name='x_placeholder')
        self.is_training = tf.placeholder_with_default(True, [],
            name='is_training')

        # network parameters
        self.n_filts = [16, 16, 16, 32, 64] # hps.filter_dims
        # self.n_filts = [16, 16, 80, 160, 320] # params for wide-resnet(5,4)
        self.kernel_dim = (3, 3) # hps.kernel_dims
        self.strides = [1, 1, 1, 2, 2] # hps.strides
        self.n_hidden = 128 # hps.n_hidden
        self.relu_alpha = hps.relu_alpha
        # initializers
        self.conv_init = tf.contrib.layers.xavier_initializer_conv2d
        self.dense_init = tf.contrib.layers.xavier_initializer

        # forward pass through resnet
        self.inference = self._get_inference()
        # get loss from output
        self.loss = self._get_loss()
        # get optimizer step to minimize loss
        self.optim = self._get_optimizer()

    def _get_inference(self):
        x = self._get_first_layer()
        x = self._build_model(x)
        return x

    def _get_first_layer(self):
        '''
        Return the first convolutional layer common to all ResNet models.

        Note: Slightly different from layer + pooling in original resnet, but
            this performs better.
        '''
        with tf.variable_scope('first_layer'):
            x = tf.layers.conv2d(self.inputs, filters=self.n_filts[0],
                kernel_size=self.kernel_dim, strides=self.strides[0],
                kernel_initializer=tf.conv_init, padding='same')
            return x

    def _build_model(self, x):
        '''Builds the residual blocks.'''
        # scales better for deeper resnet blocks
        # for i in range(1, len(self.n_filts)):
        #     with tf.variable_scope('resnet_block_{:d}'.format(i)):
        #         x = self._resnet_block(x, self.n_filts[i], self.strides[i],
        #             self.training_pl)

        with tf.variable_scope('first_block'):
            x = self._resnet_block(x, self.n_filts[1], self.strides[1],
                self.training_pl)

        with tf.variable_scope('second_block'):
            x = self._resnet_block(x, self.n_filts[2], self.strides[2],
                self.training_pl)

        with tf.variable_scope('third_block'):
            x = self._resnet_block(x, self.n_filts[3], self.strides[3],
                self.training_pl)

        with tf.variable_scope('fourth_block'):
            x = self._resnet_block(x, self.n_filts[4], self.strides[4],
                self.training_pl)

        with tf.variable_scope('GAP_block'):
            # NOTE: assumes channels_first
            x = tf.reduce_mean(x, [2, 3])

        with tf.variable_scope('dense_out'):
            x = tf.reshape(x, [self.batch_size, -1])
            x = tf.layers.dense(x, self.n_hidden, activation=self._leaky_relu,
                    kernel_initializer=self.dense_init)

        return x

    def _resnet_block(self, inputs, in_filts, strides, is_training):
        '''
        ResNet block with two convolutions.
        Projects original inputs through 1x1 conv if strides downsample.
        '''
        # store original x for skip-connection
        orig_in = inputs

        # if we are down-sampling, need to double amount of filters
        if strides == 2:
            in_filts *= 2

        # first pre-activation block
        inputs = self._batch_norm_relu(inputs, is_training)
        inputs = tf.layers.conv2d(inputs, in_filts, self.kernel_dim, strides,
            padding='same', data_format='channels_first',
            kernel_initializer=self.conv_init)

        # second pre-activation block
        inputs = self._batch_norm_relu(inputs, is_training)
        inputs = tf.layers.conv2d(inputs, in_filts, self.kernel_dim, strides,
            padding='same', data_format='channels_first',
            kernel_initializer=self.conv_init)

        # add back original inputs, with projection if needed
        if strides == 2:
            # 1x1 conv with downsample strides and filters
            orig_in = tf.conv2d(orig_in, in_filts, (1, 1), strides,
                padding='same', data_format='channels_first',
                kernel_initializer=self.conv_init)
        inputs += orig_in

        return inputs

    def _batch_norm_relu(self, inputs, is_training):
        '''
        Passes inputs through batch normalization and leaky-relu.
        '''
        inputs = tf.layers.batch_normalization(inputs, axis=3, momentum=_BATCH_NORM_DECAY,
            epsilon=_BATCH_NORM_EPSILON, center=True, scale=True, training=is_training,
            fused=True)
        inputs = self._leaky_relu(inputs)
        return inputs

    def _leaky_relu(self, inputs, alpha=0.01):
        return tf.maximum(self.relu_alpha * inputs, inputs)
