## Loaders

In [None]:
#| default_exp loaders

In [None]:
#| include: false
%load_ext autoreload
%autoreload 2
from ipynb_path import *
from nbdev import show_doc
import warnings
warnings.simplefilter(action='ignore', category=FutureWarning)

In [None]:
#| export
from __future__ import print_function, division, annotations
from jax_dataloader.imports import *
from jax_dataloader.utils import *
from jax_dataloader.datasets import *

2023-11-28 00:39:22.955810: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2023-11-28 00:39:22.955851: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2023-11-28 00:39:22.956517: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered


In [None]:
#| hide
def test_dataloader(dataloader_cls, samples=1000, batch_size=12):
    feats = jnp.arange(samples).repeat(10).reshape(samples, 10)
    labels = jnp.arange(samples).reshape(samples, 1)
    ds = ArrayDataset(feats, labels)
    # N % batchsize != 0
    dl = dataloader_cls(ds, batch_size=batch_size, shuffle=False)
    for _ in range(2):
        X_list, Y_list = [], []
        for x, y in dl:
            X_list.append(x)
            Y_list.append(y)
        _X, _Y = map(jnp.concatenate, (X_list, Y_list))
        assert jnp.array_equal(_X, feats)
        assert jnp.array_equal(_Y, labels)

    dl = dataloader_cls(ds, batch_size=batch_size, shuffle=False, drop_last=True)
    for _ in range(2):
        X_list, Y_list = [], []
        for x, y in dl:
            X_list.append(x)
            Y_list.append(y)
        _X, _Y = map(jnp.concatenate, (X_list, Y_list))
        last_idx = len(X_list) * batch_size
        assert jnp.array_equal(_X, feats[: last_idx])
        assert jnp.array_equal(_Y, labels[: last_idx])


    dl_shuffle = dataloader_cls(ds, batch_size=batch_size, shuffle=True, drop_last=False)
    last_X, last_Y = jnp.array([]), jnp.array([])
    for _ in range(2):
        X_list, Y_list = [], []
        for x, y in dl_shuffle:
            assert jnp.array_equal(x[:, :1], y)
            X_list.append(x)
            Y_list.append(y)
        _X, _Y = map(jnp.concatenate, (X_list, Y_list))
        assert not jnp.array_equal(_X, feats)
        assert not jnp.array_equal(_Y, labels)
        assert jnp.sum(_X) == jnp.sum(feats), \
            f"jnp.sum(_X)={jnp.sum(_X)}, jnp.sum(feats)={jnp.sum(feats)}"
        assert not jnp.array_equal(_X, last_X)
        assert not jnp.array_equal(_Y, last_Y)
        last_X, last_Y = _X, _Y


    dl_shuffle = dataloader_cls(ds, batch_size=batch_size, shuffle=True, drop_last=True)
    for _ in range(2):
        X_list, Y_list = [], []
        for x, y in dl_shuffle:
            assert jnp.array_equal(x[:, :1], y)
            X_list.append(x)
            Y_list.append(y)
        _X, _Y = map(jnp.concatenate, (X_list, Y_list))
        assert not jnp.array_equal(_X, feats)
        assert not jnp.array_equal(_Y, labels)
        assert len(_X) == len(X_list) * batch_size


