In [19]:
import tensorflow as tf
import time

# Prefetch

In [20]:
class FileDataset(tf.data.Dataset):
    def read_files_in_batches(num_samples):
        # Open file
        time.sleep(0.03)
        
        for sample_idx in range(num_samples):
            time.sleep(0.015)
            yield (sample_idx, )
            
    def __new__(cls, num_samples=3):
        return tf.data.Dataset.from_generator(
            cls.read_files_in_batches,
            output_signature = tf.TensorSpec(shape=(1,), dtype=tf.int64),
            args = (num_samples,)
        )

In [21]:
def benchmark(dataset, num_epochs=10):
    for epoch in range(num_epochs):
        for sample in dataset:
            time.sleep(0.01)

In [22]:
%%timeit
benchmark(FileDataset())

1.09 s ± 7.27 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [24]:
%%timeit
benchmark(FileDataset().prefetch(tf.data.AUTOTUNE)) # It should be faster :((

1.08 s ± 3.83 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


# Cache

In [28]:
dataset = tf.data.Dataset.range(5)

for sample in dataset:
    print(sample.numpy())

0
1
2
3
4


In [29]:
dataset = dataset.map(lambda x: x ** 2) # Transform the dataset

In [30]:
dataset = dataset.cache()

list(dataset.as_numpy_iterator())

[0, 1, 4, 9, 16]

In [32]:
list(dataset.as_numpy_iterator()) # It's reading from cache now - it saved transformed dataset in a cache so it doesn't have to transform it again

[0, 1, 4, 9, 16]

In [33]:
def map_function(s):
    tf.py_function(lambda: time.sleep(0.03), inp=[], Tout=[])
    return s

In [35]:
%%timeit -n1 -r1
benchmark(FileDataset().map(map_function))

2.05 s ± 0 ns per loop (mean ± std. dev. of 1 run, 1 loop each)


In [36]:
%%timeit -n1 -r1
benchmark(FileDataset().map(map_function).cache()) # For each epoch now transformed dataset is read from cache so it's much quicker

582 ms ± 0 ns per loop (mean ± std. dev. of 1 run, 1 loop each)
