In [None]:
import tensorflow as tf
from tensorflow.keras import layers, models
from tensorflow.keras.datasets import mnist
from tensorflow.keras.utils import to_categorical

In [None]:
class AttentionModule(layers.Layer):
    """
    Self-attention mechanism for feature refinement
    """
    def __init__(self, **kwargs):
        super(AttentionModule, self).__init__(**kwargs)

    def build(self, input_shape):
        self.channels = input_shape[-1]
        self.query_conv = layers.Conv2D(self.channels // 8, kernel_size=1)
        self.key_conv = layers.Conv2D(self.channels // 8, kernel_size=1)
        self.value_conv = layers.Conv2D(self.channels, kernel_size=1)

    def call(self, inputs):
        batch_size = tf.shape(inputs)[0]
        height, width = inputs.shape[1], inputs.shape[2]

        # Query, Key, Value projections
        query = self.query_conv(inputs)
        key = self.key_conv(inputs)
        value = self.value_conv(inputs)

        # Reshape for attention computation
        query = tf.reshape(query, [batch_size, -1, self.channels // 8])
        key = tf.reshape(key, [batch_size, -1, self.channels // 8])
        value = tf.reshape(value, [batch_size, -1, self.channels])

        # Compute attention scores
        attention_weights = tf.matmul(query, key, transpose_b=True)
        scale = tf.sqrt(tf.cast(self.channels // 8, tf.float32))
        attention_weights = tf.nn.softmax(attention_weights / scale)

        # Apply attention to values
        attention_output = tf.matmul(attention_weights, value)
        return tf.reshape(attention_output, [batch_size, height, width, self.channels])


In [None]:
class MultiscaleModule(layers.Layer):
    """
    Multiscale feature learning module with multiple receptive fields
    """
    def __init__(self, filters, **kwargs):
        super(MultiscaleModule, self).__init__(**kwargs)
        self.filters = filters

    def build(self, input_shape):
        # Different scale convolutions
        self.conv1 = layers.Conv2D(self.filters // 4, kernel_size=1, padding='same')
        self.conv3 = layers.Conv2D(self.filters // 4, kernel_size=3, padding='same')
        self.conv5 = layers.Conv2D(self.filters // 4, kernel_size=5, padding='same')
        self.conv7 = layers.Conv2D(self.filters // 4, kernel_size=7, padding='same')
        
        # Batch normalization layers
        self.bn1 = layers.BatchNormalization()
        self.bn3 = layers.BatchNormalization()
        self.bn5 = layers.BatchNormalization()
        self.bn7 = layers.BatchNormalization()

    def call(self, inputs, training=False):
        # Process at different scales
        scale1 = self.bn1(self.conv1(inputs), training=training)
        scale3 = self.bn3(self.conv3(inputs), training=training)
        scale5 = self.bn5(self.conv5(inputs), training=training)
        scale7 = self.bn7(self.conv7(inputs), training=training)

        # Activate all scales
        scale1 = tf.nn.relu(scale1)
        scale3 = tf.nn.relu(scale3)
        scale5 = tf.nn.relu(scale5)
        scale7 = tf.nn.relu(scale7)

        # Concatenate all scales
        return tf.concat([scale1, scale3, scale5, scale7], axis=-1)


In [None]:
def build_model(input_shape=(28, 28, 1), num_classes=10):
    inputs = layers.Input(shape=input_shape)
    
    # 1. Initial Convolution
    x = layers.Conv2D(32, kernel_size=3, padding='same')(inputs)
    x = layers.BatchNormalization()(x)
    x = layers.Activation('relu')(x)
    
    # 2. Attention Mechanism
    attention_output = AttentionModule()(x)
    
    # 3. Multiscale Feature Learning
    multiscale_features = MultiscaleModule(64)(attention_output)
    
    # Save for residual connection
    residual = multiscale_features
    
    # 4. Residual Convolution Block
    x = layers.Conv2D(64, kernel_size=3, padding='same')(multiscale_features)
    x = layers.BatchNormalization()(x)
    x = layers.Activation('relu')(x)
    x = layers.Conv2D(64, kernel_size=3, padding='same')(x)
    x = layers.BatchNormalization()(x)
    
    # Add residual connection
    x = layers.Add()([x, residual])
    x = layers.Activation('relu')(x)
    
    # 5. Global Average Pooling and Classification
    x = layers.GlobalAveragePooling2D()(x)
    x = layers.Dense(128, activation='relu')(x)
    x = layers.Dropout(0.5)(x)
    outputs = layers.Dense(num_classes, activation='softmax')(x)
    
    return models.Model(inputs, outputs)

In [None]:
(x_train, y_train), (x_test, y_test) = mnist.load_data()
x_train = x_train.astype('float32') / 255.0
x_test = x_test.astype('float32') / 255.0
x_train = x_train[..., tf.newaxis]
x_test = x_test[..., tf.newaxis]
y_train = to_categorical(y_train, 10)
y_test = to_categorical(y_test, 10)

In [None]:
model = build_model()
model.compile(
    optimizer=tf.keras.optimizers.Adam(learning_rate=0.001),
    loss='categorical_crossentropy',
    metrics=['accuracy']
)

In [None]:
model.summary()

In [9]:
history = model.fit(
    x_train, y_train,
    batch_size=64,
    epochs=2,
    validation_split=0.2
)

[1m750/750[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m219s[0m 288ms/step - accuracy: 0.5548 - loss: 1.2393 - val_accuracy: 0.2738 - val_loss: 5.4314
Epoch 2/2
[1m750/750[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m216s[0m 288ms/step - accuracy: 0.8721 - loss: 0.4021 - val_accuracy: 0.1722 - val_loss: 9.4540


In [10]:
test_loss, test_accuracy = model.evaluate(x_test, y_test)
print(f"\nTest accuracy: {test_accuracy:.4f}")

[1m313/313[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m13s[0m 40ms/step - accuracy: 0.1737 - loss: 9.4431

Test accuracy: 0.1701
