In [13]:
from keras.layers import (Dense, Input, Activation, Flatten, Conv2D, ELU, UpSampling2D,
                          MaxPooling2D, GlobalAveragePooling2D, BatchNormalization, add)
from keras.models import Model
from keras.regularizers import l2
from keras import backend as K

In [58]:
def downsizeMapping(inputTensor, filters):
    """
    Residual building block where input tensor dimensions are halved, but
    feature map dimensions double
    """
    if K.image_data_format() == 'channels_last':
        bn_axis = 3
    else:
        bn_axis = 1
    x = Conv2D(filters, (3, 3), strides=(1,1), padding='same', kernel_regularizer=l2(0.0001), bias_regularizer=l2(0.0001))(inputTensor)
    x = BatchNormalization(axis=bn_axis)(x)
    x = ELU()(x)

    x = Conv2D(filters, (3, 3), padding='same', kernel_regularizer=l2(0.0001), bias_regularizer=l2(0.0001))(x)
    x = BatchNormalization(axis=bn_axis)(x)
    x = ELU()(x)

    inputTensor = Conv2D(filters, (1, 1), strides=(1,1), kernel_regularizer=l2(0.0001), bias_regularizer=l2(0.0001))(inputTensor)
    x = add([x, inputTensor])
    x = ELU()(x)

    return x

In [59]:
def residualMapping(inputTensor, filters, filter_size):
    """
    Residual building block where input and output tensor are the same
    dimensions
    """
    stride = (2,2)
    
    if K.image_data_format() == 'channels_last':
        bn_axis = 3
    else:
        bn_axis = 1

    x = Conv2D(filters, filter_size, strides=stride, padding='same', kernel_regularizer=l2(0.0001), bias_regularizer=l2(0.0001))(inputTensor)
    x = BatchNormalization(axis=bn_axis)(x)
    x = ELU()(x)
        
    x = Conv2D(filters, filter_size, strides=stride, padding='same', kernel_regularizer=l2(0.0001), bias_regularizer=l2(0.0001))(x)
    x = BatchNormalization(axis=bn_axis)(x)
    x = ELU()(x)

    x = UpSampling2D((4,4))(x)
    
    x = add([x, inputTensor])
    x = ELU()(x)

    return x


In [78]:
def resnet_mod():
    bn_axis = 1
    inputShape = (256, 256, 1)
    inputTensor = Input(inputShape)

    filters = 32
    x = Conv2D(filters, (3,3), strides=(1,1), padding='same',
               kernel_regularizer=l2(0.0001), bias_regularizer=l2(0.0001))(inputTensor)
    x = BatchNormalization(axis=bn_axis, name='bn_conv1')(x)
    x = ELU()(x)

    x = residualMapping(x, filters, (3,3))
    x = residualMapping(x, filters, (3,3))
    x = MaxPooling2D((2, 2), padding='same')(x) 
    x = residualMapping(x, filters, (4,4))
    x = residualMapping(x, filters, (4,4))
    x = MaxPooling2D((2, 2), padding='same')(x) 
    x = residualMapping(x, filters, (5,5))
    x = residualMapping(x, filters, (5,5))
    x = MaxPooling2D((2, 2), padding='same')(x) 
    x = residualMapping(x, filters, (6,6))
    
    
    
    filters = int(filters/2)
    x = downsizeMapping(x, filters)
    x = residualMapping(x, filters, (5,5))
    x = residualMapping(x, filters, (5,5))
    x = UpSampling2D((2, 2))(x)
    x = residualMapping(x, filters, (4,4))
    x = residualMapping(x, filters, (4,4))
    filters = int(filters/2)
    x = UpSampling2D((2, 2))(x)
    x = downsizeMapping(x, filters)    
    x = residualMapping(x, filters, (3,3))
    x = UpSampling2D((2, 2))(x)
    x = Conv2D(2, (3,3), strides=(1,1), padding='same',
               activation='linear',
               kernel_regularizer=l2(0.0001), bias_regularizer=l2(0.0001))(x)
    model = Model(inputTensor, x)
    
    return model

In [79]:
model = resnet_mod()

In [80]:
model.summary()

__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
input_24 (InputLayer)           (None, 256, 256, 1)  0                                            
__________________________________________________________________________________________________
conv2d_410 (Conv2D)             (None, 256, 256, 32) 320         input_24[0][0]                   
__________________________________________________________________________________________________
bn_conv1 (BatchNormalization)   (None, 256, 256, 32) 1024        conv2d_410[0][0]                 
__________________________________________________________________________________________________
elu_528 (ELU)                   (None, 256, 256, 32) 0           bn_conv1[0][0]                   
__________________________________________________________________________________________________
conv2d_411