# Training a neural network on MNIST with Keras

This simple example demonstrates how to plug TensorFlow Datasets (TFDS) into a Keras model.


Copyright 2020 The TensorFlow Datasets Authors, Licensed under the Apache License, Version 2.0

<table class="tfo-notebook-buttons" align="left">
  <td>
    <a target="_blank" href="https://www.tensorflow.org/datasets/keras_example"><img src="https://www.tensorflow.org/images/tf_logo_32px.png" />View on TensorFlow.org</a>
  </td>
  <td>
    <a target="_blank" href="https://colab.research.google.com/github/tensorflow/datasets/blob/master/docs/keras_example.ipynb"><img src="https://www.tensorflow.org/images/colab_logo_32px.png" />Run in Google Colab</a>
  </td>
  <td>
    <a target="_blank" href="https://github.com/tensorflow/datasets/blob/master/docs/keras_example.ipynb"><img src="https://www.tensorflow.org/images/GitHub-Mark-32px.png" />View source on GitHub</a>
  </td>
  <td>
    <a href="https://storage.googleapis.com/tensorflow_docs/datasets/docs/keras_example.ipynb"><img src="https://www.tensorflow.org/images/download_logo_32px.png" />Download notebook</a>
  </td>
</table>

In [1]:
import tensorflow as tf
import tensorflow_datasets as tfds

2022-06-18 10:58:35.310804: W tensorflow/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libcudart.so.11.0'; dlerror: libcudart.so.11.0: cannot open shared object file: No such file or directory
2022-06-18 10:58:35.310939: I tensorflow/stream_executor/cuda/cudart_stub.cc:29] Ignore above cudart dlerror if you do not have a GPU set up on your machine.
  from .autonotebook import tqdm as notebook_tqdm


## Step 1: Create your input pipeline

Start by building an efficient input pipeline using advices from:
* The [Performance tips](https://www.tensorflow.org/datasets/performances) guide
* The [Better performance with the `tf.data` API](https://www.tensorflow.org/guide/data_performance#optimize_performance) guide


### Load a dataset

Load the MNIST dataset with the following arguments:

* `shuffle_files=True`: The MNIST data is only stored in a single file, but for larger datasets with multiple files on disk, it's good practice to shuffle them when training.
* `as_supervised=True`: Returns a tuple `(img, label)` instead of a dictionary `{'image': img, 'label': label}`.

In [3]:
(ds_train, ds_test), ds_info = tfds.load(
    'emnist',
    split=['train', 'test'],
    shuffle_files=True,
    as_supervised=True,
    with_info=True,
)

### Build a training pipeline

Apply the following transformations:

* `tf.data.Dataset.map`: TFDS provide images of type `tf.uint8`, while the model expects `tf.float32`. Therefore, you need to normalize images.
* `tf.data.Dataset.cache` As you fit the dataset in memory, cache it before shuffling for a better performance.<br/>
__Note:__ Random transformations should be applied after caching.
* `tf.data.Dataset.shuffle`: For true randomness, set the shuffle buffer to the full dataset size.<br/>
__Note:__ For large datasets that can't fit in memory, use `buffer_size=1000` if your system allows it.
* `tf.data.Dataset.batch`: Batch elements of the dataset after shuffling to get unique batches at each epoch.
* `tf.data.Dataset.prefetch`: It is good practice to end the pipeline by prefetching [for performance](https://www.tensorflow.org/guide/data_performance#prefetching).

In [4]:
def normalize_img(image, label):
  """Normalizes images: `uint8` -> `float32`."""
  return tf.cast(image, tf.float32) / 255., label

ds_train = ds_train.map(
    normalize_img, num_parallel_calls=tf.data.AUTOTUNE)
ds_train = ds_train.cache()
ds_train = ds_train.shuffle(ds_info.splits['train'].num_examples)
ds_train = ds_train.batch(128)
ds_train = ds_train.prefetch(tf.data.AUTOTUNE)

### Build an evaluation pipeline

Your testing pipeline is similar to the training pipeline with small differences:

 * You don't need to call `tf.data.Dataset.shuffle`.
 * Caching is done after batching because batches can be the same between epochs.

In [5]:
ds_test = ds_test.map(
    normalize_img, num_parallel_calls=tf.data.AUTOTUNE)
ds_test = ds_test.batch(128)
ds_test = ds_test.cache()
ds_test = ds_test.prefetch(tf.data.AUTOTUNE)

## Step 2: Create and train the model

Plug the TFDS input pipeline into a simple Keras model, compile the model, and train it.

In [11]:
model = tf.keras.models.Sequential([
  tf.keras.layers.Flatten(input_shape=(28, 28)),
  tf.keras.layers.Dense(128, activation='relu'),
  tf.keras.layers.Dense(62)
])
model.compile(
    optimizer=tf.keras.optimizers.Adam(0.001),
    loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
    metrics=[tf.keras.metrics.SparseCategoricalAccuracy()],
)

model.fit(
    ds_train,
    epochs=10,
    validation_data=ds_test,
    callbacks=[tf.keras.callbacks.ModelCheckpoint(filepath='tfKerasChars10Epochs', save_best_only=True, monitor="val_loss")]
)

Epoch 1/10


2022-06-18 17:00:46.270925: I tensorflow/core/kernels/data/shuffle_dataset_op.cc:390] Filling up shuffle buffer (this may take a while): 1 of 697932
2022-06-18 17:00:46.271505: I tensorflow/core/kernels/data/shuffle_dataset_op.cc:390] Filling up shuffle buffer (this may take a while): 2 of 697932
2022-06-18 17:00:55.956864: I tensorflow/core/kernels/data/shuffle_dataset_op.cc:390] Filling up shuffle buffer (this may take a while): 320917 of 697932
2022-06-18 17:01:05.956928: I tensorflow/core/kernels/data/shuffle_dataset_op.cc:390] Filling up shuffle buffer (this may take a while): 665094 of 697932


   7/5453 [..............................] - ETA: 1:49 - loss: 3.8711 - sparse_categorical_accuracy: 0.0926   

