In [18]:
import tensorflow as tf
from tensorflow.keras import layers, models
from tensorflow.keras.datasets import mnist
from tensorflow.keras.utils import to_categorical

In [19]:
(x_train, y_train), (x_test, y_test) = mnist.load_data()
x_train = x_train.astype('float32') / 255.0
x_test = x_test.astype('float32') / 255.0
x_train = x_train[..., tf.newaxis]  # Add channel dimension
x_test = x_test[..., tf.newaxis]
y_train = to_categorical(y_train, 10)
y_test = to_categorical(y_test, 10)

In [20]:
def multiscale_feature_learning(inputs):
    """Parallel convolutions with different kernel sizes for multi-scale feature extraction"""
    conv1 = layers.Conv2D(32, kernel_size=1, activation="relu", padding="same")(inputs)
    conv3 = layers.Conv2D(32, kernel_size=3, activation="relu", padding="same")(inputs)
    conv5 = layers.Conv2D(32, kernel_size=5, activation="relu", padding="same")(inputs)
    concatenated = layers.Concatenate()([conv1, conv3, conv5])
    return concatenated

In [21]:
class AttentionModule(layers.Layer):
    """Self-attention mechanism for feature refinement"""
    def __init__(self, **kwargs):
        super(AttentionModule, self).__init__(**kwargs)

    def build(self, input_shape):
        self.channels = input_shape[-1]
        self.query_conv = layers.Conv2D(self.channels // 8, kernel_size=1)
        self.key_conv = layers.Conv2D(self.channels // 8, kernel_size=1)
        self.value_conv = layers.Conv2D(self.channels, kernel_size=1)
        super(AttentionModule, self).build(input_shape)

    def call(self, inputs):
        batch_size = tf.shape(inputs)[0]
        height, width = inputs.shape[1], inputs.shape[2]

        # Query, Key, Value projections
        query = self.query_conv(inputs)
        key = self.key_conv(inputs)
        value = self.value_conv(inputs)

        # Reshape for attention computation
        query = tf.reshape(query, [batch_size, height * width, -1])
        key = tf.reshape(key, [batch_size, height * width, -1])
        value = tf.reshape(value, [batch_size, height * width, -1])

        # Compute attention scores
        attention_weights = tf.matmul(query, key, transpose_b=True)
        attention_weights = tf.nn.softmax(attention_weights / tf.sqrt(tf.cast(self.channels // 8, tf.float32)))

        # Apply attention to values
        attention_output = tf.matmul(attention_weights, value)
        attention_output = tf.reshape(attention_output, [batch_size, height, width, self.channels])
        
        return attention_output

In [22]:
def build_model(input_shape=(28, 28, 1), num_classes=10):
    inputs = layers.Input(shape=input_shape)
    
    # 1. Multiscale Feature Learning
    x = multiscale_feature_learning(inputs)  # Output: (28, 28, 96)
    
    # 2. Attention Mechanism
    attention_output = AttentionModule()(x)
    
    # 3. Residual Connection
    x = layers.Add()([x, attention_output])
    
    # Global pooling and classification head
    x = layers.GlobalAveragePooling2D()(x)  # Output: (96,)
    x = layers.Dense(128, activation="relu")(x)
    x = layers.Dropout(0.5)(x)  # Add dropout for regularization
    outputs = layers.Dense(num_classes, activation="softmax")(x)
    
    model = models.Model(inputs, outputs)
    return model

In [23]:
model = build_model()
model.compile(
    optimizer=tf.keras.optimizers.Adam(learning_rate=0.001),
    loss="categorical_crossentropy",
    metrics=["accuracy"]
)

In [24]:
model.summary()

In [26]:
history = model.fit(
    x_train, y_train,
    batch_size=64,
    epochs=2,
    validation_split=0.2,
    callbacks=[
        tf.keras.callbacks.EarlyStopping(patience=3, restore_best_weights=True),
        tf.keras.callbacks.ReduceLROnPlateau(factor=0.5, patience=2)
    ]
)

Epoch 1/2
[1m750/750[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m177s[0m 236ms/step - accuracy: 0.2472 - loss: 1.9748 - val_accuracy: 0.5671 - val_loss: 1.1442 - learning_rate: 0.0010
Epoch 2/2
[1m750/750[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m170s[0m 226ms/step - accuracy: 0.5312 - loss: 1.2149 - val_accuracy: 0.6647 - val_loss: 0.9418 - learning_rate: 0.0010


In [27]:
test_loss, test_accuracy = model.evaluate(x_test, y_test)
print(f"\nTest accuracy: {test_accuracy:.4f}")

[1m313/313[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m11s[0m 35ms/step - accuracy: 0.6488 - loss: 0.9676

Test accuracy: 0.6642
