# Spatial Attention

In [1]:
import tensorflow as tf
from tensorflow.keras.layers import Input, Conv2D, GlobalAveragePooling2D, Dense

In [2]:
def spatial_attention(input_tensor):
    
    avg_pool = tf.reduce_mean(input_tensor, axis=3, keepdims=True) # Average pooling
        
    max_pool = tf.reduce_max(input_tensor, axis=3, keepdims=True) # Max pooling
    
    pooled_features = tf.keras.layers.Concatenate(axis=-1)([avg_pool, max_pool]) # Concatenate pooled feature maps

    # Convolutional layer with 7x7 filter size
    conv_output = tf.keras.layers.Conv2D(filters=1, kernel_size=(1, 1), padding='same')(pooled_features)
    
    sigmoid_output = tf.keras.activations.sigmoid(conv_output) # Sigmoid activation
    
    return sigmoid_output * input_tensor

In [3]:
def create_spatial_attention_model(input_shape, num_classes):
    input_tensor = Input(shape=input_shape)
    
    # Example architecture, you may adjust it based on your requirements
    x = tf.keras.layers.Conv2D(64, (3, 3), padding='same', activation='relu')(input_tensor)
    x = spatial_attention(x)
    x = Conv2D(128, (3, 3), padding='same', activation='relu')(x)
    x = spatial_attention(x)
    
    x = GlobalAveragePooling2D()(x)
    x = Dense(num_classes, activation='softmax')(x)
    
    model = tf.keras.Model(inputs=input_tensor, outputs=x)
    
    return model

In [4]:
# Example usage
input_shape = (224, 224, 12)  # Adjust based on your input image size and channels
num_classes = 10  # Adjust based on your classification task
SA_model = create_spatial_attention_model(input_shape, num_classes)

# Display model summary
SA_model.summary()

Model: "model"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
input_1 (InputLayer)            [(None, 224, 224, 12 0                                            
__________________________________________________________________________________________________
conv2d (Conv2D)                 (None, 224, 224, 64) 6976        input_1[0][0]                    
__________________________________________________________________________________________________
tf.math.reduce_mean (TFOpLambda (None, 224, 224, 1)  0           conv2d[0][0]                     
__________________________________________________________________________________________________
tf.math.reduce_max (TFOpLambda) (None, 224, 224, 1)  0           conv2d[0][0]                     
______________________________________________________________________________________________