**Attention UNet Architecture**

Description for Code:

This file implements an Attention UNet architecture tailored for geospatial image segmentation tasks. It integrates attention mechanisms to enhance feature localization, particularly useful for complex datasets like satellite images and DEM layers. The code is structured to handle multi-modal inputs and is optimized for tasks like flood, landslide, glacier extent segmentation. Configurable for experimentation with different backbone (Resnet50) networks and loss functions (e.g., focal loss, Dice loss).


In [2]:
from tensorflow.keras.layers import Conv2D, Activation, concatenate, Conv2DTranspose,MaxPooling2D,Input,Cropping2D,Lambda,Dropout,BatchNormalization,Add,Multiply,UpSampling2D
from tensorflow.keras.models import Model
import tensorflow.keras.backend as K
from tensorflow.keras.losses import binary_crossentropy
from tensorflow.keras.metrics import MeanIoU
#from focal_loss import BinaryFocalLoss

In [4]:
#Attention Block
def attention_block(g,x,filter_no):

  # g decoder layer
  # x encoder layer

  g_shape = K.int_shape(g) #shape of g layer from lower layer

  x_shape = K.int_shape(x) #shape of x layer from encoder conv layer

  theta_g = Conv2D(filter_no,kernel_size=(1,1),kernel_initializer='he_normal', padding='same')(g) # First Conv

  phi_x = Conv2D(filter_no,kernel_size=(1,1),strides=(1,1),kernel_initializer='he_normal', padding='same')(x) # Second Conv

  combined = Add()([theta_g, phi_x]) #Combining the layers

  activated = Activation('relu')(combined) # Adding activation layer

  attention = Conv2D(1,kernel_size=(1,1),kernel_initializer='he_normal', padding='same')(activated) #Third Conv

  attention_sig = Activation('sigmoid')(attention) # Adding sigmoid activation layer

  sigmoid_shape = K.int_shape(attention_sig) # Shape of sigmoid shape

  #up_sampled = UpSampling2D(size=(x_shape[1]//sigmoid_shape[1],x_shape[2]//sigmoid_shape[2]))(attention_sig) # Upsampling the layer to match the shape of X layer if needed based on your connections

  weighted_x = Multiply()([attention_sig, x]) # Multiplying the x layer with upsampled layer, ie is here we are just multiplying the weights with the x layer before skip connection in the unet to get maximum result

  refined_output = Conv2D(filter_no, (1, 1), kernel_initializer='he_normal', padding='same')(weighted_x) # for refienment

  refined_output = BatchNormalization()(refined_output) # Adding normalization

  return  refined_output

In [5]:
# The Modified Attention_UNet Architecture

img_width = 256
img_height = 256
bands = 3

# ENCODER
input = Input((img_width, img_height, bands))

# Encoder block 1
s1 = Conv2D(16, (3, 3), kernel_initializer='he_normal', padding='same')(input)
s1 = BatchNormalization(axis=3)(s1)
s1 = Activation('relu')(s1)
s1 = Dropout(0.1)(s1)
s1 = Conv2D(16, (3, 3), kernel_initializer='he_normal', padding='same')(s1)
s1 = BatchNormalization(axis=3)(s1)
s1 = Activation('relu')(s1)
p1 = MaxPooling2D(pool_size=(2, 2))(s1)

# Encoder block 2
s2 = Conv2D(32, (3, 3), kernel_initializer='he_normal', padding='same')(p1)
s2 = BatchNormalization(axis=3)(s2)
s2 = Activation('relu')(s2)
s2 = Dropout(0.1)(s2)
s2 = Conv2D(32, (3, 3), kernel_initializer='he_normal', padding='same')(s2)
s2 = BatchNormalization(axis=3)(s2)
s2 = Activation('relu')(s2)
p2 = MaxPooling2D(pool_size=(2, 2))(s2)

# Encoder block 3
s3 = Conv2D(64, (3, 3), kernel_initializer='he_normal', padding='same')(p2)
s3 = BatchNormalization(axis=3)(s3)
s3 = Activation('relu')(s3)
s3 = Dropout(0.2)(s3)
s3 = Conv2D(64, (3, 3), kernel_initializer='he_normal', padding='same')(s3)
s3 = BatchNormalization(axis=3)(s3)
s3 = Activation('relu')(s3)
p3 = MaxPooling2D(pool_size=(2, 2))(s3)

