In [None]:
import tensorflow as tf
import tensorflow_datasets as tfds

In [None]:
(train, test), info = tfds.load(name = "mnist", split = ['train', 'test'], as_supervised = True, with_info = True)
assert isinstance(test, tf.data.Dataset)

In [None]:
print(info)

In [None]:
def normalize_img(image, label):
  #Normalizes images: `uint8` -> `float32`.
  return tf.cast(image, tf.float32) / 255., label

In [None]:
train = train.map(normalize_img, num_parallel_calls=tf.data.AUTOTUNE)
train = train.cache()
train = train.shuffle(info.splits['train'].num_examples)
train = train.batch(128)
train = train.prefetch(tf.data.AUTOTUNE)

test = test.map(
    normalize_img, num_parallel_calls=tf.data.AUTOTUNE)
test = test.batch(128)
test = test.cache()
test = test.prefetch(tf.data.AUTOTUNE)


In [None]:
model = tf.keras.models.Sequential([
  tf.keras.layers.Flatten(input_shape=(28, 28)),
  tf.keras.layers.Dense(128, activation='relu'),
  tf.keras.layers.Dense(10)
])

In [None]:
model.compile(
    optimizer=tf.keras.optimizers.Adam(0.001),
    loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
    metrics=[tf.keras.metrics.SparseCategoricalAccuracy()],
)

In [None]:
#train model
history = model.fit(
    train,
    epochs=6,
    validation_data=test,
)

Shows the improvement in accuracy through the epochs.

In [None]:
import pandas as pd

In [None]:
history_df = pd.DataFrame.from_dict(history.history)
history_df[["sparse_categorical_accuracy","val_sparse_categorical_accuracy"]].plot()

In [None]:
import matplotlib.pyplot as plt
import numpy as np

In [None]:
samples = tfds.load(name = "mnist", split = 'test',as_supervised = True)
assert isinstance(samples, tf.data.Dataset)

Searches through the test data for images where the Prediction does not match the Label, then prints the prediction, the label, and the image. 

In [None]:
#samples = samples.take(10)
for image, label in samples:
    image = normalize_img(image, label)
    image = image[0]
    image = np.array(image)
    image = image.reshape(28,28)
    prediction = (np.argmax(model.predict(image[np.newaxis])))
    if label != prediction:
        print("Prediction: ", prediction, "Label: ", label)
        fig = plt.figure
        plt.imshow(image, cmap='gray_r')
        plt.show()
