In [6]:
import tensorflow as tf
from tensorflow.keras.layers import Input
from tensorflow.keras.layers import Conv2D
from tensorflow.keras.layers import MaxPooling2D
from tensorflow.keras.layers import Dropout
from tensorflow.keras.layers import BatchNormalization
from tensorflow.keras.layers import Conv2DTranspose
from tensorflow.keras.layers import concatenate
from tensorflow.keras.losses import categorical_crossentropy
from tensorflow.keras import Model

In [4]:
def encoder_block(inputs, num_filters=32, dropout_prob=0.3, max_pooling=True):
    conv = Conv2D(num_filters,
                  kernel = 3,
                  padding = 'same',
                  activation = 'relu',
                 kernel_initializer='HeNormal')(inputs)
    conv1 = Conv2D(num_filters,
                  kernel = 3,
                  padding = 'same',
                  activation = 'relu',
                 kernel_initializer='HeNormal')(conv)
    conv2 = BatchNormalization()(conv, training= False)
    if droput_prob > 0:
        conv2 = Dropout(dropout_prob)(conv2)
    if max_pooling :
        next_layer = MaxPooling2d(pool_size = (2,2))(conv2)
    else:
        next_layer = conv

    skip_connection = conv

    return next_layer, skip_connection


In [5]:
def decoder_block(prev_layer_input, skip_layer_input, num_filters=32):
    transpose_layer = Conv2DTranspose(
        num_filters,
        (3,3),
        strides=(2,2),
        padding='same') (prev_layer_input)
    merge = concatenate([up, skip_layer_input], axis = 3)

    conv = Conv2D(num_filters,
                  kernel = 3,
                  padding = 'same',
                  activation = 'relu',
                 kernel_initializer='HeNormal')(merge)
    conv = Conv2D(num_filters,
                  kernel = 3,
                  padding = 'same',
                  activation = 'relu',
                 kernel_initializer='HeNormal')(conv)
    return conv

In [8]:
def UnetGraph(input_size=(128,128,3), n_filters=32, n_classes=3):
    inputs = Input(input_size)

    #ENCODER BLOCKS
    cblock1 = EncoderMiniBlock(inputs, n_filters,dropout_prob=0, max_pooling=True)
    cblock2 = EncoderMiniBlock(cblock1[0],n_filters*2,dropout_prob=0, max_pooling=True)
    cblock3 = EncoderMiniBlock(cblock2[0], n_filters*4,dropout_prob=0, max_pooling=True)
    cblock4 = EncoderMiniBlock(cblock3[0], n_filters*8,dropout_prob=0.3, max_pooling=True)
    cblock5 = EncoderMiniBlock(cblock4[0], n_filters*16, dropout_prob=0.3, max_pooling=False) 

    #DECODER BLOCKS
    ublock6 = DecoderMiniBlock(cblock5[0], cblock4[1],  n_filters * 8)
    ublock7 = DecoderMiniBlock(ublock6, cblock3[1],  n_filters * 4)
    ublock8 = DecoderMiniBlock(ublock7, cblock2[1],  n_filters * 2)
    ublock9 = DecoderMiniBlock(ublock8, cblock1[1],  n_filters)

    conv9 = Conv2D(n_filters,
                 3,
                 activation='relu',
                 padding='same',
                 kernel_initializer='he_normal')(ublock9)
    conv10 = Conv2D(n_classes, 1, padding='same', activation = 'softmax')(conv9)
    
    model = Model(inputs=inputs , outputs=conv10)
    
    return model