<a href="https://colab.research.google.com/github/Y-Noor/attention-unet/blob/main/attentionunet.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [17]:
import tensorflow as tf
from tensorflow.keras import layers, models, backend as K

In [30]:
def attention_gate(x, g, inter_shape):
    # Input x: Feature map from the encoder
    # Input g: Gating signal from the decoder
    # inter_shape: Number of intermediate filters in the attention mechanism

    x_val = x
    g_val = g

    phi_x = layers.Conv2D(inter_shape, kernel_size=1, strides=1, padding='same')(x_val)
    phi_g = layers.Conv2D(inter_shape, kernel_size=1, strides=1, padding='same')(g_val)

    add_xg = layers.add([phi_x, phi_g])

    relu_xg = layers.Activation('relu')(add_xg)

    phi_xg = layers.Conv2D(1, kernel_size=1, strides=1, padding='same')(relu_xg)

    sigmoid_xg = layers.Activation('sigmoid')(phi_xg)

    attention_coeffs = layers.multiply([x_val, sigmoid_xg])

    return attention_coeffs


From what I understand, there is a pattern of:

> conv

> conv -> send to decoder layer through attention gate

> pool

followed by:
> upsample

> upsample

> concatenate with signal received from encoder


In [31]:
# DOES NOT WORK BECAUSE Keras might not be able to automatically infer the new shape. To resolve this, you need to EXPLICITLY provide the output_shape argument in the Lambda layer.
# def upsampling(x):
#     return tf.image.resize(x, size=(n//2, n//2), method='nearest')

# # upsampled = layers.Lambda(upsampling)(input_tensor)


In [34]:
def attention_unet(input_shape, num_classes):
    inputs = layers.Input(input_shape)

    # encoder
    conv1 = layers.Conv2D(64, (3, 3), activation='relu', padding='same')(inputs)
    conv1 = layers.Conv2D(64, (3, 3), activation='relu', padding='same')(conv1)
    pool1 = layers.MaxPooling2D(pool_size=(2, 2))(conv1)

    conv2 = layers.Conv2D(128, (3, 3), activation='relu', padding='same')(pool1)
    conv2 = layers.Conv2D(128, (3, 3), activation='relu', padding='same')(conv2)
    pool2 = layers.MaxPooling2D(pool_size=(2, 2))(conv2)

    conv3 = layers.Conv2D(256, (3, 3), activation='relu', padding='same')(pool2)
    conv3 = layers.Conv2D(256, (3, 3), activation='relu', padding='same')(conv3)
    pool3 = layers.MaxPooling2D(pool_size=(2, 2))(conv3)

    conv4 = layers.Conv2D(512, (3, 3), activation='relu', padding='same')(pool3)
    conv4 = layers.Conv2D(512, (3, 3), activation='relu', padding='same')(conv4)
    pool4 = layers.MaxPooling2D(pool_size=(2, 2))(conv4)

    # lowest depth
    conv5 = layers.Conv2D(1024, (3, 3), activation='relu', padding='same')(pool4)
    conv5 = layers.Conv2D(1024, (3, 3), activation='relu', padding='same')(conv5)



    # decoder
    up6 = layers.Conv2DTranspose(512, (2, 2), strides=(2, 2), padding='same')(conv5)
    attn6 = attention_gate(conv4, up6, 512)
    merge6 = layers.concatenate([up6, attn6], axis=3)
    conv6 = layers.Conv2D(512, (3, 3), activation='relu', padding='same')(merge6)
    conv6 = layers.Conv2D(512, (3, 3), activation='relu', padding='same')(conv6)

    up7 = layers.Conv2DTranspose(256, (2, 2), strides=(2, 2), padding='same')(conv6)
    attn7 = attention_gate(conv3, up7, 256)
    merge7 = layers.concatenate([up7, attn7], axis=3)
    conv7 = layers.Conv2D(256, (3, 3), activation='relu', padding='same')(merge7)
    conv7 = layers.Conv2D(256, (3, 3), activation='relu', padding='same')(conv7)

    up8 = layers.Conv2DTranspose(128, (2, 2), strides=(2, 2), padding='same')(conv7)
    attn8 = attention_gate(conv2, up8, 128)
    merge8 = layers.concatenate([up8, attn8], axis=3)
    conv8 = layers.Conv2D(128, (3, 3), activation='relu', padding='same')(merge8)
    conv8 = layers.Conv2D(128, (3, 3), activation='relu', padding='same')(conv8)

    up9 = layers.Conv2DTranspose(64, (2, 2), strides=(2, 2), padding='same')(conv8)
    attn9 = attention_gate(conv1, up9, 64)
    merge9 = layers.concatenate([up9, attn9], axis=3)
    conv9 = layers.Conv2D(64, (3, 3), activation='relu', padding='same')(merge9)
    conv9 = layers.Conv2D(64, (3, 3), activation='relu', padding='same')(conv9)

    # Output layer
    conv10 = layers.Conv2D(num_classes, (1, 1), activation='softmax')(conv9)

    model = models.Model(inputs=inputs, outputs=conv10)

    return model

In [35]:
model = attention_unet(input_shape=(128, 128, 3), num_classes=1)
model.compile(optimizer='adam', loss='binary_crossentropy', metrics=['accuracy'])
model.summary()
