In [1]:
# import matplotlib.pyplot as plt
import numpy as np
import tensorflow as tf
import tensorflow.keras as keras
import tensorflow_datasets as tfds

# TensorFlow Datasets
TFDS 下载的数据通常以`tf.data.Dataset`或`np.array`的形式加载至程序中；注意不要混淆TFDS 和`tf.data` API，前者是包装了后者的的高级 API；强烈推荐在熟悉`tf.data`使用后再阅读此文档；

## Installation
由于 tensorflow-datasets 仅支持最新版 tensorflow，进而应先更新 tensorflow 再进行安装，否则可能会出现 TF 版本不兼容或 TF 与 TFDS 依赖的第三方库版本不兼容的问题；安装稳定版的 TFDS 操作流程如下：
```bash
pip install --upgrade tensorflow -i https://pypi.douban.com/simple
pip install tensorflow-datasets -i https://pypi.douban.com/simple
```
或通过`pip install tfds-nightly`安装最新的 TFDS，该版本每天更新一次；


## Available datasets
所有数据集构造器都是`tfds.core.DatasetBuilder`的子类，调用`tfds.list_builders()`可以获得可使用的数据集列表，也可以查看[官方数据集类别](https://www.tensorflow.org/datasets/catalog/overview)；


## Load a dataset
加载数据集最简单的方法是使用`tfds.load()`函数，该函数会将下载的数据以`TFRecord`格式保存，并将其以`Dataset`格式加载至程序中；

```python
ds = tfds.load('mnist', split='train', shuffle_files=True)
assert isinstance(ds, tf.data.Dataset)
```

或使用`tfds.builder`以及 [TFDS 的命令行接口 (CLI)](https://www.tensorflow.org/datasets/cli)：

```python
builder = tfds.builder('mnist')
builder.download_and_prepare()
ds = builder.as_dataset(split='train', shuffle_files=True)
```


## Iterate over a dataset
As dict
By default, the tf.data.Dataset object contains a dict of tf.Tensors:

ds = tfds.load('mnist', split='train')
ds = ds.take(1)  # Only take a single example

for example in ds:  # example is `{'image': tf.Tensor, 'label': tf.Tensor}`
  print(list(example.keys()))
  image = example["image"]
  label = example["label"]
  print(image.shape, label)
['image', 'label']
(28, 28, 1) tf.Tensor(4, shape=(), dtype=int64)

To find out the dict key names and structure, look at the dataset documentation in our catalog. For example: mnist documentation.

As tuple (as_supervised=True)
By using as_supervised=True, you can get a tuple (features, label) instead for supervised datasets.

ds = tfds.load('mnist', split='train', as_supervised=True)
ds = ds.take(1)

for image, label in ds:  # example is (image, label)
  print(image.shape, label)
(28, 28, 1) tf.Tensor(4, shape=(), dtype=int64)

As numpy (tfds.as_numpy)
Uses tfds.as_numpy to convert:

tf.Tensor -> np.array
tf.data.Dataset -> Iterator[Tree[np.array]] (Tree can be arbitrary nested Dict, Tuple)
ds = tfds.load('mnist', split='train', as_supervised=True)
ds = ds.take(1)

for image, label in tfds.as_numpy(ds):
  print(type(image), type(label), label)
<class 'numpy.ndarray'> <class 'numpy.int64'> 4

As batched tf.Tensor (batch_size=-1)
By using batch_size=-1, you can load the full dataset in a single batch.

This can be combined with as_supervised=True and tfds.as_numpy to get the the data as (np.array, np.array):

image, label = tfds.as_numpy(tfds.load(
    'mnist',
    split='test',
    batch_size=-1,
    as_supervised=True,
))

print(type(image), image.shape)
<class 'numpy.ndarray'> (10000, 28, 28, 1)

Be careful that your dataset can fit in memory, and that all examples have the same shape.

Benchmark your datasets
Benchmarking a dataset is a simple tfds.benchmark call on any iterable (e.g. tf.data.Dataset, tfds.as_numpy,...).

ds = tfds.load('mnist', split='train')
ds = ds.batch(32).prefetch(1)

tfds.benchmark(ds, batch_size=32)
tfds.benchmark(ds, batch_size=32)  # Second epoch much faster due to auto-caching

************ Summary ************

Examples/sec (First included) 47889.92 ex/sec (total: 60000 ex, 1.25 sec)
Examples/sec (First only) 110.24 ex/sec (total: 32 ex, 0.29 sec)
Examples/sec (First excluded) 62298.08 ex/sec (total: 59968 ex, 0.96 sec)

************ Summary ************

Examples/sec (First included) 290380.50 ex/sec (total: 60000 ex, 0.21 sec)
Examples/sec (First only) 2506.57 ex/sec (total: 32 ex, 0.01 sec)
Examples/sec (First excluded) 309338.21 ex/sec (total: 59968 ex, 0.19 sec)


Do not forget to normalize the results per batch size with the batch_size= kwarg.
In the summary, the first warmup batch is separated from the other ones to capture tf.data.Dataset extra setup time (e.g. buffers initialization,...).
Notice how the second iteration is much faster due to TFDS auto-caching.
tfds.benchmark returns a tfds.core.BenchmarkResult which can be inspected for further analysis.
Build end-to-end pipeline
To go further, you can look:

Our end-to-end Keras example to see a full training pipeline (with batching, shuffling,...).
Our performance guide to improve the speed of your pipelines (tip: use tfds.benchmark(ds) to benchmark your datasets).
Visualization
tfds.as_dataframe
tf.data.Dataset objects can be converted to pandas.DataFrame with tfds.as_dataframe to be visualized on Colab.

Add the tfds.core.DatasetInfo as second argument of tfds.as_dataframe to visualize images, audio, texts, videos,...
Use ds.take(x) to only display the first x examples. pandas.DataFrame will load the full dataset in-memory, and can be very expensive to display.
ds, info = tfds.load('mnist', split='train', with_info=True)

tfds.as_dataframe(ds.take(4), info)

tfds.show_examples
tfds.show_examples returns a matplotlib.figure.Figure (only image datasets supported now):

ds, info = tfds.load('mnist', split='train', with_info=True)

fig = tfds.show_examples(ds, info)
png

Access the dataset metadata
All builders include a tfds.core.DatasetInfo object containing the dataset metadata.

It can be accessed through:

The tfds.load API:
ds, info = tfds.load('mnist', with_info=True)
The tfds.core.DatasetBuilder API:
builder = tfds.builder('mnist')
info = builder.info
The dataset info contains additional informations about the dataset (version, citation, homepage, description,...).

print(info)
tfds.core.DatasetInfo(
    name='mnist',
    full_name='mnist/3.0.1',
    description="""
    The MNIST database of handwritten digits.
    """,
    homepage='http://yann.lecun.com/exdb/mnist/',
    data_path='gs://tensorflow-datasets/datasets/mnist/3.0.1',
    download_size=11.06 MiB,
    dataset_size=21.00 MiB,
    features=FeaturesDict({
        'image': Image(shape=(28, 28, 1), dtype=tf.uint8),
        'label': ClassLabel(shape=(), dtype=tf.int64, num_classes=10),
    }),
    supervised_keys=('image', 'label'),
    splits={
        'test': <SplitInfo num_examples=10000, num_shards=1>,
        'train': <SplitInfo num_examples=60000, num_shards=1>,
    },
    citation="""@article{lecun2010mnist,
      title={MNIST handwritten digit database},
      author={LeCun, Yann and Cortes, Corinna and Burges, CJ},
      journal={ATT Labs [Online]. Available: http://yann.lecun.com/exdb/mnist},
      volume={2},
      year={2010}
    }""",
)

Features metadata (label names, image shape,...)
Access the tfds.features.FeatureDict:

info.features
FeaturesDict({
    'image': Image(shape=(28, 28, 1), dtype=tf.uint8),
    'label': ClassLabel(shape=(), dtype=tf.int64, num_classes=10),
})
Number of classes, label names:

print(info.features["label"].num_classes)
print(info.features["label"].names)
print(info.features["label"].int2str(7))  # Human readable version (8 -> 'cat')
print(info.features["label"].str2int('7'))
10
['0', '1', '2', '3', '4', '5', '6', '7', '8', '9']
7
7

Shapes, dtypes:

print(info.features.shape)
print(info.features.dtype)
print(info.features['image'].shape)
print(info.features['image'].dtype)
{'image': (28, 28, 1), 'label': ()}
{'image': tf.uint8, 'label': tf.int64}
(28, 28, 1)
<dtype: 'uint8'>

Split metadata (e.g. split names, number of examples,...)
Access the tfds.core.SplitDict:

print(info.splits)
{'test': <SplitInfo num_examples=10000, num_shards=1>, 'train': <SplitInfo num_examples=60000, num_shards=1>}

Available splits:

print(list(info.splits.keys()))
['test', 'train']

Get info on individual split:

print(info.splits['train'].num_examples)
print(info.splits['train'].filenames)
print(info.splits['train'].num_shards)
60000
['mnist-train.tfrecord-00000-of-00001']
1

It also works with the subsplit API:

print(info.splits['train[15%:75%]'].num_examples)
print(info.splits['train[15%:75%]'].file_instructions)
36000
[FileInstruction(filename='mnist-train.tfrecord-00000-of-00001', skip=9000, take=36000, num_examples=36000)]

Troubleshooting
Manual download (if download fails)
If download fails for some reason (e.g. offline,...). You can always manually download the data yourself and place it in the manual_dir (defaults to ~/tensorflow_datasets/download/manual/.

To find out which urls to download, look into:

For new datasets (implemented as folder): tensorflow_datasets/<type>/<dataset_name>/checksums.tsv. For example: tensorflow_datasets/text/bool_q/checksums.tsv.

You can find the dataset source location in our catalog.

For old datasets: tensorflow_datasets/url_checksums/<dataset_name>.txt

Fixing NonMatchingChecksumError
TFDS ensure determinism by validating the checksums of downloaded urls. If NonMatchingChecksumError is raised, might indicate:

The website may be down (e.g. 503 status code). Please check the url.
For Google Drive URLs, try again later as Drive sometimes rejects downloads when too many people access the same URL. See bug
The original datasets files may have been updated. In this case the TFDS dataset builder should be updated. Please open a new Github issue or PR:
Register the new checksums with tfds build --register_checksums
Eventually update the dataset generation code.
Update the dataset VERSION
Update the dataset RELEASE_NOTES: What caused the checksums to change ? Did some examples changed ?
Make sure the dataset can still be built.
Send us a PR
Note: You can also inspect the downloaded file in ~/tensorflow_datasets/download/.
Citation
If you're using tensorflow-datasets for a paper, please include the following citation, in addition to any citation specific to the used datasets (which can be found in the dataset catalog).

@misc{TFDS,
  title = { {TensorFlow Datasets}, A collection of ready-to-use datasets},
  howpublished = {\url{https://www.tensorflow.org/datasets} },
}
