# 读取输入数据
## 消费 NumPy 数组
如果所有输入数据都适合存储在内存中，则根据输入数据创建 Dataset 的最简单方法是使用`Dataset.from_tensor_slices()`
将它们转换为 tf.Tensor 对象。

In [3]:
import tensorflow as tf
import numpy as np

sess=tf.InteractiveSession()

features, labels = (np.random.sample((100,2)), np.random.sample((100,1)))

# Assume that each row of `features` corresponds to the same row as `labels`.
assert features.shape[0] == labels.shape[0]

dataset = tf.data.Dataset.from_tensor_slices((features, labels))

# make_one_shot_iterator 一次访问Dataset中的一个元素
iterator = dataset.make_one_shot_iterator()
next_element = iterator.get_next()

for i in range(10):
    value = sess.run(next_element)
    print(value)


(array([0.28547566, 0.22282825]), array([0.07755469]))
(array([0.61760959, 0.0492551 ]), array([0.91944163]))
(array([0.19500593, 0.79088898]), array([0.8959786]))
(array([0.19422631, 0.32383958]), array([0.0190849]))
(array([0.64358333, 0.9216484 ]), array([0.7334814]))
(array([0.01495688, 0.35054628]), array([0.45931417]))
(array([0.18024108, 0.72396095]), array([0.13781602]))
(array([0.06179509, 0.7726721 ]), array([0.58175556]))
(array([0.1859441 , 0.25900502]), array([0.05224747]))
(array([0.39245439, 0.1266883 ]), array([0.74371682]))




请注意，上面的代码段会将 features 和 labels 数组作为 tf.constant() 指令嵌入在 TensorFlow 图中。这样非常适合小型数据集，但会浪费内存，因为会多次复制数组的内容。

In [4]:
features, labels = (np.random.sample((100,2)), np.random.sample((100,1)))

features_placeholder = tf.placeholder(features.dtype, features.shape)
labels_placeholder = tf.placeholder(labels.dtype, labels.shape)

dataset = tf.data.Dataset.from_tensor_slices((features_placeholder, labels_placeholder))

iterator = dataset.make_initializable_iterator()

sess.run(iterator.initializer, feed_dict={features_placeholder: features,
                                          labels_placeholder: labels})
for i in range(10):
    value = sess.run(next_element)
    print(value)

(array([0.62839225, 0.0977435 ]), array([0.17247136]))
(array([0.09415541, 0.96974174]), array([0.42514313]))
(array([0.06437329, 0.40008927]), array([0.77355629]))
(array([0.65468868, 0.9642231 ]), array([0.0815337]))
(array([0.13263284, 0.02321559]), array([0.57445881]))
(array([0.20099834, 0.00196639]), array([0.41225941]))
(array([0.85103846, 0.80587609]), array([0.42399514]))
(array([0.00628278, 0.13168694]), array([0.44083654]))
(array([0.17048121, 0.06848028]), array([0.7710174]))
(array([0.27084372, 0.96403615]), array([0.40035554]))


## 消费TFRecord数据

其实就是通过tf.data.TFRecordDataset这个api来读取到TFRecord文件，生成处dataset对象

In [5]:
def _parse_function(string_record):
    example = tf.train.Example()
    example.ParseFromString(string_record)
    
    height = int(example.features.feature['height'].int64_list.value[0])
    width = int(example.features.feature['width'].int64_list.value[0])
    img_string = (example.features.feature['image_raw'].bytes_list.value[0])
    label = int(example.features.feature['label'].int64_list.value[0])
    
    img_1d = np.frombuffer(img_string, dtype=np.uint8)
    img = img_1d.reshape((height, width, -1))

    return img, label

In [10]:
# 定义dataset 和 一些列trasformation method
dataset = tf.data.TFRecordDataset('./keras_cifar10_tfreocrds/cifar10_train.tfrceords')

# 创建Iterator
sample_iter = dataset.make_one_shot_iterator()
# 获取next_sample
next_element = sample_iter.get_next()

for i in range(1):
    value = sess.run(next_element)
    print(_parse_function(value))

(array([[[ 59,  62,  63],
        [ 43,  46,  45],
        [ 50,  48,  43],
        ...,
        [158, 132, 108],
        [152, 125, 102],
        [148, 124, 103]],

       [[ 16,  20,  20],
        [  0,   0,   0],
        [ 18,   8,   0],
        ...,
        [123,  88,  55],
        [119,  83,  50],
        [122,  87,  57]],

       [[ 25,  24,  21],
        [ 16,   7,   0],
        [ 49,  27,   8],
        ...,
        [118,  84,  50],
        [120,  84,  50],
        [109,  73,  42]],

       ...,

       [[208, 170,  96],
        [201, 153,  34],
        [198, 161,  26],
        ...,
        [160, 133,  70],
        [ 56,  31,   7],
        [ 53,  34,  20]],

       [[180, 139,  96],
        [173, 123,  42],
        [186, 144,  30],
        ...,
        [184, 148,  94],
        [ 97,  62,  34],
        [ 83,  53,  34]],

       [[177, 144, 116],
        [168, 129,  94],
        [179, 142,  87],
        ...,
        [216, 184, 140],
        [151, 118,  84],
        [123,  92,  72]