# MNIST Handwritten Digit Recognition Neural Network

**Description:** Educational implementation of a neural network that recognizes handwritten digits (0-9) using the MNIST dataset.

**Purpose:** Learn fundamental concepts of deep learning including data preprocessing, neural network architecture, training, and evaluation.

**Target Accuracy:** >95% on test set

**Author:** AI Development Course - Lesson 36

---

## How to Run on Google Colab
1. Go to **Runtime > Change runtime type** and select **T4 GPU** for faster training
2. Run each cell sequentially from top to bottom
3. Training takes < 1 minute on GPU, 3-5 minutes on CPU

## Neural Network Architecture
```
Input:    784 neurons  (28x28 flattened image)
  |
Hidden 1: 128 neurons  (ReLU activation)
  |
Hidden 2:  64 neurons  (ReLU activation)
  |
Output:    10 neurons  (Softmax activation) -> digits 0-9
```

## Cell 1: Imports and GPU Configuration
This cell imports all required libraries and configures GPU acceleration.
GPU acceleration significantly speeds up neural network training (minutes -> seconds).

In [None]:
# ============================================
# CELL 1: Imports and GPU Configuration
# ============================================

# Core Deep Learning Framework
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense

# Numerical Operations
import numpy as np

# Data Visualization
import matplotlib.pyplot as plt          # For plotting graphs and displaying images
import seaborn as sns                    # For beautiful confusion matrix heatmaps

# Machine Learning Utilities
from sklearn.metrics import confusion_matrix  # For evaluating prediction errors

# Image Processing (for custom predictions)
from PIL import Image                    # For loading and preprocessing custom images

# System Utilities
import os
import warnings
warnings.filterwarnings('ignore')        # Suppress warnings for cleaner output

print("=" * 60)
print("MNIST HANDWRITTEN DIGIT RECOGNITION")
print("=" * 60)
print()

# ============================================
# GPU CONFIGURATION AND VERIFICATION
# ============================================
# WHY GPU? Neural networks perform millions of matrix operations. GPUs are designed
# for parallel processing and can train models 10-100x faster than CPUs.
#
# What we're doing:
# 1. Check if GPU is available
# 2. Configure GPU memory growth (prevents out-of-memory errors)
# 3. Print GPU information
# ============================================

print("Checking GPU availability...")
print("-" * 60)

# Get list of all available GPUs
gpus = tf.config.list_physical_devices('GPU')

if gpus:
    try:
        # Configure GPU memory growth
        # WHY? By default, TensorFlow allocates all GPU memory at once.
        # Memory growth allows TensorFlow to allocate memory as needed,
        # preventing crashes and allowing multiple programs to use the GPU.
        for gpu in gpus:
            tf.config.experimental.set_memory_growth(gpu, True)

        # Print GPU information
        print(f"\u2713 GPU DETECTED: {len(gpus)} GPU(s) available")
        for i, gpu in enumerate(gpus):
            print(f"  GPU {i}: {gpu.name}")
        print()
        print("GPU will be used for training (10-100x faster than CPU)")
        print("Expected training time: < 1 minute for 10 epochs")

    except RuntimeError as e:
        print(f"GPU configuration error: {e}")
        print("Falling back to CPU...")
else:
    print("\u26a0 NO GPU DETECTED - Using CPU")
    print("Training will be slower (~2-5 minutes for 10 epochs)")
    print()
    print("To enable GPU in Colab:")
    print("  Go to Runtime > Change runtime type > T4 GPU")

print("=" * 60)

## Cell 2: Load MNIST Dataset
MNIST (Modified National Institute of Standards and Technology) is a classic dataset of 70,000 handwritten digit images (0-9).

**Dataset Structure:**
- Training set: 60,000 images (used to train the model)
- Test set: 10,000 images (used to evaluate final performance)
- Image format: 28x28 pixels, grayscale (1 channel)
- Pixel values: 0-255 (0=black, 255=white)
- Labels: 0-9 (which digit the image represents)

In [None]:
# ============================================
# CELL 2: Load MNIST Dataset
# ============================================

print("LOADING MNIST DATASET")
print("=" * 60)

# Load the MNIST dataset
# This function returns 4 arrays:
# - X_train: Training images (60000, 28, 28)
# - y_train: Training labels (60000,)
# - X_test: Test images (10000, 28, 28)
# - y_test: Test labels (10000,)
(X_train, y_train), (X_test, y_test) = keras.datasets.mnist.load_data()

# Display dataset information
print(f"\u2713 Dataset loaded successfully!")
print()
print("Dataset Statistics:")
print("-" * 60)
print(f"Training Images: {X_train.shape[0]:,} samples")
print(f"  Shape: {X_train.shape} (samples, height, width)")
print(f"  Pixel value range: {X_train.min()} to {X_train.max()}")
print()
print(f"Training Labels: {y_train.shape[0]:,} samples")
print(f"  Label range: {y_train.min()} to {y_train.max()} (digits 0-9)")
print()
print(f"Test Images: {X_test.shape[0]:,} samples")
print(f"  Shape: {X_test.shape}")
print()
print(f"Test Labels: {y_test.shape[0]:,} samples")
print()
print("Each image:")
print(f"  - Resolution: 28\u00d728 pixels (784 total pixels)")
print(f"  - Color: Grayscale (single channel)")
print(f"  - Format: White digit on black background")
print("=" * 60)

## Cell 3: Preview Sample Images (Digits 6-9)
Visualizing the data helps us understand what we're working with.

**Why visualize?**
- Verify data loaded correctly
- Understand image quality and variation
- See what the model needs to learn to distinguish

In [None]:
# ============================================
# CELL 3: Preview Sample Images (Digits 6-9)
# ============================================

print("PREVIEWING SAMPLE DIGITS (6, 7, 8, 9)")
print("=" * 60)

# Create a figure with 2x2 subplots
fig, axes = plt.subplots(2, 2, figsize=(10, 10))
fig.suptitle('Sample MNIST Digits: 6, 7, 8, 9', fontsize=16, fontweight='bold')

# Digits we want to display
target_digits = [6, 7, 8, 9]

