In [1]:
from tensorflow.keras.layers import Conv2D, BatchNormalization, Activation, MaxPool2D, UpSampling2D, Concatenate, Input, ZeroPadding2D
from tensorflow.keras.models import Model

def batchnorm_relu(inputs):
    x = BatchNormalization()(inputs)
    x = Activation("relu")(x)
    return x

def residual_block(inputs, num_filters, strides=1):
    """ Convolutional Layer """
    x = batchnorm_relu(inputs)
    x = Conv2D(num_filters, 3, padding="same", strides=strides)(x)
    x = batchnorm_relu(x)
    x = Conv2D(num_filters, 3, padding="same", strides=1)(x)

    """ Shortcut Connection """
    s = Conv2D(num_filters, 1, padding="same", strides=strides)(inputs)
    x = x + s
    return x

def decoder_block(inputs, skip_features, num_filters):
    x = UpSampling2D((2, 2))(inputs)
    x = Concatenate()([x, skip_features])
    x = residual_block(x, num_filters, strides=1)
    return x

def build_resunet(input_shape):
    inputs = Input(input_shape)

    """ Encoder 1 """
    x = Conv2D(64, 3, padding="same", strides=1)(inputs)
    x = batchnorm_relu(x)
    x = Conv2D(64, 3, padding="same", strides=1)(x)
    s = Conv2D(64, 1, padding="same", strides=1)(inputs)
    s1 = x + s

    """ Encoder 2 and 3 """
    s2 = residual_block(s1, 128, strides=2)
    s3 = residual_block(s2, 256, strides=2)

    """ Bridge """
    b = residual_block(s3, 512, strides=2)

    """ Decoder 1, 2, 3 """
    d1 = decoder_block(b, s3, 256)
    d2 = decoder_block(d1, s2, 128)
    d3 = decoder_block(d2, s1, 64)

    """ Classifier """
    outputs = Conv2D(1, 1, padding="same", activation="sigmoid")(d3)

    """ Model """
    model = Model(inputs, outputs)
    return model

if __name__ == "__main__":
    model = build_resunet((256, 256, 3))
    model.summary()

Model: "model"
__________________________________________________________________________________________________
 Layer (type)                   Output Shape         Param #     Connected to                     
 input_1 (InputLayer)           [(None, 256, 256, 3  0           []                               
                                )]                                                                
                                                                                                  
 conv2d (Conv2D)                (None, 256, 256, 64  1792        ['input_1[0][0]']                
                                )                                                                 
                                                                                                  
 batch_normalization (BatchNorm  (None, 256, 256, 64  256        ['conv2d[0][0]']                 
 alization)                     )                                                             

                                                                                                  
 tf.__operators__.add_3 (TFOpLa  (None, 32, 32, 512)  0          ['conv2d_10[0][0]',              
 mbda)                                                            'conv2d_11[0][0]']              
                                                                                                  
 up_sampling2d (UpSampling2D)   (None, 64, 64, 512)  0           ['tf.__operators__.add_3[0][0]'] 
                                                                                                  
 concatenate (Concatenate)      (None, 64, 64, 768)  0           ['up_sampling2d[0][0]',          
                                                                  'tf.__operators__.add_2[0][0]'] 
                                                                                                  
 batch_normalization_7 (BatchNo  (None, 64, 64, 768)  3072       ['concatenate[0][0]']            
 rmalizati

                                )                                                                 
                                                                                                  
 tf.__operators__.add_6 (TFOpLa  (None, 256, 256, 64  0          ['conv2d_19[0][0]',              
 mbda)                          )                                 'conv2d_20[0][0]']              
                                                                                                  
 conv2d_21 (Conv2D)             (None, 256, 256, 1)  65          ['tf.__operators__.add_6[0][0]'] 
                                                                                                  
Total params: 8,227,393
Trainable params: 8,220,993
Non-trainable params: 6,400
__________________________________________________________________________________________________
