### Import Library

In [28]:
import tensorflow as tf
from tensorflow.keras.models import Model
import tensorflow.keras.layers as L
from tensorflow.keras.layers import PReLU

Activation functions: \
https://www.tensorflow.org/api_docs/python/tf/keras/activations

### Build Blocks

In [29]:
# RelU Activation
def conv_block_relu(x, num_filters): 
    x = L.Conv2D(num_filters, 3, padding="same")(x) 
    x = L.BatchNormalization()(x) 
    x = L.Activation("relu")(x)

    x = L.Conv2D(num_filters, 3, padding="same")(x) 
    x = L.BatchNormalization()(x) 
    x = L.Activation("relu")(x)

    return x

In [30]:
# PReLU activation
def conv_block_prelu(x, num_filters):
    x = L.Conv2D(num_filters, 3, padding="same")(x)
    x = L.BatchNormalization()(x) 
    x = PReLU()(x)
    
    x = L.Conv2D(num_filters, 3, padding="same")(x)
    x = L.BatchNormalization()(x) 
    x = PReLU()(x)
    
    return x

In [31]:
def encoder_block(x, num_filters):
    x = conv_block_prelu(x, num_filters)
    p = L.MaxPooling2D((2, 2))(x)
    return x, p

In [32]:
def attention_gate(g, s, num_filters): 
    Wg = L.Conv2D(num_filters, 1, padding="same")(g) 
    Wg = L.BatchNormalization()(Wg) 

    Ws = L.Conv2D(num_filters, 1, padding="same")(s)
    Ws = L.BatchNormalization()(Ws)

    output = L.Activation("relu")(Wg + Ws) 
    output = L.Conv2D(num_filters, 1, padding="same")(output) 
    output = L.Activation("sigmoid")(output) 

    return output * s

In [33]:
def decoder_block(x, s, num_filters): 
    x = L.UpSampling2D(interpolation="bilinear")(x) 
    s = attention_gate(x, s, num_filters)
    x = L.Concatenate()([x, s])
    x = conv_block_prelu(x, num_filters) 
    return x

In [34]:
def AttentionUnet(input_shape, num_classes):
    inputs = L.Input(input_shape)
    
    # Encoder
    s1, p1 = encoder_block(inputs, 64) 
    s2, p2 = encoder_block(p1, 128) 
    s3, p3 = encoder_block(p2, 256) 
    
    b1 = conv_block(p3, 512) 
    
    # Decoder
    d1 = decoder_block(b1, s3, 256) 
    d2 = decoder_block(d1, s2, 128)
    d3 = decoder_block(d2, s1, 64)
    
    # Output
    outputs = L.Conv2D(num_classes, 1, padding="same", activation="softmax")(d3)
    
    # Model
    model = Model(inputs, outputs, name="Attention-UNet")
    return model 

In [35]:
input_shape = (256, 256, 256) 
model = AttentionUnet(input_shape, 11)
model.summary()

Model: "Attention-UNet"
__________________________________________________________________________________________________
 Layer (type)                   Output Shape         Param #     Connected to                     
 input_10 (InputLayer)          [(None, 256, 256, 2  0           []                               
                                56)]                                                              
                                                                                                  
 conv2d_113 (Conv2D)            (None, 256, 256, 64  147520      ['input_10[0][0]']               
                                )                                                                 
                                                                                                  
 batch_normalization_96 (BatchN  (None, 256, 256, 64  256        ['conv2d_113[0][0]']             
 ormalization)                  )                                                    