In [17]:
import tensorflow as tf
from tensorflow.keras import layers, models
from tensorflow.keras.layers import Layer
from tensorflow.keras.datasets import mnist
from tensorflow.keras.utils import to_categorical

In [18]:
(x_train, y_train), (x_test, y_test) = mnist.load_data()
x_train = x_train.astype('float32') / 255.0
x_test = x_test.astype('float32') / 255.0
x_train = x_train[..., tf.newaxis]  # Add channel dimension
x_test = x_test[..., tf.newaxis]
y_train = to_categorical(y_train, 10)
y_test = to_categorical(y_test, 10)

In [19]:
class AttentionModule(Layer):
    def __init__(self, **kwargs):
        super(AttentionModule, self).__init__(**kwargs)

    def build(self, input_shape):
        self.channels = input_shape[-1]
        self.query_conv = layers.Conv2D(self.channels // 8, kernel_size=1)
        self.key_conv = layers.Conv2D(self.channels // 8, kernel_size=1)
        self.value_conv = layers.Conv2D(self.channels // 8, kernel_size=1)
        self.output_conv = layers.Conv2D(self.channels, kernel_size=1)
        super(AttentionModule, self).build(input_shape)

    def call(self, inputs):
        batch_size = tf.shape(inputs)[0]
        height = inputs.shape[1]
        width = inputs.shape[2]
        
        # Create query, key, and value projections
        query = self.query_conv(inputs)  # [batch, h, w, c/8]
        key = self.key_conv(inputs)      # [batch, h, w, c/8]
        value = self.value_conv(inputs)  # [batch, h, w, c/8]
        
        # Reshape to sequence format
        query_seq = tf.reshape(query, [batch_size, height * width, self.channels // 8])  # [batch, h*w, c/8]
        key_seq = tf.reshape(key, [batch_size, height * width, self.channels // 8])      # [batch, h*w, c/8]
        value_seq = tf.reshape(value, [batch_size, height * width, self.channels // 8])  # [batch, h*w, c/8]
        
        # Compute scaled dot-product attention
        scaling_factor = tf.cast(tf.shape(key_seq)[-1], tf.float32) ** -0.5
        attention_scores = tf.matmul(query_seq, key_seq, transpose_b=True) * scaling_factor  # [batch, h*w, h*w]
        attention_weights = tf.nn.softmax(attention_scores, axis=-1)
        
        # Apply attention weights to values
        attention_output = tf.matmul(attention_weights, value_seq)  # [batch, h*w, c/8]
        
        # Reshape back to spatial dimensions
        attention_output = tf.reshape(attention_output, [batch_size, height, width, self.channels // 8])
        
        # Project back to original channel dimension
        output = self.output_conv(attention_output)
        
        # Add residual connection
        output = layers.Add()([output, inputs])
        
        return output

In [20]:
# Define the model
inputs = layers.Input(shape=(28, 28, 1))

# CNN backbone
x = layers.Conv2D(32, kernel_size=3, activation="relu", padding="same")(inputs)
x = layers.MaxPooling2D(pool_size=2)(x)

# Attention mechanism
x = AttentionModule()(x)

# Global average pooling
x = layers.GlobalAveragePooling2D()(x)

# Classifier head
x = layers.Dense(128, activation="relu")(x)
outputs = layers.Dense(10, activation="softmax")(x)

model = models.Model(inputs, outputs)

In [21]:
model.compile(optimizer="adam", loss="categorical_crossentropy", metrics=["accuracy"])

model.summary()

In [22]:
model.fit(x_train, y_train, epochs=2, batch_size=64, validation_split=0.2)

Epoch 1/2
[1m750/750[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m12s[0m 14ms/step - accuracy: 0.2464 - loss: 1.9786 - val_accuracy: 0.3839 - val_loss: 1.5868
Epoch 2/2
[1m750/750[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m10s[0m 13ms/step - accuracy: 0.3992 - loss: 1.5505 - val_accuracy: 0.4913 - val_loss: 1.3541


<keras.src.callbacks.history.History at 0x1d98f8cc470>

In [23]:
loss, accuracy = model.evaluate(x_test, y_test)
print(f"Test Accuracy: {accuracy:.2f}")

[1m313/313[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 2ms/step - accuracy: 0.4862 - loss: 1.3664
Test Accuracy: 0.50