# Find and display one sample of each digit
for idx, digit in enumerate(target_digits):
    # Find the first occurrence of this digit in training set
    # np.where returns indices where condition is True
    sample_index = np.where(y_train == digit)[0][0]

    # Get the image
    sample_image = X_train[sample_index]

    # Calculate subplot position (row, col)
    row = idx // 2  # 0, 0, 1, 1
    col = idx % 2   # 0, 1, 0, 1

    # Display the image
    axes[row, col].imshow(sample_image, cmap='gray')
    axes[row, col].set_title(f'Digit: {digit}', fontsize=14, fontweight='bold')
    axes[row, col].axis('off')  # Hide axis for cleaner display

plt.tight_layout()
plt.show()

print("\u2713 Sample images displayed successfully")
print("  Notice: Images are 28\u00d728 pixels, grayscale")
print("  White pixels = digit, Black pixels = background")
print("=" * 60)

## Cell 4: Data Preprocessing
Raw data needs preprocessing before feeding to neural network. Three critical steps:

1. **Normalization** - Scale pixel values from [0-255] to [0.0-1.0]
2. **Flattening** - Convert 2D images (28x28) to 1D vectors (784)
3. **One-Hot Encoding** - Convert integer labels to categorical vectors

### Why each step matters:
- **Normalization:** Large values cause unstable gradients and slow convergence
- **Flattening:** Dense layers require 1D input vectors
- **One-Hot Encoding:** Prevents the model from interpreting labels as ordered values (e.g., 9 > 1)

In [None]:
# ============================================
# CELL 4: Data Preprocessing
# ============================================

print("DATA PREPROCESSING")
print("=" * 60)
print()

# ============================================
# STEP 4.1: NORMALIZATION
# ============================================
# WHY NORMALIZE?
# - Original pixel values: 0-255 (integers)
# - Neural networks learn better with smaller values (0-1 range)
# - Large values can cause:
#   1. Unstable gradients during training
#   2. Slow convergence
#   3. Numerical overflow/underflow
#
# HOW? Divide all pixel values by 255.0
# Result: Values now in range [0.0, 1.0]
# ============================================

print("Step 1: NORMALIZATION")
print("-" * 60)
print("Converting pixel values from [0-255] to [0.0-1.0]")
print()
print(f"Before normalization:")
print(f"  Min value: {X_train.min()}, Max value: {X_train.max()}")
print(f"  Data type: {X_train.dtype}")

# Normalize by dividing by 255.0
X_train = X_train.astype('float32') / 255.0
X_test = X_test.astype('float32') / 255.0

print()
print(f"After normalization:")
print(f"  Min value: {X_train.min():.4f}, Max value: {X_train.max():.4f}")
print(f"  Data type: {X_train.dtype}")
print("\u2713 Normalization complete")
print()

# ============================================
# STEP 4.2: FLATTENING
# ============================================
# WHY FLATTEN?
# - Current shape: (60000, 28, 28) - 2D images
# - Dense layers need 1D input: (60000, 784)
# - 28 x 28 = 784 pixels per image
# ============================================

print("Step 2: FLATTENING")
print("-" * 60)
print("Converting 2D images (28\u00d728) to 1D vectors (784)")
print()
print(f"Before flattening:")
print(f"  Training shape: {X_train.shape}")
print(f"  Test shape: {X_test.shape}")

# Reshape from (num_samples, 28, 28) to (num_samples, 784)
X_train = X_train.reshape(X_train.shape[0], 28 * 28)
X_test = X_test.reshape(X_test.shape[0], 28 * 28)

print()
print(f"After flattening:")
print(f"  Training shape: {X_train.shape}")
print(f"  Test shape: {X_test.shape}")
print(f"  Each image is now a vector of {X_train.shape[1]} pixel values")
print("\u2713 Flattening complete")
print()

# ============================================
# STEP 4.3: ONE-HOT ENCODING
# ============================================
# WHY ONE-HOT ENCODE?
# - Current labels: Single integers (0, 1, 2, ..., 9)
# - Problem: Neural networks interpret these as ordered values
# - Solution: One-hot encoding - each class gets its own dimension
#
# How it works:
# - Label 0 -> [1, 0, 0, 0, 0, 0, 0, 0, 0, 0]
# - Label 1 -> [0, 1, 0, 0, 0, 0, 0, 0, 0, 0]
# - Label 3 -> [0, 0, 0, 1, 0, 0, 0, 0, 0, 0]
# - Label 9 -> [0, 0, 0, 0, 0, 0, 0, 0, 0, 1]
# ============================================

print("Step 3: ONE-HOT ENCODING")
print("-" * 60)
print("Converting integer labels to categorical vectors")
print()
print(f"Before one-hot encoding:")
print(f"  Training labels shape: {y_train.shape}")
print(f"  Sample labels: {y_train[:10]}")

# Convert labels to one-hot encoding
y_train = keras.utils.to_categorical(y_train, 10)
y_test = keras.utils.to_categorical(y_test, 10)

print()
print(f"After one-hot encoding:")
print(f"  Training labels shape: {y_train.shape}")
print(f"  Sample label (first image):")
print(f"    {y_train[0]}")
print(f"    \u2191 Position with '1' indicates the digit class")
print("\u2713 One-hot encoding complete")
print()
print("=" * 60)
print("\u2713 ALL PREPROCESSING COMPLETE")
print("=" * 60)

## Cell 5: Build Neural Network Architecture

We build a fully connected (Dense) neural network:

| Layer | Neurons | Activation | Parameters | Purpose |
|-------|---------|------------|------------|---------|
| Hidden 1 | 128 | ReLU | 100,480 | Feature extraction |
| Hidden 2 | 64 | ReLU | 8,256 | Feature combination |
| Output | 10 | Softmax | 650 | Classification |

**Why this architecture?**
- **Funnel Shape (784->128->64->10):** Forces network to learn compressed features
- **ReLU:** Fast, prevents vanishing gradients
- **Softmax:** Outputs probabilities summing to 1.0

In [None]:
# ============================================
# CELL 5: Build Neural Network Architecture
# ============================================

print("BUILDING NEURAL NETWORK ARCHITECTURE")
print("=" * 60)
print()

