In [1]:
import tensorflow as tf
from tensorflow.keras import layers, models
from tensorflow.keras.datasets import mnist
from tensorflow.keras.utils import to_categorical

In [2]:
class MultiscaleModule(layers.Layer):
    """
    Multiscale feature learning module that processes input at different scales
    """
    def __init__(self, filters, **kwargs):
        super(MultiscaleModule, self).__init__(**kwargs)
        self.filters = filters

    def build(self, input_shape):
        # Branch 1: 1x1 convolution
        self.conv1 = layers.Conv2D(self.filters // 4, kernel_size=1, padding='same')
        
        # Branch 2: 3x3 convolution
        self.conv3 = layers.Conv2D(self.filters // 4, kernel_size=3, padding='same')
        
        # Branch 3: 5x5 convolution (implemented as two 3x3 convs)
        self.conv5_1 = layers.Conv2D(self.filters // 4, kernel_size=3, padding='same')
        self.conv5_2 = layers.Conv2D(self.filters // 4, kernel_size=3, padding='same')
        
        # Branch 4: Dilated convolution
        self.conv_dilated = layers.Conv2D(self.filters // 4, kernel_size=3, padding='same', dilation_rate=2)
        
        # Batch normalization for each branch
        self.bn1 = layers.BatchNormalization()
        self.bn2 = layers.BatchNormalization()
        self.bn3 = layers.BatchNormalization()
        self.bn4 = layers.BatchNormalization()

    def call(self, inputs, training=False):
        # Branch 1: 1x1 convolution
        branch1 = self.conv1(inputs)
        branch1 = self.bn1(branch1, training=training)
        branch1 = tf.nn.relu(branch1)
        
        # Branch 2: 3x3 convolution
        branch2 = self.conv3(inputs)
        branch2 = self.bn2(branch2, training=training)
        branch2 = tf.nn.relu(branch2)
        
        # Branch 3: 5x5 convolution (as two 3x3)
        branch3 = self.conv5_1(inputs)
        branch3 = self.conv5_2(branch3)
        branch3 = self.bn3(branch3, training=training)
        branch3 = tf.nn.relu(branch3)
        
        # Branch 4: Dilated convolution
        branch4 = self.conv_dilated(inputs)
        branch4 = self.bn4(branch4, training=training)
        branch4 = tf.nn.relu(branch4)
        
        # Concatenate all branches
        return tf.concat([branch1, branch2, branch3, branch4], axis=-1)

In [3]:
def build_model(input_shape=(28, 28, 1), num_classes=10):
    inputs = layers.Input(shape=input_shape)
    
    # 1. Initial Convolution
    x = layers.Conv2D(32, kernel_size=3, padding='same')(inputs)
    x = layers.BatchNormalization()(x)
    x = layers.Activation('relu')(x)
    
    # 2. Multiscale Feature Learning
    multiscale_features = MultiscaleModule(64)(x)
    
    # Save for residual connection
    residual = multiscale_features
    
    # 3. Global Average Pooling
    gap_features = layers.GlobalAveragePooling2D()(multiscale_features)
    gap_features = layers.Dense(64)(gap_features)
    gap_features = layers.Reshape((1, 1, 64))(gap_features)
    gap_features = layers.UpSampling2D(size=(28, 28))(gap_features)
    
    # 4. Dilated Convolution Path
    dilated_conv = layers.Conv2D(64, kernel_size=3, padding='same', dilation_rate=2)(gap_features)
    dilated_conv = layers.BatchNormalization()(dilated_conv)
    dilated_conv = layers.Activation('relu')(dilated_conv)
    
    # 5. Residual Connection
    x = layers.Add()([residual, dilated_conv])
    x = layers.Activation('relu')(x)
    
    # Classification head
    x = layers.GlobalAveragePooling2D()(x)
    x = layers.Dense(128, activation='relu')(x)
    x = layers.Dropout(0.5)(x)
    outputs = layers.Dense(num_classes, activation='softmax')(x)
    
    return models.Model(inputs, outputs)

In [4]:
(x_train, y_train), (x_test, y_test) = mnist.load_data()
x_train = x_train.astype('float32') / 255.0
x_test = x_test.astype('float32') / 255.0
x_train = x_train[..., tf.newaxis]
x_test = x_test[..., tf.newaxis]
y_train = to_categorical(y_train, 10)
y_test = to_categorical(y_test, 10)

In [5]:
model = build_model()
model.compile(
    optimizer=tf.keras.optimizers.Adam(learning_rate=0.001),
    loss='categorical_crossentropy',
    metrics=['accuracy']
)




In [6]:
model.summary()

In [8]:
history = model.fit(
    x_train, y_train,
    batch_size=64,
    epochs=2,
    validation_split=0.2
)

Epoch 1/2
[1m750/750[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m78s[0m 104ms/step - accuracy: 0.7188 - loss: 0.8474 - val_accuracy: 0.3067 - val_loss: 4.9969
Epoch 2/2
[1m750/750[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m81s[0m 108ms/step - accuracy: 0.9434 - loss: 0.1917 - val_accuracy: 0.1357 - val_loss: 14.3407


In [9]:
test_loss, test_accuracy = model.evaluate(x_test, y_test)
print(f"\nTest accuracy: {test_accuracy:.4f}")

[1m313/313[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m4s[0m 13ms/step - accuracy: 0.1627 - loss: 13.6450

Test accuracy: 0.1503
