# data iterator

tf.data支持四种data iterator：
1. one-shot：最简单，只能遍历数据一遍，不能用参数配置
2. initializable: 需要显式初始化，可以参数配置
3. reinitializable：可以对多个相同形状的dataset重复使用
4. feedable：与reinitializable类似

4种iterator复杂度有低到高

## one-shot

In [1]:
import tensorflow as tf

In [2]:
sess = tf.Session()

In [3]:
dataset = tf.data.Dataset.range(100)

In [4]:
iterator = dataset.make_one_shot_iterator()
new_element = iterator.get_next()

In [5]:
for i in range(100):
    value = sess.run(new_element)
    assert i == value

重复运行上面循环会报错，因为one shot只能遍历数据一遍

## initializable

可以通过参数配置dataset，比如大小

In [6]:
max_value = tf.placeholder(tf.int64, shape = [])
dataset = tf.data.Dataset.range(max_value)
iterator = dataset.make_initializable_iterator()
next_elemet = iterator.get_next()

Instructions for updating:
Colocations handled automatically by placer.


In [7]:
sess.run(iterator.initializer, feed_dict = {max_value: 10})

In [8]:
for i in range(10):
    value = sess.run(next_elemet)
    print(value)

0
1
2
3
4
5
6
7
8
9


## reinitializable

同一个reinitializable iterator可以用于迭代有相同形状但是不同数据集

In [9]:
training_dataset = tf.data.Dataset.range(100).map(
    lambda x: x + tf.random_uniform([], -10, 10, tf.int64)
)

validation_dataset = tf.data.Dataset.range(50)

使用形状定义iterator

In [10]:
iterator = tf.data.Iterator.from_structure(training_dataset.output_types, training_dataset.output_shapes)

In [11]:
next_element = iterator.get_next()

In [12]:
training_init_op = iterator.make_initializer(training_dataset)

In [13]:
validation_init_op = iterator.make_initializer(validation_dataset)

In [14]:
for _ in range(20):
    sess.run(training_init_op)
    for _ in range(100):
        sess.run(next_element)
        
    sess.run(validation_init_op)
    for _ in range(50):
        sess.run(next_element)

## feedable

feedable iterator可以不用指定iterator的类型，运行的时候通过tf.placeholder传入即可，有点类似于多态

In [46]:
training_dataset = tf.data.Dataset.range(100).map(
    lambda x: x + tf.random_uniform([], -10, 10, tf.int64))

validation_dataset = tf.data.Dataset.range(50)

handle = tf.placeholder(tf.string, shape = [])
iterator = tf.data.Iterator.from_string_handle(handle, training_dataset.output_types, training_dataset.output_shapes)
next_element = iterator.get_next()

In [47]:
training_iterator = training_dataset.make_one_shot_iterator()

In [48]:
validation_iterator = validation_dataset.make_initializable_iterator()

In [49]:
training_handle = sess.run(training_iterator.string_handle())
validation_handle = sess.run(validation_iterator.string_handle())

In [None]:
for _ in range(200):
    sess.run(next_element, feed_dict = {handle: training_handle})
        
sess.run(validation_iterator.initializer)
for _ in range(50):
    sess.run(next_element, feed_dict = {handle: validation_handle})