# Create Sequential model (layers stacked in sequence)
model = Sequential([

    # HIDDEN LAYER 1: 128 neurons with ReLU
    # Input: 784 features (flattened 28x28 image)
    # Output: 128 features
    # Parameters: 784 x 128 + 128 (bias) = 100,480
    #
    # WHY ReLU (Rectified Linear Unit)?
    # - Formula: ReLU(x) = max(0, x)
    # - Fast to compute (just max operation)
    # - Prevents vanishing gradient problem
    # - Introduces non-linearity (allows learning complex patterns)
    Dense(128, activation='relu', input_shape=(784,), name='hidden_layer_1'),

    # HIDDEN LAYER 2: 64 neurons with ReLU
    # Input: 128 features from previous layer
    # Output: 64 features
    # Parameters: 128 x 64 + 64 (bias) = 8,256
    #
    # Layer 1: Low-level features (edges, strokes)
    # Layer 2: High-level features (curves, loops, digit shapes)
    Dense(64, activation='relu', name='hidden_layer_2'),

    # OUTPUT LAYER: 10 neurons with Softmax
    # Input: 64 features from previous layer
    # Output: 10 probabilities (one per digit 0-9)
    # Parameters: 64 x 10 + 10 (bias) = 650
    #
    # WHY Softmax?
    # - Converts raw scores to probabilities (sum = 1.0)
    # - Perfect for multi-class classification
    # - Example: [0.01, 0.02, 0.03, 0.89, ...] -> 89% confident it's digit 3
    Dense(10, activation='softmax', name='output_layer')
])

print("\u2713 Model architecture created")
print()
print("ARCHITECTURE SUMMARY:")
print("-" * 60)
print("Layer Structure (Funnel Design):")
print("  Input:    784 neurons  (28\u00d728 flattened image)")
print("           \u2193")
print("  Hidden 1: 128 neurons  (ReLU) - Feature extraction")
print("           \u2193")
print("  Hidden 2:  64 neurons  (ReLU) - Feature combination")
print("           \u2193")
print("  Output:    10 neurons  (Softmax) - Classification")
print()

# Display detailed model summary
print("DETAILED MODEL SUMMARY:")
print("-" * 60)
model.summary()

total_params = model.count_params()
print()
print(f"Total Parameters: {total_params:,}")
print("  These are the 'weights' that will be learned during training")
print("=" * 60)

## Cell 6: Compile Model
Compilation configures the learning process:

1. **Optimizer (Adam):** Algorithm that updates weights - combines momentum with adaptive learning rates
2. **Loss (Categorical Crossentropy):** Measures how wrong predictions are - `Loss = -sum(y_true * log(y_pred))`
3. **Metrics (Accuracy):** Performance measure to track - correct predictions / total predictions

In [None]:
# ============================================
# CELL 6: Compile Model
# ============================================

print("COMPILING THE MODEL")
print("=" * 60)
print()

# ============================================
# OPTIMIZER: ADAM (Adaptive Moment Estimation)
# ============================================
# WHY ADAM?
# - Best general-purpose optimizer for deep learning
# - Combines momentum + adaptive learning rates
# - Fast convergence with minimal hyperparameter tuning
#
# Default Hyperparameters:
# - learning_rate = 0.001
# - beta_1 = 0.9 (momentum decay)
# - beta_2 = 0.999 (variance decay)
# ============================================

print("1. OPTIMIZER: Adam")
print("-" * 60)
print("Adam (Adaptive Moment Estimation) is the gold-standard optimizer")
print()
print("How it works:")
print("  - Standard SGD: weight = weight - learning_rate \u00d7 gradient")
print("  - Adam: weight = weight - adaptive_lr \u00d7 momentum_adjusted_gradient")
print()
print("Key features:")
print("  \u2713 Adaptive learning rates per parameter")
print("  \u2713 Momentum (remembers past gradients)")
print("  \u2713 Fast convergence")
print("  \u2713 Minimal tuning needed")
print()

# ============================================
# LOSS FUNCTION: CATEGORICAL CROSSENTROPY
# ============================================
# Formula: Loss = -sum(y_true * log(y_pred))
#
# Good prediction (digit 3):  Loss = -log(0.85) = 0.163 (LOW)
# Bad prediction (digit 3):   Loss = -log(0.01) = 4.605 (HIGH)
# ============================================

print("2. LOSS FUNCTION: Categorical Crossentropy")
print("-" * 60)
print("Formula: Loss = -\u03a3(y_true \u00d7 log(y_pred))")
print()
print("Example:")
print("  Good prediction: Loss = -log(0.85) = 0.163  \u2190 LOW (good!)")
print("  Bad prediction:  Loss = -log(0.01) = 4.605  \u2190 HIGH (bad!)")
print()

print("3. METRICS: Accuracy")
print("-" * 60)
print("Accuracy = Correct Predictions / Total Predictions")
print()

# Compile the model
model.compile(
    optimizer='adam',                    # Adam optimizer with default params
    loss='categorical_crossentropy',     # Loss function for multi-class
    metrics=['accuracy']                 # Track accuracy during training
)

print("=" * 60)
print("\u2713 MODEL COMPILED SUCCESSFULLY")
print("=" * 60)
print("Ready for training with:")
print("  Optimizer: Adam (learning_rate=0.001)")
print("  Loss: Categorical Crossentropy")
print("  Metric: Accuracy")
print("=" * 60)

## Cell 7: Train the Model

Training is the process where the model learns to recognize digits:
1. Show model a batch of images
2. Model makes predictions
3. Calculate loss (how wrong the predictions are)
4. Backpropagation: Calculate gradients
5. Optimizer updates weights to reduce loss
6. Repeat!

**Training Parameters:**
- **Epochs:** 10 (complete passes through training data)
- **Batch Size:** 32 (samples per weight update)
- **Validation Split:** 10% (monitor overfitting)

In [None]:
# ============================================
# CELL 7: Train the Model
# ============================================

print("TRAINING THE MODEL")
print("=" * 60)
print()

# Training Parameters
EPOCHS = 10          # Number of complete passes through training data
BATCH_SIZE = 32      # Number of samples per weight update

