In [3]:
# mnist_cnn.py
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
import numpy as np

# 1) Load dataset (MNIST comes bundled in Keras)
(x_train, y_train), (x_test, y_test) = keras.datasets.mnist.load_data()

# 2) Preprocess data
# - MNIST images are 28x28 grayscale with values 0-255.
# - We normalize to 0-1 and add a channel dimension for the CNN: (28,28,1).
x_train = x_train.astype("float32") / 255.0
x_test  = x_test.astype("float32")  / 255.0

# Add channel dimension (height, width, channels)
x_train = np.expand_dims(x_train, -1)  # shape -> (num_samples, 28, 28, 1)
x_test  = np.expand_dims(x_test, -1)

# Convert labels to integer type (they are already ints, but this is explicit)
y_train = y_train.astype("int32")
y_test  = y_test.astype("int32")

# 3) Build the CNN model
model = keras.Sequential(
    [
        # First convolutional block: extract low-level features (edges, small curves)
        layers.Conv2D(32, kernel_size=(3, 3), activation="relu", input_shape=(28, 28, 1)),
        layers.MaxPooling2D(pool_size=(2, 2)),  # reduce spatial size -> focus on important features

        # Second convolutional block: learn more complex features
        layers.Conv2D(64, kernel_size=(3, 3), activation="relu"),
        layers.MaxPooling2D(pool_size=(2, 2)),

        # Flatten feature maps to pass into dense (fully-connected) layers
        layers.Flatten(),

        # Dense layer: combine detected features to decide patterns
        layers.Dense(128, activation="relu"),
        layers.Dropout(0.4),  # randomly turn off 40% of neurons during training to avoid overfitting

        # Output layer: 10 neurons for 10 digit classes, softmax gives probability distribution
        layers.Dense(10, activation="softmax"),
    ]
)

# 4) Compile the model: define loss, optimizer and metrics
model.compile(
    optimizer=keras.optimizers.Adam(),                  # Adam optimizer adapts learning rates
    loss="sparse_categorical_crossentropy",             # labels are integers 0..9
    metrics=["accuracy"],                               # track accuracy during training
)

# Optional: show model architecture summary
model.summary()

# 5) Train the model
history = model.fit(
    x_train,
    y_train,
    validation_split=0.1,   # hold out 10% of training data as validation set
    epochs=10,              # how many times to iterate over the dataset
    batch_size=64,          # number of samples per gradient update
    verbose=2,
)

# 6) Evaluate on test data
test_loss, test_acc = model.evaluate(x_test, y_test, verbose=2)
print(f"\nTest accuracy: {test_acc:.4f}, Test loss: {test_loss:.4f}")

# 7) Predict on new samples (first 5 test images as example)
pred_probs = model.predict(x_test[:5])          # probabilities for each class
pred_labels = np.argmax(pred_probs, axis=1)     # predicted class = index of max probability
print("Predictions for first 5 test images:", pred_labels)
print("True labels:                 ", y_test[:5])

# 8) Save the trained model (optional)
model.save("mnist_cnn.keras")



Epoch 1/10
844/844 - 8s - 9ms/step - accuracy: 0.9293 - loss: 0.2267 - val_accuracy: 0.9847 - val_loss: 0.0537
Epoch 2/10
844/844 - 6s - 8ms/step - accuracy: 0.9773 - loss: 0.0752 - val_accuracy: 0.9905 - val_loss: 0.0351
Epoch 3/10
844/844 - 6s - 8ms/step - accuracy: 0.9836 - loss: 0.0544 - val_accuracy: 0.9898 - val_loss: 0.0342
Epoch 4/10
844/844 - 6s - 7ms/step - accuracy: 0.9872 - loss: 0.0425 - val_accuracy: 0.9917 - val_loss: 0.0307
Epoch 5/10
844/844 - 6s - 7ms/step - accuracy: 0.9889 - loss: 0.0352 - val_accuracy: 0.9913 - val_loss: 0.0345
Epoch 6/10
844/844 - 6s - 8ms/step - accuracy: 0.9898 - loss: 0.0308 - val_accuracy: 0.9900 - val_loss: 0.0355
Epoch 7/10
844/844 - 6s - 8ms/step - accuracy: 0.9916 - loss: 0.0264 - val_accuracy: 0.9902 - val_loss: 0.0405
Epoch 8/10
844/844 - 6s - 8ms/step - accuracy: 0.9926 - loss: 0.0230 - val_accuracy: 0.9917 - val_loss: 0.0313
Epoch 9/10
844/844 - 6s - 8ms/step - accuracy: 0.9934 - loss: 0.0205 - val_accuracy: 0.9912 - val_loss: 0.0359
E