In [1]:
import tensorflow as tf

In [2]:
_BATCH_NORM_DECAY = 0.997
_BATCH_NORM_EPSILON = 1e-5
DEFAULT_VERSION = 2
DEFAULT_DTYPE = tf.float32
CASTABLE_TYPES = (tf.float16,)
ALLOWED_TYPES = (DEFAULT_DTYPE,) + CASTABLE_TYPES

In [3]:
def batch_normalize(inputs, is_training, dFormat):
    return tf.layers.batch_normalization(inputs, axis=1 if dFormat=='channels_first' else 3,
                                         momentum=_BATCH_NORM_DECAY, epsilon=_BATCH_NORM_EPSILON,
                                         training=is_training, fused=True)

def fixed_pad(inputs, ksize, dFormat):
    pad_total = ksize - 1
    pad_left = pad_total//2
    pad_right = pad_total - pad_left

    if dFormat == 'channels_first':
        return tf.pad(inputs, [[0,0],[0,0],[pad_left,pad_right],[pad_left,pad_right]])
    else:
        return tf.pad(inputs, [[0,0],[pad_left,pad_right],[pad_left,pad_right],[0,0]])

def conv_fixed_pad(inputs, n_filters, k_size, strides, dFormat):    
    if strides > 1:
        inputs = fixed_pad(inputs, ksize, dFormat)
    
    return tf.layers.conv2d(inputs=inputs, filters=n_filters, kernel_size=k_size, 
                            strides=[strides]*2, padding='SAME' if strides==1 else 'VALID',
                            data_format=dFormat)

def conv_bn(inputs, n_filters, k_size, strides, is_training, dFormat, activation):
    inputs = conv_fixed_pad(inputs, n_filters, k_size, strides, dFormat)
    inputs = batch_normalize(inputs, is_training, dFormat)
    return tf.nn.relu(inputs) if activation == 'RELU' else inputs

conv_bn = tf.contrib.framework.add_arg_scope(conv_bn)

In [4]:
def _res_block_v1(inputs, n_filters, strides, is_training, projection_shortcut, dFormat):
    shortcut = inputs
    
    if projection_shortcut is not None:
        shortcut = projection_shortcut(shortcut)
        shortcut = batch_normalize(inputs=shortcut, is_training=is_training, dFormat=dFormat)
    
    with tf.contrib.framework.arg_scope([conv_bn], n_filters=n_filters, k_size=3, 
                                        is_training=is_training, dFormat=dFormat):
        inputs = conv_bn(inputs=inputs, strides=strides, activation='RELU')
        inputs = conv_bn(inputs=inputs, strides=1, activation=None)
    
    return tf.nn.relu(inputs + shortcut)

def _res_block_v2(inputs, n_filters, strides, is_training, projection_shortcut, dFormat):
    shortcut = inputs
    inputs = batch_normalize(inputs=inputs, is_training=is_training, dFormat=dFormat)
    inputs = tf.nn.relu(inputs)
    
    if projection_shortcut is not None:
        shortcut = projection_shortcut(inputs)

    inputs = conv_bn(inputs=inputs, n_filters=n_filters, k_size=3, strides=strides, 
                     is_training=is_training, dFormat=dFormat, activation='RELU')
    inputs = conv_fixed_pad(inputs=inputs, n_filters=n_filters, k_size=3, strides=1, dFormat=dFormat)
    
    return inputs + shortcut

def _bottleneck_res_block_v1(inputs, n_filters, strides, is_training, projection_shortcut, dFormat):
    shortcut = inputs
    
    if projection_shortcut is not None:
        shortcut = projection_shortcut(shortcut)
        shortcut = batch_normalize(inputs=shortcut, is_training=is_training, dFormat=dFormat)
    
    with tf.contrib.framework.arg_scope([conv_bn], n_filters=n_filters, is_training=is_training, 
                                        dFormat=dFormat, activation='RELU'):
        inputs = conv_bn(inputs=inputs, k_size=1, strides=1)
        inputs = conv_bn(inputs=inputs, k_size=3, strides=strides)
    
    inputs = conv_bn(inputs=inputs, n_filters=4*n_filters, k_size=1, strides=1, 
                     is_training=is_training, dFormat=dFormat, activation=None)
    
    return tf.nn.relu(inputs + shortcut)

def _bottleneck_res_block_v2(inputs, n_filters, strides, is_training, projection_shortcut, dFormat):
    shortcut = inputs
    inputs = batch_normalize(inputs=inputs, is_training=is_training, dFormat=dFormat)
    inputs = tf.nn.relu(inputs)
    
    if projection_shortcut is not None:
        shortcut = projection_shortcut(inputs)
    
    with tf.contrib.framework.arg_scope([conv_bn], n_filters=n_filters, is_training=is_training, 
                                        dFormat=dFormat, activation='RELU'):
        inputs = conv_bn(inputs=inputs, k_size=1, strides=1)
        inputs = conv_bn(inputs=inputs, k_size=3, strides=strides)
    
    inputs = conv_fixed_pad(inputs=inputs, n_filters=4*n_filters, k_size=1, strides=1, dFormat=dFormat)
    
    return inputs + shortcut

