In [None]:
import numpy as np
import pandas as pd
import os
import matplotlib.pyplot as plt
import tensorflow as tf
from tensorflow import keras
import tensorflow_datasets as tfds
import random
from utils import plot_samples_with_labels, classify, plot_conf_matrix

### Load and prepare data

In [None]:
# Download the MNIST dataset (handwritten digit images with labels)
mnist = keras.datasets.mnist
(training_images, training_labels), (test_images, test_labels) = mnist.load_data()

In [None]:
training_images = training_images / 255.0
test_images = test_images / 255.0

# Reshape the data to include a channel dimension
training_images = training_images.reshape(training_images.shape[0], 28, 28, 1)
test_images = test_images.reshape(test_images.shape[0], 28, 28, 1)

### Build the model

In [None]:
from tensorflow.keras import layers, models

# Define the model
model = models.Sequential([
    # First convolutional layer
    layers.Conv2D(16, (3, 3), activation='relu', input_shape=(28, 28, 1)),
    layers.MaxPooling2D((2, 2)),
    
    # Second convolutional layer
    layers.Conv2D(32, (3, 3), activation='relu'),
    layers.MaxPooling2D((2, 2)),
    
    # Third convolutional layer
    layers.Conv2D(32, (3, 3), activation='relu'),
    
    # Fourth convolutional layer
    layers.Conv2D(32, (3, 3), activation='relu'),
    
    # Flattening the 3D output to 1D before feeding it into the dense layer
    layers.Flatten(),
    
    # Dense layers for classification
    layers.Dense(64, activation='relu'),
    layers.Dense(10, activation='softmax')  # For 10 classes
])

# Compile the model
model.compile(optimizer='adam',
              loss='sparse_categorical_crossentropy',
              metrics=['accuracy'])

# Model summary to see the architecture and parameters
model.summary()

In [None]:
model.compile(optimizer='adam', 
              loss='sparse_categorical_crossentropy', 
              metrics=['accuracy'])

### Train the model

In [None]:
BATCH_SIZE = 32
num_training_images = training_images.shape[0]
num_epochs = 10

history = model.fit(
    training_images, training_labels,
    epochs= num_epochs,
    batch_size=BATCH_SIZE,
    validation_data=(test_images, test_labels)
)

In [None]:
model.save('models/digit_CNN.h5')

In [None]:
num_epochs = 10

In [None]:
# Save the metrics.
metrics = history.history
# Save the loss values.
training_loss_list = metrics['loss']
test_loss_list = metrics['val_loss']
# Plot the training and test loss.
x = np.arange(0, num_epochs, 1)
plt.title('Training and Test Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.plot(x, training_loss_list, label='Training Loss')
plt.plot(x, test_loss_list, label='Test Loss')
plt.legend()
plt.show()

In [None]:
train_accuracy_list = metrics['accuracy']
test_accuracy_list = metrics['val_accuracy']
plt.title('Training and Test Accuracy')
plt.xlabel('Epoch')
plt.ylabel('Accuracy')
plt.plot(x, train_accuracy_list, label='Training Accuracy')
plt.plot(x, test_accuracy_list, label='Test Accuracy')
plt.legend()
plt.show()

### Using the model to classify new images

In [None]:
from tensorflow.keras.models import load_model
model_cnn = load_model('model/digit_CNN.h5')

In [None]:
predicted_labels = classify(test_images, model_cnn)
predicted_labels = np.array(predicted_labels)

In [None]:
from sklearn.metrics import classification_report, accuracy_score
# Calculate metrics
print("Accuracy:", accuracy_score(test_labels, predicted_labels))
print("\nClassification Report:\n")
print(classification_report(test_labels, predicted_labels))

In [None]:
# Generate confusion matrix
plot_conf_matrix(test_labels, predicted_labels)

In [None]:
# Visualize some misclassified images

misclassified_indices = (test_labels != predicted_labels)
misclassified_images = test_images[misclassified_indices]
misclassified_true_labels = test_labels[misclassified_indices]
misclassified_predicted_labels = predicted_labels [misclassified_indices]

if len(misclassified_images) > 0:
    plot_samples_with_labels(misclassified_images, misclassified_true_labels, misclassified_predicted_labels, num_samples = 10, randomize= True)
else:
    print("No misclassified images found in the selected batch.")