In [1]:
# Importing necessary libraries
import tensorflow as tf
from tensorflow.keras.layers import Input, Conv2D, GlobalAveragePooling2D, Dense, Activation, Add

In [2]:
def Channel(input_shape, r):
    # Calculate average and maximum values along height and width dimensions
    avg_p = tf.reduce_mean(input_shape, axis=[1, 2], keepdims=True)
    max_p = tf.reduce_max(input_shape, axis=[1, 2], keepdims=True)
    
    reduced_features = int(input_shape.shape[-1] // r) # Calculate the reduced number of features
    
    # Define dense layers
    dense1 = Dense(reduced_features, activation="relu")
    dense2 = Dense(input_shape.shape[-1], activation="linear")
    
    # Pass average and maximum values through the first dense layer
    Dense1_avg = dense1(avg_p)
    Dense1_max = dense1(max_p)

    # Pass the outputs of the first dense layer through the second dense layer
    Dense2_avg = dense2(Dense1_avg)
    Dense2_max = dense2(Dense1_max)

    # Apply sigmoid activation to the sum of outputs of the second dense layer
    out = Activation('sigmoid')(Dense2_avg + Dense2_max)

    return out * input_shape # Scale the input_shape by the sigmoid output

In [3]:
def Spatial(input_shape, ks=7):
    # Calculate average and maximum values along the channel dimension
    avg_p = tf.reduce_mean(input_shape, axis=[-1], keepdims=True)
    max_p = tf.reduce_max(input_shape, axis=[-1], keepdims=True)

    # Concatenate average and maximum values along the channel dimension
    concat_pool = tf.concat([avg_p, max_p], axis=-1)
    
    # Apply convolutional layer with kernel size (ks, ks) and sigmoid activation
    out = Conv2D(1, (ks, ks), padding='same', activation='sigmoid')(concat_pool)

    # Scale the input_shape by the sigmoid output
    return out * input_shape

In [4]:
def CBAM(input_shape, ks=7, r=2):
    # Apply channel attention mechanism
    channel_out = Channel(input_shape, r)
    
    # Apply spatial attention mechanism on the output of channel attention
    Spatial_out = Spatial(channel_out, ks)
    
    # Add the output of channel attention and spatial attention
    out = Add()([channel_out, Spatial_out])
    
    return Activation('relu')(out) # Apply ReLU activation to the combined output

In [5]:
def create_CBAM_model(input_shape, num_classes):
    
    input_tensor = Input(shape=input_shape) # Define input tensor
    
    # First convolutional layer with 64 filters and ReLU activation, followed by CBAM
    x = tf.keras.layers.Conv2D(64, (3, 3), padding='same', activation='relu')(input_tensor)
    x = CBAM(x)  # Apply CBAM
    
    # Second convolutional layer with 128 filters and ReLU activation, followed by CBAM
    x = Conv2D(128, (3, 3), padding='same', activation='relu')(x)
    x = CBAM(x)  # Apply CBAM
    
    x = GlobalAveragePooling2D()(x) # Global average pooling layer
    
    # Fully connected layer with num_classes nodes and softmax activation
    x = Dense(num_classes, activation='softmax')(x)
    
    model = tf.keras.Model(inputs=input_tensor, outputs=x) # Create the model by specifying input and output tensors
    
    return model

In [6]:
# Example usage
input_shape = (224, 224, 12)  # Adjust based on your input image size and channels
num_classes = 10  # Adjust based on your classification task
CBAM_model = create_CBAM_model(input_shape, num_classes)

# Display model summary
CBAM_model.summary()

Model: "model"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
input_1 (InputLayer)            [(None, 224, 224, 12 0                                            
__________________________________________________________________________________________________
conv2d (Conv2D)                 (None, 224, 224, 64) 6976        input_1[0][0]                    
__________________________________________________________________________________________________
tf.math.reduce_mean (TFOpLambda (None, 1, 1, 64)     0           conv2d[0][0]                     
__________________________________________________________________________________________________
tf.math.reduce_max (TFOpLambda) (None, 1, 1, 64)     0           conv2d[0][0]                     
______________________________________________________________________________________________