In [None]:
#| hide
def test_keras_dataloader(samples=1000, batch_size=12):
    from keras.trainers.epoch_iterator import EpochIterator

    feats = jnp.arange(samples).repeat(10).reshape(samples, 10)
    labels = jnp.arange(samples).reshape(samples, 1)
    ds = ArrayDataset(feats, labels)
    # N % batchsize != 0
    dl = EpochIterator(feats, labels, batch_size=batch_size, shuffle=False)
    for _ in range(2):
        X_list, Y_list = [], []
        for step, batch in dl.enumerate_epoch('np'):
            x, y = batch[0]
            X_list.append(x)
            Y_list.append(y)
        _X, _Y = map(jnp.concatenate, (X_list, Y_list))
        assert jnp.array_equal(_X, feats)
        assert jnp.array_equal(_Y, labels)

    dl = EpochIterator(feats, labels, batch_size=batch_size, shuffle=False, )
    for _ in range(2):
        X_list, Y_list = [], []
        for step, batch in dl.enumerate_epoch('np'):
            x, y = batch[0]
            X_list.append(x)
            Y_list.append(y)
        _X, _Y = map(jnp.concatenate, (X_list, Y_list))
        last_idx = len(X_list) * batch_size
        jnp.array_equal(_X, feats[: last_idx])
        jnp.array_equal(_Y, labels[: last_idx])


    dl_shuffle = EpochIterator(feats, labels, batch_size=batch_size, shuffle=True, )
    last_X, last_Y = jnp.array([]), jnp.array([])
    for _ in range(2):
        X_list, Y_list = [], []
        for step, batch in dl_shuffle.enumerate_epoch('np'):
            x, y = batch[0]
            assert jnp.array_equal(x[:, :1], y)
            X_list.append(x)
            Y_list.append(y)
        _X, _Y = map(jnp.concatenate, (X_list, Y_list))
        not jnp.array_equal(_X, feats)
        not jnp.array_equal(_Y, labels)
        jnp.sum(_X) == jnp.sum(feats), \
            f"jnp.sum(_X)={jnp.sum(_X)}, jnp.sum(feats)={jnp.sum(feats)}"
        not jnp.array_equal(_X, last_X)
        not jnp.array_equal(_Y, last_Y)
        last_X, last_Y = _X, _Y


    dl_shuffle = EpochIterator(feats, labels, batch_size=batch_size, shuffle=True, )
    for _ in range(2):
        X_list, Y_list = [], []
        for step, batch in dl_shuffle.enumerate_epoch('np'):
            x, y = batch[0]
            assert jnp.array_equal(x[:, :1], y)
            X_list.append(x)
            Y_list.append(y)
        _X, _Y = map(jnp.concatenate, (X_list, Y_list))
        not jnp.array_equal(_X, feats)
        not jnp.array_equal(_Y, labels)
        len(_X) == len(X_list) * batch_size


In [None]:
#| export
class BaseDataLoader:
    """Dataloader Interface"""
    
    def __init__(
        self, 
        dataset: Dataset, 
        batch_size: int = 1,  # batch size
        shuffle: bool = False,  # if true, dataloader shuffles before sampling each batch
        drop_last: bool = False,
        **kwargs
    ):
        pass

    def __len__(self):
        raise NotImplementedError
    
    def __next__(self):
        raise NotImplementedError
    
    def __iter__(self):
        raise NotImplementedError

## JAX Dataloader


