In [21]:
import tensorflow as tf
import time

# Prefetch

In [30]:
# Define a custom dataset class using tf.data.Dataset

class FileDataset(tf.data.Dataset):
    # Define a method to read files in batches, simulating file read delays
    def read_file_in_batches(num_samples):
        # Simulate opening the file (0.03 seconds delay)
        time.sleep(0.03)

        # Yield data samples with a simulated read delay (0.015 seconds per sample)
        for sample_indx in range(num_samples):
            time.sleep(0.015)
            yield(sample_indx,)

    def __new__(cls,num_samples=3):
        return tf.data.Dataset.from_generator(
            cls.read_file_in_batches,
            output_signature=tf.TensorSpec(shape=(1,),dtype = tf.int64),
            args = (num_samples,)
        )

In [31]:
# Define a benchmark function to measure dataset performance
def benchmark(dataset,num_epochs=2):
    for epoch_num in range(num_epochs):
        for sample in dataset:
            # Simulate a training step (0.01 seconds delay)
            time.sleep(0.01)

In [39]:
%%timeit
benchmark(FileDataset())

# Measure the performance of the dataset without prefetching

268 ms ± 2.92 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [40]:
%%timeit
benchmark(FileDataset().prefetch(1))

# Measure the performance of the dataset with prefetch(1)

266 ms ± 4.83 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [44]:
%%timeit
benchmark(FileDataset().prefetch(tf.data.AUTOTUNE))

# Measure the performance of the dataset with prefetch(tf.data.AUTOTUNE)

263 ms ± 8.26 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


**Notice that using prefetch improves the performance from 268 ms to 266 and 263 ms**

# Cache

In [46]:
# Create a simple dataset that generates a range of numbers
dataset = tf.data.Dataset.range(5)

# Apply a map transformation to square each element
dataset = dataset.map(lambda x: x**2)

# Cache the dataset to a file
dataset = dataset.cache("mycache.txt")

In [48]:
# The first time reading through the data will generate the data using `range` and `map`.
print(list(dataset.as_numpy_iterator()))

[0, 1, 4, 9, 16]


In [49]:
# Subsequent iterations read from the cache, so no recomputation is needed.
print(list(dataset.as_numpy_iterator()))

[0, 1, 4, 9, 16]


In [50]:
# Define a mapped function to simulate heavy preprocessing
def mapped_function(s):
    
    # Simulate heavy preprocessing with a delay
    tf.py_function(lambda: time.sleep(0.03), [], ())
    return s

In [51]:
%%timeit -r1 -n1
benchmark(FileDataset().map(mapped_function), 5)
# Measure the performance of the dataset with the mapped function without caching

1.25 s ± 0 ns per loop (mean ± std. dev. of 1 run, 1 loop each)


In [54]:
%%timeit -r1 -n1
benchmark(FileDataset().map(mapped_function).cache(), 5)
# Measure the performance of the dataset with the mapped function and caching

413 ms ± 0 ns per loop (mean ± std. dev. of 1 run, 1 loop each)
