In [None]:
# MNIST Handwritten Digits Classification with CNN
# Google Colab Ready Implementation

# Import required libraries
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
import numpy as np
import matplotlib.pyplot as plt
from sklearn.metrics import classification_report, confusion_matrix
import seaborn as sns

# Set random seeds for reproducibility
tf.random.set_seed(42)
np.random.seed(42)

print("TensorFlow version:", tf.__version__)
print("GPU Available:", tf.config.list_physical_devices('GPU'))

# Load and preprocess the MNIST dataset
print("Loading MNIST dataset...")
(x_train, y_train), (x_test, y_test) = keras.datasets.mnist.load_data()

print(f"Training data shape: {x_train.shape}")
print(f"Test data shape: {x_test.shape}")
print(f"Number of classes: {len(np.unique(y_train))}")

# Normalize pixel values to [0, 1] range
x_train = x_train.astype('float32') / 255.0
x_test = x_test.astype('float32') / 255.0

# Reshape data to add channel dimension (28, 28, 1) for CNN
x_train = x_train.reshape(-1, 28, 28, 1)
x_test = x_test.reshape(-1, 28, 28, 1)

# Convert labels to categorical one-hot encoding
num_classes = 10
y_train_cat = keras.utils.to_categorical(y_train, num_classes)
y_test_cat = keras.utils.to_categorical(y_test, num_classes)

print(f"Preprocessed training data shape: {x_train.shape}")
print(f"Preprocessed test data shape: {x_test.shape}")

# Build the CNN model
def create_cnn_model():
    model = keras.Sequential([
        # First Convolutional Block
        layers.Conv2D(32, (3, 3), activation='relu', input_shape=(28, 28, 1)),
        layers.BatchNormalization(),
        layers.Conv2D(32, (3, 3), activation='relu'),
        layers.MaxPooling2D((2, 2)),
        layers.Dropout(0.25),

        # Second Convolutional Block
        layers.Conv2D(64, (3, 3), activation='relu'),
        layers.BatchNormalization(),
        layers.Conv2D(64, (3, 3), activation='relu'),
        layers.MaxPooling2D((2, 2)),
        layers.Dropout(0.25),

        # Third Convolutional Block
        layers.Conv2D(128, (3, 3), activation='relu'),
        layers.BatchNormalization(),
        layers.Dropout(0.25),

        # Fully Connected Layers
        layers.Flatten(),
        layers.Dense(512, activation='relu'),
        layers.BatchNormalization(),
        layers.Dropout(0.5),
        layers.Dense(256, activation='relu'),
        layers.Dropout(0.5),
        layers.Dense(num_classes, activation='softmax')
    ])

    return model

# Create and compile the model
model = create_cnn_model()

# Compile the model with Adam optimizer
model.compile(
    optimizer='adam',
    loss='categorical_crossentropy',
    metrics=['accuracy']
)

# Display model architecture
print("\nModel Architecture:")
model.summary()

# Calculate total parameters
total_params = model.count_params()
print(f"\nTotal parameters: {total_params:,}")

# Define callbacks for training
callbacks = [
    keras.callbacks.EarlyStopping(
        monitor='val_accuracy',
        patience=5,
        restore_best_weights=True,
        verbose=1
    ),
    keras.callbacks.ReduceLROnPlateau(
        monitor='val_loss',
        factor=0.5,
        patience=3,
        min_lr=1e-7,
        verbose=1
    )
]

# Train the model
print("\nStarting training...")
batch_size = 128
epochs = 5

history = model.fit(
    x_train, y_train_cat,
    batch_size=batch_size,
    epochs=epochs,
    validation_data=(x_test, y_test_cat),
    callbacks=callbacks,
    verbose=1
)

# Evaluate the model
print("\nEvaluating model...")
test_loss, test_accuracy = model.evaluate(x_test, y_test_cat, verbose=0)
print(f"Test Loss: {test_loss:.4f}")
print(f"Test Accuracy: {test_accuracy:.4f} ({test_accuracy*100:.2f}%)")

# Check if we achieved >95% accuracy
if test_accuracy > 0.95:
    print("✅ Successfully achieved >95% test accuracy!")
else:
    print("⚠️  Test accuracy is below 95%. Consider training longer or adjusting hyperparameters.")

# Plot training history
def plot_training_history(history):
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 4))

    # Plot accuracy
    ax1.plot(history.history['accuracy'], label='Training Accuracy')
    ax1.plot(history.history['val_accuracy'], label='Validation Accuracy')
    ax1.set_title('Model Accuracy')
    ax1.set_xlabel('Epoch')
    ax1.set_ylabel('Accuracy')
    ax1.legend()
    ax1.grid(True)

    # Plot loss
    ax2.plot(history.history['loss'], label='Training Loss')
    ax2.plot(history.history['val_loss'], label='Validation Loss')
    ax2.set_title('Model Loss')
    ax2.set_xlabel('Epoch')
    ax2.set_ylabel('Loss')
    ax2.legend()
    ax2.grid(True)

    plt.tight_layout()
    plt.show()

