TensorFlow makes available the tf.data API to create efficient input pipelines for
machine learning models. Its core class is tf.data.Dataset.


A Dataset object is an iterator: you can use it in a for loop. It will typically return
batches of input data and labels. You can pass a Dataset object directly to the fit()
method of a Keras model.


The Dataset class handles many key features that would otherwise be cumbersome
to implement yourself—in particular, asynchronous data prefetching (preprocessing
the next batch of data while the previous one is being handled by the model, which
keeps execution flowing without interruptions).


The Dataset class also exposes a functional-style API for modifying datasets

In [1]:
import numpy as np
import tensorflow as tf
random_numbers = np.random.normal(size=(1000, 16))

In [3]:
random_numbers.shape

(1000, 16)

The from_tensor_slices() class method can be
used to create a Dataset from a NumPy array,
or a tuple or dict of NumPy arrays.

In [4]:
dataset = tf.data.Dataset.from_tensor_slices(random_numbers)

In [5]:
dataset

<_TensorSliceDataset element_spec=TensorSpec(shape=(16,), dtype=tf.float64, name=None)>

At first, our dataset just yields single samples:

In [6]:
for i, element in enumerate(dataset):
  print(element.shape)
  if i >= 2:
    break

(16,)
(16,)
(16,)


We can use the .batch() method to batch the data:

In [8]:
batched_dataset = dataset.batch(32)
for i, element in enumerate(batched_dataset):
  print(element.shape)
  if i >= 2:
    break

(32, 16)
(32, 16)
(32, 16)


More broadly, we have access to a range of useful dataset methods, such as

 .shuffle(buffer_size)—Shuffles elements within a buffer

 .prefetch(buffer_size)—Prefetches a buffer of elements in GPU memory
to achieve better device utilization.

 .map(callable)—Applies an arbitrary transformation to each element of the
dataset (the function callable, which expects to take as input a single element yielded by the dataset).

The .map() method, in particular, is one that you will use often. Here’s an example.

In [9]:
reshaped_dataset = dataset.map(lambda x: tf.reshape(x, (4, 4)))
for i, element in enumerate(reshaped_dataset):
  print(element.shape)
  if i>=2:
    break

(4, 4)
(4, 4)
(4, 4)
