In [None]:
# MNIST Model Analysis & Evaluation

This notebook provides comprehensive analysis of our trained MNIST classifier, including detailed performance metrics, error analysis, and visualization of model behavior.

## Analysis Overview
- Load pre-trained model
- Detailed performance evaluation
- Confusion matrix analysis
- Error analysis and misclassified examples
- Custom image testing
- Model interpretation and insights


In [None]:
## Setup and Load Model


In [None]:
import sys
import os
sys.path.append('../src')

import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.metrics import confusion_matrix, classification_report
from mnist_classifier import MNISTClassifier

# Set style for better plots
plt.style.use('default')
plt.rcParams['figure.figsize'] = (12, 8)
sns.set_palette("husl")

print("✅ All imports successful!")


In [None]:
# Initialize classifier and load pre-trained model
classifier = MNISTClassifier()
classifier.load_data()

# Try to load existing model, or train a new one if needed
try:
    classifier.load_model("../models/mnist_model.keras")
    print("✅ Loaded pre-trained model!")
except:
    print("⚠️ No pre-trained model found. Please run training_demo.ipynb first.")
    print("For demo purposes, we'll build and do a quick train...")
    classifier.build_model()
    classifier.train(epochs=5, batch_size=128)
    
print(f"Model loaded and ready for analysis!")


In [None]:
## Comprehensive Performance Analysis


In [None]:
# Get predictions and evaluate
predictions = classifier.model.predict(classifier.x_test)
predicted_classes = np.argmax(predictions, axis=1)
true_classes = np.argmax(classifier.y_test, axis=1)
prediction_confidences = np.max(predictions, axis=1)

# Overall performance
test_loss, test_accuracy = classifier.evaluate()
print("🎯 OVERALL PERFORMANCE")
print("=" * 50)
print(f"Test Accuracy: {test_accuracy:.4f} ({test_accuracy*100:.2f}%)")
print(f"Test Loss: {test_loss:.4f}")
print(f"Total Test Samples: {len(true_classes)}")
print(f"Correct Predictions: {np.sum(predicted_classes == true_classes)}")
print(f"Incorrect Predictions: {np.sum(predicted_classes != true_classes)}")

# Confidence statistics
print(f"\n📊 CONFIDENCE STATISTICS")
print("=" * 50)
print(f"Mean Confidence: {np.mean(prediction_confidences):.4f}")
print(f"Median Confidence: {np.median(prediction_confidences):.4f}")
print(f"Min Confidence: {np.min(prediction_confidences):.4f}")
print(f"Max Confidence: {np.max(prediction_confidences):.4f}")


In [None]:
## Confusion Matrix & Classification Report


In [None]:
# Plot confusion matrix
classifier.plot_confusion_matrix(true_classes, predicted_classes)

# Detailed classification report
print("📋 DETAILED CLASSIFICATION REPORT")
print("=" * 60)
print(classification_report(true_classes, predicted_classes, 
                          target_names=[f'Digit {i}' for i in range(10)]))


In [None]:
## Error Analysis - Misclassified Examples


In [None]:
# Find misclassified examples
import tensorflow as tf
(x_train_orig, y_train_orig), (x_test_orig, y_test_orig) = tf.keras.datasets.mnist.load_data()

misclassified_indices = np.where(predicted_classes != true_classes)[0]
print(f"🔍 Found {len(misclassified_indices)} misclassified examples")

# Show worst predictions (lowest confidence among errors)
if len(misclassified_indices) > 0:
    error_confidences = prediction_confidences[misclassified_indices]
    worst_errors_idx = misclassified_indices[np.argsort(error_confidences)[:12]]
    
    fig, axes = plt.subplots(3, 4, figsize=(16, 12))
    fig.suptitle('Most Difficult Misclassified Examples (Lowest Confidence)', fontsize=16, fontweight='bold')
    
    for i, idx in enumerate(worst_errors_idx):
        row, col = i // 4, i % 4
        axes[row, col].imshow(x_test_orig[idx], cmap='gray')
        axes[row, col].set_title(f'True: {true_classes[idx]}, Pred: {predicted_classes[idx]}\\nConf: {prediction_confidences[idx]:.3f}', 
                                color='red', fontweight='bold')
        axes[row, col].axis('off')
    
    plt.tight_layout()
    plt.show()
else:
    print("🎉 Perfect classification! No errors found.")


In [None]:
## Sample Predictions Visualization


In [None]:
# Visualize sample predictions
classifier.visualize_predictions(num_samples=10)

print("🎉 Analysis completed!")
print("\\n📝 Summary:")
print(f"- Model achieves {test_accuracy*100:.2f}% accuracy on test set")
print(f"- {len(misclassified_indices)} out of {len(true_classes)} examples misclassified")
print(f"- Average prediction confidence: {np.mean(prediction_confidences):.3f}")
print("\\n✅ The model shows excellent performance on MNIST digit classification!")
