In [12]:
# Importing the libraries
# This will solve import issues.
from tensorflow import keras
from keras.layers import Conv2D, Input, BatchNormalization, Activation, MaxPooling2D, GlobalAveragePooling2D, Dense
from keras.models import Model
from keras.applications import ResNet50
from keras.layers import Conv2DTranspose
from keras.layers import Concatenate

In [11]:
# Implementing Convolution Block
# Relu activation is used after every convolution layer
# Batch Normalization is used after every convolution layer as it improves the training speed and reduces the chances of getting stuck in local minima

def conv_block(inputs, num_filters):
    x = Conv2D(num_filters, 3, padding='same')(inputs)
    # Batch Normalization is used after every convolution layer as it improves the training speed and reduces the chances of getting stuck in local minima
    x = BatchNormalization()(x)
    x = Activation('relu')(x)
    
    # Changing inputs to x so that we can add the inputs to the output of the convolution block
    x = Conv2D(num_filters, 3, padding='same')(x)
    x = BatchNormalization()(x)
    x = Activation('relu')(x)
    
    return x

In [13]:
# Building a Decoder Block

def decoder_block(inputs, skip_features, num_filters):
    x = Conv2DTranspose(num_filters, (2, 2), strides=2, padding='same')(inputs)
    x = Concatenate()([x, skip_features])
    x = conv_block(x, num_filters)
    return x

In [3]:
def build_resnet50_unet(input_shape):
    # Input
    inputs = Input(input_shape)
    
    # Pretained ResNet50 model 
    
    # include_top=False means that we don't want the classification layers of ResNet50, only want colvolutional layers
    # we want the model to be initialized with it's weights 
    resnet50 = ResNet50(include_top=False, weights='imagenet', input_tensor=inputs)
    
    # Summary of ResNet50
    resnet50.summary()
    
    # Encoder
    # we have to skip the connections from the encoder to the decoder
    # There are 4 stages in ResNet50
    s1 = resnet50.get_layer("input_1").output # input layer 512x512x3
    s2 = resnet50.get_layer("conv1_relu").output # 256x256x64
    s3 = resnet50.get_layer("conv2_block3_out").output # 128x128x128
    s4 = resnet50.get_layer("conv3_block4_out").output # 64x64x256
    
    print(s1.shape, s2.shape, s3.shape, s4.shape)
    
    
    # Bridge (Bottleneck)
    
    b1 = resnet50.get_layer("conv4_block6_out").output # 32x32x512
    
    # Decoder 
    # Since we have 4 skip connections, we will have 4 decoder blocks
    d1 = decoder_block(b1, s4, 512) # 64x64x256
    d2 = decoder_block(d1, s3, 256) # 128x128x128
    d3 = decoder_block(d2, s2, 128) # 256x256x64
    d4 = decoder_block(d3, s1, 64) # 512x512x32
    
    # Output Layer
    # We will use a convolution layer with 1 filter and kernel size of 1
    outputs = Conv2D(1, 1, padding='same', activation='sigmoid')(d4)    
    # in case of multi class we will use softmax activation function and number of filters will be equal to number of classes
    
    
    # Model
    
    model = Model(inputs, outputs, name='ResNet50_Unet')
    return model

    

In [4]:
# Giving the input shape
input_shape = (512, 512, 3)

In [5]:
# Calling the function
model = build_resnet50_unet(input_shape)

Model: "resnet50"
__________________________________________________________________________________________________
 Layer (type)                   Output Shape         Param #     Connected to                     
 input_1 (InputLayer)           [(None, 512, 512, 3  0           []                               
                                )]                                                                
                                                                                                  
 conv1_pad (ZeroPadding2D)      (None, 518, 518, 3)  0           ['input_1[0][0]']                
                                                                                                  
 conv1_conv (Conv2D)            (None, 256, 256, 64  9472        ['conv1_pad[0][0]']              
                                )                                                                 
                                                                                           