In [26]:
from keras.models import Model
from keras.layers import Input, Conv2D, MaxPooling2D, UpSampling2D, Cropping2D, Concatenate
from keras import optimizers

In [46]:
# I intentionally left out dropout layers which can be easily added later on.

def unet(input_size=(572, 572, 1)):
    # Provide the input_size for dimension check convenience
    inputs = Input(input_size)
    
    # The first conv block 
    conv1 = Conv2D(64, 3, activation='relu', padding='valid', kernel_initializer='he_normal')(inputs) # 570x570x64
    conv1 = Conv2D(64, 3, activation='relu', padding='valid', kernel_initializer='he_normal')(conv1) # 568x568x64
    pool1 = MaxPooling2D((2, 2))(conv1) # 284x284x64
    
    # The second conv block  
    conv2 = Conv2D(128, 3, activation='relu', padding='valid', kernel_initializer='he_normal')(pool1) # 282x282x128
    conv2 = Conv2D(128, 3, activation='relu', padding='valid', kernel_initializer='he_normal')(conv2) # 280x280x128
    pool2 = MaxPooling2D((2, 2))(conv2) # 140x140x128

    # The third conv block 
    conv3 = Conv2D(256, 3, activation='relu', padding='valid', kernel_initializer='he_normal')(pool2) # 138x138x256
    conv3 = Conv2D(256, 3, activation='relu', padding='valid', kernel_initializer='he_normal')(conv3) # 136x136x256
    pool3 = MaxPooling2D((2, 2))(conv3) # 68x68x256
    
    # The forth conv block 
    conv4 = Conv2D(512, 3, activation='relu', padding='valid', kernel_initializer='he_normal')(pool3) # 66x66x512
    conv4 = Conv2D(512, 3, activation='relu', padding='valid', kernel_initializer='he_normal')(conv4) # 64x64x512
    pool4 = MaxPooling2D((2, 2))(conv4) # 32x32x512
    
    # The fifth conv block, this is the last block for downsampling
    conv5 = Conv2D(1024, 3, activation='relu', padding='valid', kernel_initializer='he_normal')(pool4) # 30x30x1024
    conv5 = Conv2D(1024, 3, activation='relu', padding='valid', kernel_initializer='he_normal')(conv5) # 28x28x1024
    
    
    # The first upsamling and upconvolution block
    upsample1 = UpSampling2D((2, 2))(conv5) # 56x56x1024
    upconv1 = Conv2D(512, 2, activation='relu', padding='same', kernel_initializer='he_normal')(upsample1) # 56x56x512
    
    # The first copy and crop operation
    crop1= Cropping2D(cropping=(4, 4))(conv4) # 64x64x512 --> 56x56x512
    
    # The first concatenate operation
    concat1 = Concatenate()([crop1, upconv1]) # 56x56x1024
    
    # The sixth conv block
    conv6 = Conv2D(512, 3, activation='relu', padding='valid', kernel_initializer='he_normal')(concat1) # 54x54x512
    conv6 = Conv2D(512, 3, activation='relu', padding='valid', kernel_initializer='he_normal')(conv6) # 52x52x512
    
    # The second upsampling and upconvolution block
    upsample2 = UpSampling2D((2, 2))(conv6) #104x104x512
    upconv2 = Conv2D(256, 2, activation='relu', padding='same', kernel_initializer='he_normal')(upsample2) #104x104x256
    
    # The second copy and crop operation
    crop2 = Cropping2D(cropping=(16, 16))(conv3) # 136x136x256 --> 104x104x256
    
    # The second concatenate operation
    concat2 = Concatenate()([crop2, upconv2]) # 104x104x512
    
    # The seventh conv block
    conv7 = Conv2D(256, 3, activation='relu', padding='valid', kernel_initializer='he_normal')(concat2) # 102x102x256
    conv7 = Conv2D(256, 3, activation='relu', padding='valid', kernel_initializer='he_normal')(conv7) # 100x100x256
    
    # The third upsampling and upconvolution block
    upsample3 = UpSampling2D((2, 2))(conv7) # 200x200x256 
    upconv3 = Conv2D(128, 2, activation='relu', padding='same', kernel_initializer='he_normal')(upsample3) # 200x200x128
    
    # The third copy and crop operation
    crop3 = Cropping2D(cropping=(40, 40))(conv2) # 280x280x128 --> 200x200x128
    
    # The third concatenate operation
    concat3 = Concatenate()([crop3, upconv3]) # 200x200x256
    
    # The eighth conv block
    conv8 = Conv2D(128, 3, activation='relu', padding='valid', kernel_initializer='he_normal')(concat3) # 198x198x128
    conv8 = Conv2D(128, 3, activation='relu', padding='valid', kernel_initializer='he_normal')(conv8) # 196x196x128  
    
    # The forth upsampling and upconvolution block
    upsample4 = UpSampling2D((2, 2))(conv8) # 392x392x128
    upconv4 = Conv2D(64, 2, activation='relu', padding='same', kernel_initializer='he_normal')(upsample4) #392x392x64
    
    # The forth copy and crop operation
    crop4 = Cropping2D(cropping=(88, 88))(conv1) # 568x568x64 --> 392x392x64
    
    # The forth concatenate operation
    concat4 = Concatenate()([crop4, upconv4]) # 392x392x128
    
    # The nineth conv block
    conv9 = Conv2D(64, 3, activation='relu', padding='valid', kernel_initializer='he_normal')(concat4) # 390x390x64
    conv9= Conv2D(64, 3, activation='relu', padding='valid', kernel_initializer='he_normal')(conv9) # 388x388x64  
    
    # The last 1x1 convolution
    outputs = Conv2D(2, 1, activation='relu', padding='valid', kernel_initializer='he_normal')(conv9) # 388x388x2
    
    model = Model(inputs=inputs, outputs=outputs)
    model.summary()
    
    model.compile(optimizer=optimizers.Adam(lr=1e-4), loss='binary_crossentropy', metrics=['accuracy'])
    
    return model

In [48]:
model = unet()

__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
input_20 (InputLayer)           (None, 572, 572, 1)  0                                            
__________________________________________________________________________________________________
conv2d_211 (Conv2D)             (None, 570, 570, 64) 640         input_20[0][0]                   
__________________________________________________________________________________________________
conv2d_212 (Conv2D)             (None, 568, 568, 64) 36928       conv2d_211[0][0]                 
__________________________________________________________________________________________________
max_pooling2d_63 (MaxPooling2D) (None, 284, 284, 64) 0           conv2d_212[0][0]                 
__________________________________________________________________________________________________
conv2d_213