# Training a neural network on MNIST with Keras

This simple example demonstrates how to plug TensorFlow Datasets (TFDS) into a Keras model.


In [1]:
import tensorflow as tf
import numpy as np
import tensorflow_datasets as tfds




  from .autonotebook import tqdm as notebook_tqdm


In [2]:
tf.config.list_physical_devices('GPU')

[]

## Step 1: Create your input pipeline

Start by building an efficient input pipeline using advices from:
* The [Performance tips](https://www.tensorflow.org/datasets/performances) guide
* The [Better performance with the `tf.data` API](https://www.tensorflow.org/guide/data_performance#optimize_performance) guide


### Load a dataset

Load the MNIST dataset with the following arguments:

* `shuffle_files=True`: The MNIST data is only stored in a single file, but for larger datasets with multiple files on disk, it's good practice to shuffle them when training.
* `as_supervised=True`: Returns a tuple `(img, label)` instead of a dictionary `{'image': img, 'label': label}`.

In [3]:
(ds_train, ds_test), ds_info = tfds.load(
    'mnist',
    split=['train', 'test'],
    shuffle_files=True,
    as_supervised=True,
    with_info=True,
)

### Build a training pipeline

Apply the following transformations:

* `tf.data.Dataset.map`: TFDS provide images of type `tf.uint8`, while the model expects `tf.float32`. Therefore, you need to normalize images.
* `tf.data.Dataset.cache` As you fit the dataset in memory, cache it before shuffling for a better performance.<br/>
__Note:__ Random transformations should be applied after caching.
* `tf.data.Dataset.shuffle`: For true randomness, set the shuffle buffer to the full dataset size.<br/>
__Note:__ For large datasets that can't fit in memory, use `buffer_size=1000` if your system allows it.
* `tf.data.Dataset.batch`: Batch elements of the dataset after shuffling to get unique batches at each epoch.
* `tf.data.Dataset.prefetch`: It is good practice to end the pipeline by prefetching [for performance](https://www.tensorflow.org/guide/data_performance#prefetching).

In [4]:
def normalize_img(image, label):
  """Normalizes images: `uint8` -> `float32`."""
  return tf.cast(image, tf.float32) / 255., label

ds_train = ds_train.map(
    normalize_img, num_parallel_calls=tf.data.AUTOTUNE)
ds_train = ds_train.cache()
ds_train = ds_train.shuffle(ds_info.splits['train'].num_examples)
ds_train = ds_train.batch(128)
ds_train = ds_train.prefetch(tf.data.AUTOTUNE)

### Build an evaluation pipeline

Your testing pipeline is similar to the training pipeline with small differences:

 * You don't need to call `tf.data.Dataset.shuffle`.
 * Caching is done after batching because batches can be the same between epochs.

In [5]:
ds_test = ds_test.map(
    normalize_img, num_parallel_calls=tf.data.AUTOTUNE)
ds_test = ds_test.batch(128)
ds_test = ds_test.cache()
ds_test = ds_test.prefetch(tf.data.AUTOTUNE)

## Step 2: Create and train the model

Plug the TFDS input pipeline into a simple Keras model, compile the model, and train it.

In [6]:
model = tf.keras.models.Sequential([
  tf.keras.layers.Input(shape=(28, 28)),
  tf.keras.layers.Flatten(),
  tf.keras.layers.Dense(128, activation='relu'),
  tf.keras.layers.Dense(10)
])
model.compile(
    optimizer=tf.keras.optimizers.Adam(0.001),
    loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
    metrics=[tf.keras.metrics.SparseCategoricalAccuracy()],
)

model.fit(
    ds_train,
    epochs=6,
    validation_data=ds_test,
)







Epoch 1/6






Epoch 2/6
Epoch 3/6
Epoch 4/6
Epoch 5/6
Epoch 6/6


<keras.src.callbacks.History at 0x25c3c76e950>

Epoch 2/6


  1/469 [..............................] - ETA: 35s - loss: 0.1062 - sparse_categorical_accuracy: 0.9766

 23/469 [>.............................] - ETA: 1s - loss: 0.1855 - sparse_categorical_accuracy: 0.9467 

 45/469 [=>............................] - ETA: 0s - loss: 0.1784 - sparse_categorical_accuracy: 0.9497

 67/469 [===>..........................] - ETA: 0s - loss: 0.1805 - sparse_categorical_accuracy: 0.9488

 89/469 [====>.........................] - ETA: 0s - loss: 0.1765 - sparse_categorical_accuracy: 0.9506





































Epoch 3/6


  1/469 [..............................] - ETA: 32s - loss: 0.1546 - sparse_categorical_accuracy: 0.9609

 23/469 [>.............................] - ETA: 1s - loss: 0.1243 - sparse_categorical_accuracy: 0.9640 

 45/469 [=>............................] - ETA: 0s - loss: 0.1216 - sparse_categorical_accuracy: 0.9648

 67/469 [===>..........................] - ETA: 0s - loss: 0.1272 - sparse_categorical_accuracy: 0.9628

 90/469 [====>.........................] - ETA: 0s - loss: 0.1230 - sparse_categorical_accuracy: 0.9653



































Epoch 4/6


  1/469 [..............................] - ETA: 32s - loss: 0.0986 - sparse_categorical_accuracy: 0.9688

 24/469 [>.............................] - ETA: 0s - loss: 0.1070 - sparse_categorical_accuracy: 0.9684 

 47/469 [==>...........................] - ETA: 0s - loss: 0.1030 - sparse_categorical_accuracy: 0.9714

 70/469 [===>..........................] - ETA: 0s - loss: 0.0953 - sparse_categorical_accuracy: 0.9733

 93/469 [====>.........................] - ETA: 0s - loss: 0.0948 - sparse_categorical_accuracy: 0.9733



































Epoch 5/6


  1/469 [..............................] - ETA: 32s - loss: 0.0625 - sparse_categorical_accuracy: 0.9922

 23/469 [>.............................] - ETA: 1s - loss: 0.0698 - sparse_categorical_accuracy: 0.9820 

 45/469 [=>............................] - ETA: 0s - loss: 0.0682 - sparse_categorical_accuracy: 0.9811

 68/469 [===>..........................] - ETA: 0s - loss: 0.0707 - sparse_categorical_accuracy: 0.9805

 90/469 [====>.........................] - ETA: 0s - loss: 0.0729 - sparse_categorical_accuracy: 0.9793



































Epoch 6/6


  1/469 [..............................] - ETA: 32s - loss: 0.0219 - sparse_categorical_accuracy: 1.0000

 24/469 [>.............................] - ETA: 1s - loss: 0.0620 - sparse_categorical_accuracy: 0.9821 

 47/469 [==>...........................] - ETA: 0s - loss: 0.0629 - sparse_categorical_accuracy: 0.9830

 70/469 [===>..........................] - ETA: 0s - loss: 0.0640 - sparse_categorical_accuracy: 0.9830

 93/469 [====>.........................] - ETA: 0s - loss: 0.0619 - sparse_categorical_accuracy: 0.9830



































<keras.src.callbacks.History at 0x7fc41e0cb880>

In [7]:
model.evaluate(ds_test)



[0.08246538788080215, 0.974399983882904]

In [8]:
model.save("model.keras")

In [9]:

from tensorflow.keras.models import load_model, save_model
save_model(model=model, filepath='tf_model/', include_optimizer=False)

INFO:tensorflow:Assets written to: tf_model/assets


INFO:tensorflow:Assets written to: tf_model/assets


In [10]:
extact = list(ds_test)

In [11]:
X, y = extact[0]
X = X.numpy()
y = y.numpy()
X, y

(array([[[[0.],
          [0.],
          [0.],
          ...,
          [0.],
          [0.],
          [0.]],
 
         [[0.],
          [0.],
          [0.],
          ...,
          [0.],
          [0.],
          [0.]],
 
         [[0.],
          [0.],
          [0.],
          ...,
          [0.],
          [0.],
          [0.]],
 
         ...,
 
         [[0.],
          [0.],
          [0.],
          ...,
          [0.],
          [0.],
          [0.]],
 
         [[0.],
          [0.],
          [0.],
          ...,
          [0.],
          [0.],
          [0.]],
 
         [[0.],
          [0.],
          [0.],
          ...,
          [0.],
          [0.],
          [0.]]],
 
 
        [[[0.],
          [0.],
          [0.],
          ...,
          [0.],
          [0.],
          [0.]],
 
         [[0.],
          [0.],
          [0.],
          ...,
          [0.],
          [0.],
          [0.]],
 
         [[0.],
          [0.],
          [0.],
          ...,
      

In [12]:
np.savez("test_data.npz", X = X, y = y)

In [13]:
npzfile = np.load("test_data.npz")

In [14]:
X.shape

(128, 28, 28, 1)

In [38]:
%timeit model.predict(X)

65.9 ms ± 3.54 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)
