In [1]:
import tensorflow as tf
from tensorflow.keras import layers, models
from tensorflow.keras.datasets import mnist
from tensorflow.keras.utils import to_categorical

In [2]:
(x_train, y_train), (x_test, y_test) = mnist.load_data()
x_train = x_train.astype('float32') / 255.0
x_test = x_test.astype('float32') / 255.0
x_train = x_train[..., tf.newaxis]
x_test = x_test[..., tf.newaxis]
y_train = to_categorical(y_train, 10)
y_test = to_categorical(y_test, 10)

In [3]:
class AttentionModule(layers.Layer):
    """
    Self-attention mechanism for feature refinement
    """
    def __init__(self, **kwargs):
        super(AttentionModule, self).__init__(**kwargs)

    def build(self, input_shape):
        self.channels = input_shape[-1]
        self.query_conv = layers.Conv2D(self.channels // 8, kernel_size=1)
        self.key_conv = layers.Conv2D(self.channels // 8, kernel_size=1)
        self.value_conv = layers.Conv2D(self.channels, kernel_size=1)
        super(AttentionModule, self).build(input_shape)

    def call(self, inputs):
        batch_size = tf.shape(inputs)[0]
        height, width = inputs.shape[1], inputs.shape[2]

        # Query, Key, Value projections
        query = self.query_conv(inputs)
        key = self.key_conv(inputs)
        value = self.value_conv(inputs)

        # Reshape for attention computation
        query = tf.reshape(query, [batch_size, -1, self.channels // 8])
        key = tf.reshape(key, [batch_size, -1, self.channels // 8])
        value = tf.reshape(value, [batch_size, -1, self.channels])

        # Compute attention scores with scaling
        attention_weights = tf.matmul(query, key, transpose_b=True)
        scale = tf.sqrt(tf.cast(self.channels // 8, tf.float32))
        attention_weights = tf.nn.softmax(attention_weights / scale)

        # Apply attention to values
        attention_output = tf.matmul(attention_weights, value)
        return tf.reshape(attention_output, [batch_size, height, width, self.channels])

In [4]:
def build_model(input_shape=(28, 28, 1), num_classes=10):
    inputs = layers.Input(shape=input_shape)
    
    # 1. Initial Convolution Layer
    x = layers.Conv2D(64, kernel_size=3, padding='same')(inputs)
    x = layers.BatchNormalization()(x)
    x = layers.Activation('relu')(x)
    conv_output = x  # Save for residual connection
    
    # 2. Global Average Pooling
    gap_features = layers.GlobalAveragePooling2D()(x)
    
    # Reshape and expand to maintain spatial information
    gap_features = layers.Dense(64)(gap_features)
    gap_features = layers.Reshape((1, 1, 64))(gap_features)
    gap_features = layers.UpSampling2D(size=(28, 28))(gap_features)
    
    # 3. Attention Mechanism
    attention_output = AttentionModule()(gap_features)
    
    # 4. Residual Connection
    x = layers.Add()([conv_output, attention_output])
    x = layers.Activation('relu')(x)
    
    # Classification head
    x = layers.GlobalAveragePooling2D()(x)
    x = layers.Dense(128, activation='relu')(x)
    x = layers.Dropout(0.5)(x)
    outputs = layers.Dense(num_classes, activation='softmax')(x)
    
    return models.Model(inputs, outputs)

In [5]:
model = build_model()
model.compile(
    optimizer=tf.keras.optimizers.Adam(learning_rate=0.001),
    loss='categorical_crossentropy',
    metrics=['accuracy']
)




In [6]:
model.summary()

In [8]:
history = model.fit(
    x_train, y_train,
    batch_size=64,
    epochs=2,
    validation_split=0.2,
)

Epoch 1/2
[1m750/750[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m162s[0m 216ms/step - accuracy: 0.2973 - loss: 1.8752 - val_accuracy: 0.4720 - val_loss: 1.4154
Epoch 2/2
[1m750/750[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m157s[0m 209ms/step - accuracy: 0.4545 - loss: 1.4408 - val_accuracy: 0.2567 - val_loss: 2.1778


In [9]:
test_loss, test_accuracy = model.evaluate(x_test, y_test)
print(f"\nTest accuracy: {test_accuracy:.4f}")

[1m313/313[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m11s[0m 34ms/step - accuracy: 0.2428 - loss: 2.1614

Test accuracy: 0.2457
