# MNIST CNN Training for TensorFlow.js

This notebook trains a Convolutional Neural Network (CNN) on the MNIST dataset and converts it to TensorFlow.js format for use in the web application.

## Overview
- Load and preprocess MNIST data
- Build a CNN architecture optimized for digit recognition
- Train the model with data augmentation
- Evaluate performance
- Convert to TensorFlow.js format


In [None]:
# Import required libraries
import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
from tensorflow import keras
from tensorflow.keras import layers
import tensorflowjs as tfjs
import os

print(f"TensorFlow version: {tf.__version__}")
print(f"GPU available: {tf.config.list_physical_devices('GPU')}")

# Set random seeds for reproducibility
tf.random.set_seed(42)
np.random.seed(42)


## 1. Load and Preprocess MNIST Data


In [None]:
# Load MNIST dataset
(x_train, y_train), (x_test, y_test) = keras.datasets.mnist.load_data()

print(f"Training data shape: {x_train.shape}")
print(f"Training labels shape: {y_train.shape}")
print(f"Test data shape: {x_test.shape}")
print(f"Test labels shape: {y_test.shape}")

# Normalize pixel values to [0, 1] range
x_train = x_train.astype('float32') / 255.0
x_test = x_test.astype('float32') / 255.0

# Reshape data to add channel dimension (28, 28, 1)
x_train = x_train.reshape(x_train.shape[0], 28, 28, 1)
x_test = x_test.reshape(x_test.shape[0], 28, 28, 1)

# Convert labels to categorical one-hot encoding
y_train = keras.utils.to_categorical(y_train, 10)
y_test = keras.utils.to_categorical(y_test, 10)

print(f"Preprocessed training data shape: {x_train.shape}")
print(f"Preprocessed training labels shape: {y_train.shape}")


In [None]:
# Visualize some sample images
plt.figure(figsize=(12, 6))
for i in range(10):
    plt.subplot(2, 5, i + 1)
    plt.imshow(x_train[i].reshape(28, 28), cmap='gray')
    plt.title(f'Label: {np.argmax(y_train[i])}')
    plt.axis('off')
plt.suptitle('Sample MNIST Images')
plt.tight_layout()
plt.show()


## 2. Build CNN Architecture

We'll create a lightweight CNN that's optimized for web deployment while maintaining good accuracy.


In [None]:
def create_cnn_model():
    """
    Create a CNN model optimized for MNIST digit recognition and web deployment.
    
    Architecture:
    - 2 Convolutional blocks with MaxPooling
    - Dropout for regularization
    - Dense layers for classification
    - Optimized for small size and fast inference
    """
    model = keras.Sequential([
        # First convolutional block
        layers.Conv2D(32, (3, 3), activation='relu', input_shape=(28, 28, 1)),
        layers.MaxPooling2D((2, 2)),
        layers.Dropout(0.25),
        
        # Second convolutional block
        layers.Conv2D(64, (3, 3), activation='relu'),
        layers.MaxPooling2D((2, 2)),
        layers.Dropout(0.25),
        
        # Flatten and dense layers
        layers.Flatten(),
        layers.Dense(128, activation='relu'),
        layers.Dropout(0.5),
        layers.Dense(10, activation='softmax')  # 10 classes for digits 0-9
    ])
    
    return model

# Create the model
model = create_cnn_model()

# Display model architecture
model.summary()


In [None]:
# Compile the model
model.compile(
    optimizer='adam',
    loss='categorical_crossentropy',
    metrics=['accuracy']
)

# Create data augmentation for better generalization
datagen = keras.preprocessing.image.ImageDataGenerator(
    rotation_range=10,
    zoom_range=0.1,
    width_shift_range=0.1,
    height_shift_range=0.1
)

datagen.fit(x_train)


## 3. Train the Model


In [None]:
# Define callbacks for training
callbacks = [
    keras.callbacks.EarlyStopping(
        monitor='val_accuracy',
        patience=3,
        restore_best_weights=True
    ),
    keras.callbacks.ReduceLROnPlateau(
        monitor='val_loss',
        factor=0.2,
        patience=2,
        min_lr=0.0001
    )
]

# Train the model
print("Starting training...")
history = model.fit(
    datagen.flow(x_train, y_train, batch_size=128),
    steps_per_epoch=len(x_train) // 128,
    epochs=15,
    validation_data=(x_test, y_test),
    callbacks=callbacks,
    verbose=1
)

print("Training completed!")


## 4. Evaluate Model Performance


In [None]:
# Evaluate the model on test data
test_loss, test_accuracy = model.evaluate(x_test, y_test, verbose=0)
print(f"Test accuracy: {test_accuracy:.4f}")
print(f"Test loss: {test_loss:.4f}")

# Plot training history
plt.figure(figsize=(12, 4))

plt.subplot(1, 2, 1)
plt.plot(history.history['accuracy'], label='Training Accuracy')
plt.plot(history.history['val_accuracy'], label='Validation Accuracy')
plt.title('Model Accuracy')
plt.xlabel('Epoch')
plt.ylabel('Accuracy')
plt.legend()

