# Dataloader

> Support various dataloader for loading batches.

In [None]:
#| default_exp core

In [None]:
#| include: false
%load_ext autoreload
%autoreload 2
from nbdev import show_doc
import warnings
warnings.simplefilter(action='ignore', category=FutureWarning)

In [None]:
#| export
from __future__ import annotations
import numpy as np
from typing import List, Tuple, Dict, Any, Optional, Iterable, Sequence, Iterator
import jax
from jax import vmap, grad, jit, numpy as jnp
from jax.random import PRNGKey
from abc import ABC
from dataclasses import dataclass
import collections
import warnings

In [None]:
#| export
try: 
    import torch.utils.data as torch_data
    import torch
except ModuleNotFoundError: 
    torch_data = None
    torch = None

try: import haiku as hk 
except ModuleNotFoundError: hk = None

try: import datasets as hf_datasets
except ModuleNotFoundError: hf_datasets = None

In [None]:
#| exporti
#| hide
@dataclass
class Config:
    rng_reserve_size: int
    global_seed: int

    @classmethod
    def default(cls) -> Config:
        return cls(rng_reserve_size=1, global_seed=42)

In [None]:
#| exporti
main_config = Config.default()

In [None]:
#| exporti
def get_config() -> Config:
    return main_config

In [None]:
#| export
class PRNGSequence(Iterator[PRNGKey]):
    """An Interator of Jax PRNGKey (minimal version of haiku.PRNGSequence)."""

    def __init__(self, seed: int):
        self._key = jax.random.PRNGKey(seed)
        self._subkeys = collections.deque()

    def reserve(self, num):
        """Splits additional ``num`` keys for later use."""
        if num > 0:
            new_keys = tuple(jax.random.split(self._key, num + 1))
            self._key = new_keys[0]
            self._subkeys.extend(new_keys[1:])
            
    def __next__(self):
        if not self._subkeys:
            self.reserve(get_config().rng_reserve_size)
        return self._subkeys.popleft()

In [None]:
#| exporti
def _check_pytorch_installed():
    if torch_data is None:
        raise ModuleNotFoundError("`pytorch` library needs to be installed. "
            "Try `pip install torch`. Please refer to pytorch documentation for details: "
            "https://pytorch.org/get-started/.")
    
def _check_hf_installed():
    if hf_datasets is None:
        raise ModuleNotFoundError("`datasets` library needs to be installed. "
            "Try `pip install datasets`. Please refer to huggingface documentation for details: "
            "https://huggingface.co/docs/datasets/installation.html.")

## Dataset

In [None]:
#| export
class Dataset:
    """A pytorch-like Dataset class."""

    def __len__(self):
        raise NotImplementedError

    def __getitem__(self, index):
        raise NotImplementedError

In [None]:
#| export
class ArrayDataset(Dataset):
    """Dataset wrapping numpy arrays."""

    def __init__(
        self, 
        *arrays: jnp.DeviceArray # Numpy array with same first dimension
    ):
        assert all(arrays[0].shape[0] == arr.shape[0] for arr in arrays), \
            "All arrays must have the same dimension."
        self.arrays = arrays

    def __len__(self):
        return self.arrays[0].shape[0]

    def __getitem__(self, index):
        return tuple(arr[index] for arr in self.arrays)


