In [1]:
import tensorflow as tf
import time

In [2]:
tf.__version__

'2.7.0'

## Training in the normal/unoptimized way

In [3]:
class FileDataset(tf.data.Dataset):
    def read_file_in_batches(num_samples):
        # Opening the file
        time.sleep(0.03)

        for sample_idx in range(num_samples):
            # Reading data time (line, record) from the file
            time.sleep(0.015) 

            yield (sample_idx,)

    def __new__(cls, num_samples=3):
        return tf.data.Dataset.from_generator(
            cls.read_file_in_batches,
            output_signature = tf.TensorSpec(shape = (1,), dtype = tf.int64),
            args=(num_samples,)
        )

In [8]:
def benchmark(dataset, num_epochs=10):
    for epoch_num in range(num_epochs):
        for sample in dataset:
            # Time for performing a training step
            time.sleep(0.01)


In [9]:
FileDataset()

<FlatMapDataset shapes: (1,), types: tf.int64>

In [10]:
%%timeit
benchmark(FileDataset())

1 loop, best of 5: 1.36 s per loop


## Optimizing the performance

In [11]:
%%timeit
benchmark(FileDataset().prefetch(1))

1 loop, best of 5: 1.17 s per loop


## Using tf.data.AUTOTUNE

> It figures out on its own how many batches it wants to prefetch while the GPU is training.



In [12]:
%%timeit
benchmark(FileDataset().prefetch(tf.data.AUTOTUNE))

1 loop, best of 5: 1.17 s per loop


## CACHE API

In [13]:
dataset=tf.data.Dataset.range(5)

for d in dataset:
  print(d.numpy())

0
1
2
3
4


In [15]:
dataset = dataset.map(lambda x: x**2)

for d in dataset:
  print(d.numpy())

0
1
16
81
256


In [18]:
dataset_cache=dataset.cache()

for d in dataset:
  print(d.numpy()) # Reading the data from the 'cache'

0
1
16
81
256


In [20]:
def mapped_function(s):
    # Do some hard pre-processing
    tf.py_function(lambda: time.sleep(0.03), [], ()) # Introducing some kind of delay
    return s

## Without Using the CACHE()

In [23]:
%%timeit -r1 -n1
benchmark(FileDataset().map(mapped_function), 25)

1 loop, best of 1: 5.88 s per loop


## Using the CACHE()

In [24]:
%%timeit -r1 -n1
benchmark(FileDataset().map(mapped_function).cache(), 25)

1 loop, best of 1: 1.2 s per loop
