# Dataset

In [None]:
#| default_exp datasets

In [None]:
#| include: false
%load_ext autoreload
%autoreload 2
from ipynb_path import *
import warnings
warnings.simplefilter(action='ignore', category=FutureWarning)

In [None]:
#| export
from __future__ import print_function, division, annotations
from jax_dataloader.imports import *
from jax_dataloader.utils import asnumpy

In [None]:
#| export
class Dataset:
    """A pytorch-like Dataset class."""

    def __len__(self):
        raise NotImplementedError

    def __getitem__(self, index):
        raise NotImplementedError

In [None]:
#| export
class ArrayDataset(Dataset):
    """Dataset wrapping numpy arrays."""

    def __init__(
        self, 
        *arrays: jax.Array, # Numpy array with same first dimension
        asnumpy: bool = True, # Store arrays as numpy arrays if True; otherwise store as array type of *arrays
    ):
        assert all(arrays[0].shape[0] == arr.shape[0] for arr in arrays), \
            "All arrays must have the same dimension."
        self.arrays = tuple(arrays)
        if asnumpy:
            self.asnumpy()            
    
    def asnumpy(self):
        """Convert all arrays to numpy arrays."""
        self.arrays = tuple(asnumpy(arr) for arr in self.arrays)

    def __len__(self):
        return self.arrays[0].shape[0]

    def __getitem__(self, index):
        return jax.tree_util.tree_map(lambda x: x[index], self.arrays)

This is similar to [torch.utils.data.TensorDataset](https://pytorch.org/docs/stable/data.html#torch.utils.data.TensorDataset), 
but it wrapps numpy arrays.

In [None]:
X = jnp.arange(10000).reshape(1000, 10)
y = jnp.arange(1000)
ds = ArrayDataset(X, y)
assert len(ds) == 1000

We index numpy arrays along the first dimension. Dataset indexing is done via `ds[index]`.

In [None]:
x1, y1 = ds[1] # get the first sample
assert jnp.array_equal(x1, X[1])
assert jnp.array_equal(y1, y[1])

x10, y10 = ds[:10]
assert jnp.array_equal(x10, X[:10])
assert jnp.array_equal(y10, y[:10])

By default, `ArrayDataset` stores arrays as [numpy.array](https://numpy.org/doc/stable/reference/generated/numpy.array.html).

In [None]:
x, _ = ds[:10]
assert isinstance(x, np.ndarray)
assert not isinstance(x, jnp.ndarray)

If you want to store the array type the way you passed, 
you can simply pass `asnumpy=False`.

In [None]:
ds = ArrayDataset(X, y, asnumpy=False)
x, _ = ds[:10]
assert isinstance(x, jnp.ndarray)

In [None]:
#| export
JAXDataset = Dataset