In [1]:
import tensorflow as tf
from tensorflow.keras import layers, models
from tensorflow.keras.datasets import mnist
from tensorflow.keras.utils import to_categorical

In [2]:
# Load and preprocess MNIST dataset
(x_train, y_train), (x_test, y_test) = mnist.load_data()
x_train = x_train.astype('float32') / 255.0
x_test = x_test.astype('float32') / 255.0
x_train = x_train[..., tf.newaxis]  # Add channel dimension
x_test = x_test[..., tf.newaxis]
y_train = to_categorical(y_train, 10)
y_test = to_categorical(y_test, 10)

In [3]:
# Multiscale feature learning
def multiscale_feature_learning(inputs):
    conv_3x3 = layers.Conv2D(64, kernel_size=3, padding="same", activation="relu")(inputs)
    conv_5x5 = layers.Conv2D(64, kernel_size=5, padding="same", activation="relu")(inputs)
    conv_7x7 = layers.Conv2D(64, kernel_size=7, padding="same", activation="relu")(inputs)

    combined = layers.Concatenate()([conv_3x3, conv_5x5, conv_7x7])
    return combined

In [4]:
# Build the model
inputs = layers.Input(shape=(28, 28, 1))

# CNN backbone
x = layers.Conv2D(32, kernel_size=3, activation="relu", padding="same")(inputs)
x = layers.MaxPooling2D(pool_size=2)(x)

# Multiscale feature learning
x = multiscale_feature_learning(x)

# Classifier head
x = layers.GlobalAveragePooling2D()(x)
x = layers.Dense(128, activation="relu")(x)
outputs = layers.Dense(10, activation="softmax")(x)

model = models.Model(inputs, outputs)

In [5]:
# Compile and train the model
model.compile(optimizer="adam", loss="categorical_crossentropy", metrics=["accuracy"])

model.summary()

In [7]:
model.fit(x_train, y_train, epochs=2, batch_size=64, validation_split=0.2)

Epoch 1/2
[1m750/750[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m21s[0m 28ms/step - accuracy: 0.8232 - loss: 0.5886 - val_accuracy: 0.9208 - val_loss: 0.2685
Epoch 2/2
[1m750/750[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m22s[0m 29ms/step - accuracy: 0.9183 - loss: 0.2686 - val_accuracy: 0.9377 - val_loss: 0.2069


<keras.src.callbacks.history.History at 0x2aa1a85b950>

In [8]:
# Evaluate the model
loss, accuracy = model.evaluate(x_test, y_test)
print(f"Test Accuracy: {accuracy:.2f}")

[1m313/313[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 5ms/step - accuracy: 0.9266 - loss: 0.2389
Test Accuracy: 0.94