plot_training_history(history)

# Generate predictions for evaluation
y_pred = model.predict(x_test, verbose=0)
y_pred_classes = np.argmax(y_pred, axis=1)

# Print classification report
print("\nClassification Report:")
print(classification_report(y_test, y_pred_classes))

# Plot confusion matrix
plt.figure(figsize=(10, 8))
cm = confusion_matrix(y_test, y_pred_classes)
sns.heatmap(cm, annot=True, fmt='d', cmap='Blues',
            xticklabels=range(10), yticklabels=range(10))
plt.title('Confusion Matrix')
plt.xlabel('Predicted Label')
plt.ylabel('True Label')
plt.show()

# Visualize predictions on 5 sample images
def visualize_predictions(model, x_test, y_test, num_samples=5):
    # Select random samples
    indices = np.random.choice(len(x_test), num_samples, replace=False)

    # Get predictions
    predictions = model.predict(x_test[indices], verbose=0)
    predicted_classes = np.argmax(predictions, axis=1)
    confidence_scores = np.max(predictions, axis=1)

    # Create visualization
    fig, axes = plt.subplots(1, num_samples, figsize=(15, 3))
    if num_samples == 1:
        axes = [axes]

    for i, idx in enumerate(indices):
        # Display image
        axes[i].imshow(x_test[idx].reshape(28, 28), cmap='gray')
        axes[i].axis('off')

        # Add title with prediction info
        true_label = y_test[idx]
        pred_label = predicted_classes[i]
        confidence = confidence_scores[i]

        color = 'green' if true_label == pred_label else 'red'
        title = f'True: {true_label}\nPred: {pred_label}\nConf: {confidence:.3f}'
        axes[i].set_title(title, color=color, fontsize=10)

    plt.suptitle('Model Predictions on Sample Images', fontsize=14)
    plt.tight_layout()
    plt.show()

    return indices, predicted_classes, confidence_scores

print("\nVisualizing predictions on 5 sample images:")
sample_indices, sample_predictions, sample_confidences = visualize_predictions(model, x_test, y_test, 5)

# Display detailed results for the samples
print("\nDetailed Results for Sample Images:")
for i, idx in enumerate(sample_indices):
    true_label = y_test[idx]
    pred_label = sample_predictions[i]
    confidence = sample_confidences[i]
    status = "✅ Correct" if true_label == pred_label else "❌ Incorrect"
    print(f"Sample {i+1}: True={true_label}, Predicted={pred_label}, Confidence={confidence:.4f} - {status}")

# Additional analysis: Show model performance per digit
def analyze_per_digit_performance(y_true, y_pred):
    accuracy_per_digit = []
    for digit in range(10):
        digit_mask = (y_true == digit)
        digit_accuracy = np.mean(y_pred[digit_mask] == digit)
        accuracy_per_digit.append(digit_accuracy)

    plt.figure(figsize=(10, 6))
    bars = plt.bar(range(10), accuracy_per_digit, color='skyblue', edgecolor='navy')
    plt.xlabel('Digit')
    plt.ylabel('Accuracy')
    plt.title('Per-Digit Classification Accuracy')
    plt.xticks(range(10))
    plt.ylim(0, 1)
    plt.grid(axis='y', alpha=0.3)

    # Add accuracy values on top of bars
    for i, bar in enumerate(bars):
        height = bar.get_height()
        plt.text(bar.get_x() + bar.get_width()/2., height + 0.01,
                f'{height:.3f}', ha='center', va='bottom')

    plt.tight_layout()
    plt.show()

    return accuracy_per_digit

print("\nAnalyzing per-digit performance:")
digit_accuracies = analyze_per_digit_performance(y_test, y_pred_classes)

# Print summary statistics
print(f"\nSummary Statistics:")
print(f"Overall Test Accuracy: {test_accuracy:.4f}")
print(f"Best Digit Accuracy: {max(digit_accuracies):.4f} (Digit {np.argmax(digit_accuracies)})")
print(f"Worst Digit Accuracy: {min(digit_accuracies):.4f} (Digit {np.argmin(digit_accuracies)})")
print(f"Average Per-Digit Accuracy: {np.mean(digit_accuracies):.4f}")

# Optional: Save the model
print("\nSaving model...")
model.save('mnist_cnn_model.h5')
print("Model saved as 'mnist_cnn_model.h5'")

print("\n" + "="*50)
print("MNIST CNN Classification Complete!")
print(f"Final Test Accuracy: {test_accuracy*100:.2f}%")
if test_accuracy > 0.95:
    print("🎉 Successfully achieved >95% accuracy target!")
print("="*50)