In [2]:
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
import matplotlib.pyplot as plt
import numpy as np

In [9]:
def double_conv_block(x, n_filters):

    # Conv2D then ReLU activation
    x = layers.Conv2D(n_filters, 3, padding = "same", activation = "relu", kernel_initializer = "he_normal")(x)
    # Conv2D then ReLU activation
    x = layers.Conv2D(n_filters, 3, padding = "same", activation = "relu", kernel_initializer = "he_normal")(x)

    return x 



def downsample_block(x, n_filters):
    f = double_conv_block(x, n_filters)
    p = layers.MaxPool2D(2)(f)
    p = layers.Dropout(0.3)(p)

    return f, p

def upsample_block(x, conv_features, n_filters):
    # upsample
    x = layers.Conv2DTranspose(n_filters, 3, 2, padding="same")(x)
    # concatenate
    x = layers.concatenate([x, conv_features])
    # dropout
    x = layers.Dropout(0.3)(x)
    # Conv2D twice with ReLU activation
    x = double_conv_block(x, n_filters)
    return x


 # 
 # 
def get_model():
    inputs = layers.Input(shape=(256,256,1))

    # encoder: contracting path - downsample
    # 1 - downsample
    f1, p1 = downsample_block(inputs, 32)
    # 2 - downsample
    f2, p2 = downsample_block(p1, 64)
    # 3 - downsample
    f3, p3 = downsample_block(p2, 128)
    # 4 - downsample
    f4, p4 = downsample_block(p3, 256)

    # 5 - bottleneck
    bottleneck = double_conv_block(p4, 512)

    # decoder: expanding path - upsample
    # 6 - upsample
    u6 = upsample_block(bottleneck, f4, 512)
    # 7 - upsample
    u7 = upsample_block(u6, f3, 256)
    # 8 - upsample
    u8 = upsample_block(u7, f2, 128)
    # 9 - upsample
    u9 = upsample_block(u8, f1, 64)
    # - upsample
    u10 = upsample_block(u8, f1, 32)

    # outputs
    outputs = layers.Conv2D(3, 1, padding="same", activation = "softmax")(u10)

    # unet model with Keras Functional API
    unet_model = tf.keras.Model(inputs, outputs, name="U-Net")
    unet_model.summary()
    return unet_model





In [10]:
get_model()

Model: "U-Net"
__________________________________________________________________________________________________
 Layer (type)                   Output Shape         Param #     Connected to                     
 input_3 (InputLayer)           [(None, 256, 256, 1  0           []                               
                                )]                                                                
                                                                                                  
 conv2d_33 (Conv2D)             (None, 256, 256, 32  320         ['input_3[0][0]']                
                                )                                                                 
                                                                                                  
 conv2d_34 (Conv2D)             (None, 256, 256, 32  9248        ['conv2d_33[0][0]']              
                                )                                                             

 conv2d_47 (Conv2D)             (None, 128, 128, 12  221312      ['dropout_20[0][0]']             
                                8)                                                                
                                                                                                  
 conv2d_48 (Conv2D)             (None, 128, 128, 12  147584      ['conv2d_47[0][0]']              
                                8)                                                                
                                                                                                  
 conv2d_transpose_11 (Conv2DTra  (None, 256, 256, 32  36896      ['conv2d_48[0][0]']              
 nspose)                        )                                                                 
                                                                                                  
 concatenate_11 (Concatenate)   (None, 256, 256, 64  0           ['conv2d_transpose_11[0][0]',    
          

<keras.engine.functional.Functional at 0x1f2047b4c10>