In [None]:
#| export
class DataLoaderJax(BaseDataLoader):
    """Dataloder in Vanilla Jax"""

    def __init__(
        self, 
        dataset,
        batch_size: int = 1,  # batch size
        shuffle: bool = False,  # if true, dataloader shuffles before sampling each batch
        drop_last: bool = False, # drop last batches or not
        **kwargs
    ):
        self.dataset = dataset
        self.batch_size = batch_size
        self.shuffle = shuffle
        self.drop_last = drop_last

        self.keys = PRNGSequence(seed=get_config().global_seed) \
            if hk is None else hk.PRNGSequence(get_config().global_seed)
        self.data_len = len(dataset)  # Length of the dataset
        self.indices = jnp.arange(self.data_len) # available indices in the dataset
        self.pose = 0  # record the current position in the dataset
        self._shuffle()

        self.num_batches = len(self)

    def _shuffle(self):
        if self.shuffle:
            self.indices = jax.random.permutation(next(self.keys), self.indices)
        
    def _stop_iteration(self):
        self.pose = 0
        self._shuffle()
        raise StopIteration
    
    def __len__(self):
        if self.drop_last:
            batches = len(self.dataset) // self.batch_size  # get the floor of division
        else:
            batches = -(len(self.dataset) // -self.batch_size)  # get the ceil of division
        return batches

    def __next__(self):
        if self.pose < self.num_batches:
            batch_indices = self.indices[self.pose * self.batch_size: (self.pose + 1) * self.batch_size]
            batch_data = self.dataset[batch_indices]
            self.pose += 1
            return batch_data
        else:
            self._stop_iteration()

    def __iter__(self):
        return self


In [None]:
#| hide
test_dataloader(DataLoaderJax, samples=20, batch_size=12)
test_dataloader(DataLoaderJax, samples=20, batch_size=10)
test_dataloader(DataLoaderJax, samples=11, batch_size=10)
test_dataloader(DataLoaderJax, samples=40, batch_size=12)

An NVIDIA GPU may be present on this machine, but a CUDA-enabled jaxlib is not installed. Falling back to cpu.


In [None]:
#| hide
test_keras_dataloader(samples=20, batch_size=12)
test_keras_dataloader(samples=20, batch_size=10)
test_keras_dataloader(samples=11, batch_size=10)
test_keras_dataloader(samples=40, batch_size=12)

In [None]:
%%timeit -n 5 -r 3
test_dataloader(DataLoaderJax, samples=1280, batch_size=10)

1.48 s ± 29.8 ms per loop (mean ± std. dev. of 3 runs, 5 loops each)


In [None]:
#| hide
%%timeit -n 5 -r 3
test_keras_dataloader(samples=1280, batch_size=10)

301 ms ± 2.4 ms per loop (mean ± std. dev. of 3 runs, 5 loops each)


## `Pytorch`-backed Dataloader

Use `Pytorch` to load batches. It requires [pytorch](https://pytorch.org/get-started/) to be installed.

In [None]:
#| exporti
# adapted from https://jax.readthedocs.io/en/latest/notebooks/Neural_Network_and_Data_Loading.html
def _numpy_collate(batch):
    if isinstance(batch[0], (np.ndarray, jax.Array)):
        return np.stack(batch)
    elif isinstance(batch[0], (tuple, list)):
        transposed = zip(*batch)
        return [_numpy_collate(samples) for samples in transposed]
    elif isinstance(batch[0], dict):
        return {key: _numpy_collate([d[key] for d in batch]) for key in batch[0]}
    else:
        return np.array(batch)

def _convert_dataset_pytorch(dataset: Dataset):
    class DatasetPytorch(torch_data.Dataset):
        def __init__(self, dataset: Dataset): self.dataset = dataset
        def __len__(self): return len(self.dataset)
        def __getitem__(self, idx): return self.dataset[idx]
    
    return DatasetPytorch(dataset)

In [None]:
#| export
class DataLoaderPytorch(BaseDataLoader):
    """Pytorch Dataloader"""
    def __init__(
        self, 
        dataset,
        batch_size: int = 1,  # Batch size
        shuffle: bool = False,  # If true, dataloader shuffles before sampling each batch
        drop_last: bool = False, # Drop last batch or not
        **kwargs
    ):
        super().__init__(dataset, batch_size, shuffle, drop_last)
        check_pytorch_installed()
        from torch.utils.data import BatchSampler, RandomSampler, SequentialSampler
        
        if 'sampler' in kwargs:
            warnings.warn("`sampler` is currently not supported. We will ignore it and use `shuffle` instead.")
            del kwargs['sampler']

        sampler = RandomSampler(dataset) if shuffle else SequentialSampler(dataset)
        batch_sampler = BatchSampler(sampler, batch_size=batch_size, drop_last=drop_last)

        self.dataloader = torch_data.DataLoader(
            dataset, 
            batch_sampler=batch_sampler,
            # batch_size=batch_size, 
            # shuffle=shuffle, 
            # drop_last=drop_last,
            collate_fn=_numpy_collate,
            **kwargs
        ) 

    def __len__(self):
        return len(self.dataloader)

    def __next__(self):
        return next(self.dataloader)

    def __iter__(self):
        return self.dataloader.__iter__()

In [None]:
#| hide
#| torch
test_dataloader(DataLoaderPytorch, samples=20, batch_size=12)
test_dataloader(DataLoaderPytorch, samples=20, batch_size=10)
test_dataloader(DataLoaderPytorch, samples=11, batch_size=10)

## `Tensorflow`-backed Dataloader

In [None]:
#| export
def to_tf_dataset(dataset) -> tf.data.Dataset:
    if is_tf_dataset(dataset):
        return dataset
    elif is_hf_dataset(dataset):
        return dataset.to_tf_dataset()
    elif is_jdl_dataset(dataset):
        return dataset.to_tf_dataset()
    else:
        raise ValueError(f"Dataset type {type(dataset)} is not supported.")

In [None]:
#| export
class DataLoaderTensorflow(BaseDataLoader):
    """Tensorflow Dataloader"""
    def __init__(
        self, 
        dataset,
        batch_size: int = 1,  # Batch size
        shuffle: bool = False,  # If true, dataloader shuffles before sampling each batch
        drop_last: bool = False, # Drop last batch or not
        **kwargs
    ):
        super().__init__(dataset, batch_size, shuffle, drop_last)
        check_tf_installed()
        # Convert to tf dataset
        ds = to_tf_dataset(dataset)
        ds = ds.shuffle(buffer_size=len(dataset), seed=get_config().global_seed) if shuffle else ds
        ds = ds.batch(batch_size, drop_remainder=drop_last)
        ds = ds.prefetch(tf.data.AUTOTUNE)
        self.dataloader = ds

    def __len__(self):
        return len(self.dataloader)

    def __next__(self):
        return next(self.dataloader)

    def __iter__(self):
        return self.dataloader.as_numpy_iterator()

In [None]:
#| hide
#| tf
test_dataloader(DataLoaderTensorflow, samples=20, batch_size=12)
test_dataloader(DataLoaderTensorflow, samples=20, batch_size=10)
test_dataloader(DataLoaderTensorflow, samples=11, batch_size=10)