In [1]:
#In the variants of U-Net, the standard convolution block is substituted with a modified convolution block

In [2]:
#import libraries for model development
from keras import layers

In [None]:
# Res-Unet: repalce stand_conv_block with res_conv_block
# Deep residual learning for image recognition. CVPR 2016

In [3]:
def res_conv_block(inputs, filter_size, filter_num, dropout, batch_norm=True):

    conv = layers.Conv2D(filter_num, (filter_size, filter_size), padding='same')(inputs)
    
    # Batch normalization operation
    if batch_norm is True:
        conv = layers.BatchNormalization(axis=3)(conv)

    conv = layers.Activation('relu')(conv)
    
    conv = layers.Conv2D(filter_num, (filter_size, filter_size), padding='same')(conv)
    
    # Batch normalization operation
    if batch_norm is True:
        conv = layers.BatchNormalization(axis=3)(conv)
    
    # Dropout operation
    if dropout > 0:
        conv = layers.Dropout(dropout)(conv)

    # Create shortcut
    shortcut = layers.Conv2D(filter_num, kernel_size=(1, 1), padding='same')(inputs)

    # Batch normalization operation
    if batch_norm is True:
        shortcut = layers.BatchNormalization(axis=3)(shortcut)

    # Add shortcut with conv
    res_conv = layers.add([shortcut, conv])

    res_conv = layers.Activation('relu')(res_conv)
    
    return res_conv


In [None]:
# Squeeze-Unet: repalce stand_conv_block with fire_block in the downsampling layers
# SqueezeNet: AlexNet-level accuracy with 50x fewer parameters and <0.5MB model size

In [None]:
def fire_block(inputs, filter_size, filter_num, dropout, batch_norm=True):

    #squeeze
    squeeze = layers.Conv2D(filter_num, (1, 1), padding='same')(inputs)

    if batch_norm is True:
        squeeze = layers.BatchNormalization(axis=3)(squeeze)

    squeeze = layers.Activation('relu')(squeeze)

    #left
    left = layers.Conv2D(filter_num, (1, 1), padding='same')(squeeze)

    if batch_norm is True:
        left = layers.BatchNormalization(axis=3)(left)

    left = layers.Activation('relu')(left)

    #right
    right = layers.Conv2D(filter_num,(filter_size, filter_size), padding='same')(squeeze)

    if batch_norm is True:
        right = layers.BatchNormalization(axis=3)(right)
        
    right = layers.Activation("relu")(right)

    #concatenate
    fire_block = layers.concatenate([left, right], axis=3)

    if dropout > 0:
        fire_block = layers.Dropout(dropout)(fire_block)

    return fire_block