In [None]:
import tensorflow as tf
import matplotlib.pyplot as plt
from tensorflow.keras.applications  import ResNet50
from tensorflow.keras import layers
from tensorflow.keras import Model , models
from tensorflow import keras

In [None]:
base_model = ResNet50(input_shape=(224 , 224 , 3) , weights='imagenet' , include_top=False)

In [None]:
def spatial_attention(tensor) :
  AvgPool = keras.ops.mean(tensor  , axis=-1 , keepdims=True)
  MaxPool = keras.ops.max(tensor   , axis=-1 , keepdims=True)
  concat_features    = layers.Concatenate(axis=-1)([AvgPool , MaxPool])
  Conv_attention     = layers.Conv2D(filters=1 , kernel_size=(7 , 7) , padding='same')(concat_features)
  activate_attention = layers.Activation(activation='sigmoid')(Conv_attention)
  return layers.Multiply()([tensor , activate_attention])
#-----
def channel_attention(tensor , rr=16) :
  def shared_MLP(input_T) :
    n_channels = input_T.shape[-1]
    hidden_layer = layers.Dense(n_channels//rr , activation='relu' , kernel_initializer='he_normal')(input_T)
    output       = layers.Dense(n_channels)(hidden_layer)
    return output
  #----
  GAP = layers.GlobalAveragePooling2D(keepdims=True)(tensor)
  GMP = layers.GlobalMaxPooling2D(keepdims=True)(tensor)
  #----
  GAP_out = shared_MLP(GAP)
  GMP_out = shared_MLP(GMP)
  #----
  add = layers.Add()([GAP_out , GMP_out])
  activate_attention = layers.Activation(activation='sigmoid')(add)
  return layers.Multiply()([tensor , activate_attention])
#------
def CBAM(tensor) :
  sp_att   = channel_attention(tensor)
  cbam_att = spatial_attention(sp_att)
  return cbam_att

In [None]:
# base on the pretrained impementation of resnet models in keras package
def resnet_cbam(base_model) :
  def insert_cbam(layer, *args, **kwargs):
    out = layer(*args, **kwargs)
    if layer.name == "conv5_block3_3_bn":
        out = CBAM(out)
    return out
  #--------
  resnet_add_cbam = models.clone_model(
      base_model ,
      input_tensors=base_model.input ,
      call_function=insert_cbam
  )
  #---------
  weight_dict = {}
  for layer in base_model.layers :
    weight_dict[layer.name] = layer.get_weights()
  for layer in resnet_add_cbam.layers :
    try :
      layer_weights = weight_dict[layer.name]
      layer.set_weights(layer_weights)
      layer.trainable = False
    except :
      pass
  #-----
  return resnet_add_cbam

In [None]:
def top_cls(backbone , hidden_unit=32) :
  output = layers.GlobalAveragePooling2D()(backbone.output)
  output = layers.Dense(hidden_unit , activation='relu' , kernel_initializer='he_normal')(output)
  output = layers.Dense(1 , activation='sigmoid')(output)
  model  = Model(inputs=backbone.input , outputs=output , name='resnet50_cbam')
  return model

In [None]:
resnet50_cbam = resnet_cbam(base_model)
resnet50_cbam = top_cls(resnet50_cbam , hidden_unit=128)

In [None]:
resnet50_cbam.summary()