In [1]:
import tensorflow as tf
import numpy as np

In [2]:
import functools

## tf.data.Dataset.from_generator的用法

参考：
- https://www.jianshu.com/p/d80ea5d73446

In [3]:
# 创建数据生成的函数
def data_generator():
    dataset = np.array(range(5))
    for d in dataset:
        print("in func, data:",d)
        yield d

In [4]:
shapes=(tf.TensorShape([]))
types=(tf.int32)
dataset = tf.data.Dataset.from_generator(data_generator, output_shapes=shapes, output_types=types)

Instructions for updating:
tf.py_func is deprecated in TF V2. Instead, use
    tf.py_function, which takes a python function which manipulates tf eager
    tensors instead of numpy arrays. It's easy to convert a tf eager tensor to
    an ndarray (just call tensor.numpy()) but having access to eager tensors
    means `tf.py_function`s can use accelerators such as GPUs as well as
    being differentiable using a gradient tape.
    


In [5]:
'''
API 支持以下四种 iterator，复杂程度递增：
    one-shot
    initializable
    reinitializable
    feedable
    one-shot
one-shot iterator 谁最简单的一种 iterator，仅支持对整个数据集访问一遍，不需要显式的初始化。
one-shot iterator 不支参数化。
'''
iterator = dataset.make_one_shot_iterator()
one_element = iterator.get_next()

with tf.Session() as sess:
    try:
        batch_num=0
        while True:
            one_batch = sess.run(one_element)
            print('Batch No. %d:' % batch_num,one_batch)
            batch_num+=1
    except tf.errors.OutOfRangeError:
        print('end!')

in func, data: 0
Batch No. 0: 0
in func, data: 1
Batch No. 1: 1
in func, data: 2
Batch No. 2: 2
in func, data: 3
Batch No. 3: 3
in func, data: 4
Batch No. 4: 4
end!


In [6]:
# epoch=2，即repeat两次
dataset=dataset.repeat(2)
iterator = dataset.make_one_shot_iterator()
one_element = iterator.get_next()
 
with tf.Session() as sess:
    try:
        batch_num=0
        while True:
            one_batch = sess.run(one_element)
            print('Batch No. %d:' % batch_num,one_batch)
            batch_num+=1
    except tf.errors.OutOfRangeError:
        print('end!')

in func, data: 0
Batch No. 0: 0
in func, data: 1
Batch No. 1: 1
in func, data: 2
Batch No. 2: 2
in func, data: 3
Batch No. 3: 3
in func, data: 4
Batch No. 4: 4
in func, data: 0
Batch No. 5: 0
in func, data: 1
Batch No. 6: 1
in func, data: 2
Batch No. 7: 2
in func, data: 3
Batch No. 8: 3
in func, data: 4
Batch No. 9: 4
end!


In [7]:
# buffer 大小设置为8，打乱 dataset
dataset=dataset.shuffle(8)
iterator = dataset.make_one_shot_iterator()
one_element = iterator.get_next()
 
with tf.Session() as sess:
    try:
        batch_num=0
        while True:
            one_batch = sess.run(one_element)
            print('Batch No. %d:' % batch_num,one_batch)
            batch_num+=1
    except tf.errors.OutOfRangeError:
        print('end!')

in func, data: 0
in func, data: 1
in func, data: 2
in func, data: 3
in func, data: 4
in func, data: 0
in func, data: 1
in func, data: 2
Batch No. 0: 2
in func, data: 3
Batch No. 1: 0
in func, data: 4
Batch No. 2: 2
Batch No. 3: 0
Batch No. 4: 3
Batch No. 5: 4
Batch No. 6: 1
Batch No. 7: 3
Batch No. 8: 4
Batch No. 9: 1
end!


In [8]:
# 设置 batch size 为3
dataset = dataset.batch(3)
dataset=dataset.shuffle(10)
iterator = dataset.make_one_shot_iterator()
one_element = iterator.get_next()

with tf.Session() as sess:
    try:
        batch_num=0
        while True:
            one_batch = sess.run(one_element)
            print('Batch No. %d:' % batch_num,one_batch)
            batch_num+=1
    except tf.errors.OutOfRangeError:
        print('end!')

in func, data: 0
in func, data: 1
in func, data: 2
in func, data: 3
in func, data: 4
in func, data: 0
in func, data: 1
in func, data: 2
in func, data: 3
in func, data: 4
Batch No. 0: [0]
Batch No. 1: [1 4 2]
Batch No. 2: [0 3 3]
Batch No. 3: [1 2 4]
end!


