In [2]:
from tensorflow.keras.layers import Conv2D, BatchNormalization, Activation, MaxPool2D, UpSampling2D, Concatenate, Input
from tensorflow.keras.models import Model

In [3]:
def batchnorm_relu(inputs):
    # Batch Normalization & ReLU
    x = BatchNormalization()(inputs)
    x = Activation("relu")(x)
    return x

In [4]:
def residual_block(inputs, num_filters, strides=1):
    # Convolutional Layers
    x = batchnorm_relu(inputs)
    x = Conv2D(num_filters, 3, padding="same", strides=strides)(x)
    x = batchnorm_relu(x)
    x = Conv2D(num_filters, 3, padding="same", strides=1)(x)

    # Shortcut Connection (Identity Mapping)
    s = Conv2D(num_filters, 1, padding="same", strides=strides)(inputs)

    # Addition
    x = x + s
    return x

In [5]:
# Decoder Block
def decoder_block(inputs, skip_features, num_filters):
    x = UpSampling2D((2, 2))(inputs)
    x = Concatenate()([x, skip_features])
    x = residual_block(x, num_filters, strides=1)
    return x

In [6]:
def build_resunet(input_shape, num_classes):
    # RESUNET Architecture

    inputs = Input(input_shape)

    # Endoder 1
    x = Conv2D(64, 3, padding="same", strides=1)(inputs)
    x = batchnorm_relu(x)
    x = Conv2D(64, 3, padding="same", strides=1)(x)
    s = Conv2D(64, 1, padding="same")(inputs)
    s1 = x + s

    # Encoder 2, 3
    s2 = residual_block(s1, 128, strides=2)
    s3 = residual_block(s2, 256, strides=2)

    # Bridge
    b = residual_block(s3, 512, strides=2)

    # Decoder 1, 2, 3
    x = decoder_block(b, s3, 256)
    x = decoder_block(x, s2, 128)
    x = decoder_block(x, s1, 64)

    # Classifier
    outputs = Conv2D(num_classes, 1, padding="same", activation="softmax")(x)

    # Model
    model = Model(inputs, outputs, name="RESUNET")
    return model

In [7]:
shape = (600, 600, 512)
model = build_resunet(shape, 11)

model.summary()

Model: "RESUNET"
__________________________________________________________________________________________________
 Layer (type)                   Output Shape         Param #     Connected to                     
 input_1 (InputLayer)           [(None, 600, 600, 5  0           []                               
                                12)]                                                              
                                                                                                  
 conv2d (Conv2D)                (None, 600, 600, 64  294976      ['input_1[0][0]']                
                                )                                                                 
                                                                                                  
 batch_normalization (BatchNorm  (None, 600, 600, 64  256        ['conv2d[0][0]']                 
 alization)                     )                                                           