## Optimize tensorflow pipeline performance with prefetch and caching

In [1]:
import tensorflow as tf
import time

In [2]:
tf.__version__

'2.9.2'

# Prefetch

In [3]:
class FileDataset(tf.data.Dataset):
    def read_file_in_batches(num_samples):
        # Opening the file
        time.sleep(0.03)

        for sample_idx in range(num_samples):
            # Reading data (line, record) from the file
            time.sleep(0.015)

            yield (sample_idx,)

    def __new__(cls, num_samples=3):
        return tf.data.Dataset.from_generator(
            cls.read_file_in_batches,
            output_signature = tf.TensorSpec(shape = (1,), dtype = tf.int64),
            args=(num_samples,)
        )

In [4]:
def benchmark(dataset, num_epochs=2):
    for epoch_num in range(num_epochs):
        for sample in dataset:
            # Performing a training step
            time.sleep(0.01)

In [5]:
%%timeit
benchmark(FileDataset())

Metal device set to: Apple M1


2023-02-17 16:54:17.296359: I tensorflow/core/common_runtime/pluggable_device/pluggable_device_factory.cc:305] Could not identify NUMA node of platform GPU ID 0, defaulting to 0. Your kernel may not have been built with NUMA support.
2023-02-17 16:54:17.296693: I tensorflow/core/common_runtime/pluggable_device/pluggable_device_factory.cc:271] Created TensorFlow device (/job:localhost/replica:0/task:0/device:GPU:0 with 0 MB memory) -> physical PluggableDevice (device: 0, name: METAL, pci bus id: <undefined>)
2023-02-17 16:54:17.335442: W tensorflow/core/platform/profile_utils/cpu_utils.cc:128] Failed to get CPU frequency: 0 Hz


313 ms ± 6.64 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [6]:
%%timeit
benchmark(FileDataset().prefetch(1))

271 ms ± 8.91 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [7]:
%%timeit
benchmark(FileDataset().prefetch(tf.data.AUTOTUNE))

259 ms ± 20.7 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


# Cache

In [8]:
dataset = tf.data.Dataset.range(5)
dataset = dataset.map(lambda x: x**2)
dataset = dataset.cache("mycache.txt")
# The first time reading through the data will generate the data using
# `range` and `map`.
list(dataset.as_numpy_iterator())

2023-02-17 16:55:46.008798: W tensorflow/core/kernels/data/cache_dataset_ops.cc:296] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset  will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead.


[0, 1, 4, 9, 16]

In [9]:
# Subsequent iterations read from the cache.
list(dataset.as_numpy_iterator())

[0, 1, 4, 9, 16]

In [10]:
def mapped_function(s):
    # Do some hard pre-processing
    tf.py_function(lambda: time.sleep(0.03), [], ())
    return s

In [11]:
%%timeit -r1 -n1
benchmark(FileDataset().map(mapped_function), 5)

1.31 s ± 0 ns per loop (mean ± std. dev. of 1 run, 1 loop each)


In [12]:
%%timeit -r1 -n1
benchmark(FileDataset().map(mapped_function).cache(), 5)

424 ms ± 0 ns per loop (mean ± std. dev. of 1 run, 1 loop each)
