In [35]:
import tensorflow as tf
import timeit

In [36]:
tf.__version__

'2.13.0'

In [37]:
#Creating a class
class FileDataset(tf.data.Dataset):
    def read_files_in_batches(num_samples):
        #open file
        time.sleep(0.03) #Reading the file
        for sample_idx in range(num_samples):
            time.sleep(0.015)
            #yield is a generator
            yield(sample_idx,)

    def __new__(cls,num_samples=3):
        return tf.data.Dataset.from_generator(
            cls.read_files_in_batches,
            output_signature = tf.TensorSpec(shape = (1,), dtype = tf.int64),args=(num_samples,)
        )
       

In [38]:
def benchmark(dataset, num_epochs=2):
    for epoch_num in range(num_epochs):  
        for sample in dataset:
            time.sleep(0.01) 

In [39]:
%%timeit
benchmark(FileDataset()) #Read files in batches : While CPU was reading GPU was training

292 ms ± 4.54 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [40]:
%%timeit   #Use prefetch() API to chcek how it increases the performance
benchmark(FileDataset().prefetch(1))

290 ms ± 8.15 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [43]:
%%timeit
benchmark(FileDataset().prefetch(tf.data.AUTOTUNE))

287 ms ± 13.2 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


#### Cache 

In [45]:
dataset = tf.data.Dataset.range(5)
for d in dataset:
    print(d.numpy())

0
1
2
3
4


In [46]:
dataset = dataset.map(lambda x:x**2)
for d in dataset:
    print(d.numpy())

0
1
4
9
16


In [49]:
dataset = dataset.cache()
list(dataset.as_numpy_iterator())


[0, 1, 4, 9, 16]

In [50]:
list(dataset.as_numpy_iterator()) #Reading from cache

[0, 1, 4, 9, 16]

In [52]:
def mapped_function(s):
    tf.py_function(lambda:time.sleep(0.03),[],())
    return s

In [54]:
%%timeit
benchmark(FileDataset().map(mapped_function),5)

1.21 s ± 12.6 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [55]:
%%timeit #Not calling the map functio for 2nd, 3rd..epochs, using the map data in the catche
benchmark(FileDataset().map(mapped_function).cache(),5)

423 ms ± 5.09 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
