In [31]:
import tensorflow as tf
from tensorflow.keras.layers import GlobalAveragePooling2D, GlobalMaxPool2D, Activation, AveragePooling2D, MaxPool2D, Dense, Conv2D,Lambda

In [66]:
class Channel_Attention(tf.keras.layers.Layer) : # Channel attention module assuming the input dimensions to have channels-last order
  def __init__(self,C,ratio) :
    super(Channel_Attention,self).__init__()
    self.avg_pool = GlobalAveragePooling2D()
    self.max_pool = GlobalMaxPool2D()
    self.activation = Activation('sigmoid')
    self.fc1 = Dense(C/ratio, activation = 'relu')
    self.fc2 = Dense(C)
  
  def call(self,x) :
    avg_out1 = self.avg_pool(x)
    avg_out2 = self.fc1(avg_out1)
    avg_out3 = self.fc2(avg_out2)
    max_out1 = self.max_pool(x)
    max_out2 = self.fc1(max_out1)
    max_out3 = self.fc2(max_out2)
    add_out = tf.math.add(max_out3,avg_out3)
    channel_att = self.activation(add_out)
    return channel_att  

In [67]:
class Spatial_Attention(tf.keras.layers.Layer) : # spatial attention module assuming the input dimensions to have channels-last order
  def __init__(self) :
    super(Spatial_Attention,self).__init__()
    self.conv2d = Conv2D(1,(7,7),padding='same',activation='sigmoid')
    self.avg_pool_chl = Lambda(lambda x:tf.keras.backend.mean(x,axis=3,keepdims=True)) # avg-pooling along channel axis
    self.max_pool_chl = Lambda(lambda x:tf.keras.backend.max(x,axis=3,keepdims=True))  # max-pooling along channel axis
  
  def call(self,x) :
    avg_out1 = self.avg_pool_chl(x)
    max_out1 = self.max_pool_chl(x)
    concat_out = tf.concat([avg_out1,max_out1],axis=-1)
    spatial_att = self.conv2d(concat_out)
    return spatial_att 

In [68]:
class CBAM(tf.keras.layers.Layer) : # convolutional block attention module assuming the input dimensions to have channels-last order 
  def __init__(self,C,ratio) :
    super(CBAM,self).__init__()
    self.C = C
    self.ratio = ratio
    self.channel_attention = Channel_Attention(self.C,self.ratio)
    self.spatial_attention = Spatial_Attention()
  def call(self,y,H,W,C) :
    ch_out1 = self.channel_attention(y)
    ch_out2 = tf.expand_dims(ch_out1, axis=1)
    ch_out3 = tf.expand_dims(ch_out2, axis=2)
    ch_out4 = tf.tile(ch_out3, multiples=[1,H,W,1])
    ch_out5 = tf.math.multiply(ch_out4,y)
    sp_out1 = self.spatial_attention(ch_out5)
    sp_out2 = tf.tile(sp_out1, multiples = [1,1,1,C])
    sp_out3 = tf.math.multiply(sp_out2,ch_out5)
    return sp_out3        

In [59]:
inputs = tf.keras.Input(shape=(15,15,64))

In [60]:
channel_atts = Channel_Attention(64,16)
spatial_atts = Spatial_Attention()
cbam = CBAM(64,16)

In [61]:
channel_atts(inputs).shape

TensorShape([None, 64])

In [62]:
spatial_atts(inputs).shape

TensorShape([None, 15, 15, 1])

In [63]:
cbam(inputs,15,15,64).shape

TensorShape([None, 15, 15, 64])