In [4]:
import tensorflow as tf
import tensorflow_datasets as tfds

import matplotlib.pyplot as plt
import numpy as np


def conv_block(inputs, filters, kernel_size=3):
    x = inputs
    for i in range(2):
        x = tf.keras.layers.Conv2D(filters, kernel_size, kernel_initializer="he_normal",
                                   padding="same")(x)
        x = tf.keras.layers.BatchNormalization()(x)  # Apply BatchNorm
        x = tf.keras.layers.Activation("relu")(x)
    return x


def encoder_block(inputs, filters=64):
    f = conv_block(inputs, filters=filters)  # output for the concat
    p = tf.keras.layers.MaxPooling2D((2, 2))(f)
    p = tf.keras.layers.Dropout(0.2)(p)
    return f, p


def encoder(inputs):
    filters = 64
    f1, p1 = encoder_block(inputs, filters)
    f2, p2 = encoder_block(p1, filters * 2)
    f3, p3 = encoder_block(p2, filters * 4)
    f4, p4 = encoder_block(p3, filters * 8)
    return p4, (f1, f2, f3, f4)


def bottleneck(inputs):
    bottle_neck = conv_block(inputs, 1024)
    return bottle_neck


def decoder_block(inputs, conv_out, filters=512, kernel_size=3, strides=2):
    u = tf.keras.layers.Conv2DTranspose(filters, kernel_size, strides, padding="same")(inputs)
    cat = tf.keras.layers.concatenate([u, conv_out])
    cat = tf.keras.layers.Dropout(0.2)(cat)
    cat = conv_block(cat, filters, kernel_size)
    return cat


def decoder(inputs, convs, out_channels, filters=512):
    f1, f2, f3, f4 = convs

    c6 = decoder_block(inputs, f4, filters)
    c7 = decoder_block(c6, f3, filters / 2)
    c8 = decoder_block(c7, f2, filters / 4)
    c9 = decoder_block(c8, f1, filters / 8)

    outputs = tf.keras.layers.Conv2D(out_channels, 1, activation="softmax")(c9)
    return outputs


In [5]:
OUTPUT_CHANNELS=4
def unet():
    
    inputs= tf.keras.layers.Input(shape=(128, 128, 1))
    #inputs= tf.keras.layers.Lambda(lambda x: x/255)(inputs)
    encoder_output, convs = encoder(inputs)
    bottle_neck= bottleneck(encoder_output)
    outputs= decoder(bottle_neck, convs, OUTPUT_CHANNELS)
    model= tf.keras.Model(inputs=inputs, outputs=outputs)
    return model

model= unet()
model.summary()

Model: "model_1"
__________________________________________________________________________________________________
 Layer (type)                   Output Shape         Param #     Connected to                     
 input_2 (InputLayer)           [(None, 128, 128, 1  0           []                               
                                )]                                                                
                                                                                                  
 conv2d_19 (Conv2D)             (None, 128, 128, 64  640         ['input_2[0][0]']                
                                )                                                                 
                                                                                                  
 batch_normalization_18 (BatchN  (None, 128, 128, 64  256        ['conv2d_19[0][0]']              
 ormalization)                  )                                                           

                                                                                                  
 batch_normalization_26 (BatchN  (None, 8, 8, 1024)  4096        ['conv2d_27[0][0]']              
 ormalization)                                                                                    
                                                                                                  
 activation_26 (Activation)     (None, 8, 8, 1024)   0           ['batch_normalization_26[0][0]'] 
                                                                                                  
 conv2d_28 (Conv2D)             (None, 8, 8, 1024)   9438208     ['activation_26[0][0]']          
                                                                                                  
 batch_normalization_27 (BatchN  (None, 8, 8, 1024)  4096        ['conv2d_28[0][0]']              
 ormalization)                                                                                    
          

 concatenate_7 (Concatenate)    (None, 128, 128, 12  0           ['conv2d_transpose_7[0][0]',     
                                8)                                'activation_19[0][0]']          
                                                                                                  
 dropout_15 (Dropout)           (None, 128, 128, 12  0           ['concatenate_7[0][0]']          
                                8)                                                                
                                                                                                  
 conv2d_35 (Conv2D)             (None, 128, 128, 64  73792       ['dropout_15[0][0]']             
                                )                                                                 
                                                                                                  
 batch_normalization_34 (BatchN  (None, 128, 128, 64  256        ['conv2d_35[0][0]']              
 ormalizat