In [53]:
import tensorflow as tf
import numpy as np
from keras.layers import Conv2D, MaxPooling2D, Flatten, UpSampling2D, Conv2DTranspose, BatchNormalization, Input, concatenate, Cropping2D
from keras.models import Model

In [3]:
def EncoderBlock(inputs, filters, max_pooling=True, skip=True):
  Conv = Conv2D(filters, kernel_size=(3,3), padding='same', activation='relu', kernel_initializer='HeNormal')(inputs)
  Conv = Conv2D(filters, kernel_size=(3,3), padding='same', activation='relu', kernel_initializer='HeNormal')(Conv)
  Conv = BatchNormalization()(Conv, training=False)
  if max_pooling:
    next_layer = MaxPooling2D((2,2), strides=2)(Conv)
  else :
    next_layer = Conv
  
  if skip:
    skip_conn = Conv
    return next_layer, skip_conn
  else: 
    return next_layer

In [54]:
def DecoderBlock(prev_layer, filters, skip_conn):
  up = Conv2DTranspose(filters, kernel_size=(3,3), strides=(2,2), padding='same')(prev_layer)
  if (up.shape != skip_conn.shape):
    size1 = up.shape[1]
    size2 = skip_conn.shape[2]
    delta = size2-size1
    crop = Cropping2D(cropping=((delta,0),(delta,0)))(skip_conn)
  else:
    crop = skip_conn

  concat = concatenate([up, crop], axis=-1)
  Conv = Conv2D(filters, kernel_size=(3,3), padding='same', activation='relu', kernel_initializer='HeNormal')(concat)
  Conv = Conv2D(filters, kernel_size=(3,3), padding='same', activation='relu', kernel_initializer='HeNormal')(Conv)
  return Conv



In [51]:
def UNET(shape):
  inputs = Input(shape=shape)
  x, skip_conn1 = EncoderBlock(inputs, filters=64)
  x, skip_conn2 = EncoderBlock(x, filters=128)
  x, skip_conn3 = EncoderBlock(x, filters=256)
  x, skip_conn4 = EncoderBlock(x, filters=512)
  x = EncoderBlock(x, filters=1024, max_pooling=False, skip=False)
  x = DecoderBlock(x, filters=512, skip_conn=skip_conn4)
  x = DecoderBlock(x, filters=256, skip_conn=skip_conn3)
  x = DecoderBlock(x, filters=128, skip_conn=skip_conn2)
  x = DecoderBlock(x, filters=64, skip_conn=skip_conn1)
  outputs = Conv2D(2, kernel_size=1)(x)
  model = Model(inputs=inputs, outputs=outputs)
  return model

In [55]:
model = UNET(shape=(572, 572, 3))
print(model.summary())

Model: "model_3"
__________________________________________________________________________________________________
 Layer (type)                   Output Shape         Param #     Connected to                     
 input_19 (InputLayer)          [(None, 572, 572, 3  0           []                               
                                )]                                                                
                                                                                                  
 conv2d_229 (Conv2D)            (None, 572, 572, 64  1792        ['input_19[0][0]']               
                                )                                                                 
                                                                                                  
 conv2d_230 (Conv2D)            (None, 572, 572, 64  36928       ['conv2d_229[0][0]']             
                                )                                                           