print("Training Configuration:")
print("-" * 60)
print(f"Epochs: {EPOCHS}")
print("  \u2192 The model will see the entire training dataset 10 times")
print()
print(f"Batch Size: {BATCH_SIZE}")
print("  \u2192 Process 32 images at a time before updating weights")
print()
print(f"Validation Split: 10%")
print("  \u2192 Use 10% of training data to monitor overfitting")
print()
print("=" * 60)
print("STARTING TRAINING...")
print("=" * 60)
print()

# Train the model
# history object stores metrics for each epoch
history = model.fit(
    X_train,                    # Training images (flattened, normalized)
    y_train,                    # Training labels (one-hot encoded)
    epochs=EPOCHS,              # Number of complete passes through data
    batch_size=BATCH_SIZE,      # Samples per gradient update
    validation_split=0.1,       # Use 10% of training data for validation
    verbose=1                   # Display progress bar
)

print()
print("=" * 60)
print("\u2713 TRAINING COMPLETE!")
print("=" * 60)
print()
print("What just happened?")
print("  - Model saw 60,000 training images 10 times (10 epochs)")
print("  - Adjusted 109,386 parameters to minimize prediction error")
print("  - Loss decreased \u2192 Model learned patterns in digit images")
print("  - Accuracy increased \u2192 Model makes better predictions")
print()
print("Notice the trend:")
print("  - Early epochs: Big improvements (learning basic patterns)")
print("  - Later epochs: Small improvements (refining understanding)")
print("  - Validation accuracy close to training = good generalization")
print("=" * 60)

## Cell 8: Evaluate Model on Test Set

**Why test data?**
- Test set contains images the model has **NEVER** seen
- Measures real-world performance (generalization)
- Training accuracy can be misleading (overfitting)

**Overfitting vs. Generalization:**
- **Overfitting:** Model memorizes training data, fails on new data
- **Good generalization:** Model learns patterns, works on new data

In [None]:
# ============================================
# CELL 8: Evaluate Model
# ============================================

print("EVALUATING MODEL ON TEST SET")
print("=" * 60)
print()
print("Why evaluate on test set?")
print("  - Test data was NEVER seen during training")
print("  - Measures real-world performance")
print("  - Detects overfitting (memorization vs. learning)")
print()
print("Evaluating on 10,000 test images...")
print("-" * 60)

# Evaluate model on test set
test_loss, test_accuracy = model.evaluate(X_test, y_test, verbose=1)

print()
print("=" * 60)
print("TEST SET RESULTS")
print("=" * 60)
print(f"Test Loss: {test_loss:.4f}")
print(f"Test Accuracy: {test_accuracy * 100:.2f}%")
print()

# Interpret results
if test_accuracy >= 0.97:
    print("\u2713 EXCELLENT! Model performs very well (\u226597%)")
elif test_accuracy >= 0.95:
    print("\u2713 GOOD! Model meets target accuracy (\u226595%)")
elif test_accuracy >= 0.90:
    print("\u26a0 ACCEPTABLE but below target (90-95%)")
else:
    print("\u2717 POOR performance (<90%) - Model needs improvement")

print()
print("What does this mean?")
print(f"  - Out of 10,000 test images, model correctly classifies")
print(f"    approximately {int(test_accuracy * 10000):,} images")
print(f"  - Error rate: {(1 - test_accuracy) * 100:.2f}%")
print()

# Compare training vs test accuracy (check for overfitting)
final_train_accuracy = history.history['accuracy'][-1]
accuracy_gap = abs(final_train_accuracy - test_accuracy)

print("Overfitting Check:")
print(f"  Final Training Accuracy: {final_train_accuracy * 100:.2f}%")
print(f"  Test Accuracy: {test_accuracy * 100:.2f}%")
print(f"  Difference: {accuracy_gap * 100:.2f}%")

if accuracy_gap < 0.03:
    print("  \u2713 Good generalization (difference < 3%)")
elif accuracy_gap < 0.05:
    print("  \u26a0 Slight overfitting (difference 3-5%)")
else:
    print("  \u2717 Overfitting detected (difference > 5%)")
    print("    Model memorized training data instead of learning patterns")

print("=" * 60)

## Cell 9: Training History - Loss Curve

**How to interpret:**
- Both lines decreasing = Model is learning
- Lines close together = Good generalization
- Training << Validation = Overfitting
- Both lines flat/high = Underfitting

In [None]:
# ============================================
# CELL 9: Plot Training History (Loss Curve)
# ============================================

print("VISUALIZING TRAINING HISTORY - LOSS CURVE")
print("=" * 60)

plt.figure(figsize=(12, 5))

plt.plot(history.history['loss'], marker='o', linestyle='-', linewidth=2,
         label='Training Loss', color='#2E86AB')
plt.plot(history.history['val_loss'], marker='s', linestyle='--', linewidth=2,
         label='Validation Loss', color='#A23B72')

plt.title('Model Loss Over Training Epochs', fontsize=16, fontweight='bold', pad=20)
plt.xlabel('Epoch', fontsize=12, fontweight='bold')
plt.ylabel('Loss (Categorical Crossentropy)', fontsize=12, fontweight='bold')
plt.legend(loc='upper right', fontsize=11, frameon=True, shadow=True)
plt.grid(True, alpha=0.3, linestyle='--')
plt.xticks(range(EPOCHS))
plt.tight_layout()
plt.show()

print()
print("How to interpret this graph:")
print("-" * 60)
print("\u2713 Both lines decreasing = Model is learning")
print("\u2713 Lines close together = Good generalization")
print("\u2717 Training << Validation = Overfitting")
print("\u2717 Both lines flat/high = Underfitting")
print("=" * 60)

## Cell 10: Training History - Accuracy Curve

Accuracy is easier to interpret than loss (higher = better).

In [None]:
# ============================================
# CELL 10: Plot Training History (Accuracy Curve)
# ============================================

print("VISUALIZING TRAINING HISTORY - ACCURACY CURVE")
print("=" * 60)

plt.figure(figsize=(12, 5))

plt.plot(history.history['accuracy'], marker='o', linestyle='-', linewidth=2,
         label='Training Accuracy', color='#06A77D')
plt.plot(history.history['val_accuracy'], marker='s', linestyle='--', linewidth=2,
         label='Validation Accuracy', color='#D4526E')

