In [None]:
#| default_exp tensorflow.loaders

In [None]:
#| include: false
%load_ext autoreload
%autoreload 2
from ipynb_path import *
import warnings
warnings.simplefilter(action='ignore', category=FutureWarning)

In [None]:
#| export
from __future__ import print_function, division, annotations
from jax_dataloader.imports import *
from jax_dataloader.loaders import BaseDataLoader
from jax_dataloader.datasets import Dataset, ArrayDataset
from jax_dataloader.utils import is_tf_dataset, is_hf_dataset, is_jdl_dataset, check_tf_installed, get_config
from jax_dataloader.tests import *
from jax.tree_util import tree_map

## `Tensorflow`-backed Dataloader

In [None]:
#| export
def to_tf_dataset(dataset) -> tf.data.Dataset:
    if is_tf_dataset(dataset):
        return dataset
    elif is_hf_dataset(dataset):
        return dataset.to_tf_dataset()
    elif is_jdl_dataset(dataset):
        return tf.data.Dataset.from_tensor_slices(dataset[:])
    else:
        raise ValueError(f"Dataset type {type(dataset)} is not supported.")

In [None]:
#| export
class DataLoaderTensorflow(BaseDataLoader):
    """Tensorflow Dataloader"""
    def __init__(
        self, 
        dataset,
        batch_size: int = 1,  # Batch size
        shuffle: bool = False,  # If true, dataloader shuffles before sampling each batch
        drop_last: bool = False, # Drop last batch or not
        **kwargs
    ):
        super().__init__(dataset, batch_size, shuffle, drop_last)
        check_tf_installed()
        # Convert to tf dataset
        ds = to_tf_dataset(dataset)
        ds = ds.shuffle(buffer_size=len(dataset), seed=get_config().global_seed) if shuffle else ds
        ds = ds.batch(batch_size, drop_remainder=drop_last)
        ds = ds.prefetch(tf.data.AUTOTUNE)
        self.dataloader = ds

    def __len__(self):
        return len(self.dataloader)

    def __next__(self):
        return next(self.dataloader)

    def __iter__(self):
        return self.dataloader.as_numpy_iterator()

In [None]:
#| tf
test_dataloader(DataLoaderTensorflow, samples=20, batch_size=12)
test_dataloader(DataLoaderTensorflow, samples=20, batch_size=10)
test_dataloader(DataLoaderTensorflow, samples=11, batch_size=10)

An NVIDIA GPU may be present on this machine, but a CUDA-enabled jaxlib is not installed. Falling back to cpu.