# Encoder block 4
s4 = Conv2D(128, (3, 3), kernel_initializer='he_normal', padding='same')(p3)
s4 = BatchNormalization(axis=3)(s4)
s4 = Activation('relu')(s4)
s4 = Dropout(0.2)(s4)
s4 = Conv2D(128, (3, 3), kernel_initializer='he_normal', padding='same')(s4)
s4 = BatchNormalization(axis=3)(s4)
s4 = Activation('relu')(s4)
p4 = MaxPooling2D(pool_size=(2, 2))(s4)

# Base block
s5 = Conv2D(256, (3, 3), kernel_initializer='he_normal', padding='same')(p4)
s5 = BatchNormalization(axis=3)(s5)
s5 = Activation('relu')(s5)
s5 = Dropout(0.3)(s5)
s5 = Conv2D(256, (3, 3), kernel_initializer='he_normal', padding='same')(s5)
s5 = BatchNormalization(axis=3)(s5)
s5 = Activation('relu')(s5)

# DECODER
# Decoder block 1
u1 = Conv2DTranspose(128, (2, 2), strides=(2, 2), padding='same')(s5)
#u1 = attention_gate(u1,128)
a1 = attention_block(u1,s4,filter_no=128)
u1 = concatenate([u1, a1])
s6 = Conv2D(128, (3, 3), kernel_initializer='he_normal', padding='same')(u1)
s6 = BatchNormalization(axis=3)(s6)
s6 = Activation('relu')(s6)
s6 = Dropout(0.2)(s6)
s6 = Conv2D(128, (3, 3), kernel_initializer='he_normal', padding='same')(s6)
s6 = BatchNormalization(axis=3)(s6)
s6 = Activation('relu')(s6)

# Decoder block 2
u2 = Conv2DTranspose(64, (2, 2), strides=(2, 2), padding='same')(s6)
#u2 = attention_gate(u2,128)
a2 = attention_block(u2,s3,filter_no=64)
u2 = concatenate([u2, a2])
s7 = Conv2D(64, (3, 3), kernel_initializer='he_normal', padding='same')(u2)
s7 = BatchNormalization(axis=3)(s7)
s7 = Activation('relu')(s7)
s7 = Dropout(0.2)(s7)
s7 = Conv2D(64, (3, 3), kernel_initializer='he_normal', padding='same')(s7)
s7 = BatchNormalization(axis=3)(s7)
s7 = Activation('relu')(s7)

# Decoder block 3
u3 = Conv2DTranspose(32, (2, 2), strides=(2, 2), padding='same')(s7)
#u3 = attention_gate(u3,128)
a3 = attention_block(u3,s2,filter_no=32)
u3 = concatenate([u3, a3])
s8 = Conv2D(32, (3, 3), kernel_initializer='he_normal', padding='same')(u3)
s8 = BatchNormalization(axis=3)(s8)
s8 = Activation('relu')(s8)
s8 = Dropout(0.1)(s8)
s8 = Conv2D(32, (3, 3), kernel_initializer='he_normal', padding='same')(s8)
s8 = BatchNormalization(axis=3)(s8)
s8 = Activation('relu')(s8)

# Decoder block 4
u4 = Conv2DTranspose(16, (2, 2), strides=(2, 2), padding='same')(s8)
#u4 = attention_gate(u4,128)
a4 = attention_block(u4,s1,filter_no=16)
u4 = concatenate([u4, a4])
s9 = Conv2D(16, (3, 3), kernel_initializer='he_normal', padding='same')(u4)
s9 = BatchNormalization(axis=3)(s9)
s9 = Activation('relu')(s9)
s9 = Dropout(0.1)(s9)
s9 = Conv2D(16, (3, 3), kernel_initializer='he_normal', padding='same')(s9)
s9 = BatchNormalization(axis=3)(s9)
s9 = Activation('relu')(s9)

output = Conv2D(1, (1, 1), activation='sigmoid')(s9)

model = Model(inputs=input, outputs=output)

mean_iou = MeanIoU(num_classes=2)

#loss_func = BinaryFocalLoss(gamma=2)

model.compile(optimizer='adam', loss='binary_crossentropy' , metrics=['accuracy',mean_iou])

model.summary()