A collection of example how to use the [Dataset API](https://www.tensorflow.org/api_docs/python/tf/data/Dataset).

In [3]:
import numpy as np
import tensorflow as tf

Create a Dataset object from a numpy array using `from_tensor_slices()`. Call `make_one_shot_iterator()` on the dataset to create an iterator. This iterator can only be used once.

In [26]:
ds = tf.data.Dataset.from_tensor_slices(np.arange(0, 5))
iterator = ds.make_one_shot_iterator()
next_op = iterator.get_next()
with tf.Session() as sess:
    try:
        while True:
            val = sess.run(next_op)
            print(val)
    except tf.errors.OutOfRangeError as e:
        print('End of sequence')

0
1
2
3
4
End of sequence


Call `make_initializable_iterator()` on the dataset to create an iterator that can be used multiple times.

In [37]:
ds = tf.data.Dataset.from_tensor_slices(np.arange(0, 5))
iterator = ds.make_initializable_iterator()
next_op = iterator.get_next()
with tf.Session() as sess:
    for i in range(12):
        if i%5==0:
            sess.run(iterator.initializer)
        val = sess.run(next_op)
        print('%d: %d' % (i, val))


0: 0
1: 1
2: 2
3: 3
4: 4
5: 0
6: 1
7: 2
8: 3
9: 4
10: 0
11: 1


A dataset can be repeated:

In [38]:
ds = tf.data.Dataset.from_tensor_slices(np.arange(0, 5)).repeat(2)
iterator = ds.make_one_shot_iterator()
next_op = iterator.get_next()
with tf.Session() as sess:
    try:
        while True:
            val = sess.run(next_op)
            print(val)
    except tf.errors.OutOfRangeError as e:
        print('End of sequence')

0
1
2
3
4
0
1
2
3
4
End of sequence


A Dataset can batch samples:

In [39]:
ds = tf.data.Dataset.from_tensor_slices(np.arange(0, 5)).batch(2)
iterator = ds.make_one_shot_iterator()
next_op = iterator.get_next()
with tf.Session() as sess:
    try:
        while True:
            val = sess.run(next_op)
            print(val)
    except tf.errors.OutOfRangeError as e:
        print('End of sequence')

[0 1]
[2 3]
[4]
End of sequence


Datasets can be zipped together. This is useful to create pairs of training samples and labels.

In [36]:
X = tf.data.Dataset.from_tensor_slices(np.array([[0,0], [1,1], [2,2], [3,3], [4,4]]))
y = tf.data.Dataset.from_tensor_slices(np.array([0,1,1,0,0]))
ds = tf.data.Dataset.zip((X, y)) # try batching: .batch(2)
iterator = ds.make_one_shot_iterator()
next_op = iterator.get_next()
with tf.Session() as sess:
    try:
        while True:
            val = sess.run(next_op)
            print(val)
    except tf.errors.OutOfRangeError as e:
        print('End of sequence')

(array([0, 0]), 0)
(array([1, 1]), 1)
(array([2, 2]), 1)
(array([3, 3]), 0)
(array([4, 4]), 0)
End of sequence


Datasets can be split into a number of partitions by calling `shard()`. This is useful for distributed training or k-fold validation.

In [56]:
ds = tf.data.Dataset.from_tensor_slices(np.arange(0, 20))
num_shards = 4
iterators = []
for shard_idx in range(num_shards):
    shard = ds.shard(num_shards, shard_idx)
    iterator = shard.make_one_shot_iterator()
    iterators.append(iterator)

with tf.Session() as sess:
    for i, iterator in enumerate(iterators):
        print('Shard %d' % i)
        next_op = iterator.get_next()
        try:
            while True:
                val = sess.run(next_op)
                print(val)
        except tf.errors.OutOfRangeError:
            pass

Shard 0
0
4
8
12
16
Shard 1
1
5
9
13
17
Shard 2
2
6
10
14
18
Shard 3
3
7
11
15
19


The alternative to partition a Dataset is by using `skip()` and `take()`.

In [55]:
ds = tf.data.Dataset.from_tensor_slices(np.arange(0, 20))
num_partitions = 4
iterators = []
for i in range(num_partitions):
    shard = ds.skip(i*5).take(5)
    iterator = shard.make_one_shot_iterator()
    iterators.append(iterator)

with tf.Session() as sess:
    for i, iterator in enumerate(iterators):
        print('Partition %d' % i)
        next_op = iterator.get_next()
        try:
            while True:
                val = sess.run(next_op)
                print(val)
        except tf.errors.OutOfRangeError:
            pass

Partition 0
0
1
2
3
4
Partition 1
5
6
7
8
9
Partition 2
10
11
12
13
14
Partition 3
15
16
17
18
19


Use `map()` to transform a dataset.

In [58]:
ds = tf.data.Dataset.from_tensor_slices(np.arange(0, 5)).map(lambda x: x**2)
iterator = ds.make_one_shot_iterator()
next_op = iterator.get_next()
with tf.Session() as sess:
    try:
        while True:
            val = sess.run(next_op)
            print(val)
    except tf.errors.OutOfRangeError:
        pass

0
1
4
9
16
