In [8]:
import tensorflow as tf
from tensorflow.keras import layers, models
from tensorflow.keras.layers import Conv2D, GlobalAveragePooling2D, GlobalMaxPooling2D, Dense, Add, Multiply, Reshape, Activation

class CBAM(layers.Layer):
    def __init__(self, ratio=8, channel_attention=True, spatial_attention=True, **kwargs):
        super(CBAM, self).__init__(**kwargs)
        self.ratio = ratio
        self.channel_attention = channel_attention
        self.spatial_attention = spatial_attention

    def build(self, input_shape):
        # Channel Attention (CA)
        if self.channel_attention:
            self.channel_avg_pool = GlobalAveragePooling2D()
            self.channel_max_pool = GlobalMaxPooling2D()
            self.channel_fc1 = Dense(input_shape[-1] // self.ratio, activation='relu')  # Reduce the number of channels
            self.channel_fc2 = Dense(input_shape[-1], activation='sigmoid')  # Return original number of channels

        # Spatial Attention (SA)
        if self.spatial_attention:
            self.spatial_conv = Conv2D(1, (7, 7), padding='same', activation='sigmoid')

    def call(self, inputs):
        x = inputs
        
        # Channel Attention (CA)
        if self.channel_attention:
            avg_pool = self.channel_avg_pool(x)
            max_pool = self.channel_max_pool(x)
            
            # Reshaping to match channel dimensions for dense layer
            avg_pool = Reshape((1, 1, avg_pool.shape[-1]))(avg_pool)
            max_pool = Reshape((1, 1, max_pool.shape[-1]))(max_pool)
            
            avg_out = self.channel_fc2(self.channel_fc1(avg_pool))
            max_out = self.channel_fc2(self.channel_fc1(max_pool))
            
            # Add the two outputs to create the channel attention map
            channel_attention_map = Add()([avg_out, max_out])
            x = Multiply()([x, channel_attention_map])  # Apply the attention map

        # Spatial Attention (SA)
        if self.spatial_attention:
            spatial_attention_map = self.spatial_conv(x)
            x = Multiply()([x, spatial_attention_map])  # Apply the attention map

        return x

In [9]:
# Build CNN model with CBAM
def build_cnn_with_cbam(input_shape=(32, 32, 3), num_classes=10):
    inputs = layers.Input(shape=input_shape)

    # First Convolution Block
    x = Conv2D(32, kernel_size=3, padding='same', activation='relu')(inputs)
    x = CBAM(ratio=8)(x)  # Apply CBAM here

    # Second Convolution Block
    x = Conv2D(64, kernel_size=3, padding='same', activation='relu')(x)
    x = CBAM(ratio=8)(x)  # Apply CBAM here

    # Third Convolution Block
    x = Conv2D(128, kernel_size=3, padding='same', activation='relu')(x)
    x = CBAM(ratio=8)(x)  # Apply CBAM here

    # Global Average Pooling and Fully Connected Layer
    x = GlobalAveragePooling2D()(x)
    x = Dense(256, activation='relu')(x)
    x = Dense(num_classes, activation='softmax')(x)

    # Create model
    model = models.Model(inputs, x)
    return model

In [10]:
# Example usage for CIFAR-10
input_shape = (32, 32, 3)  # CIFAR-10 images are 32x32x3
num_classes = 10  # CIFAR-10 has 10 classes
model = build_cnn_with_cbam(input_shape, num_classes)

# Show model summary
model.summary()

In [None]:
# Compile the model
model.compile(optimizer='adam',
              loss='sparse_categorical_crossentropy',
              metrics=['accuracy'])

# Assuming you have a training dataset
# model.fit(train_data, train_labels, epochs=10, batch_size=32)


In [15]:
from tensorflow.keras.datasets import cifar10
from tensorflow.keras.utils import to_categorical

# Load and preprocess CIFAR-10 dataset
(x_train, y_train), (x_test, y_test) = cifar10.load_data()

# Normalize the image data to [0, 1]
x_train, x_test = x_train / 255.0, x_test / 255.0

# One-hot encode the labels
y_train, y_test = to_categorical(y_train, num_classes=10), to_categorical(y_test, num_classes=10)

# Train the model
model.fit(x_train, y_train, epochs=10, batch_size=64, validation_data=(x_test, y_test))


Epoch 1/10
[1m782/782[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m220s[0m 275ms/step - accuracy: 0.2241 - loss: 2.0316 - val_accuracy: 0.3940 - val_loss: 1.5916
Epoch 2/10
[1m782/782[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m209s[0m 268ms/step - accuracy: 0.4134 - loss: 1.5684 - val_accuracy: 0.4877 - val_loss: 1.3945
Epoch 3/10
[1m782/782[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m232s[0m 297ms/step - accuracy: 0.4966 - loss: 1.3756 - val_accuracy: 0.5190 - val_loss: 1.3041
Epoch 4/10
[1m782/782[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m333s[0m 426ms/step - accuracy: 0.5420 - loss: 1.2607 - val_accuracy: 0.5638 - val_loss: 1.1812
Epoch 5/10
[1m782/782[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m383s[0m 489ms/step - accuracy: 0.5781 - loss: 1.1583 - val_accuracy: 0.5862 - val_loss: 1.1366
Epoch 6/10
[1m782/782[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m384s[0m 491ms/step - accuracy: 0.6022 - loss: 1.0997 - val_accuracy: 0.5998 - val_loss: 1.0865
Epoc

<keras.src.callbacks.history.History at 0x227c7ec1ee0>