In [None]:
# =====================================
# 1. IMPORT LIBRARIES
# =====================================

import os                              # For file path and checking saved model
import numpy as np                     # For numerical operations and arrays
import matplotlib.pyplot as plt        # For displaying images and plots
from PIL import Image, ImageOps        # Pillow for image loading and manipulation
import tensorflow as tf                # TensorFlow for deep learning
from tensorflow.keras import layers, models, callbacks  # Keras layers, model, and callbacks
from tensorflow.keras.preprocessing.image import ImageDataGenerator  # For data augmentation

# =====================================
# 2. CONFIGURATION
# =====================================

MODEL_FILENAME = "mnist_tf_robust.h5"  # File name to save/load the trained model
IMG_SIZE = 28                          # MNIST images are 28x28 pixels
BATCH_SIZE = 128                       # Number of images processed at once during training
EPOCHS = 8                             # Number of passes through the dataset

# Function to normalize image data to range [-1, 1]
def normalize_fn(x):
    return (x.astype("float32") / 255.0 - 0.5) / 0.5

# =====================================
# 3. PREPROCESSING FUNCTION
# =====================================
def preprocess_pil_to_mnist(pil_img):
    """
    Convert any input image into MNIST-like format:
    - Convert to grayscale
    - Invert if needed
    - Crop to the digit area with padding
    - Resize to 20x20
    - Place onto a 28x28 black background
    - Normalize pixel values
    """
    
    img = pil_img.convert("L")  # Convert image to grayscale ("L" mode = 8-bit pixels)
    
    # If image is huge, resize down proportionally to avoid memory issues
    if max(img.size) > 1024:
        ratio = 1024 / max(img.size)
        img = img.resize((int(img.size[0]*ratio), int(img.size[1]*ratio)), Image.LANCZOS)

    np_img = np.array(img)  # Convert to NumPy array for pixel analysis

    # If average brightness is high, it’s likely a white background → invert colors
    if np.mean(np_img) > 127:
        img = ImageOps.invert(img)
        np_img = np.array(img)

    # Create a mask where the digit (foreground) exists
    mask = np_img > 10
    if not mask.any():  # If image is empty, just resize directly to MNIST format
        img_small = img.resize((28,28), Image.LANCZOS)
        arr = normalize_fn(np.array(img_small))
        return arr.reshape(28,28,1)

    # Get bounding box of the digit
    coords = np.column_stack(np.where(mask))
    y_min, x_min = coords.min(axis=0)
    y_max, x_max = coords.max(axis=0)

    # Add some padding around the digit
    pad = 8
    x_min = max(x_min - pad, 0)
    y_min = max(y_min - pad, 0)
    x_max = min(x_max + pad, img.size[0])
    y_max = min(y_max + pad, img.size[1])

    # Crop to just the digit with padding
    cropped = img.crop((x_min, y_min, x_max, y_max))

    # Ensure the digit is white on black
    if np.mean(np.array(cropped)) > 127:
        cropped = ImageOps.invert(cropped)

    # Resize the digit to fit inside a 20x20 box while keeping aspect ratio
    cropped.thumbnail((20,20), Image.LANCZOS)

    # Create a 28x28 black background
    canvas = Image.new("L", (28,28), color=0)
    
    # Calculate position to paste the digit so it is centered
    paste_x = (28 - cropped.size[0]) // 2
    paste_y = (28 - cropped.size[1]) // 2
    canvas.paste(cropped, (paste_x, paste_y))

    # Normalize pixel values to [-1, 1]
    arr = normalize_fn(np.array(canvas))
    return arr.reshape(28,28,1)

# =====================================
# 4. BUILD MODEL
# =====================================
def build_model():
    """
    Builds a Convolutional Neural Network (CNN) for MNIST digit recognition.
    """
    model = models.Sequential([
        layers.Input(shape=(28,28,1)),                       # Input: 28x28 grayscale
        layers.Conv2D(32, (3,3), activation='relu', padding='same'), # Conv layer 1
        layers.BatchNormalization(),                         # Normalize activations
        layers.MaxPooling2D((2,2)),                          # Downsample
        layers.Dropout(0.25),                                # Dropout to prevent overfitting

        layers.Conv2D(64, (3,3), activation='relu', padding='same'), # Conv layer 2
        layers.BatchNormalization(),
        layers.MaxPooling2D((2,2)),
        layers.Dropout(0.25),

        layers.Flatten(),                                    # Flatten to vector
        layers.Dense(256, activation='relu'),                # Fully connected layer
        layers.BatchNormalization(),
        layers.Dropout(0.4),
        layers.Dense(128, activation='relu'),
        layers.BatchNormalization(),
        layers.Dropout(0.3),
        layers.Dense(10, activation='softmax')               # Output: 10 digits (0–9)
    ])
    return model

