In [1]:
import tensorflow as tf
from tensorflow.keras import layers, models
from tensorflow.keras.datasets import mnist
from tensorflow.keras.utils import to_categorical

In [2]:
# Load and preprocess MNIST dataset
(x_train, y_train), (x_test, y_test) = mnist.load_data()
x_train = x_train.astype('float32') / 255.0
x_test = x_test.astype('float32') / 255.0
x_train = x_train[..., tf.newaxis]  # Add channel dimension
x_test = x_test[..., tf.newaxis]
y_train = to_categorical(y_train, 10)
y_test = to_categorical(y_test, 10)


In [3]:
# Build the model
inputs = layers.Input(shape=(28, 28, 1))

# CNN backbone
x = layers.Conv2D(32, kernel_size=3, activation="relu", padding="same")(inputs)
x = layers.MaxPooling2D(pool_size=2)(x)

# Global average pooling
x = layers.GlobalAveragePooling2D()(x)

# Expand dimensions back for dilated convolution
x = layers.Reshape((1, 1, -1))(x)

# Dilated convolution
x = layers.Conv2D(64, kernel_size=3, dilation_rate=2, activation="relu", padding="same")(x)
x = layers.Reshape((1, 1, 64))(x)  # Ensure compatibility for classification

# Classifier head
x = layers.Flatten()(x)
x = layers.Dense(128, activation="relu")(x)
outputs = layers.Dense(10, activation="softmax")(x)

model = models.Model(inputs, outputs)

In [4]:
# Compile and train the model
model.compile(optimizer="adam", loss="categorical_crossentropy", metrics=["accuracy"])

model.summary()

In [5]:
model.fit(x_train, y_train, epochs=2, batch_size=64, validation_split=0.2)

Epoch 1/2
[1m750/750[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m4s[0m 5ms/step - accuracy: 0.2223 - loss: 2.0802 - val_accuracy: 0.3372 - val_loss: 1.6874
Epoch 2/2
[1m750/750[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m3s[0m 4ms/step - accuracy: 0.3651 - loss: 1.6390 - val_accuracy: 0.4308 - val_loss: 1.4472


<keras.src.callbacks.history.History at 0x1c9ca7c5b50>

In [6]:
loss, accuracy = model.evaluate(x_test, y_test)
print(f"Test Accuracy: {accuracy:.2f}")

[1m313/313[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 1ms/step - accuracy: 0.4334 - loss: 1.4504
Test Accuracy: 0.44
