Loads & preprocesses MNIST (normalization, resizing, RGB conversion)

Uses DenseNet169 as a feature extractor (pretrained on ImageNet)

Applies GAP instead of Flatten to reduce parameters & overfitting

Trains for 5 epochs and evaluates performance

Plots accuracy & loss curves for analysis

In [2]:
import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
from tensorflow.keras.datasets import mnist
from tensorflow.keras.utils import to_categorical
from tensorflow.keras.applications import DenseNet169
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense, GlobalAveragePooling2D

# Load MNIST dataset
(x_train, y_train), (x_test, y_test) = mnist.load_data()

# Normalize pixel values (0-255 → 0-1)
x_train, x_test = x_train / 255.0, x_test / 255.0

# Convert grayscale (28,28,1) to RGB (28,28,3)
x_train = np.stack((x_train,)*3, axis=-1)
x_test = np.stack((x_test,)*3, axis=-1)

# Resize images to 128x128 (required for DenseNet)
x_train = tf.image.resize(x_train, [128, 128])
x_test = tf.image.resize(x_test, [128, 128])

# One-hot encode labels (10 classes)
y_train = to_categorical(y_train, 10)
y_test = to_categorical(y_test, 10)

# Load DenseNet169 without top layers
base_model = DenseNet169(weights='imagenet', include_top=False, input_shape=(128, 128, 3))

# Freeze DenseNet layers
for layer in base_model.layers:
    layer.trainable = False

# Build the model
model = Sequential([
    base_model,                       # DenseNet as feature extractor
    GlobalAveragePooling2D(),         # GAP instead of Flatten
    Dense(256, activation='relu'),    # Fully connected layer
    Dense(10, activation='softmax')   # Output layer (10 classes)
])

# Compile the model
model.compile(optimizer='adam', loss='categorical_crossentropy', metrics=['accuracy'])

# Train the model
history = model.fit(x_train, y_train, epochs=5, validation_data=(x_test, y_test), batch_size=256)

# Evaluate the model
test_loss, test_acc = model.evaluate(x_test, y_test)
print(f"Test Accuracy: {test_acc:.4f}")

# Plot accuracy and loss
plt.figure(figsize=(12, 4))

plt.subplot(1, 2, 1)
plt.plot(history.history['accuracy'], label='Training Accuracy')
plt.plot(history.history['val_accuracy'], label='Validation Accuracy')
plt.title('Model Accuracy')
plt.xlabel('Epochs')
plt.ylabel('Accuracy')
plt.legend()

plt.subplot(1, 2, 2)
plt.plot(history.history['loss'], label='Training Loss')
plt.plot(history.history['val_loss'], label='Validation Loss')
plt.title('Model Loss')
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.legend()

plt.show()