# Training a neural network on MNIST with Keras

This simple example demonstrates how to plug TensorFlow Datasets (TFDS) into a Keras model.


In [None]:
import tensorflow as tf
import tensorflow_datasets as tfds

## Step 1: Create your input pipeline

Start by building an efficient input pipeline using advices from:
* The [Performance tips](https://www.tensorflow.org/datasets/performances) guide
* The [Better performance with the `tf.data` API](https://www.tensorflow.org/guide/data_performance#optimize_performance) guide


### Load a dataset

Load the MNIST dataset with the following arguments:

* `shuffle_files=True`: The MNIST data is only stored in a single file, but for larger datasets with multiple files on disk, it's good practice to shuffle them when training.
* `as_supervised=True`: Returns a tuple `(img, label)` instead of a dictionary `{'image': img, 'label': label}`.

In [None]:
(ds_train, ds_test), ds_info = tfds.load(
    'mnist',
    split=['train', 'test'],
    shuffle_files=True,
    as_supervised=True,
    with_info=True,
)

### Build a training pipeline

Apply the following transformations:

* `tf.data.Dataset.map`: TFDS provide images of type `tf.uint8`, while the model expects `tf.float32`. Therefore, you need to normalize images.
* `tf.data.Dataset.cache` As you fit the dataset in memory, cache it before shuffling for a better performance.<br/>
__Note:__ Random transformations should be applied after caching.
* `tf.data.Dataset.shuffle`: For true randomness, set the shuffle buffer to the full dataset size.<br/>
__Note:__ For large datasets that can't fit in memory, use `buffer_size=1000` if your system allows it.
* `tf.data.Dataset.batch`: Batch elements of the dataset after shuffling to get unique batches at each epoch.
* `tf.data.Dataset.prefetch`: It is good practice to end the pipeline by prefetching [for performance](https://www.tensorflow.org/guide/data_performance#prefetching).

In [None]:
def normalize_img(image, label):
  """Normalizes images: `uint8` -> `float32`."""
  return tf.cast(image, tf.float32) / 255., label

ds_train = ds_train.map(
    normalize_img, num_parallel_calls=tf.data.AUTOTUNE)
ds_train = ds_train.cache()
ds_train = ds_train.shuffle(ds_info.splits['train'].num_examples)
ds_train = ds_train.batch(128)
ds_train = ds_train.prefetch(tf.data.AUTOTUNE)

### Build an evaluation pipeline

Your testing pipeline is similar to the training pipeline with small differences:

 * You don't need to call `tf.data.Dataset.shuffle`.
 * Caching is done after batching because batches can be the same between epochs.

In [None]:
ds_test = ds_test.map(
    normalize_img, num_parallel_calls=tf.data.AUTOTUNE)
ds_test = ds_test.batch(128)
ds_test = ds_test.cache()
ds_test = ds_test.prefetch(tf.data.AUTOTUNE)

## Step 2: Create and train the model

Plug the TFDS input pipeline into a simple Keras model, compile the model, and train it.

In [None]:
model = tf.keras.models.Sequential([
  tf.keras.layers.Flatten(input_shape=(28, 28)),
  tf.keras.layers.Dense(128, activation='relu'),
  tf.keras.layers.Dense(10)
])
model.compile(
    optimizer=tf.keras.optimizers.Adam(0.001),
    loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
    metrics=[tf.keras.metrics.SparseCategoricalAccuracy()],
)

history = model.fit(
    ds_train,
    epochs=6,
    validation_data=ds_test,
)

In [None]:
from sklearn.metrics import precision_score, recall_score, f1_score, accuracy_score
import numpy as np

# --- Metrics for Training Data ---
all_true_labels_train = []
all_predicted_labels_train = []

for images, labels in ds_train:
  predictions = model.predict(images, verbose=0)
  predicted_labels = tf.argmax(predictions, axis=1).numpy()
  all_true_labels_train.extend(labels.numpy())
  all_predicted_labels_train.extend(predicted_labels)

train_accuracy = accuracy_score(all_true_labels_train, all_predicted_labels_train)
train_precision = precision_score(all_true_labels_train, all_predicted_labels_train, average='macro')
train_recall = recall_score(all_true_labels_train, all_predicted_labels_train, average='macro')
train_f1 = f1_score(all_true_labels_train, all_predicted_labels_train, average='macro')

print("\n--- Training Metrics ---")
print(f"Training Accuracy: {train_accuracy:.4f}")
print(f"Training Precision (macro): {train_precision:.4f}")
print(f"Training Recall (macro): {train_recall:.4f}")
print(f"Training F1-score (macro): {train_f1:.4f}")

In [None]:
import matplotlib.pyplot as plt

# Plot training and validation loss
plt.figure(figsize=(10, 6))
plt.plot(history.history['loss'], label='Training Loss')
plt.plot(history.history['val_loss'], label='Validation Loss')
plt.title('Training and Validation Loss Over Epochs')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.legend()
plt.grid(True)
plt.show()

In [None]:
print(model.evaluate(ds_test))

In [None]:
import matplotlib.pyplot as plt
import numpy as np

for images, labels in ds_test.take(1):
  predictions = model.predict(images)
  predicted_labels = tf.argmax(predictions, axis=1).numpy()

  # Display a few images with their predicted and true labels
  plt.figure(figsize=(10, 10))
  for i in range(25):
    plt.subplot(5, 5, i + 1)
    plt.imshow(images[i].numpy().squeeze(), cmap='gray')
    plt.title(f"Pred: {predicted_labels[i]}\nTrue: {labels[i].numpy()}")
    plt.axis('off')
  plt.tight_layout()
  plt.show()

In [None]:
import numpy as np
from sklearn.metrics import confusion_matrix
import matplotlib.pyplot as plt
import seaborn as sns

all_true_labels = []
all_predicted_labels = []

for images, labels in ds_test:
  predictions = model.predict(images)
  predicted_labels = tf.argmax(predictions, axis=1).numpy()
  all_true_labels.extend(labels.numpy())
  all_predicted_labels.extend(predicted_labels)

# Calculate the confusion matrix
cm = confusion_matrix(all_true_labels, all_predicted_labels)

print("Confusion Matrix:")
print(cm)

# Visualize the confusion matrix (optional but good practice)
plt.figure(figsize=(10, 8))
sns.heatmap(cm, annot=True, fmt='d', cmap='Blues')
plt.xlabel('Predicted Label')
plt.ylabel('True Label')
plt.title('Confusion Matrix')
plt.show()

In [None]:
# --- Metrics for Test Data ---
all_true_labels_test = []
all_predicted_labels_test = []

for images, labels in ds_test:
  predictions = model.predict(images, verbose=0)
  predicted_labels = tf.argmax(predictions, axis=1).numpy()
  all_true_labels_test.extend(labels.numpy())
  all_predicted_labels_test.extend(predicted_labels)

test_accuracy = accuracy_score(all_true_labels_test, all_predicted_labels_test)
test_precision = precision_score(all_true_labels_test, all_predicted_labels_test, average='macro')
test_recall = recall_score(all_true_labels_test, all_predicted_labels_test, average='macro')
test_f1 = f1_score(all_true_labels_test, all_predicted_labels_test, average='macro')

print("\n--- Test Metrics ---")
print(f"Test Accuracy: {test_accuracy:.4f}")
print(f"Test Precision (macro): {test_precision:.4f}")
print(f"Test Recall (macro): {test_recall:.4f}")
print(f"Test F1-score (macro): {test_f1:.4f}")