In [5]:
def block_layer(inputs, n_blocks, block_func, bottleneck, 
                  n_filters, strides, is_training, dFormat, output_name):
    # each block chain has a fixed n_filters
    
    n_filters_out = 4*n_filters if bottleneck else n_filters
    
    def projection_shortcut(inputs):
        return conv_fixed_pad(inputs=inputs, n_filters=filters_out, k_size=1, strides=strides, dFormat=dFormat)    
    
    # the input is projected to have depth=n_filters in the first block
    inputs = block_func(inputs, n_filters, strides, is_training, projection_shortcut, dFormat)
    
    for _ in n_blocks:
        inputs = block_func(inputs, n_filters, 1, is_training, None, dFormat)
    
    return tf.identity(inputs, output_name)

In [6]:
class ResNet_Model():
    def __init__(self, resnet_size, final_size, n_classes,
                 n_filters_initial, block_sizes, block_strides, bottleneck,
                 first_conv_size, first_conv_strides, first_pool_size, first_pool_strides,
                 version=DEFAULT_VERSION, data_format=None, dtype=DEFAULT_DTYPE):
        
        self.resnet_size = resnet_size
        self.final_size = final_size
        self.n_classes = n_classes
        self.n_filters_initial = n_filters_initial
        self.block_sizes = block_sizes
        self.block_strides = block_strides
        self.bottleneck = bottleneck  
        self.first_conv_size = first_conv_size
        self.first_conv_strides = first_conv_strides
        self.first_pool_size = first_pool_size
        self.first_pool_strides = first_pool_strides
        
        if version not in {1, 2}:
            raise ValueError('Resnet version should be 1 or 2.')
        self.version = version
        
        if data_format not in {'channes_first', 'channels_last'}:
            raise ValueError('Data format should be "channel_first" or "channel_last"')
        self.data_format = data_format
        
        if dtype not in ALLOWED_TYPES:
            raise ValueError('dtype must be one of: {}'.format(ALLOWED_TYPES))
        self.dtype = dtype
        
        if bottleneck == True:
            if version == 1:
                self.block_func = _bottleneck_res_block_v1
            else:
                self.block_func = _bottleneck_res_block_v2
        else:
            if version == 1:
                self.block_func = _res_block_v1
            else:
                self.block_func = _res_block_v2
    
    def _custom_dtype_getter(self, getter, name, shape=None, dtype=DEFAULT_DTYPE, *args, **kwargs):
        if dtype in CASTABLE_TYPES:
            var = getter(name, shape, tf.float32, *args, **kwargs)
            return tf.cast(var, dtype=dtype, name=name + '_cast')
        return getter(name, shape, dtype, *args, **kwargs)
    
    def _model_variable_scope(self):
        return tf.variable_scope('resnet_model', custom_getter=self._custom_dtype_getter)
    
    def __call__(inputs, is_training):
        with self._model_variabel_scope():          
            if self.data_format == 'channels_first':
                inputs = tf.transpose(inputs, [0,3,1,2])

            inputs = conv_fixed_pad(inputs=inputs, n_filters=self.n_filters_initial, k_size=self.first_conv_size, 
                                    strides=self.first_conv_strides, dFormat=self.data_format)

            if self.first_pool_size:
                inputs = tf.layers.max_pooling2d(inputs=inputs, pool_size=self.first_pool_size,
                                                 strides=self.first_pool_strides, padding='SAME',
                                                 data_format=self.data_format)
                inputs = tf.identity(inputs, 'initial_maxpool')

            for i, n_blocks in enumerate(self.block_sizes):
                n_filters = self.n_filters * (2**i)
                inputs = block_layer(inputs=inputs, n_blocks=n_blocks, block_func=self.block_func, 
                                     bottleneck=self.bottleneck, n_filters=n_filters, 
                                     strides=self.block_strides[i], is_training=is_training, 
                                     dFormat=self.data_format, output_name='block_layer_{}'.format(i))

            inputs = tf.nn.relu(batch_normalize(inputs, is_training, self.data_format))

            axes = [2, 3] if self.data_format == 'channels_first' else [1, 2]
            inputs = tf.reduce_mean(inputs, axes, keepdims=True)
            inputs = tf.identity(inputs, 'final_reduce_mean')

            inputs = tf.reshape(inputs, [-1, self.final_size])
            inputs = tf.layers.dense(inputs=inputs, units=self.n_classes)
            inputs = tf.identity(inputs, 'final_dense')
            return inputs     