It is important to make optimal use of your hardware resources (CPU and GPU) while training a deep learning model. You can  use tf.data.Dataset.prefetch(AUTOTUNE) and tf.data.Dataset.cache() methods for this purpose.

In [1]:
import tensorflow as tf
import time

In [2]:
tf.__version__

'2.7.0'

# Prefetch

In [3]:
class FileDataset(tf.data.Dataset):
    # dummy class mimics real life scenario
    def read_files_in_batches(num_samples):
        # open file
        time.sleep(0.03)
        for sample_idx in range(num_samples):
            time.sleep(0.015)
            yield (sample_idx,)

    def __new__(cls, num_samples= 3):
        # print("new called")
        return tf.data.Dataset.from_generator(
            cls.read_files_in_batches,
            output_signature= tf.TensorSpec(shape= (1,), dtype= tf.int64),
            args= (num_samples,)
        )
        

In [4]:
def benchmark(dataset, num_epochs= 2):
    for epoch_num in range(num_epochs):
        for sample in dataset:
            time.sleep(0.01)


In [8]:
%%timeit
benchmark(FileDataset())

1 loop, best of 5: 282 ms per loop


In [6]:
# obj = FileDataset()

In [7]:
import time

In [9]:
%%timeit
# use prefetch to improve performance

benchmark(FileDataset().prefetch(1))

1 loop, best of 5: 241 ms per loop


In [10]:
%%timeit

# use autotune
benchmark(FileDataset().prefetch(tf.data.AUTOTUNE))

1 loop, best of 5: 242 ms per loop


# Cache

In [11]:
# create a new dataset
dataset = tf.data.Dataset.range(5)
for d in dataset:
  print(d.numpy())

0
1
2
3
4


In [12]:
# square components
dataset = dataset.map(lambda x: x**2)
for d in dataset:
  print(d.numpy())

0
1
4
9
16


In [14]:
dataset = dataset.cache()

# for d in dataset.as_numpy_iterator():
#   print(d)

list(dataset.as_numpy_iterator())

[0, 1, 4, 9, 16]

In [15]:
# reading from cache
list(dataset.as_numpy_iterator())

[0, 1, 4, 9, 16]

In [20]:
# applying in function
def mapped_function(s):
  tf.py_function(lambda: time.sleep(0.03), [], ())
  return s
  


In [21]:
%%timeit -n1 -r1

benchmark(FileDataset().map(mapped_function), 5)

1 loop, best of 1: 1.2 s per loop


In [22]:
%%timeit -n1 -r1

# improve performance with cache

benchmark(FileDataset().map(mapped_function).cache(), 5)

1 loop, best of 1: 405 ms per loop


Further reading https://www.tensorflow.org/guide/data_performance#caching