In [15]:
# Construction of Standard U-Net
# U-Net: deep learning for cell counting, detection, and morphometry. Nat Methods 16, 67–70 (2019)

In [16]:
#import libraries for model development
from keras import models, layers

In [17]:
#Definition of standard convolution block

def stand_conv_block(inputs, filter_size, filter_num, dropout, batch_norm=True):

    conv = layers.Conv2D(filter_num, (filter_size, filter_size), padding='same')(inputs)
    
    # Batch normalization operation
    if batch_norm is True:

        conv = layers.BatchNormalization(axis=3)(conv)

    conv = layers.Activation('relu')(conv)

    conv = layers.Conv2D(filter_num, (filter_size, filter_size), padding='same')(conv)
    
    # Batch normalization operation
    if batch_norm is True:

        conv = layers.BatchNormalization(axis=3)(conv)

    conv_output = layers.Activation('relu')(conv)
    
    # Dropout operation
    if dropout > 0:
        
        conv_output = layers.Dropout(dropout)(conv)

    return conv_output

In [18]:
#Definition of Unet architecture

def Unet(input_size, class_num, dropout, batch_norm):

    # parameters of network congfiguration
    filter_num = 64 # number of filters
    filter_size = 3 # size of filters
    up_samp_size = 2 # size of upsampling filters

    inputs = layers.Input(input_size)

    # Downsampling

    # Downsampling step 1
    conv_1 = stand_conv_block(inputs, filter_size, 1*filter_num, dropout, batch_norm)
    pool_1 = layers.MaxPooling2D(pool_size=(2,2))(conv_1)

    # Downsampling step 2
    conv_2 = stand_conv_block(pool_1, filter_size, 2*filter_num, dropout, batch_norm)
    pool_2 = layers.MaxPooling2D(pool_size=(2,2))(conv_2)

    # Downsampling step 3
    conv_3 = stand_conv_block(pool_2, filter_size, 4*filter_num, dropout, batch_norm)
    pool_3 = layers.MaxPooling2D(pool_size=(2,2))(conv_3)

    # Downsampling step 4
    conv_4 = stand_conv_block(pool_3, filter_size, 8*filter_num, dropout, batch_norm)
    pool_4 = layers.MaxPooling2D(pool_size=(2,2))(conv_4)

    conv_5 = stand_conv_block(pool_4, filter_size, 16*filter_num, dropout, batch_norm)

    # Upsampling

    # Upsampling step 1

    up_1 = layers.UpSampling2D(size=(up_samp_size, up_samp_size), data_format="channels_last")(conv_5)
    up_1 = layers.concatenate([up_1, conv_4], axis=3)
    up_conv_1 = stand_conv_block(up_1, filter_size, 8*filter_num, dropout, batch_norm)
   
    # Upsampling step 2
    up_2 = layers.UpSampling2D(size=(up_samp_size, up_samp_size), data_format="channels_last")(up_conv_1)
    up_2 = layers.concatenate([up_2, conv_3], axis=3)
    up_conv_2 = stand_conv_block(up_2, filter_size, 4*filter_num, dropout, batch_norm)
    
    # Upsampling step 3
    up_3 = layers.UpSampling2D(size=(up_samp_size, up_samp_size), data_format="channels_last")(up_conv_2)
    up_3 = layers.concatenate([up_3, conv_2], axis=3)
    up_conv_3 = stand_conv_block(up_3, filter_size, 2*filter_num, dropout, batch_norm)

    # Upsampling step 4
    up_4 = layers.UpSampling2D(size=(up_samp_size, up_samp_size), data_format="channels_last")(up_conv_3)
    up_4 = layers.concatenate([up_4, conv_1], axis=3)
    up_conv_4 = stand_conv_block(up_4, filter_size, 1*filter_num, dropout, batch_norm)

    # 1*1 convolution
    conv_final = layers.Conv2D(class_num, kernel_size=(1,1))(up_conv_4)
    conv_final = layers.BatchNormalization(axis=3)(conv_final)
    conv_final = layers.Activation('softmax')(conv_final)

    # Model 
    model = models.Model(inputs, conv_final, name="U-Net")

    # print model summary for details
    print(model.summary())

    return model

In [19]:
#Size of images and input shape

input_size_x= 256 
input_size_y= 512
input_size = (input_size_x,input_size_y,1)

In [20]:
#Number of pixel labels: Keyhole, pore, substract, background and powder
class_num=5

#Setting dropout rate
dropout=0.0

In [21]:
#create the standard U-net model

model = Unet(input_size, class_num, dropout, batch_norm=True)

Model: "U-Net"
__________________________________________________________________________________________________
 Layer (type)                   Output Shape         Param #     Connected to                     
 input_3 (InputLayer)           [(None, 256, 512, 1  0           []                               
                                )]                                                                
                                                                                                  
 conv2d_38 (Conv2D)             (None, 256, 512, 64  640         ['input_3[0][0]']                
                                )                                                                 
                                                                                                  
 batch_normalization_38 (BatchN  (None, 256, 512, 64  256        ['conv2d_38[0][0]']              
 ormalization)                  )                                                             