plt.title('Model Accuracy Over Training Epochs', fontsize=16, fontweight='bold', pad=20)
plt.xlabel('Epoch', fontsize=12, fontweight='bold')
plt.ylabel('Accuracy', fontsize=12, fontweight='bold')
plt.legend(loc='lower right', fontsize=11, frameon=True, shadow=True)
plt.grid(True, alpha=0.3, linestyle='--')
plt.xticks(range(EPOCHS))
plt.ylim([0.9, 1.0])  # Focus on 90-100% range
plt.tight_layout()
plt.show()

print()
print("How to interpret this graph:")
print("-" * 60)
print("\u2713 Both lines increasing = Model is learning")
print("\u2713 Lines close together = Good generalization")
print("\u2717 Training >> Validation = Overfitting")
print("\u2022 Rapid improvement in early epochs (basic patterns)")
print("\u2022 Slower improvement in later epochs (fine-tuning)")
print("=" * 60)

## Cell 11: Confusion Matrix

A confusion matrix shows which digits the model confuses with each other:
- **Rows:** Actual/True labels
- **Columns:** Predicted labels
- **Diagonal:** Correct predictions (should be highest)
- **Off-diagonal:** Misclassifications (errors)

Common confusions: 4<->9, 3<->8, 5<->6, 7<->1

In [None]:
# ============================================
# CELL 11: Confusion Matrix
# ============================================

print("GENERATING CONFUSION MATRIX")
print("=" * 60)
print()

# Generate predictions on test set
y_pred_probs = model.predict(X_test, verbose=0)

# Convert predictions from probabilities to class labels
y_pred_classes = np.argmax(y_pred_probs, axis=1)

# Convert true labels from one-hot to class labels
y_true_classes = np.argmax(y_test, axis=1)

# Create confusion matrix
cm = confusion_matrix(y_true_classes, y_pred_classes)

print(f"\u2713 Generated predictions for {len(y_test):,} test images")
print()

# Visualize confusion matrix as heatmap
plt.figure(figsize=(12, 10))

sns.heatmap(cm,
            annot=True,              # Show numbers in cells
            fmt='d',                 # Integer format
            cmap='Blues',            # Color scheme
            xticklabels=range(10),
            yticklabels=range(10),
            cbar_kws={'label': 'Number of Predictions'},
            linewidths=0.5,
            linecolor='gray')

plt.title('Confusion Matrix - MNIST Test Set (10,000 images)',
          fontsize=16, fontweight='bold', pad=20)
plt.xlabel('Predicted Digit', fontsize=13, fontweight='bold')
plt.ylabel('True Digit', fontsize=13, fontweight='bold')
plt.tight_layout()
plt.show()

print()
print("How to read the confusion matrix:")
print("-" * 60)
print("Diagonal: Correct predictions (darker = more correct)")
print("Off-diagonal: Misclassifications (errors)")
print()

# Find most common confusion
max_confusion = 0
confused_pair = (0, 0)
for i in range(10):
    for j in range(10):
        if i != j and cm[i, j] > max_confusion:
            max_confusion = cm[i, j]
            confused_pair = (i, j)

print("Most common confusion:")
print(f"  Digit {confused_pair[0]} misclassified as {confused_pair[1]}: "
      f"{max_confusion} times")
print()

# Per-digit accuracy
print("Per-digit accuracy:")
print("-" * 60)
for digit in range(10):
    correct = cm[digit, digit]
    total = cm[digit].sum()
    digit_accuracy = correct / total * 100
    print(f"  Digit {digit}: {digit_accuracy:.2f}% ({correct}/{total} correct)")

print()
print("Common confusions to look for:")
print("  - 4 \u2194 9 (similar top parts)")
print("  - 3 \u2194 8 (similar curves)")
print("  - 5 \u2194 6 (similar shapes)")
print("  - 7 \u2194 1 (similar strokes)")
print("=" * 60)

## Cell 12: Custom Image Prediction Functions

These functions allow you to predict digits from your own handwritten images.

**Pipeline:**
1. Load image -> Convert to grayscale -> Resize to 28x28
2. Normalize to [0,1] -> Flatten to 784 -> Predict

**Image requirements:**
- Format: PNG, JPG, or JPEG
- Content: Single handwritten digit (0-9)
- Recommended: White digit on black background

In [None]:
# ============================================
# CELL 12: Custom Image Prediction Functions
# ============================================

print("CUSTOM IMAGE PREDICTION FUNCTIONS")
print("=" * 60)
print()

def preprocess_image(image_path):
    """
    Preprocess a custom image to match MNIST format.

    Pipeline:
    1. Load image from file
    2. Convert to grayscale (if RGB)
    3. Resize to 28x28 pixels
    4. Invert colors if needed (MNIST = white on black)
    5. Normalize to [0, 1]
    6. Flatten to 784-element vector
    7. Reshape to (1, 784) for model input

    Args:
        image_path (str): Path to image file

    Returns:
        numpy.ndarray: Preprocessed image ready for prediction (1, 784)
    """
    try:
        img = Image.open(image_path)
        img = img.convert('L')  # Grayscale
        img = img.resize((28, 28), Image.Resampling.LANCZOS)
        img_array = np.array(img)

        # Invert colors if needed (MNIST = white on black)
        # Uncomment if your image has black digit on white background:
        # img_array = 255 - img_array

        img_array = img_array.astype('float32') / 255.0
        img_array = img_array.reshape(1, 784)
        return img_array

    except FileNotFoundError:
        print(f"Error: File not found - {image_path}")
        return None
    except Exception as e:
        print(f"Error preprocessing image: {e}")
        return None


def predict_digit(model, image_path):
    """
    Predict the digit in a custom image.

    Args:
        model: Trained Keras model
        image_path (str): Path to image file

    Returns:
        tuple: (predicted_digit, confidence, all_probabilities)
    """
    img_array = preprocess_image(image_path)
    if img_array is None:
        return None, None, None

    predictions = model.predict(img_array, verbose=0)
    predicted_digit = np.argmax(predictions[0])
    confidence = predictions[0][predicted_digit] * 100
    all_probabilities = predictions[0]

    return predicted_digit, confidence, all_probabilities


