本文将对比tf.data与tf.keras中keras读数据方式下那种速度快，具体有三点：
1. tf.data与keras生成器读数据速度对比
2. tf.data包装后的keras生成器与原始生成器速度对比
3. model.fit 与 model.fit_generator分别使用以上数据的实验

In [43]:
import tensorflow as tf
import numpy as np
import time
print(tf.__version__)

2.0.0


### 1、准备本文所用数据

In [3]:
(train_x,train_y),(test_x,test_y) = tf.keras.datasets.fashion_mnist.load_data()
print(train_x.shape,train_y.shape,test_x.shape,test_y.shape)

(60000, 28, 28) (60000,) (10000, 28, 28) (10000,)


### 2、准备tf.data数据

In [6]:
train_ds = tf.data.Dataset.from_tensor_slices((train_x,train_y))
test_ds = tf.data.Dataset.from_tensor_slices((test_x,test_y))


In [7]:
train_ds = train_ds.shuffle(buffer_size=1000).batch(256).prefetch(buffer_size=1000).repeat()
test_ds = test_ds.batch(256).prefetch(buffer_size=1000)

In [24]:
#检查数据
for data,label in test_ds.take(1):
    pass
print(data.shape,label.shape)
np.testing.assert_array_almost_equal(data,test_x[:256,...])#不返回报错信息表示数据相等
np.testing.assert_array_almost_equal(label,test_y[:256])
print(train_ds)

(256, 28, 28) (256,)
<RepeatDataset shapes: ((None, 28, 28), (None,)), types: (tf.uint8, tf.uint8)>


### 3、keras生成器读数据方式

In [10]:
gen = tf.keras.preprocessing.image.ImageDataGenerator()#不做任何数据预处理

In [37]:
new_train_x=np.expand_dims(train_x,-1)# keras生成器读数据要求输入形状是rank=4
new_test_x=np.expand_dims(test_x,-1)
train_flow=gen.flow(new_train_x,train_y,batch_size=256,shuffle=True)#与tf.data中batch相同大小，并且shuffle
test_flow=gen.flow(new_test_x,test_y,batch_size=256,shuffle=False)

In [27]:
#检查数据
data,label= next(test_flow)
np.testing.assert_array_almost_equal(data,new_test_x[:256,...])#不返回报错信息表示数据相等
np.testing.assert_array_almost_equal(label,test_y[:256])
print(train_flow)

<keras_preprocessing.image.numpy_array_iterator.NumpyArrayIterator object at 0x0000025F0F253E48>


### 4、tf.data包装keras生成器

In [118]:
gen = tf.keras.preprocessing.image.ImageDataGenerator()
wrap_train_ds = tf.data.Dataset.from_generator(lambda:gen.flow(new_train_x,train_y,batch_size=256,shuffle=True),
    output_types=(tf.uint8, tf.uint8),
    output_shapes = ([None,28,28,1],[None])
)
wrap_test_ds = tf.data.Dataset.from_generator(lambda:gen.flow(new_test_x,test_y,batch_size=256,shuffle=False),
    output_types=(tf.uint8, tf.uint8),
    output_shapes = (tf.TensorShape([None,28,28,1]),tf.TensorShape([None]))#tf.TensorShape可以不用
)

In [119]:
#检查数据
for data,label in wrap_test_ds.take(1):
    pass
print(data.shape,label.shape)
np.testing.assert_array_almost_equal(data,new_test_x[:256,...])#不返回报错信息表示数据相等
np.testing.assert_array_almost_equal(label,test_y[:256])
print(wrap_train_ds)

(256, 28, 28, 1) (256,)
<DatasetV1Adapter shapes: ((None, 28, 28, 1), (None,)), types: (tf.uint8, tf.uint8)>


### 5、有了三种数据开始比较速度

In [41]:
default_timeit_steps = 1000

def timeit(ds, steps=default_timeit_steps):
    start = time.time()
    it = iter(ds)
    for i in range(steps):
        batch = next(it)
        if i%10 == 0:
            print('.',end='')
    print()
    end = time.time()

    duration = end-start
    print("{} batches: {} s".format(steps, duration))
    print("{:0.5f} samples/s".format(256*steps/duration))

In [44]:
timeit(train_ds)

....................................................................................................
1000 batches: 1.3849620819091797 s
184842.60569 samples/s


In [45]:
timeit(train_flow)

....................................................................................................
1000 batches: 4.649678945541382 s
55057.56483 samples/s


In [120]:
timeit(wrap_train_ds)

....................................................................................................
1000 batches: 6.324928283691406 s
40474.76723 samples/s


**对比结论**
显然tf.data是最快的，wrap后的生成器最慢，我们肯定是要用tf.data的。关于wrap后比原始keras读数据的方式慢的原因，可能是因为这个生成器有问题，具体不再深究，所以我们就直接用tf.data了。

In [None]:
#明天对tf.data进行改造，还要进行训练，查看效果