In [None]:
import tensorflow as tf
from tensorflow.keras.layers import *
from tensorflow.keras.models import Model

size = 256

def conv_block(x, num_filters):
    x = Conv2D(num_filters, (3, 3), padding="same")(x)
    x = BatchNormalization()(x)
    x = Activation("relu")(x)

    x = Conv2D(num_filters, (3, 3), padding="same")(x)
    x = BatchNormalization()(x)
    x = Activation("relu")(x)

    return x

def aspp_block(x, num_filters):
    ## ASPP with atrous convolutions
    atrous_rates = [6, 12, 18]
    aspp_pooling = GlobalAveragePooling2D()(x)
    aspp_pooling = Reshape((1, 1, num_filters))(aspp_pooling)
    aspp_pooling = Conv2D(num_filters, (1, 1), padding="same", use_bias=False)(aspp_pooling)
    aspp_pooling = BatchNormalization()(aspp_pooling)
    aspp_pooling = Activation("relu")(aspp_pooling)
    aspp_pooling = UpSampling2D((size // 16, size // 16), interpolation="bilinear")(aspp_pooling)

    concat_layers = [x]
    for rate in atrous_rates:
        aspp_branch = Conv2D(num_filters, (3, 3), padding="same", dilation_rate=rate)(x)
        aspp_branch = BatchNormalization()(aspp_branch)
        aspp_branch = Activation("relu")(aspp_branch)
        concat_layers.append(aspp_branch)

    x = Concatenate()(concat_layers)
    x = Conv2D(num_filters, (1, 1), padding="same", use_bias=False)(x)
    x = BatchNormalization()(x)
    x = Activation("relu")(x)

    x = Add()([x, aspp_pooling])

    return x

def build_model():
    size = 256
    num_filters = [16, 32, 48, 64]
    inputs = Input((size, size, 3))

    skip_x = []
    x = inputs
    ## Encoder
    for f in num_filters:
        x = conv_block(x, f)
        skip_x.append(x)
        x = MaxPooling2D((2, 2))(x)

    ## Bridge
    x = conv_block(x, num_filters[-1])

    ## ASPP Block
    x = aspp_block(x, num_filters[-1])

    num_filters.reverse()
    skip_x.reverse()
    ## Decoder
    for i, f in enumerate(num_filters):
        x = UpSampling2D((2, 2))(x)
        xs = skip_x[i]
        x = Concatenate()([x, xs])
        x = conv_block(x, f)

    ## Output
    x = Conv2D(1, (1, 1), padding="same")(x)
    x = Activation("sigmoid")(x)

    return Model(inputs, x)

if __name__ == "__main__":
    model = build_model()
    model.summary()