In [None]:
import tensorflow as tf
import numpy as np
from tensorflow.keras import layers, models
from tensorflow.keras.callbacks import EarlyStopping, ModelCheckpoint
import matplotlib.pyplot as plt

In [None]:
# Load the MNIST dataset
(train_images, train_labels), (test_images, test_labels) = tf.keras.datasets.mnist.load_data()

In [None]:
# Preprocess: reshape and normalize to [0, 1] range
train_images = train_images.reshape((train_images.shape[0], 28, 28, 1)).astype('float32') / 255
test_images = test_images.reshape((test_images.shape[0], 28, 28, 1)).astype('float32') / 255

In [None]:
# Convert labels to categorical (one-hot encoding)
train_labels = tf.keras.utils.to_categorical(train_labels)
test_labels = tf.keras.utils.to_categorical(test_labels)

In [None]:
# IMPORTANT: Split training data into train and validation sets
# This keeps the test set completely unseen until final evaluation
validation_split = 0.1667  # This gives us 10,000 validation samples from 60,000 total
split_index = int(len(train_images) * (1 - validation_split))
train_imgs = train_images[:split_index]
train_lbls = train_labels[:split_index]
val_imgs = train_images[split_index:]
val_lbls = train_labels[split_index:]

print(f"Training Samples: {len(train_imgs)}")
print(f"Validation Samples: {len(val_imgs)}")
print(f"Test Samples: {len(test_images)}")

In [None]:
# Build the model
model = models.Sequential()

# First convolutional block
model.add(layers.Conv2D(32, (3, 3), activation='relu', input_shape=(28, 28, 1)))
model.add(layers.MaxPooling2D(2, 2))

# Second convolutional block
model.add(layers.Conv2D(64, (3, 3), activation='relu'))
model.add(layers.MaxPooling2D(2, 2))

# Flatten the feature maps into a 1D vector
model.add(layers.Flatten())

# Dense layers with dropout for regularization
model.add(layers.Dense(128, activation='relu'))
model.add(layers.Dropout(0.5))

model.add(layers.Dense(64, activation='relu'))
model.add(layers.Dropout(0.5))

# Output layer with 10 units (one per digit class)
model.add(layers.Dense(10, activation='softmax'))

In [None]:
# Compile the model
model.compile(optimizer='adam',
              loss='categorical_crossentropy',
              metrics=['accuracy'])

In [None]:
# Print model summary to see architecture
print('\n Model Architecture:')
model.summary()

In [None]:
# Set up callbacks for smarter training
# Early stopping: stops training when validation loss stops improving
early_stop = EarlyStopping(
    monitor='val_loss', # Watch the validation loss
    patience=3,         # Stop if no improvement for 3 epochs
    restore_best_weights=True, # Restore the best model weights
    verbose=1
)

In [None]:
# Model checkpoint: saves the best model during training
checkpoint = ModelCheckpoint(
    'best_mnist_model.keras', # File to save the model
    monitor='val_accuracy',  # Watch validation accuracy
    save_best_only=True,    # Only save when we beat the previous best
    mode='max',               # Higher accuracy is better
    verbose=1
)

In [None]:
# Train the model with validation data and callbacks
print("\nStarting training...")
history = model.fit(train_imgs, train_lbls,
                    epochs=20,   # Set a high number; early stopping will handle it
                    batch_size=64, # Use proper validation set
                    validation_data=(val_imgs, val_lbls),  # Apply our smart training callbacks
                    callbacks=[early_stop, checkpoint],
                    verbose=1
                   )

In [None]:
# Evaluate on the test set (completely unseen data)
print('\nEvaluating on test set...')
test_loss, test_acc = model.evaluate(test_images, test_labels, verbose=0)

print('\nFinal Results:')
print(f'Test loss: {test_loss:.4f}')
print(f'Test accuracy: {test_acc:.4f} ({test_acc*100:.2}%)')

In [None]:
# Make predictions on a few test samples to verify
print("\nSample Predictions:")
predictions = model.predict(test_images[:10], verbose=0)

for i in range(10):
    predicted_class = predictions[i].argmax()
    true_class = test_labels[i].argmax()
    confidence = predictions[i].max() * 100
    status = '✓' if predicted_class == true_class else '✗'
    print(f'{status} Prediction: {predicted_class}, True: {true_class}, Confidence: {confidence:.1f}%')
    print(f'Prediction {predictions[i].argmax()}, True label: {test_labels[i].argmax()}')

In [None]:
# Plot training history to visualize learning
print(f"Total epochs trained: {len(history.history['loss'])}")

In [None]:
# Create a figure with two subplots side by side
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 5))

# Plot 1: Loss over epochs
# This shows how "wrong" the model is over time (lower is better)
ax1.plot(history.history['loss'], label='Training Loss', linewidth=2, marker='o')
ax1.plot(history.history['val_loss'], label='Validation Loss', linewidth=2, marker='s')
ax1.set_xlabel('Epoch', fontsize=12)
ax1.set_ylabel('Loss', fontsize=12)
ax1.set_title('Model Loss During Training', fontsize=14, fontweight='bold')
ax1.legend(fontsize=11)
ax1.grid(True, alpha=0.3)

# Add annotation explaining what we want to see
ax1.text(0.02, 0.98,
         'Good: Both curves decrease together\nBad: Training drops but validation rises (overfitting)',
        transform=ax1.transAxes, fontsize=9, verticalalignment='top',
        bbox=dict(boxstyle='round', facecolor='wheat', alpha=0.3))

# Plot 2: Accuracy over epochs  
# This shows how "correct" the model is over time (higher is better)
ax2.plot(history.history['accuracy'], label='Training Accuracy', linewidth=2, marker='o')
ax2.plot(history.history['val_accuracy'], label='Validation Accuracy', linewidth=2, marker='s')
ax2.set_xlabel('Epoch', fontsize=12)
ax2.set_ylabel('Accuracy', fontsize=12)
ax2.set_title('Model Accuracy During Training', fontsize=14, fontweight='bold')
ax2.legend(fontsize=11)
ax2.grid(True, alpha=0.3)

# Add annotation
ax2.text(4.9, 0.92,
         'Good: Both curves increase together\nBad: Training climbs but validation plateaus',
         bbox=dict(boxstyle='round', facecolor='lightblue', alpha=0.3)
)

# plt.tight_layout()
plt.show()

In [None]:
best_epoch = history.history['val_accuracy'].index(max(history.history['val_accuracy'])) + 1
print(f'Best validation accuracy achieved at epoch: {best_epoch}')
print(f'Best validation accuracy: {max(history.history['val_accuracy']):.4f}')
print(f'Final training accuracy: {history.history['accuracy'][-1]:.4f}')
print(f'Final validation accuracy: {history.history['val_accuracy'][-1]:.4f}')

In [None]:
# Check for overfitting signs
train_val_gap = history.history['accuracy'][-1] - history.history['val_accuracy'][-1]
if train_val_gap > 0.5:
    print(f'\n Warning: Training accuracy is {train_val_gap:.2%} higher than validation')
    print(f'This suggest overfitting. Consider adding more dropout or regularization.')
else:
    print(f'\n Good: Training and validation accuracy are close (gap: {train_val_gap:.2%})')
    print('  The model is generalizing well!')