In [2]:
import tensorflow as tf
import tensorflow_datasets as tfds

import matplotlib.pyplot as plt
import numpy as np


def conv_block(inputs, filters, kernel_size=3):
    x = inputs
    for i in range(2):
        x = tf.keras.layers.Conv2D(filters, kernel_size, kernel_initializer="he_normal",
                                   padding="same")(x)
        x = tf.keras.layers.BatchNormalization()(x)  # Apply BatchNorm
        x = tf.keras.layers.Activation("relu")(x)
    return x


def encoder_block(inputs, filters=16):
    f = conv_block(inputs, filters=filters)  # output for the concat
    p = tf.keras.layers.MaxPooling2D((2, 2))(f)
    p = tf.keras.layers.Dropout(0.2)(p)
    return f, p


def encoder(inputs):
    filters = 16
    f1, p1 = encoder_block(inputs, filters)
    f2, p2 = encoder_block(p1, filters * 2)
    f3, p3 = encoder_block(p2, filters * 4)
    f4, p4 = encoder_block(p3, filters * 8)
    return p4, (f1, f2, f3, f4)


def bottleneck(inputs):
    bottle_neck = conv_block(inputs, 256)
    return bottle_neck


def decoder_block(inputs, conv_out, filters=128, kernel_size=3, strides=2):
    u = tf.keras.layers.Conv2DTranspose(filters, kernel_size, strides, padding="same")(inputs)
    cat = tf.keras.layers.concatenate([u, conv_out])
    cat = tf.keras.layers.Dropout(0.2)(cat)
    cat = conv_block(cat, filters, kernel_size)
    return cat


def decoder(inputs, convs, out_channels, filters=128):
    f1, f2, f3, f4 = convs

    c6 = decoder_block(inputs, f4, filters)
    c7 = decoder_block(c6, f3, filters / 2)
    c8 = decoder_block(c7, f2, filters / 4)
    c9 = decoder_block(c8, f1, filters / 8)

    outputs = tf.keras.layers.Conv2D(out_channels, 1, activation="softmax")(c9)
    return outputs


In [3]:
OUTPUT_CHANNELS=4
def unet():
    
    inputs= tf.keras.layers.Input(shape=(128, 128, 1))
    #inputs= tf.keras.layers.Lambda(lambda x: x/255)(inputs)
    encoder_output, convs = encoder(inputs)
    bottle_neck= bottleneck(encoder_output)
    outputs= decoder(bottle_neck, convs, OUTPUT_CHANNELS)
    model= tf.keras.Model(inputs=inputs, outputs=outputs)
    return model

model= unet()
model.summary()

Model: "model"
__________________________________________________________________________________________________
 Layer (type)                   Output Shape         Param #     Connected to                     
 input_1 (InputLayer)           [(None, 128, 128, 1  0           []                               
                                )]                                                                
                                                                                                  
 conv2d (Conv2D)                (None, 128, 128, 16  160         ['input_1[0][0]']                
                                )                                                                 
                                                                                                  
 batch_normalization (BatchNorm  (None, 128, 128, 16  64         ['conv2d[0][0]']                 
 alization)                     )                                                             

                                                                                                  
 batch_normalization_8 (BatchNo  (None, 8, 8, 256)   1024        ['conv2d_8[0][0]']               
 rmalization)                                                                                     
                                                                                                  
 activation_8 (Activation)      (None, 8, 8, 256)    0           ['batch_normalization_8[0][0]']  
                                                                                                  
 conv2d_9 (Conv2D)              (None, 8, 8, 256)    590080      ['activation_8[0][0]']           
                                                                                                  
 batch_normalization_9 (BatchNo  (None, 8, 8, 256)   1024        ['conv2d_9[0][0]']               
 rmalization)                                                                                     
          

                                )                                 'activation_1[0][0]']           
                                                                                                  
 dropout_7 (Dropout)            (None, 128, 128, 32  0           ['concatenate_3[0][0]']          
                                )                                                                 
                                                                                                  
 conv2d_16 (Conv2D)             (None, 128, 128, 16  4624        ['dropout_7[0][0]']              
                                )                                                                 
                                                                                                  
 batch_normalization_16 (BatchN  (None, 128, 128, 16  64         ['conv2d_16[0][0]']              
 ormalization)                  )                                                                 
          