# Dataloader for JAX

![Python](https://img.shields.io/pypi/pyversions/jax-dataloader.svg)
![CI status](https://github.com/BirkhoffG/jax-dataloader/actions/workflows/nbdev.yaml/badge.svg)
![Docs](https://github.com/BirkhoffG/jax-dataloader/actions/workflows/deploy.yaml/badge.svg)
![pypi](https://img.shields.io/pypi/v/jax-dataloader.svg)
![GitHub License](https://img.shields.io/github/license/BirkhoffG/jax-dataloader.svg)
<a href="https://static.pepy.tech/badge/jax-dataloader"><img src="https://static.pepy.tech/badge/jax-dataloader" alt="Downloads"></a>


## Overview

`jax_dataloader` brings *pytorch-like* dataloader API to `jax`. 
It supports

* **4 datasets to download and pre-process data**: 
    * [jax dataset](https://birkhoffg.github.io/jax-dataloader/dataset/)
    * [huggingface datasets](https://github.com/huggingface/datasets) 
    * [pytorch Dataset](https://pytorch.org/docs/stable/data.html#torch.utils.data.Dataset)
    * [tensorflow dataset](www.tensorflow.org/datasets)

* **3 backends to iteratively load batches**: 
    * [jax dataloader](https://birkhoffg.github.io/jax-dataloader/core.html#jax-dataloader)
    * [pytorch dataloader](https://pytorch.org/docs/stable/data.html#torch.utils.data.DataLoader) 
    * [tensorflow dataset](www.tensorflow.org/datasets)


A minimum `jax-dataloader` example:

```python
import jax_dataloader as jdl

dataloader = jdl.DataLoader(
    dataset, # Can be a jdl.Dataset or pytorch or huggingface dataset
    backend='jax', # Use 'jax' for loading data (also supports `pytorch`)
)

batch = next(iter(dataloader)) # iterate next batch
```

## Installation

The latest `jax-dataloader` release can directly be installed from PyPI:

```sh
pip install jax-dataloader
```

or install directly from the repository:

```sh
pip install git+https://github.com/BirkhoffG/jax-dataloader.git
```

:::{.callout-note} 

We keep `jax-dataloader`'s dependencies minimum, which only install `jax`-related dependencies, and `plum-dispatch` for backend dispatching.
If you wish to use integration of `pytorch`, huggingface `datasets`, or `tensorflow`,
we recommend manually install those dependencies.

You can also run `pip install jax-dataloader[all]` to install everything (not recommended).

:::

## Usage

`jax_dataloader.core.DataLoader` follows similar API as the pytorch dataloader.

* The `dataset` should be an object of the subclass of `jax_dataloader.core.Dataset` 
or `torch.utils.data.Dataset` or (the huggingface) `datasets.Dataset`
or `tf.data.Dataset`.
* The `backend` should be one of `"jax"` or `"pytorch"` or `"tensorflow"`. 
This argument specifies which backend dataloader to load batches.

Note that not every dataset is compatible with every backend. See the compatibility table below:

In [None]:
#| hide
%load_ext autoreload
%autoreload 2
import warnings
from IPython.display import Markdown
warnings.simplefilter(action='ignore', category=FutureWarning)

In [None]:
#| hide
from jax_dataloader.core import get_backend_compatibilities, SUPPORTED_DATASETS, JAXDataset
from jax_dataloader.imports import *
import pandas as pd
import jax_dataloader as jdl
import jax.numpy as jnp

In [None]:
#|echo: false
compat = get_backend_compatibilities()
annotate = {
    JAXDataset: "`jdl.Dataset`",
    TorchDataset: "`torch_data.Dataset`",
    TFDataset: "`tf.data.Dataset`",
    HFDataset: "`datasets.Dataset`",
}
assert len(annotate) == len(SUPPORTED_DATASETS)
supported = {}

for backend, ds in compat.items():
    if len(ds) > 0:
        _supported = [s in ds for s in SUPPORTED_DATASETS]
        supported[f'`"{backend}"`'] = list(map(lambda x: "✅" if x else "❌", _supported))

Markdown(
    pd.DataFrame(supported)
    .T
    .rename(columns={"index": "Backend"})
    .rename(columns={i: annotated for i, annotated in enumerate(annotate.values())})
    .to_markdown()
)

|                | `jdl.Dataset`   | `torch_data.Dataset`   | `tf.data.Dataset`   | `datasets.Dataset`   |
|:---------------|:----------------|:-----------------------|:--------------------|:---------------------|
| `"jax"`        | ✅              | ❌                     | ❌                  | ✅                   |
| `"pytorch"`    | ✅              | ✅                     | ❌                  | ✅                   |
| `"tensorflow"` | ✅              | ❌                     | ✅                  | ✅                   |

### Using `ArrayDataset`

The `jax_dataloader.core.ArrayDataset` is an easy way to wrap 
multiple `jax.numpy.array` into one Dataset. For example, 
we can create an `ArrayDataset` as follows:

In [None]:
#| output: false
# Create features `X` and labels `y`
X = jnp.arange(100).reshape(10, 10)
y = jnp.arange(10)
# Create an `ArrayDataset`
arr_ds = jdl.ArrayDataset(X, y)

This `arr_ds` can be loaded by *every* backends.

In [None]:
#| torch
# Create a `DataLoader` from the `ArrayDataset` via jax backend
dataloader = jdl.DataLoader(arr_ds, 'jax', batch_size=5, shuffle=True)
# Or we can use the pytorch backend
dataloader = jdl.DataLoader(arr_ds, 'pytorch', batch_size=5, shuffle=True)

### Using Huggingface Datasets

The huggingface [datasets](https://github.com/huggingface/datasets)
is a morden library for downloading, pre-processing, and sharing datasets.
`jax_dataloader` supports directly passing the huggingface datasets.

In [None]:
#| hf
from datasets import load_dataset

For example, We load the `"squad"` dataset from `datasets`:

In [None]:
#| output: false
#| hf
hf_ds = load_dataset("squad")

Then, we can use `jax_dataloader` to load batches of `hf_ds`.

In [None]:
#| hf torch
# Create a `DataLoader` from the `datasets.Dataset` via jax backend
dataloader = jdl.DataLoader(hf_ds['train'], 'jax', batch_size=5, shuffle=True)
# Or we can use the pytorch backend
dataloader = jdl.DataLoader(hf_ds['train'], 'pytorch', batch_size=5, shuffle=True)

### Using Pytorch Datasets

The [pytorch Dataset](https://pytorch.org/docs/stable/data.html)
and its ecosystems (e.g., 
[torchvision](https://pytorch.org/vision/stable/index.html),
[torchtext](https://pytorch.org/text/stable/index.html),
[torchaudio](https://pytorch.org/audio/stable/index.html)) 
supports many built-in datasets. 
`jax_dataloader` supports directly passing the pytorch Dataset.

:::{.callout-note} 

Unfortuantely, the [pytorch Dataset](https://pytorch.org/docs/stable/data.html)
can only work with `backend=pytorch`. See the belowing example.

:::

In [None]:
#| torch
from torchvision.datasets import MNIST
import numpy as np

We load the MNIST dataset from `torchvision`. 
The `ToNumpy` object transforms images to `numpy.array`.

In [None]:
class ToNumpy(object):
  def __call__(self, pic):
    return np.array(pic, dtype=float)

In [None]:
#| torch
pt_ds = MNIST('/tmp/mnist/', download=True, transform=ToNumpy(), train=False)

This `pt_ds` can **only** be loaded via `"pytorch"` dataloaders.

In [None]:
#| torch
dataloader = jdl.DataLoader(pt_ds, 'pytorch', batch_size=5, shuffle=True)

### Using Tensowflow Datasets

`jax_dataloader` supports directly passing the [tensorflow datasets](www.tensorflow.org/datasets).

In [None]:
import tensorflow_datasets as tfds
import tensorflow as tf

For instance, we can load the MNIST dataset from `tensorflow_datasets`

In [None]:
tf_ds = tfds.load('mnist', split='test', as_supervised=True)

and use `jax_dataloader` for iterating the dataset.

In [None]:
dataloader = jdl.DataLoader(tf_ds, 'tensorflow', batch_size=5, shuffle=True)