In [7]:
#Construction of Standard Unet

In [8]:
#import libraries for model development
from keras import models, layers
from keras import backend as K

In [9]:
#Definition of convolution block

def conv_block(inputs, filter_size, size, dropout, batch_norm=True):

    conv = layers.Conv2D(size, (filter_size, filter_size), padding='same')(inputs)

    if batch_norm is True:
        conv = layers.BatchNormalization(axis=3)(conv)

    conv = layers.Activation('relu')(conv)

    if dropout > 0:
        conv = layers.Dropout(dropout)(conv)

    return conv

In [10]:
#Definition of Unet architecture

def Unet(input_shape, NUM_CLASSES, dropout_rate, batch_norm):

    # network structure
    FILTER_NUM = 32 # number of filters for the first layer
    FILTER_SIZE = 3 # size of the convolutional filter
    UP_SAMP_SIZE = 2 # size of upsampling filters

    inputs = layers.Input(input_shape)

    # Downsampling layers

    # DownRes 1, convolution + pooling
    conv_1 = conv_block(inputs, FILTER_SIZE, 1*FILTER_NUM, dropout_rate, batch_norm)
    pool_1 = layers.MaxPooling2D(pool_size=(2,2))(conv_1)

    # DownRes 2, convolution + pooling
    conv_2 = conv_block(pool_1, FILTER_SIZE, 2*FILTER_NUM, dropout_rate, batch_norm)
    pool_2 = layers.MaxPooling2D(pool_size=(2,2))(conv_2)

    # DownRes 3, convolution + pooling
    conv_3 = conv_block(pool_2, FILTER_SIZE, 4*FILTER_NUM, dropout_rate, batch_norm)
    pool_3 = layers.MaxPooling2D(pool_size=(2,2))(conv_3)

    # DownRes 4
    conv_4 = conv_block(pool_3, FILTER_SIZE, 8*FILTER_NUM, dropout_rate, batch_norm)
    pool_4 = layers.MaxPooling2D(pool_size=(2,2))(conv_4)

    # DownRes 5, convolution only
    conv_5 = conv_block(pool_4, FILTER_SIZE, 16*FILTER_NUM, dropout_rate, batch_norm)


    # UpRes 1, upsampling  + concatenate
    up_1 = layers.UpSampling2D(size=(UP_SAMP_SIZE, UP_SAMP_SIZE), data_format="channels_last")(conv_5)
    up_1 = layers.concatenate([up_1, conv_4], axis=3)
    up_conv_1 = conv_block(up_1, FILTER_SIZE, 8*FILTER_NUM, dropout_rate, batch_norm)
   
    # UpRes 2, upsampling  + concatenate
    up_2 = layers.UpSampling2D(size=(UP_SAMP_SIZE, UP_SAMP_SIZE), data_format="channels_last")(up_conv_1)
    up_2 = layers.concatenate([up_2, conv_3], axis=3)
    up_conv_2 = conv_block(up_2, FILTER_SIZE, 4*FILTER_NUM, dropout_rate, batch_norm)
    
    # UpRes 3, upsampling  + concatenate
    up_3 = layers.UpSampling2D(size=(UP_SAMP_SIZE, UP_SAMP_SIZE), data_format="channels_last")(up_conv_2)
    up_3 = layers.concatenate([up_3, conv_2], axis=3)
    up_conv_3 = conv_block(up_3, FILTER_SIZE, 2*FILTER_NUM, dropout_rate, batch_norm)

    # UpRes 4, upsampling  + concatenate
    up_4 = layers.UpSampling2D(size=(UP_SAMP_SIZE, UP_SAMP_SIZE), data_format="channels_last")(up_conv_3)
    up_4 = layers.concatenate([up_4, conv_1], axis=3)
    up_conv_4 = conv_block(up_4, FILTER_SIZE, 1*FILTER_NUM, dropout_rate, batch_norm)

    # 1*1 convolutional layers
    conv_final = layers.Conv2D(NUM_CLASSES, kernel_size=(1,1))(up_conv_4)
    conv_final = layers.BatchNormalization(axis=3)(conv_final)
    conv_final = layers.Activation('softmax')(conv_final)  #Change to softmax for multichannel

    # Model 
    model = models.Model(inputs, conv_final, name="Unet")

    print(model.summary())

    return model

In [11]:
#Size of images and input shape

SIZE_X = 256 
SIZE_Y = 512
input_shape = (SIZE_X,SIZE_Y,1)

#Number of classes for segmentation
n_classes=5 

In [12]:
#create the AM_SegNet model

model = Unet(input_shape, NUM_CLASSES=5, dropout_rate=0.0, batch_norm=True)

Model: "Unet"
__________________________________________________________________________________________________
 Layer (type)                   Output Shape         Param #     Connected to                     
 input_2 (InputLayer)           [(None, 256, 512, 1  0           []                               
                                )]                                                                
                                                                                                  
 conv2d_10 (Conv2D)             (None, 256, 512, 32  320         ['input_2[0][0]']                
                                )                                                                 
                                                                                                  
 batch_normalization_10 (BatchN  (None, 256, 512, 32  128        ['conv2d_10[0][0]']              
 ormalization)                  )                                                              