## Datasets

> Pytorch-like Dataset API.

In [None]:
#| default_exp datasets

In [None]:
#| include: false
%load_ext autoreload
%autoreload 2
from ipynb_path import *
from nbdev import show_doc
import warnings
warnings.simplefilter(action='ignore', category=FutureWarning)

In [None]:
#| export
from __future__ import print_function, division, annotations
from jax_dataloader.imports import *
from jax_dataloader.utils import *

In [None]:
#| export
class Dataset:
    """A pytorch-like Dataset class."""

    def __len__(self):
        raise NotImplementedError

    def __getitem__(self, index):
        raise NotImplementedError

In [None]:
#| export
class ArrayDataset(Dataset):
    """Dataset wrapping numpy arrays."""

    def __init__(
        self, 
        *arrays: jnp.DeviceArray # Numpy array with same first dimension
    ):
        assert all(arrays[0].shape[0] == arr.shape[0] for arr in arrays), \
            "All arrays must have the same dimension."
        self.arrays = arrays

    def __len__(self):
        return self.arrays[0].shape[0]

    def __getitem__(self, index):
        return tuple(arr[index] for arr in self.arrays)

This is similar to [torch.utils.data.TensorDataset](https://pytorch.org/docs/stable/data.html#torch.utils.data.TensorDataset), 
but it wrapps numpy arrays.

In [None]:
X = jnp.arange(10000).reshape(1000, 10)
y = jnp.arange(1000)
ds = ArrayDataset(X, y)
assert len(ds) == 1000



We index numpy arrays along the first dimension.
Dataset indexing is done via `ds[index]`.

In [None]:
x1, y1 = ds[1] # get the first sample
assert jnp.array_equal(x1, X[1])
assert jnp.array_equal(y1, y[1])

In [None]:
#| exporti
def _has_tensor(batch) -> bool:
    if isinstance(batch[0], torch.Tensor):
        return True
    elif isinstance(batch[0], (tuple, list)):
        transposed = zip(*batch)
        return any([_has_tensor(samples) for samples in transposed])
    else:
        return False

In [None]:
#| exporti
class TorchDataset(Dataset):
    """[Deprecated] A Dataset class that wraps a pytorch Dataset."""
    
    def __init__(
        self, 
        dataset: torch_data.Dataset # Pytorch Dataset
    ):
        check_pytorch_installed()
        if not isinstance(dataset, torch_data.Dataset):
            raise TypeError(f"`dataset` must be a torch Dataset, but got {type(dataset)}")
        # Give a warning if the dataset is not in numpy format
        if _has_tensor(dataset[0]):
            warnings.warn("The dataset contains `torch.Tensor`. "
                "Please make sure the dataset is in numpy format.")
        self._ds = dataset

    def __len__(self):
        return len(self._ds)

    def __getitem__(self, index):
        return self._ds[index]

`TorchDataset` is a wrapper class of `torch.utils.data`. It does not modify inner behavior of the input pytorch `dataset`.

:::{.callout-warning}

`TorchDataset` will **NOT** turn a `torch.Tensor` into `numpy.array`.
Therefore, it is suggested to ensure the input `dataset` is in numpy format 
before passing to the `TorchDataset`.
`TorchDataset` will give a warning if `torch.Tensor` is found in the dataset.

:::


Let's load the [MNIST](http://yann.lecun.com/exdb/mnist/) dataset 
using the Pytorch Dataset.

In [None]:
from torch.utils.data import TensorDataset
from torchvision.datasets import MNIST

We flatten and cast the PIL image into the `numpy.array`
(brought from [jax official tutorial](https://jax.readthedocs.io/en/latest/notebooks/Neural_Network_and_Data_Loading.html)).

In [None]:
class FlattenAndCast(object):
  def __call__(self, pic):
    return np.ravel(np.array(pic, dtype=float))

We load the pytorch [MNIST](https://pytorch.org/vision/stable/generated/torchvision.datasets.MNIST.html#torchvision.datasets.MNIST) dataset.

In [None]:
mnist_torch = MNIST('/tmp/mnist/', download=True, transform=FlattenAndCast())

Finally, we can wrape the `mnist_torch` as follows.

In [None]:
mnist_ds = TorchDataset(mnist_torch)
assert isinstance(mnist_ds[0][0], np.ndarray)

In [None]:
#| exporti
class HFDataset(Dataset):
    """[Deprecated] A Dataset class that wraps a huggingface Dataset."""
    
    def __init__(
        self, 
        dataset: hf_datasets.Dataset # Huggingface Dataset
    ):
        check_hf_installed()
        # if not isinstance(dataset, hf_datasets.Dataset):
        #     raise TypeError(f"`dataset` must be a huggingface Dataset, "
        #                     f"but got {type(dataset)}")
        # Ensure the dataset is in jax format
        self._ds = dataset.with_format("jax")

    def __len__(self):
        return len(self._ds)

    def __getitem__(self, index):
        return self._ds[index]

`HFDataset` wraps a huggingface dataset. Unlike `TorchDataset`,
`HFDataset` will ensure the input dataset with the format of `jax.DeviceArray`.

Again, we load the [MNIST](http://yann.lecun.com/exdb/mnist/) dataset, 
but load the data via the huggingface datasets.

In [None]:
from datasets import load_dataset

In [None]:
#|output: false
mnist_hf = load_dataset("mnist", split="train")



We wrap the `mnist_hf` as follows:

In [None]:
mnist_ds = HFDataset(mnist_hf)
assert isinstance(mnist_ds[0]['image'], jnp.ndarray)