In [22]:
import tensorflow as tf
import tensorflow_datasets as tfds

import matplotlib.pyplot as plt
import numpy as np


In [23]:
def conv_block(inputs, filters, kernel_size=3):
    x= inputs
    for i in range(2):
        x= tf.keras.layers.Conv2D(filters, kernel_size, kernel_initializer="he_normal",
                                 padding="same")(x)
        x= tf.keras.layers.Activation("relu")(x)
    return x
    

In [24]:
def encoder_block(inputs, filters=64):
    f= conv_block(inputs, filters=filters) ##output for the concat
    p= tf.keras.layers.MaxPooling2D((2,2))(f)
    p= tf.keras.layers.Dropout(0.2)(p)
    return f, p
    

In [25]:
def encoder(inputs):
    filters=64
    f1, p1 = encoder_block(inputs, filters)
    f2, p2 = encoder_block(p1, filters*2)
    f3, p3 = encoder_block(p2, filters*4)
    f4, p4 = encoder_block(p3, filters*8)
    return p4, (f1, f2, f3, f4)

In [26]:
def bottleneck(inputs):
    bottle_neck = conv_block(inputs, 1024)
    return bottle_neck

In [27]:
def decoder_block(inputs, conv_out, filters=512, kernel_size=3, strides=2):
    u= tf.keras.layers.Conv2DTranspose(filters, kernel_size, strides, padding="same")(inputs)
    cat= tf.keras.layers.concatenate([u, conv_out])
    cat= tf.keras.layers.Dropout(0.2)(cat)
    cat= conv_block(cat, filters, kernel_size)
    return cat

In [28]:
def decoder(inputs, convs, out_channels,filters=512,):
    f1, f2, f3, f4 = convs
    
    c6 = decoder_block(inputs, f4, filters)
    c7 = decoder_block(c6, f3, filters/2)
    c8= decoder_block(c7, f2, filters/4)
    c9= decoder_block(c8, f1, filters/8)
    
    outputs= tf.keras.layers.Conv2D(out_channels, 1, activation="sigmoid")(c9)
    return outputs

In [29]:
OUTPUT_CHANNELS=1
def unet():
    
    inputs= tf.keras.layers.Input(shape=(128, 128, 3))
    inputs= tf.keras.layers.Lambda(lambda x: x/255)(inputs)
    encoder_output, convs = encoder(inputs)
    bottle_neck= bottleneck(encoder_output)
    outputs= decoder(bottle_neck, convs, OUTPUT_CHANNELS)
    model= tf.keras.Model(inputs=inputs, outputs=outputs)
    return model

model= unet()
model.summary()

Model: "model_2"
__________________________________________________________________________________________________
 Layer (type)                   Output Shape         Param #     Connected to                     
 input_6 (InputLayer)           [(None, 128, 128, 3  0           []                               
                                )]                                                                
                                                                                                  
 conv2d_38 (Conv2D)             (None, 128, 128, 64  1792        ['input_6[0][0]']                
                                )                                                                 
                                                                                                  
 activation_36 (Activation)     (None, 128, 128, 64  0           ['conv2d_38[1][0]']              
                                )                                                           

 activation_47 (Activation)     (None, 16, 16, 512)  0           ['conv2d_49[1][0]']              
                                                                                                  
 conv2d_transpose_9 (Conv2DTran  (None, 32, 32, 256)  1179904    ['activation_47[1][0]']          
 spose)                                                                                           
                                                                                                  
 concatenate_9 (Concatenate)    (None, 32, 32, 512)  0           ['conv2d_transpose_9[1][0]',     
                                                                  'activation_41[1][0]']          
                                                                                                  
 dropout_21 (Dropout)           (None, 32, 32, 512)  0           ['concatenate_9[1][0]']          
                                                                                                  
 conv2d_50