In [16]:
import tensorflow as tf
import time

In [17]:
tf.__version__

'2.10.1'

# Prefetch

In [18]:
# Making demy class for better understanding
class FileDataset(tf.data.Dataset):
    def read_files_in_batches(num_samples):
        # open file
        time.sleep(0.03)
        for sample_idx in range(num_samples):
            time.sleep(0.015)
            yield (sample_idx,) # yield is generator function
            
    def __new__(cls, num_samples=3): #over riding new method
        return tf.data.Dataset.from_generator(
            cls.read_files_in_batches,
            output_signature =tf.TensorSpec(shape =(1,), dtype = tf.int64),
            args=(num_samples,)
        )
        

In [19]:
def benchmark(dataset, num_epochs=2):
    for epoch_num in range(num_epochs):
        for sample in dataset:
            time.sleep(0.01)

In [20]:
%%timeit
benchmark(FileDataset())

398 ms ± 21.2 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [21]:
%%timeit
benchmark(FileDataset().prefetch(1))

345 ms ± 21.1 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [22]:
%%timeit
benchmark(FileDataset().prefetch(tf.data.AUTOTUNE))

338 ms ± 17.9 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


# Cache API

In [41]:
# Creating new dataset and dataset is nothing but a bunch of numbers

dataset = tf.data.Dataset.range(5)

for d in dataset:
    print(d.numpy())

0
1
2
3
4


In [48]:
# Squaring of each number

dataset = tf.data.Dataset.range(5)
dataset = dataset.map(lambda x: x**2)
dataset = dataset.cache("mycache.txt")
# The first time reading through the data will generate the data using
# `range` and `map`.
list(dataset.as_numpy_iterator())

[0, 1, 4, 9, 16]

In [49]:
# Reading data from cache

dataset = dataset.cache()

list(dataset.as_numpy_iterator())

[0, 1, 4, 9, 16]

In [50]:
list(dataset.as_numpy_iterator())

[0, 1, 4, 9, 16]

In [51]:
def mapped_function(s):
    # Do some hard pre-processing
    tf.py_function(lambda: time.sleep(0.03), [], ())
    return s

In [54]:
%%timeit -r1 -n1
benchmark(FileDataset().map(mapped_function), 5)

1.57 s ± 0 ns per loop (mean ± std. dev. of 1 run, 1 loop each)


In [53]:
%%timeit -r1 -n1
benchmark(FileDataset().map(mapped_function).cache(), 5)

540 ms ± 0 ns per loop (mean ± std. dev. of 1 run, 1 loop each)