def display_prediction(image_path, digit, confidence, probabilities):
    """
    Display the input image with prediction results.
    """
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 5))

    # Left: Original image
    img = Image.open(image_path).convert('L')
    ax1.imshow(img, cmap='gray')
    ax1.set_title('Input Image', fontsize=14, fontweight='bold')
    ax1.axis('off')

    # Right: Probability bar chart
    ax2.bar(range(10), probabilities * 100, color='steelblue', alpha=0.7)
    ax2.bar(digit, probabilities[digit] * 100, color='crimson', alpha=0.9)
    ax2.set_xlabel('Digit', fontsize=12, fontweight='bold')
    ax2.set_ylabel('Probability (%)', fontsize=12, fontweight='bold')
    ax2.set_title(f'Prediction: {digit} (Confidence: {confidence:.2f}%)',
                  fontsize=14, fontweight='bold')
    ax2.set_xticks(range(10))
    ax2.grid(True, alpha=0.3, axis='y')

    plt.tight_layout()
    plt.show()

    print()
    print("=" * 60)
    print("PREDICTION RESULTS")
    print("=" * 60)
    print(f"Predicted Digit: {digit}")
    print(f"Confidence: {confidence:.2f}%")
    print()
    print("All class probabilities:")
    for i, prob in enumerate(probabilities):
        marker = " \u2190 PREDICTED" if i == digit else ""
        print(f"  Digit {i}: {prob * 100:5.2f}%{marker}")
    print("=" * 60)


print("\u2713 Custom prediction functions defined:")
print("  - preprocess_image(image_path)")
print("  - predict_digit(model, image_path)")
print("  - display_prediction(image_path, digit, confidence, probabilities)")
print()
print("=" * 60)

## Cell 13: Test with a Custom Image (Google Colab)

**3 ways to test your own digit:**

| Method | How | Best For |
|--------|-----|----------|
| **A - Upload** | Click "Choose Files" button | Testing images from your computer |
| **B - Draw** | Draw a digit with your mouse directly in the notebook | Quick testing, no files needed |
| **C - Test Set** | Automatically picks a random test image | Quick demo, no input needed |

**Important:** Most handwritten images have a **black digit on white background**, but MNIST expects **white digit on black background**. The code auto-inverts colors for you.

In [None]:
# ============================================
# CELL 13: Test with Custom Image - 3 Methods
# ============================================
# Method A: Upload image from your computer
# Method B: Draw a digit directly in the notebook
# Method C: Test with a random image from the test set
# ============================================

import io
import base64

def predict_from_array(img_array_2d):
    """
    Predict digit from a 28x28 numpy array.
    Handles preprocessing: normalize, flatten, predict.
    """
    # Show the image as the model sees it
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 5))

    ax1.imshow(img_array_2d, cmap='gray')
    ax1.set_title('Image (as model sees it)', fontsize=14, fontweight='bold')
    ax1.axis('off')

    # Normalize and flatten
    img_flat = img_array_2d.astype('float32') / 255.0 if img_array_2d.max() > 1 else img_array_2d.astype('float32')
    img_flat = img_flat.reshape(1, 784)

    # Predict
    predictions = model.predict(img_flat, verbose=0)
    predicted_digit = np.argmax(predictions[0])
    confidence = predictions[0][predicted_digit] * 100

    # Probability bar chart
    ax2.bar(range(10), predictions[0] * 100, color='steelblue', alpha=0.7)
    ax2.bar(predicted_digit, predictions[0][predicted_digit] * 100, color='crimson', alpha=0.9)
    ax2.set_xlabel('Digit', fontsize=12, fontweight='bold')
    ax2.set_ylabel('Probability (%)', fontsize=12, fontweight='bold')
    ax2.set_title(f'Prediction: {predicted_digit} ({confidence:.1f}%)',
                  fontsize=14, fontweight='bold')
    ax2.set_xticks(range(10))
    ax2.grid(True, alpha=0.3, axis='y')

    plt.tight_layout()
    plt.show()

    print(f"\nPredicted Digit: {predicted_digit}")
    print(f"Confidence: {confidence:.2f}%")
    print()
    print("All class probabilities:")
    for i, prob in enumerate(predictions[0]):
        marker = " <-- PREDICTED" if i == predicted_digit else ""
        print(f"  Digit {i}: {prob * 100:5.2f}%{marker}")

    return predicted_digit, confidence


# ============================================
# Choose your method below!
# Uncomment ONE method and run the cell.
# ============================================

# ==========================================
# METHOD A: Upload image from your computer
# ==========================================
# Works in Google Colab - shows a "Choose Files" button.
# Supports PNG, JPG, JPEG images.

# try:
#     from google.colab import files
#     print("METHOD A: Upload Image")
#     print("=" * 60)
#     print("Click 'Choose Files' below to upload a handwritten digit image:")
#     print()
#     uploaded = files.upload()
#
#     for filename in uploaded.keys():
#         print(f"\nProcessing: {filename}")
#         print("-" * 60)
#
#         # Load and preprocess
#         img = Image.open(filename).convert('L')        # Convert to grayscale
#         img = img.resize((28, 28), Image.Resampling.LANCZOS)  # Resize to 28x28
#         img_array = np.array(img)
#
#         # Auto-invert: if background is bright (white paper), invert colors
#         # MNIST expects white digit on black background
#         if img_array.mean() > 127:
#             print("  (Auto-inverted colors: black-on-white -> white-on-black)")
#             img_array = 255 - img_array
#
#         predict_from_array(img_array)
#
# except ImportError:
#     print("Not running on Google Colab. Use Method B or C instead.")


# ==========================================
# METHOD B: Draw a digit with your mouse
# ==========================================
# Creates an interactive canvas in Google Colab.
# Draw a digit, then click "Predict" button.

