# Optimize tensorflow pipeline performance with prefetch and caching

# Prefetch

In [1]:
import tensorflow as tf
import time

In [2]:
tf.__version__

'2.16.1'

Function for Prefetch

In [31]:
class FileDataset(tf.data.Dataset):
    def read_file_in_batches(num_samples):
        #opening the file 
        time.sleep(0.03)
        for sample_idx in range(num_samples):
            # Reading data (line, record) from the file
            time.sleep(0.025)

            yield (sample_idx,)

    def __new__(self,num_samples=3):
        return tf.data.Dataset.from_generator(
            self.read_file_in_batches,output_signature=tf.TensorSpec(shape = (1,),dtype = tf.int64),args=(num_samples,)
        )
        

In [32]:
def benchmark(dataset,num_epochs = 2):
    for epoch_num in range (num_epochs):
        # Performing a training step
            time.sleep(0.03)

In [33]:
%%timeit
benchmark(FileDataset())

86.4 ms ± 1.4 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)


In [30]:

%%timeit
benchmark(FileDataset().prefetch(1))

68.7 ms ± 2.83 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)


In [10]:

%%timeit
benchmark(FileDataset().prefetch(tf.data.AUTOTUNE)) # by using autotune

47.8 ms ± 818 μs per loop (mean ± std. dev. of 7 runs, 10 loops each)


There is some improvement in execution time

# Cache

In [13]:
dataset = tf.data.Dataset.range(5)
dataset = dataset.map(lambda x: x**2)
dataset = dataset.cache("mycache.txt") # A my cache txt file would be generated
# The first time reading through the data will generate the data using
# `range` and `map`.
list(dataset.as_numpy_iterator())

[0, 1, 4, 9, 16]

In [14]:

# Subsequent iterations read from the cache.
list(dataset.as_numpy_iterator())

[0, 1, 4, 9, 16]

In [21]:

def mapped_function(s):
    # Do some hard pre-processing
    tf.py_function(lambda: time.sleep(0.5), [], ())
    return s

In [22]:

%%timeit -r1 -n1
benchmark(FileDataset().map(mapped_function), 5)

142 ms ± 0 ns per loop (mean ± std. dev. of 1 run, 1 loop each)


In [23]:

%%timeit -r1 -n1
benchmark(FileDataset().map(mapped_function).cache(), 5)

103 ms ± 0 ns per loop (mean ± std. dev. of 1 run, 1 loop each)


As by reading cache there is some improvement