In [49]:
# Libraries
import tensorflow as tf
import time

In [50]:
from numpy import dtype


class FileDataset(tf.data.Dataset):
    def read_files(num_samples):
        # open file
        time.sleep(0.03)
        for sample_idx in range(num_samples):
            time.sleep(0.015)
            yield(sample_idx, )

    def __new__(cls, num_samples=3):
        return tf.data.Dataset.from_generator(
            cls.read_files,
            output_signature=tf.TensorSpec(shape = (1, ), dtype=tf.int64),
            args=(num_samples,)
        )

In [51]:
def benchmark(dataset, num_epoches=2):
    for epoches in range(num_epoches):
        for sample in dataset:
            time.sleep(0.01)

In [52]:
%%timeit
benchmark(FileDataset())

354 ms ± 42.1 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


# Using Prefetch API

In [53]:
%%timeit
benchmark(FileDataset().prefetch(tf.data.AUTOTUNE))

323 ms ± 32.1 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


# Using Catch Function

In [54]:
dataset_1 = tf.data.Dataset.range(5)
for d in dataset_1:
    print(d.numpy())

0
1
2
3
4


In [55]:
dataset_2 = dataset_1.map(lambda x: x**2)
for d in dataset_2:
    print(d.numpy())

0
1
4
9
16


In [56]:
dataset_3 = dataset_2.cache()
for d in dataset_3.as_numpy_iterator():
    print(d)

0
1
4
9
16


In [57]:
def mapped_func(s):
    tf.py_function(lambda: time.sleep(0.03), [], [])
    return s

In [61]:
%%timeit -n1 -r1
benchmark(FileDataset().map(mapped_func), 5)

1.24 s ± 0 ns per loop (mean ± std. dev. of 1 run, 1 loop each)


In [62]:
%%timeit -n1 -r1
# Improving Performence
benchmark(FileDataset().map(mapped_func).cache(), 5)

525 ms ± 0 ns per loop (mean ± std. dev. of 1 run, 1 loop each)