2022-06-18 17:01:06.898182: I tensorflow/core/kernels/data/shuffle_dataset_op.cc:415] Shuffle buffer filled.




INFO:tensorflow:Assets written to: tfKerasChars10Epochs/assets


Epoch 2/10


INFO:tensorflow:Assets written to: tfKerasChars10Epochs/assets


Epoch 3/10


INFO:tensorflow:Assets written to: tfKerasChars10Epochs/assets


Epoch 4/10


INFO:tensorflow:Assets written to: tfKerasChars10Epochs/assets


Epoch 5/10


INFO:tensorflow:Assets written to: tfKerasChars10Epochs/assets


Epoch 6/10


INFO:tensorflow:Assets written to: tfKerasChars10Epochs/assets


Epoch 7/10


INFO:tensorflow:Assets written to: tfKerasChars10Epochs/assets


Epoch 8/10
Epoch 9/10
Epoch 10/10


INFO:tensorflow:Assets written to: tfKerasChars10Epochs/assets




<keras.callbacks.History at 0x7fa933630580>

In [12]:
converter = tf.lite.TFLiteConverter.from_saved_model('tfKerasChars10Epochs')
tflite_model = converter.convert()

with open("tfKerasChars10Epochs.tflite", "wb") as fp:
    fp.write(tflite_model)

2022-06-18 17:12:33.568974: W tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.cc:357] Ignored output_format.
2022-06-18 17:12:33.569131: W tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.cc:360] Ignored drop_control_dependency.
2022-06-18 17:12:33.571514: I tensorflow/cc/saved_model/reader.cc:43] Reading SavedModel from: tfKerasChars10Epochs
2022-06-18 17:12:33.584278: I tensorflow/cc/saved_model/reader.cc:78] Reading meta graph with tags { serve }
2022-06-18 17:12:33.584377: I tensorflow/cc/saved_model/reader.cc:119] Reading SavedModel debug info (if present) from: tfKerasChars10Epochs
2022-06-18 17:12:33.592236: I tensorflow/cc/saved_model/loader.cc:228] Restoring SavedModel bundle.
2022-06-18 17:12:33.689579: I tensorflow/cc/saved_model/loader.cc:212] Running initialization op on SavedModel bundle at path: tfKerasChars10Epochs
2022-06-18 17:12:33.717265: I tensorflow/cc/saved_model/loader.cc:301] SavedModel load for tags { serve }; Status: success: OK. 