# MNIST Digit Classification using CNN - Minimal Version (For Memorization)

CNN for MNIST digit recognition (0-9)


In [None]:
# Imports
import numpy as np
import matplotlib.pyplot as plt
from tensorflow.keras.datasets import mnist
from tensorflow.keras.utils import to_categorical
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Conv2D, MaxPooling2D, Flatten, Dense, Dropout

np.random.seed(42)


In [None]:
# Load MNIST dataset
(X_train, y_train), (X_test, y_test) = mnist.load_data()
print(f"Train: {X_train.shape}, Test: {X_test.shape}")


In [None]:
# Reshape for CNN (add channel dimension)
X_train = X_train.reshape(-1, 28, 28, 1)
X_test = X_test.reshape(-1, 28, 28, 1)
print(f"Reshaped: {X_train.shape}")


In [None]:
# Normalize pixel values
X_train = X_train / 255.0
X_test = X_test / 255.0


In [None]:
# One-hot encode labels
y_train = to_categorical(y_train, 10)
y_test = to_categorical(y_test, 10)
print(f"Labels shape: {y_train.shape}")


In [None]:
# Build CNN model
model = Sequential([
    Conv2D(32, (3, 3), activation='relu', input_shape=(28, 28, 1)),
    MaxPooling2D((2, 2)),
    Conv2D(64, (3, 3), activation='relu'),
    MaxPooling2D((2, 2)),
    Flatten(),
    Dense(128, activation='relu'),
    Dropout(0.3),
    Dense(10, activation='softmax')
])

model.compile(optimizer='adam', loss='categorical_crossentropy', metrics=['accuracy'])
model.summary()


In [None]:
# Train
history = model.fit(X_train, y_train, epochs=10, batch_size=128, validation_split=0.1, verbose=1)


In [None]:
# Evaluate
loss, acc = model.evaluate(X_test, y_test, verbose=1)
print(f"Test Accuracy: {acc:.4f}")
print(f"Test Loss: {loss:.4f}")


In [None]:
# Plot training history
plt.figure(figsize=(12, 4))
plt.subplot(1, 2, 1)
plt.plot(history.history['accuracy'], label='Train')
plt.plot(history.history['val_accuracy'], label='Val')
plt.xlabel('Epoch')
plt.ylabel('Accuracy')
plt.legend()
plt.title('Accuracy')
plt.grid(True)

plt.subplot(1, 2, 2)
plt.plot(history.history['loss'], label='Train')
plt.plot(history.history['val_loss'], label='Val')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.legend()
plt.title('Loss')
plt.grid(True)
plt.tight_layout()
plt.show()
