In [1]:
import tensorflow as tf
from tensorflow.keras import layers, models
from tensorflow.keras.datasets import mnist
from tensorflow.keras.utils import to_categorical

In [2]:
# Load and preprocess MNIST dataset
(x_train, y_train), (x_test, y_test) = mnist.load_data()
x_train = x_train.astype('float32') / 255.0
x_test = x_test.astype('float32') / 255.0
x_train = x_train[..., tf.newaxis]  # Add channel dimension
x_test = x_test[..., tf.newaxis]
y_train = to_categorical(y_train, 10)
y_test = to_categorical(y_test, 10)

In [3]:
# Attention mechanism for CNNs
def attention_module(inputs):
    # Flatten spatial dimensions into a single sequence dimension
    height, width, channels = inputs.shape[1], inputs.shape[2], inputs.shape[3]
    flattened_inputs = layers.Reshape((height * width, channels))(inputs)

    # Create query, key, value tensors
    query = layers.Dense(64, activation="relu")(flattened_inputs)
    key = layers.Dense(64, activation="relu")(flattened_inputs)
    value = layers.Dense(64, activation="relu")(flattened_inputs)

    # Apply attention
    attention_scores = layers.Attention()([query, key])
    attention_output = layers.Multiply()([attention_scores, value])

    # Reshape back to the original spatial dimensions
    attention_output = layers.Reshape((height, width, channels))(attention_output)
    return attention_output


In [4]:
# Build the model
inputs = layers.Input(shape=(28, 28, 1))

# CNN backbone
x = layers.Conv2D(32, kernel_size=3, activation="relu", padding="same")(inputs)
x = layers.MaxPooling2D(pool_size=2)(x)
x = layers.Conv2D(64, kernel_size=3, activation="relu", padding="same")(x)
x = layers.MaxPooling2D(pool_size=2)(x)

# Attention mechanism
x = attention_module(x)

# Classifier head
x = layers.GlobalAveragePooling2D()(x)
x = layers.Dense(128, activation="relu")(x)
outputs = layers.Dense(10, activation="softmax")(x)

model = models.Model(inputs, outputs)

In [5]:
# Compile and train the model
model.compile(optimizer="adam", loss="categorical_crossentropy", metrics=["accuracy"])

model.summary()

In [7]:
model.fit(x_train, y_train, epochs=2, batch_size=64, validation_split=0.2)

Epoch 1/2
[1m750/750[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m7s[0m 10ms/step - accuracy: 0.6801 - loss: 0.9279 - val_accuracy: 0.8974 - val_loss: 0.3333
Epoch 2/2
[1m750/750[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m7s[0m 9ms/step - accuracy: 0.8972 - loss: 0.3323 - val_accuracy: 0.9312 - val_loss: 0.2212


<keras.src.callbacks.history.History at 0x24b3b6a16d0>

In [8]:
loss, accuracy = model.evaluate(x_test, y_test)
print(f"Test Accuracy: {accuracy:.2f}")

[1m313/313[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 2ms/step - accuracy: 0.9192 - loss: 0.2583
Test Accuracy: 0.93