# try:
#     from google.colab import output
#     from IPython.display import HTML, display
#
#     print("METHOD B: Draw a Digit")
#     print("=" * 60)
#     print("Draw a digit (0-9) in the black box below, then click 'Predict'")
#     print()
#
#     canvas_html = """
#     <div style="text-align:center">
#     <canvas id="canvas" width="280" height="280"
#             style="border:2px solid #666; cursor:crosshair; background:black;"></canvas>
#     <br><br>
#     <button onclick="predict()" style="padding:10px 30px; font-size:16px;
#             background:#4CAF50; color:white; border:none; border-radius:5px;
#             cursor:pointer; margin:5px;">Predict</button>
#     <button onclick="clearCanvas()" style="padding:10px 30px; font-size:16px;
#             background:#f44336; color:white; border:none; border-radius:5px;
#             cursor:pointer; margin:5px;">Clear</button>
#     </div>
#
#     <script>
#     var canvas = document.getElementById('canvas');
#     var ctx = canvas.getContext('2d');
#     var drawing = false;
#
#     // Set up drawing style (white pen on black background)
#     ctx.fillStyle = 'black';
#     ctx.fillRect(0, 0, 280, 280);
#     ctx.strokeStyle = 'white';
#     ctx.lineWidth = 18;
#     ctx.lineCap = 'round';
#     ctx.lineJoin = 'round';
#
#     canvas.addEventListener('mousedown', function(e) {
#         drawing = true;
#         ctx.beginPath();
#         var rect = canvas.getBoundingClientRect();
#         ctx.moveTo(e.clientX - rect.left, e.clientY - rect.top);
#     });
#
#     canvas.addEventListener('mousemove', function(e) {
#         if (drawing) {
#             var rect = canvas.getBoundingClientRect();
#             ctx.lineTo(e.clientX - rect.left, e.clientY - rect.top);
#             ctx.stroke();
#         }
#     });
#
#     canvas.addEventListener('mouseup', function() { drawing = false; });
#     canvas.addEventListener('mouseleave', function() { drawing = false; });
#
#     // Touch support for mobile/tablet
#     canvas.addEventListener('touchstart', function(e) {
#         e.preventDefault();
#         drawing = true;
#         ctx.beginPath();
#         var rect = canvas.getBoundingClientRect();
#         var touch = e.touches[0];
#         ctx.moveTo(touch.clientX - rect.left, touch.clientY - rect.top);
#     });
#
#     canvas.addEventListener('touchmove', function(e) {
#         e.preventDefault();
#         if (drawing) {
#             var rect = canvas.getBoundingClientRect();
#             var touch = e.touches[0];
#             ctx.lineTo(touch.clientX - rect.left, touch.clientY - rect.top);
#             ctx.stroke();
#         }
#     });
#
#     canvas.addEventListener('touchend', function() { drawing = false; });
#
#     function clearCanvas() {
#         ctx.fillStyle = 'black';
#         ctx.fillRect(0, 0, 280, 280);
#     }
#
#     function predict() {
#         var dataURL = canvas.toDataURL('image/png');
#         google.colab.kernel.invokeFunction('notebook.predict_drawn', [dataURL], {});
#     }
#     </script>
#     """
#
#     def predict_drawn(data_url):
#         """Callback: receives the drawn image from the canvas."""
#         # Decode base64 image data
#         header, data = data_url.split(',', 1)
#         image_bytes = base64.b64decode(data)
#         img = Image.open(io.BytesIO(image_bytes)).convert('L')
#         img = img.resize((28, 28), Image.Resampling.LANCZOS)
#         img_array = np.array(img)
#         predict_from_array(img_array)
#
#     output.register_callback('notebook.predict_drawn', predict_drawn)
#     display(HTML(canvas_html))
#
# except ImportError:
#     print("Not running on Google Colab. Use Method C instead.")


# ==========================================
# METHOD C: Test with random test set image
# ==========================================
# No upload needed - picks a random image from the MNIST test set.
# Run this cell multiple times to see different examples.

print("METHOD C: Random Test Set Image")
print("=" * 60)
print("Picking a random image from the 10,000 test images...")
print()

random_idx = np.random.randint(0, len(X_test))
test_image_flat = X_test[random_idx]
actual_digit = np.argmax(y_test[random_idx])

# Convert back to 28x28 for display (already normalized to 0-1)
test_image_2d = (test_image_flat.reshape(28, 28) * 255).astype(np.uint8)

print(f"Test image index: {random_idx}")
print(f"True label: {actual_digit}")
print()

predicted, confidence = predict_from_array(test_image_2d)

print()
if predicted == actual_digit:
    print(f"CORRECT! Model predicted {predicted}, actual is {actual_digit}")
else:
    print(f"WRONG! Model predicted {predicted}, but actual is {actual_digit}")
print("=" * 60)
print()
print("Run this cell again to test another random image!")

## Cell 13A: Upload Your Own Image (Google Colab)
Run this cell to upload a handwritten digit image from your computer.
The code **auto-inverts colors** (black-on-white -> white-on-black) to match MNIST format.

In [None]:
# ============================================
# CELL 13A: Upload Image from Your Computer
# ============================================
# Run this cell -> Click "Choose Files" -> Select your image -> See prediction!
# ============================================

from google.colab import files

print("METHOD A: Upload Image from Your Computer")
print("=" * 60)
print("Click 'Choose Files' below to upload a handwritten digit image (PNG/JPG):")
print()

uploaded = files.upload()

for filename in uploaded.keys():
    print(f"\nProcessing: {filename}")
    print("-" * 60)

    # Load and preprocess the uploaded image
    img = Image.open(filename).convert('L')                # Convert to grayscale
    img = img.resize((28, 28), Image.Resampling.LANCZOS)   # Resize to 28x28
    img_array = np.array(img)

    # Auto-invert colors if needed
    # Most handwritten images: black digit on white paper (mean > 127)
    # MNIST format: white digit on black background (mean < 127)
    if img_array.mean() > 127:
        print("  (Auto-inverted colors: black-on-white -> white-on-black)")
        img_array = 255 - img_array

    predict_from_array(img_array)

print("=" * 60)

## Cell 13B: Draw a Digit with Your Mouse (Google Colab)
Run this cell to get an interactive drawing canvas. Draw a digit with your mouse, then click **"Predict"**.
Click **"Clear"** to start over. Works on mobile with touch too!

In [None]:
# ============================================
# CELL 13B: Draw a Digit with Your Mouse
# ============================================
# A black canvas appears below. Draw a white digit with your mouse.
# Click "Predict" to see the model's prediction.
# Click "Clear" to erase and try again.
# ============================================