# =====================================
# 5. LOAD + AUGMENT DATA
# =====================================
def prepare_training_data():
    """
    Loads the MNIST dataset, normalizes it, and performs data augmentation.
    """
    (x_train, y_train), (x_test, y_test) = tf.keras.datasets.mnist.load_data()

    # Reshape and normalize to [-1, 1]
    x_train = normalize_fn(x_train.reshape(-1,28,28,1))
    x_test = normalize_fn(x_test.reshape(-1,28,28,1))

    # Data augmentation: random rotation, shifting, zooming
    datagen = ImageDataGenerator(
        rotation_range=15,
        width_shift_range=0.10,
        height_shift_range=0.10,
        shear_range=10,
        zoom_range=0.1,
        fill_mode='nearest'
    )

    # Augment the entire training set
    x_aug, y_aug = next(datagen.flow(x_train, y_train, batch_size=len(x_train), shuffle=False))
    
    # Combine original and augmented data
    x_combined = np.concatenate([x_train, x_aug], axis=0)
    y_combined = np.concatenate([y_train, y_aug], axis=0)

    # Shuffle the combined dataset
    perm = np.random.permutation(len(x_combined))
    return x_combined[perm], y_combined[perm], x_test, y_test

# =====================================
# 6. TRAIN / LOAD MODEL
# =====================================
def train_and_save_if_needed():
    """
    Trains the model if no saved model exists, otherwise loads the saved model.
    """
    if os.path.exists(MODEL_FILENAME):
        print("Loading existing model from file...")
        return tf.keras.models.load_model(MODEL_FILENAME)

    # Prepare data
    x_train, y_train, x_val, y_val = prepare_training_data()
    print(f"Training samples: {len(x_train)}, Validation samples: {len(x_val)}")

    # Build model
    model = build_model()

    # Compile with Adam optimizer
    model.compile(optimizer='adam',
                  loss='sparse_categorical_crossentropy',
                  metrics=['accuracy'])

    # Callbacks: reduce learning rate, save best model, early stopping
    cb = [
        callbacks.ReduceLROnPlateau(monitor='val_loss', factor=0.5, patience=2, verbose=1),
        callbacks.ModelCheckpoint(MODEL_FILENAME, save_best_only=True, monitor='val_loss', verbose=1),
        callbacks.EarlyStopping(monitor='val_loss', patience=5, restore_best_weights=True, verbose=1)
    ]

    # Train the model
    model.fit(x_train, y_train,
              validation_data=(x_val, y_val),
              epochs=EPOCHS,
              batch_size=BATCH_SIZE,
              callbacks=cb,
              verbose=2)

    # Save the trained model
    model.save(MODEL_FILENAME)
    print("Model saved successfully.")
    return model

# =====================================
# 7. PREDICT FROM FILE
# =====================================
def predict_from_filepath(image_path, model):
    """
    Loads an image, preprocesses it, predicts the digit, and displays the result.
    """
    pil = Image.open(image_path).convert("RGB")  # Load and ensure RGB format
    processed = preprocess_pil_to_mnist(pil)     # Convert to MNIST-like
    inp = np.expand_dims(processed, axis=0)      # Add batch dimension

    # Predict digit (highest probability)
    pred = np.argmax(model.predict(inp, verbose=0), axis=1)[0]

    # Display original and processed image side by side
    plt.figure(figsize=(6,3))
    plt.subplot(1,2,1)
    plt.imshow(pil.convert("L"), cmap='gray')
    plt.title("Original Image")
    plt.axis('off')

    plt.subplot(1,2,2)
    plt.imshow(processed.squeeze(), cmap='gray')
    plt.title(f"Predicted: {pred}")
    plt.axis('off')
    plt.show()

    return pred

# =====================================
# 8. RUN TRAINING (OR LOAD MODEL)
# =====================================
model = train_and_save_if_needed()

# =====================================
# 9. TEST ON YOUR OWN IMAGE
# =====================================
# Example usage after placing your image in the same folder:
# predict_from_filepath("my_digit.png", model)