This is similar to [torch.utils.data.TensorDataset](https://pytorch.org/docs/stable/data.html#torch.utils.data.TensorDataset), 
but it wrapps numpy arrays.

In [None]:
X = jnp.arange(10000).reshape(1000, 10)
y = jnp.arange(1000)
ds = ArrayDataset(X, y)
assert len(ds) == 1000

We index numpy arrays along the first dimension.
Dataset indexing is done via `ds[index]`.

In [None]:
x1, y1 = ds[1] # get the first sample
assert jnp.array_equal(x1, X[1])
assert jnp.array_equal(y1, y[1])

In [None]:
#| exporti
def _has_tensor(batch) -> bool:
    if isinstance(batch[0], torch.Tensor):
        return True
    elif isinstance(batch[0], (tuple, list)):
        transposed = zip(*batch)
        return any([_has_tensor(samples) for samples in transposed])
    else:
        return False

In [None]:
#| export
class TorchDataset(Dataset):
    """A Dataset class that wraps a pytorch Dataset."""
    
    def __init__(
        self, 
        dataset: torch_data.Dataset # Pytorch Dataset
    ):
        _check_pytorch_installed()
        if not isinstance(dataset, torch_data.Dataset):
            raise TypeError(f"`dataset` must be a torch Dataset, but got {type(dataset)}")
        # Give a warning if the dataset is not in numpy format
        if _has_tensor(dataset[0]):
            warnings.warn("The dataset contains `torch.Tensor`. "
                "Please make sure the dataset is in numpy format.")
        self._ds = dataset

    def __len__(self):
        return len(self._ds)

    def __getitem__(self, index):
        return self._ds[index]

`TorchDataset` is a wrapper class of `torch.utils.data`. It does not modify inner behavior of the input pytorch `dataset`.

:::{.callout-warning}

`TorchDataset` will **NOT** turn a `torch.Tensor` into `numpy.array`.
Therefore, it is suggested to ensure the input `dataset` is in numpy format 
before passing to the `TorchDataset`.
`TorchDataset` will give a warning if `torch.Tensor` is found in the dataset.

:::


Let's load the [MNIST](http://yann.lecun.com/exdb/mnist/) dataset 
using the Pytorch Dataset.

In [None]:
from torch.utils.data import TensorDataset
from torchvision.datasets import MNIST

We flatten and cast the PIL image into the `numpy.array`
(brought from [jax official tutorial](https://jax.readthedocs.io/en/latest/notebooks/Neural_Network_and_Data_Loading.html)).

In [None]:
class FlattenAndCast(object):
  def __call__(self, pic):
    return np.ravel(np.array(pic, dtype=float))

We load the pytorch [MNIST](https://pytorch.org/vision/stable/generated/torchvision.datasets.MNIST.html#torchvision.datasets.MNIST) dataset.

In [None]:
mnist_torch = MNIST('/tmp/mnist/', download=True, transform=FlattenAndCast())

Finally, we can wrape the `mnist_torch` as follows.

In [None]:
mnist_ds = TorchDataset(mnist_torch)
assert isinstance(mnist_ds[0][0], np.ndarray)

In [None]:
#| export
class HFDataset(Dataset):
    """A Dataset class that wraps a huggingface Dataset."""
    
    def __init__(
        self, 
        dataset: hf_datasets.Dataset # Huggingface Dataset
    ):
        _check_hf_installed()
        # if not isinstance(dataset, hf_datasets.Dataset):
        #     raise TypeError(f"`dataset` must be a huggingface Dataset, "
        #                     f"but got {type(dataset)}")
        # Ensure the dataset is in jax format
        self._ds = dataset.with_format("jax")

    def __len__(self):
        return len(self._ds)

    def __getitem__(self, index):
        return self._ds[index]

`HFDataset` wraps a huggingface dataset. Unlike `TorchDataset`,
`HFDataset` will ensure the input dataset with the format of `jax.DeviceArray`.

Again, we load the [MNIST](http://yann.lecun.com/exdb/mnist/) dataset, 
but load the data via the huggingface datasets.

In [None]:
from datasets import load_dataset

In [None]:
#|output: false
mnist_hf = load_dataset("mnist", split="train")



We wrap the `mnist_hf` as follows:

In [None]:
mnist_ds = HFDataset(mnist_hf)
assert isinstance(mnist_ds[0]['image'], jnp.ndarray)

## Dataloader

In [None]:
#| export
class BaseDataLoader:
    """Dataloader Interface"""
    def __init__(
        self, 
        dataset, 
        batch_size: int = 1,  # batch size
        shuffle: bool = False,  # if true, dataloader shuffles before sampling each batch
        drop_last: bool = False,
        **kwargs
    ):
        pass

    def __len__(self):
        raise NotImplementedError
    
    def __next__(self):
        raise NotImplementedError
    
    def __iter__(self):
        raise NotImplementedError

## Jax Dataloader

In [None]:
#| export
class DataLoaderJax(BaseDataLoader):
    """Dataloder in Vanilla Jax"""

    def __init__(
        self, 
        dataset: Dataset,
        batch_size: int = 1,  # batch size
        shuffle: bool = False,  # if true, dataloader shuffles before sampling each batch
        drop_last: bool = False, # drop last batches or not
        **kwargs
    ):
        self.dataset = dataset
        self.batch_size = batch_size
        self.shuffle = shuffle
        self.drop_last = drop_last

        self.keys = PRNGSequence(seed=get_config().global_seed) \
            if hk is None else hk.PRNGSequence(get_config().global_seed)
        self.data_len = len(dataset)  # Length of the dataset
        self.indices = jnp.arange(self.data_len) # available indices in the dataset
        self.pose = 0  # record the current position in the dataset
        self._shuffle()

    def _shuffle(self):
        if self.shuffle:
            self.indices = jax.random.permutation(next(self.keys), self.indices)
        
    def _stop_iteration(self):
        self.pose = 0
        self._shuffle()
        raise StopIteration

    def __len__(self):
        if self.drop_last:
            batches = len(self.dataset) // self.batch_size  # get the floor of division
        else:
            batches = -(len(self.dataset) // -self.batch_size)  # get the ceil of division
        return batches

    def __next__(self):
        if self.pose + self.batch_size <= self.data_len:
            batch_indices = self.indices[self.pose: self.pose + self.batch_size]
            batch_data = self.dataset[batch_indices, ...]
            self.pose += self.batch_size
            return batch_data
        elif self.pose < self.data_len and not self.drop_last:
            batch_indices = self.indices[self.pose:]
            batch_data = self.dataset[batch_indices, ...]
            self.pose += self.batch_size
            return batch_data
        else:
            self._stop_iteration()

    def __iter__(self):
        return self

In [None]:
#| hide
def test_dataloader(dataloader_cls, samples=1000, batch_size=12):
    feats = jnp.arange(samples).repeat(10).reshape(samples, 10)
    labels = jnp.arange(samples).reshape(samples, 1)
    ds = ArrayDataset(feats, labels)
    # N % batchsize != 0
    dl = dataloader_cls(ds, batch_size=batch_size, shuffle=False)
    for _ in range(2):
        X_list, Y_list = [], []
        for x, y in dl:
            X_list.append(x)
            Y_list.append(y)
        _X, _Y = map(jnp.concatenate, (X_list, Y_list))
        assert jnp.array_equal(_X, feats)
        assert jnp.array_equal(_Y, labels)

    dl = dataloader_cls(ds, batch_size=batch_size, shuffle=False, drop_last=True)
    for _ in range(2):
        X_list, Y_list = [], []
        for x, y in dl:
            X_list.append(x)
            Y_list.append(y)
        _X, _Y = map(jnp.concatenate, (X_list, Y_list))
        last_idx = len(X_list) * batch_size
        assert jnp.array_equal(_X, feats[: last_idx])
        assert jnp.array_equal(_Y, labels[: last_idx])


    dl_shuffle = dataloader_cls(ds, batch_size=batch_size, shuffle=True, drop_last=False)
    for _ in range(2):
        X_list, Y_list = [], []
        for x, y in dl_shuffle:
            assert jnp.array_equal(x[:, :1], y)
            X_list.append(x)
            Y_list.append(y)
        _X, _Y = map(jnp.concatenate, (X_list, Y_list))
        assert not jnp.array_equal(_X, feats)
        assert not jnp.array_equal(_Y, labels)
        assert jnp.sum(_X) == jnp.sum(feats), \
            f"jnp.sum(_X)={jnp.sum(_X)}, jnp.sum(feats)={jnp.sum(feats)}"


    dl_shuffle = dataloader_cls(ds, batch_size=batch_size, shuffle=True, drop_last=True)
    for _ in range(2):
        X_list, Y_list = [], []
        for x, y in dl_shuffle:
            assert jnp.array_equal(x[:, :1], y)
            X_list.append(x)
            Y_list.append(y)
        _X, _Y = map(jnp.concatenate, (X_list, Y_list))
        assert not jnp.array_equal(_X, feats)
        assert not jnp.array_equal(_Y, labels)
        assert len(_X) == len(X_list) * batch_size


In [None]:
#| hide
test_dataloader(DataLoaderJax, samples=20, batch_size=12)
test_dataloader(DataLoaderJax, samples=20, batch_size=10)
test_dataloader(DataLoaderJax, samples=11, batch_size=10)

## Pytorch Dataloader

Use `Pytorch` to load batches. It requires [pytorch](https://pytorch.org/get-started/) to be installed.

In [None]:
#| exporti
# copy from https://jax.readthedocs.io/en/latest/notebooks/Neural_Network_and_Data_Loading.html#data-loading-with-pytorch
def _numpy_collate(batch):
    if isinstance(batch[0], np.ndarray):
        return np.stack(batch)
    elif isinstance(batch[0], (tuple, list)):
        transposed = zip(*batch)
        return [_numpy_collate(samples) for samples in transposed]
    else:
        return np.array(batch)

def _convert_dataset_pytorch(dataset: Dataset):
    class DatasetPytorch(torch_data.Dataset):
        def __init__(self, dataset: Dataset): self.dataset = dataset
        def __len__(self): return len(self.dataset)
        def __getitem__(self, idx): return self.dataset[idx]
    
    return DatasetPytorch(dataset)

In [None]:
#| export
class DataLoaderPytorch(BaseDataLoader):
    """Pytorch Dataloader"""
    def __init__(
        self, 
        dataset: Dataset,
        batch_size: int = 1,  # Batch size
        shuffle: bool = False,  # If true, dataloader shuffles before sampling each batch
        drop_last: bool = False, # Drop last batch or not
        **kwargs
    ):
        super().__init__(dataset, batch_size, shuffle, drop_last)
        _check_pytorch_installed()
        
        dataset = _convert_dataset_pytorch(dataset)
        self.dataloader = torch_data.DataLoader(
            dataset, 
            batch_size=batch_size, 
            shuffle=shuffle, 
            drop_last=drop_last,
            collate_fn=_numpy_collate,
            **kwargs
        ) 

    def __len__(self):
        return len(self.dataloader)

    def __next__(self):
        return next(self.dataloader)

    def __iter__(self):
        return self.dataloader.__iter__()

In [None]:
#| hide
test_dataloader(DataLoaderPytorch, samples=20, batch_size=12)
test_dataloader(DataLoaderPytorch, samples=20, batch_size=10)
test_dataloader(DataLoaderPytorch, samples=11, batch_size=10)

## Main Dataloader Class

In [None]:
#| export
def _is_hf_dataset(dataset):
    return hf_datasets and (
        isinstance(dataset, hf_datasets.Dataset) 
        or isinstance(dataset, hf_datasets.DatasetDict)
    )

def _dispatch_dataset(
    dataset, # Dataset or Pytorch Dataset or HuggingFace Dataset
):
    if isinstance(dataset, Dataset):
        return dataset
    elif torch_data and isinstance(dataset, torch_data.Dataset):
        return TorchDataset(dataset)
    elif _is_hf_dataset(dataset):
        return HFDataset(dataset)
    else:
        raise ValueError(f"dataset must be one of `jax_loader.core.Dataset`, "
                         "`torch.utils.data.Dataset`, `datasets.Dataset`, "
                         f"but got {type(dataset)}")

In [None]:
#| export
@dataclass(frozen=True)
class DataloaderBackends:
    jax: BaseDataLoader = DataLoaderJax
    pytorch: BaseDataLoader = DataLoaderPytorch
    tensorflow: BaseDataLoader = None
    merlin: BaseDataLoader = None

    __all__ = dict(
        jax=jax, pytorch=pytorch, tensorflow=tensorflow, merlin=merlin
    )

    def __getitem__(self, key):
        return self.__all__[key]

    @property
    def supported(self) -> List[str]:
        return [
            backend for backend, dl_cls in self.__all__.items() if dl_cls is not None
        ]

In [None]:
#| exporti
def _dispatch_dataloader(
    backend: str # dataloader backend
) -> BaseDataLoader:
    """Return Dataloader class based on given `backend`"""
    backends = DataloaderBackends()
    if not backend in backends.supported:
        raise ValueError(f"backend=`{backend}` is either an invalid backend or not supported yet. "
            f"Should be one of {backends.supported}.")
    
    dl_cls = backends[backend]
    return dl_cls


In [None]:
show_doc(_dispatch_dataloader)

---

[source](https://github.com/birkhoffg/jax-dataloader/blob/master/jax_dataloader/core.py#L330){target="_blank" style="float:right; font-size:smaller"}

### _dispatch_dataloader

>      _dispatch_dataloader (backend:str)

Return Dataloader class based on given `backend`

|    | **Type** | **Details** |
| -- | -------- | ----------- |
| backend | str | dataloader backend |
| **Returns** | **BaseDataLoader** |  |

In [None]:
#| export
class DataLoader:
    """Main Dataloader class to load Numpy data batches"""
    def __init__(
        self,
        dataset: Dataset,
        backend: str, # Dataloader backend
        batch_size: int = 1,  # batch size
        shuffle: bool = False,  # if true, dataloader shuffles before sampling each batch
        drop_last: bool = False, # drop last batches or not
        **kwargs
    ):
        dataset = _dispatch_dataset(dataset)
        dataloader_cls = _dispatch_dataloader(backend)
        self.dataloader = dataloader_cls(
            dataset=dataset, 
            batch_size=batch_size, 
            shuffle=shuffle, 
            drop_last=drop_last,
            **kwargs
        )

    def __len__(self):
        return len(self.dataloader)

    def __next__(self):
        return next(self.dataloader)

    def __iter__(self):
        return self.dataloader.__iter__()

#### A Minimum Example of using Dataloader

We showcase how to use `Dataloader` for training a simple regression model.


In [None]:
from sklearn.datasets import make_regression
import optax
import haiku as hk

In [None]:
X, y = make_regression(n_samples=10000, n_features=20)
dataset = ArrayDataset(X, y.reshape(-1, 1))
keys = hk.PRNGSequence(0)

Define `loss`, `step`, `train`:

In [None]:
def loss(w, x, y):
    return jnp.mean(vmap(optax.l2_loss)(x @ w.T, y))

def step(w, x, y):
    lr = 0.1
    grad = jax.grad(loss)(w, x, y)
    w -= lr * grad
    return w

def train(dataloader: DataLoader, key: jax.random.PRNGKey):
    w = jax.random.normal(key, shape=(1, 20))
    n_epochs = 10
    for _ in range(n_epochs):
        for x, y in dataloader:
            w = step(w, x, y)
    return w

def eval(dataloader: DataLoader, w):
    err = []
    for x, y in dataloader:
        err.append(loss(w, x, y))
    return np.mean(err)
    

Train this linear regression model via `DataLoaderJax`:

In [None]:
dataloader = DataLoader(
    dataset, 'jax', batch_size=128, shuffle=True)
w = train(dataloader, next(keys)).block_until_ready()
assert np.allclose(eval(dataloader, w), 0.)

In [None]:
dataloader = DataLoader(dataset, 'jax', batch_size=200, shuffle=True)
w = train(dataloader, next(keys)).block_until_ready()
assert np.allclose(eval(dataloader, w), 0.)

Train this linear regression model via `DataLoaderPytorch`:

In [None]:
dataloader = DataLoader(
    dataset, 'pytorch', batch_size=128, shuffle=True)
w = train(dataloader, next(keys)).block_until_ready()
assert np.allclose(eval(dataloader, w), 0.)