from google.colab import output
from IPython.display import HTML, display

print("METHOD B: Draw a Digit")
print("=" * 60)
print("Draw a digit (0-9) in the black box below, then click 'Predict'")
print()

canvas_html = """
<div style="text-align:center">
<canvas id="canvas" width="280" height="280"
        style="border:2px solid #666; cursor:crosshair; background:black;"></canvas>
<br><br>
<button onclick="predict()" style="padding:10px 30px; font-size:16px;
        background:#4CAF50; color:white; border:none; border-radius:5px;
        cursor:pointer; margin:5px;">Predict</button>
<button onclick="clearCanvas()" style="padding:10px 30px; font-size:16px;
        background:#f44336; color:white; border:none; border-radius:5px;
        cursor:pointer; margin:5px;">Clear</button>
<p style="color:#888; margin-top:10px;">Draw a large, centered digit for best results</p>
</div>

<script>
var canvas = document.getElementById('canvas');
var ctx = canvas.getContext('2d');
var drawing = false;

// Black background, white pen (matches MNIST format)
ctx.fillStyle = 'black';
ctx.fillRect(0, 0, 280, 280);
ctx.strokeStyle = 'white';
ctx.lineWidth = 18;
ctx.lineCap = 'round';
ctx.lineJoin = 'round';

// Mouse events
canvas.addEventListener('mousedown', function(e) {
    drawing = true;
    ctx.beginPath();
    var rect = canvas.getBoundingClientRect();
    ctx.moveTo(e.clientX - rect.left, e.clientY - rect.top);
});

canvas.addEventListener('mousemove', function(e) {
    if (drawing) {
        var rect = canvas.getBoundingClientRect();
        ctx.lineTo(e.clientX - rect.left, e.clientY - rect.top);
        ctx.stroke();
    }
});

canvas.addEventListener('mouseup', function() { drawing = false; });
canvas.addEventListener('mouseleave', function() { drawing = false; });

// Touch events (mobile/tablet)
canvas.addEventListener('touchstart', function(e) {
    e.preventDefault();
    drawing = true;
    ctx.beginPath();
    var rect = canvas.getBoundingClientRect();
    var touch = e.touches[0];
    ctx.moveTo(touch.clientX - rect.left, touch.clientY - rect.top);
});

canvas.addEventListener('touchmove', function(e) {
    e.preventDefault();
    if (drawing) {
        var rect = canvas.getBoundingClientRect();
        var touch = e.touches[0];
        ctx.lineTo(touch.clientX - rect.left, touch.clientY - rect.top);
        ctx.stroke();
    }
});

canvas.addEventListener('touchend', function() { drawing = false; });

function clearCanvas() {
    ctx.fillStyle = 'black';
    ctx.fillRect(0, 0, 280, 280);
}

function predict() {
    var dataURL = canvas.toDataURL('image/png');
    google.colab.kernel.invokeFunction('notebook.predict_drawn', [dataURL], {});
}
</script>
"""

def predict_drawn(data_url):
    """Callback: receives the drawn image from the JavaScript canvas."""
    # Decode base64 image data from canvas
    header, data = data_url.split(',', 1)
    image_bytes = base64.b64decode(data)

    # Open as grayscale, resize to 28x28 (MNIST size)
    img = Image.open(io.BytesIO(image_bytes)).convert('L')
    img = img.resize((28, 28), Image.Resampling.LANCZOS)
    img_array = np.array(img)

    predict_from_array(img_array)

# Register the JavaScript -> Python callback
output.register_callback('notebook.predict_drawn', predict_drawn)

# Display the canvas
display(HTML(canvas_html))

## Cell 14: Program Summary

Recap of everything we accomplished and key learnings.

In [None]:
# ============================================
# CELL 14: Program Summary
# ============================================

print()
print("=" * 60)
print("PROGRAM EXECUTION COMPLETE!")
print("=" * 60)
print()
print("Summary of what we accomplished:")
print("-" * 60)
print("\u2713 1. GPU Configuration")
print("     - Detected and configured GPU for acceleration")
print()
print("\u2713 2. Data Loading")
print("     - Loaded 60,000 training + 10,000 test images")
print()
print("\u2713 3. Sample Preview")
print("     - Displayed sample digits (6, 7, 8, 9)")
print()
print("\u2713 4. Data Preprocessing")
print("     - Normalized, flattened, one-hot encoded")
print()
print("\u2713 5. Model Building")
print("     - Neural network: 784 \u2192 128 \u2192 64 \u2192 10")
print(f"     - {model.count_params():,} trainable parameters")
print()
print("\u2713 6. Model Compilation")
print("     - Adam + Categorical Crossentropy")
print()
print("\u2713 7. Model Training")
print(f"     - {EPOCHS} epochs, batch size {BATCH_SIZE}")
print()
print("\u2713 8. Model Evaluation")
print(f"     - Test Accuracy: {test_accuracy * 100:.2f}%")
print(f"     - Test Loss: {test_loss:.4f}")
print()
print("\u2713 9. Visualizations")
print("     - Loss curve, Accuracy curve, Confusion matrix")
print()
print("\u2713 10. Custom Prediction")
print("     - Functions ready for custom image prediction")
print()
print("=" * 60)
print()
print("Key Learnings:")
print("-" * 60)
print("1. Neural networks learn by adjusting weights to minimize loss")
print("2. Preprocessing (normalization, encoding) is critical for training")
print("3. Architecture choices affect model capacity and performance")
print("4. Adam optimizer + categorical crossentropy = good for classification")
print("5. Test accuracy measures real-world generalization")
print("6. Confusion matrix reveals systematic prediction errors")
print()
print("Next Steps:")
print("-" * 60)
print("1. Experiment with different architectures (more/fewer layers)")
print("2. Try different activation functions (tanh, LeakyReLU)")
print("3. Add dropout layers to prevent overfitting")
print("4. Implement convolutional layers (CNN) for better accuracy")
print("5. Test on your own handwritten digit images")
print()
print("=" * 60)
print("Thank you for learning with this educational implementation!")
print("=" * 60)