In [9]:
# 指定range
dataset = tf.data.Dataset.range(9)
dataset = dataset.batch(3)
dataset=dataset.shuffle(10)
iterator = dataset.make_one_shot_iterator()
one_element = iterator.get_next()
 
with tf.Session() as sess:
    try:
        batch_num=0
        while True:
            one_batch = sess.run(one_element)
            print('Batch No. %d:' % batch_num,one_batch)
            batch_num+=1
    except tf.errors.OutOfRangeError:
        print('end!')

Batch No. 0: [6 7 8]
Batch No. 1: [0 1 2]
Batch No. 2: [3 4 5]
end!


### Initializable iterator

要求在使用之前显式的通过调用iterator.initializer操作初始化，这使得在定义数据集时可以结合tf.placeholder传入参数

In [2]:
# Initializable iterator
max_value = tf.placeholder(tf.int64, shape=[])
dataset = tf.data.Dataset.range(max_value)
iterator = dataset.make_initializable_iterator()
next_element = iterator.get_next()

sess=tf.Session()
# Initialize an iterator over a dataset with 5 elements.
sess.run(iterator.initializer, feed_dict={max_value: 5})
for i in range(5):
    value = sess.run(next_element)
    assert i == value
    print("(1)",i)

# Initialize the same iterator over a dataset with 100 elements.
sess.run(iterator.initializer, feed_dict={max_value: 100})
for i in range(7):
    value = sess.run(next_element)
    assert i == value
    print("(2)",i)
sess.close()

Instructions for updating:
Colocations handled automatically by placer.
(1) 0
(1) 1
(1) 2
(1) 3
(1) 4
(2) 0
(2) 1
(2) 2
(2) 3
(2) 4
(2) 5
(2) 6


### reinitializable iterator 

可以被不同的 dataset 对象初始化，
比如对于训练集进行了shuffle的操作，对于验证集则没有处理，通常这种情况会使用两个具有相同结构的dataset对象

In [2]:
# Define training and validation datasets with the same structure.
training_dataset = tf.data.Dataset.range(3).map(lambda x: x + tf.random_uniform([], -10, 10, tf.int64))
validation_dataset = tf.data.Dataset.range(2)

# A reinitializable iterator is defined by its structure. We could use the
# `output_types` and `output_shapes` properties of either `training_dataset`
# or `validation_dataset` here, because they are compatible.
print("training_dataset.output_types",training_dataset.output_types)
print("training_dataset.output_shapes",training_dataset.output_shapes)
iterator = tf.data.Iterator.from_structure(training_dataset.output_types,training_dataset.output_shapes)
next_element = iterator.get_next()

training_init_op = iterator.make_initializer(training_dataset)
validation_init_op = iterator.make_initializer(validation_dataset)

# Run 20 epochs in which the training dataset is traversed, followed by the
# validation dataset.
sess=tf.Session()
for i in range(2):
    # Initialize an iterator over the training dataset.
    sess.run(training_init_op)
    for _ in range(3):
        output=sess.run(next_element)
        print("epoch [%d] train:"%(i),output)

  # Initialize an iterator over the validation dataset.
    sess.run(validation_init_op)
    for _ in range(2):
        output=sess.run(next_element)
        print("epoch [%d] val:"%(i),output)
sess.close()

training_dataset.output_types <dtype: 'int64'>
training_dataset.output_shapes ()
Instructions for updating:
Colocations handled automatically by placer.
epoch [0] train: 7
epoch [0] train: 2
epoch [0] train: 3
epoch [0] val: 0
epoch [0] val: 1
epoch [1] train: -9
epoch [1] train: 10
epoch [1] train: 6
epoch [1] val: 0
epoch [1] val: 1


### feedable iterator
可以通过和tf.placeholder结合在一起，同通过feed_dict机制来选择在每次调用tf.Session.run的时候选择哪种Iterator。

它提供了与 reinitilizable iterator 类似的功能，并且在切换数据集的时候不需要在开始的时候初始化iterator。

通过tf.data.Iterator.from_string_handle来定义一个 feedable iterator，达到切换数据集的目的。

例子如下，和上面的例子类似。

In [7]:
# Define training and validation datasets with the same structure.
training_dataset = tf.data.Dataset.range(3).map(lambda x: x + tf.random_uniform([], -10, 10, tf.int64)).repeat()
validation_dataset = tf.data.Dataset.range(2)