plt.subplot(1, 2, 2)
plt.plot(history.history['loss'], label='Training Loss')
plt.plot(history.history['val_loss'], label='Validation Loss')
plt.title('Model Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.legend()

plt.tight_layout()
plt.show()


In [None]:
# Test predictions on sample images
predictions = model.predict(x_test[:10])
predicted_classes = np.argmax(predictions, axis=1)
true_classes = np.argmax(y_test[:10], axis=1)

plt.figure(figsize=(15, 6))
for i in range(10):
    plt.subplot(2, 5, i + 1)
    plt.imshow(x_test[i].reshape(28, 28), cmap='gray')
    plt.title(f'True: {true_classes[i]}, Pred: {predicted_classes[i]}')
    plt.axis('off')
plt.suptitle('Sample Predictions')
plt.tight_layout()
plt.show()

# Print confidence scores
print("Prediction confidence scores:")
for i in range(10):
    confidence = np.max(predictions[i])
    print(f"Image {i}: Predicted {predicted_classes[i]} with {confidence:.3f} confidence")


## 5. Convert Model to TensorFlow.js Format

This step converts the trained model to a format that can be used in the web browser.


In [None]:
# Create directory for the TensorFlow.js model
model_dir = './tfjs_model'
os.makedirs(model_dir, exist_ok=True)

# Convert and save the model in TensorFlow.js format
print("Converting model to TensorFlow.js format...")
tfjs.converters.save_keras_model(model, model_dir)
print(f"Model saved to {model_dir}")

# Display model file information
import os
print("\nGenerated files:")
for file in os.listdir(model_dir):
    file_path = os.path.join(model_dir, file)
    size = os.path.getsize(file_path)
    print(f"  {file}: {size:,} bytes")

print(f"\nTotal model size: {sum(os.path.getsize(os.path.join(model_dir, f)) for f in os.listdir(model_dir)):,} bytes")


## 6. Test Model with Canvas-like Input

Let's test how the model performs with images similar to what users will draw on the canvas.


In [None]:
def preprocess_canvas_image(image_array):
    """
    Preprocess an image array to match the model's expected input format.
    This simulates the preprocessing that will happen in the web app.
    """
    # Ensure the image is 28x28
    if image_array.shape != (28, 28):
        # In the web app, we'll resize the canvas drawing to 28x28
        pass
    
    # Normalize to [0, 1] range
    image_array = image_array.astype('float32') / 255.0
    
    # Add batch and channel dimensions
    image_array = image_array.reshape(1, 28, 28, 1)
    
    return image_array

# Test with a sample image
test_image = x_test[0].reshape(28, 28) * 255  # Convert back to 0-255 range
processed_image = preprocess_canvas_image(test_image)

# Make prediction
prediction = model.predict(processed_image, verbose=0)
predicted_digit = np.argmax(prediction)
confidence = np.max(prediction)

print(f"Predicted digit: {predicted_digit}")
print(f"Confidence: {confidence:.4f}")
print(f"All probabilities: {prediction[0]}")

# Visualize
plt.figure(figsize=(8, 3))
plt.subplot(1, 2, 1)
plt.imshow(test_image, cmap='gray')
plt.title(f'Input Image')
plt.axis('off')

plt.subplot(1, 2, 2)
plt.bar(range(10), prediction[0])
plt.title(f'Prediction Probabilities')
plt.xlabel('Digit')
plt.ylabel('Probability')
plt.xticks(range(10))
plt.tight_layout()
plt.show()


## 7. Save the Model

Save the trained model in multiple formats for backup and future use.


In [None]:
# Save the model in native Keras format (.keras)
model_path = './mnist_cnn_model.keras'
model.save(model_path)
print(f"Model saved as {model_path}")

# Save model architecture as JSON (optional)
model_json = model.to_json()
with open('./mnist_cnn_architecture.json', 'w') as json_file:
    json_file.write(model_json)
print("Model architecture saved as mnist_cnn_architecture.json")

# Save model weights separately (HDF5 requires filename ending with `.weights.h5`)
model.save_weights('./mnist_cnn_model.weights.h5')
print("Model weights saved as mnist_cnn_model.weights.h5")

# Save training history
import pickle
with open('./training_history.pkl', 'wb') as f:
    pickle.dump(history.history, f)
print("Training history saved as training_history.pkl")

# Print summary of saved files
print("\n" + "="*50)
print("TRAINING COMPLETE - FILES SAVED:")
print("="*50)
print(f"✓ TensorFlow.js model: ./tfjs_model/")
print(f"✓ Keras model: {model_path}")
print(f"✓ Model architecture: ./mnist_cnn_architecture.json")
print(f"✓ Model weights: ./mnist_cnn_weights.h5")
print(f"✓ Training history: ./training_history.pkl")
print(f"✓ Final test accuracy: {test_accuracy:.4f}")
print("="*50)
