# Import Libraries

In [1]:
import tensorflow as tf
from tensorflow.keras.layers import GlobalAveragePooling2D, GlobalMaxPooling2D, Reshape, Dense, Input
from tensorflow.keras.layers import Activation, Concatenate, Conv2D, Multiply

# Channel Attention Module

In [2]:
def Channel_Attention_Module(input, ratio=16):

  b, _, _, channel = input.shape

  # Shared MLP
  l1 = Dense(channel//ratio, activation='relu', use_bias=False)
  l2 = Dense(channel, use_bias=False)

  # Global Average Pooling
  avepool = GlobalAveragePooling2D()(input)
  a = l1(avepool)
  a = l2(a)

  # Global Max Pooling
  maxpool = GlobalMaxPooling2D()(input)
  m = l1(maxpool)
  m = l2(m)

  # Add Average and Max Pooling
  concat = a + m
  concat = Activation('sigmoid')(concat)

  output = Multiply()([input, concat])

  return output

# Spatial Attention Module

In [3]:
def Spatial_Attention_Module(input):

  # Average Pooling
  avepool = tf.reduce_mean(input, axis=-1)
  avepool = tf.expand_dims(input, axis=-1)

  # Max Pooling
  maxpool = tf.reduce_max(input, axis=-1)
  maxpool = tf.expand_dims(input, axis=-1)

  # Concatenate Average and Max Pooling
  concat = Concatenate()([avepool, maxpool])

  conv = Conv2D(1, kernel_size=7, padding='same', activation='sigmoid')(concat)

  output = Multiply()([input, conv])

  return output

# CBAM

In [4]:
def CBAM(input):
  attention = Channel_Attention_Module(input)
  attention = Spatial_Attention_Module(attention)

  return attention