# A feedable iterator is defined by a handle placeholder and its structure. We
# could use the `output_types` and `output_shapes` properties of either
# `training_dataset` or `validation_dataset` here, because they have
# identical structure.
handle = tf.placeholder(tf.string, shape=[])
iterator = tf.data.Iterator.from_string_handle(handle, training_dataset.output_types, training_dataset.output_shapes)
next_element = iterator.get_next()

# You can use feedable iterators with a variety of different kinds of iterator
# (such as one-shot and initializable iterators).
training_iterator = training_dataset.make_one_shot_iterator()
validation_iterator = validation_dataset.make_initializable_iterator()

sess=tf.Session()

# The `Iterator.string_handle()` method returns a tensor that can be evaluated
# and used to feed the `handle` placeholder.
training_handle = sess.run(training_iterator.string_handle())
validation_handle = sess.run(validation_iterator.string_handle())

for _ in range(3):
    output=sess.run(next_element, feed_dict={handle: training_handle})
    print("train:",output)

# Run one pass over the validation dataset.
sess.run(validation_iterator.initializer)
for _ in range(2):
    output=sess.run(next_element, feed_dict={handle: validation_handle})
    print("val:",output)

sess.close()

train: -10
train: -7
train: -4
val: 0
val: 1


### 如果序列不等长，在形成dataset batch时可以使用Dataset.padded_batch方法


In [4]:
tf.reset_default_graph()
sess = tf.Session()

#tf.TensorShape([])     表示长度为单个数字
#tf.TensorShape([None]) 表示长度未知的向量
    #padded_shapes=(tf.TensorShape([None]),)
    #注意，在tf.TensorShape([None])后面不能添加 ",",因为这里递归嵌套，会认为","后面还有一维数据，
    #只是数据格式为 None。

In [5]:
dataset = tf.data.Dataset.range(10)
iterator = dataset.make_one_shot_iterator()
next_element = iterator.get_next()
try:
    while True:
        print(sess.run(iterator.get_next()))
except tf.errors.OutOfRangeError:
    print("end")

0
1
2
3
4
5
6
7
8
9
end


In [6]:
dataset = dataset.map(lambda x: tf.fill([tf.cast(x, tf.int32)], x))
iterator = dataset.make_one_shot_iterator()
next_element = iterator.get_next()
try:
    while True:
        print(sess.run(iterator.get_next()))
except tf.errors.OutOfRangeError:
    print("end")

[]
[1]
[2 2]
[3 3 3]
[4 4 4 4]
[5 5 5 5 5]
[6 6 6 6 6 6]
[7 7 7 7 7 7 7]
[8 8 8 8 8 8 8 8]
[9 9 9 9 9 9 9 9 9]
end


In [7]:
dataset = dataset.padded_batch(3, padded_shapes=[None])

iterator = dataset.make_one_shot_iterator()
next_element = iterator.get_next()

print(sess.run(next_element))
print(sess.run(next_element))
print(sess.run(next_element))
print(sess.run(next_element))

[[0 0]
 [1 0]
 [2 2]]
[[3 3 3 0 0]
 [4 4 4 4 0]
 [5 5 5 5 5]]
[[6 6 6 6 6 6 0 0]
 [7 7 7 7 7 7 7 0]
 [8 8 8 8 8 8 8 8]]
[[9 9 9 9 9 9 9 9 9]]


In [8]:
x = [[1, 0, 0],
     [2, 3, 0],
     [4, 5, 6],
     [7, 8, 0],
     [9, 0, 0]]
x_new = [np.array(i) for i in x]

dataset = tf.data.Dataset.from_tensor_slices(x)
iterator = dataset.make_one_shot_iterator()
sess = tf.Session()
try:
    while True:
        print(sess.run(iterator.get_next()))
except tf.errors.OutOfRangeError:
    print("end")

padded_shapes=(tf.TensorShape([None]))
dataset = dataset.padded_batch(2, padded_shapes=padded_shapes)#,padding_values=([333,333,333]))
iterator = dataset.make_one_shot_iterator()

try:
    while True:
        print(sess.run(iterator.get_next()))
except tf.errors.OutOfRangeError:
    print("end")

[1 0 0]
[2 3 0]
[4 5 6]
[7 8 0]
[9 0 0]
end
[[1 0 0]
 [2 3 0]]
[[4 5 6]
 [7 8 0]]
[[9 0 0]]
end


In [9]:
sess.close()