In [21]:
from tensorflow.keras.layers import Conv2D, BatchNormalization, Activation, MaxPool2D, Conv2DTranspose, Concatenate, Input
from tensorflow.keras.models import Model


def conv_block(inputs, num_filters):
    """
    num_filters: define the number of channels in the output feature map, generally it is equall to the number of classes
    """

    #1st layer is convolution layer with 3*3 kernal size and padding="same" i.e the output feature map will be same as input feature map, here inputs is the input image
    x=Conv2D(num_filters, 3, padding="same")(inputs)
    x=BatchNormalization()(x)
    x=Activation("relu")(x)

    x=Conv2D(num_filters, 3, padding="same")(x)
    x=BatchNormalization()(x)
    x=Activation("relu")(x)

    return x


def encoder_block(inputs, num_filters):
    x=conv_block(inputs, num_filters)
    p=MaxPool2D((2, 2))(x) # the height and width of feature map will reduce to half
    return x,p #we return two output as we need to have some skip connections


def decoder_block(inputs,skip_connection, num_filters):
    x = Conv2DTranspose(num_filters, 2, strides=2, padding="same")(inputs) #strides=2 signifies that the output feature map will be twice the input
    x = Concatenate()([x, skip_connection])
    x = conv_block(x, num_filters)
    return x

def Unet(input_shape, number_of_classes):
    inputs=Input(input_shape)

    #we call the encoder block
    skip_connection_1, pooling_block_1=encoder_block(inputs, num_filters=64)
    skip_connection_2, pooling_block_2=encoder_block(pooling_block_1, num_filters=128)
    skip_connection_3, pooling_block_3=encoder_block(pooling_block_2, num_filters=256)
    skip_connection_4, pooling_block_4=encoder_block(pooling_block_3, num_filters=512)

    #defining the bridge block between the encoder and decoder block
    bridge_block_1=conv_block(pooling_block_4, num_filters=1024 )


    #calling the decoder block
    decoder_1=decoder_block(bridge_block_1, skip_connection_4, 512)
    decoder_2=decoder_block(decoder_1, skip_connection_3, 256)
    decoder_3=decoder_block(decoder_2, skip_connection_2, 128)
    decoder_4=decoder_block(decoder_3, skip_connection_1, 64)
    print(decoder_4.shape)

    #defining the output layer
    output=Conv2D(number_of_classes, 1, padding="same", activation="softmax")(decoder_4)
    model=Model(inputs, output)
    return model


In [23]:
input_shape=(256, 256, 3)
model = Unet(input_shape, 20)
model.summary()

(None, 256, 256, 64)
Model: "model_3"
__________________________________________________________________________________________________
 Layer (type)                   Output Shape         Param #     Connected to                     
 input_11 (InputLayer)          [(None, 256, 256, 3  0           []                               
                                )]                                                                
                                                                                                  
 conv2d_145 (Conv2D)            (None, 256, 256, 64  1792        ['input_11[0][0]']               
                                )                                                                 
                                                                                                  
 batch_normalization_142 (Batch  (None, 256, 256, 64  256        ['conv2d_145[0][0]']             
 Normalization)                 )                                      