# Dataloader

> Support various dataloader for loading batches.

In [None]:
#| default_exp core

In [None]:
#| include: false
%load_ext autoreload
%autoreload 2
from ipynb_path import *
from nbdev import show_doc
import warnings
warnings.simplefilter(action='ignore', category=FutureWarning)

In [None]:
#| export
from __future__ import print_function, division, annotations
from jax_dataloader.imports import *
from jax_dataloader.utils import *
from jax_dataloader.datasets import *
from jax_dataloader.loaders import *

In [None]:
#| export
@dataclass(frozen=True)
class DataloaderBackends:
    jax = DataLoaderJax
    pytorch: BaseDataLoader = DataLoaderPytorch
    tensorflow: BaseDataLoader = None
    merlin: BaseDataLoader = None

    __all__ = dict(
        jax=jax, pytorch=pytorch, tensorflow=tensorflow, merlin=merlin
    )

    def __getitem__(self, key):
        return self.__all__[key]

    @property
    def supported(self) -> List[str]:
        return [
            backend for backend, dl_cls in self.__all__.items() if dl_cls is not None
        ]

In [None]:
#| export
def _get_backends() -> List[str]:
    """Return list of supported dataloader backends"""
    return DataloaderBackends().__all__.keys()


def _dispatch_dataloader(
    backend: str # dataloader backend
) -> BaseDataLoader:
    """Return Dataloader class based on given `backend`"""
    backends = DataloaderBackends()
    if not backend in backends.supported:
        raise ValueError(f"backend=`{backend}` is either an invalid backend or not supported yet. "
            f"Should be one of {backends.supported}.")
    
    dl_cls = backends[backend]
    return dl_cls

In [None]:
#| export
def _dispatch_dataset(
    dataset, # Dataset or Pytorch Dataset or HuggingFace Dataset
) -> Dataset:
    if isinstance(dataset, Dataset):
        return dataset
    elif is_torch_dataset(dataset):
        # Give a warning if the dataset is not in numpy format
        if has_pytorch_tensor(dataset[0]):
            warnings.warn("The dataset contains `torch.Tensor`. "
                "Please make sure the dataset is in numpy format.")
        return dataset
    elif is_hf_dataset(dataset):
        return dataset.with_format("jax")
    else:
        raise ValueError(f"dataset must be one of `jax_loader.Dataset`, "
                         "`torch.utils.data.Dataset`, `datasets.Dataset`, "
                         f"but got {type(dataset)}")

In [None]:
#| exporti
def is_jdl_dataset(dataset):
    return isinstance(dataset, Dataset)


def _check_backend_compatibility(dataset, backend: str):
    compatible_set = {
        "jax": [is_jdl_dataset, is_hf_dataset],
        "pytorch": [is_jdl_dataset, is_torch_dataset, is_hf_dataset],
        "tensorflow": [],
        "merlin": [],
    }
    assert all([backend in compatible_set for backend in _get_backends()])

    if not backend in _get_backends():
        raise ValueError(f"backend=`{backend}` is not supported yet. "
            f"Should be one of {_get_backends()}.")
    
    if not any([check_dataset_fn(dataset) for check_dataset_fn in compatible_set[backend]]):
        raise ValueError(f"dataset (type=`{type(dataset)}`) is not compatible with backend='{backend}'. ")
    
    # if backend != "pytorch" and is_torch_dataset(dataset):
    #     raise ValueError(f"dataset (type={type(dataset)}) is a pytorch dataset, "
    #                      "which is only supported by 'pytorch' backend."
    #                      f"However, we got `backend={backend}`, which is not 'pytorch'.")

In [None]:
#| export
def _dispatch_dataset_and_backend(
    dataset, # Dataset or Pytorch Dataset or HuggingFace Dataset
    backend: str # dataloader backend
) -> Tuple[Dataset, BaseDataLoader]:
    """Return Dataset and Dataloader class based on given `dataset` and `backend`"""

    # if backend != "pytorch" and isinstance(dataset, torch_data.Dataset):
    #     raise ValueError(f"dataset (type={type(dataset)}) is a pytorch dataset, "
    #                      "which is only supported by 'pytorch' backend."
    #                      f"However, we got `backend={backend}`, which is not 'pytorch'.")
    _check_backend_compatibility(dataset, backend)
    dataset = _dispatch_dataset(dataset)    
    dl_cls = _dispatch_dataloader(backend)
    return dataset, dl_cls


In [None]:
#| export
class DataLoader:
    """Main Dataloader class to load Numpy data batches"""

    def __init__(
        self,
        dataset, # Dataset or Pytorch Dataset or HuggingFace Dataset
        backend: str, # Dataloader backend
        batch_size: int = 1,  # batch size
        shuffle: bool = False,  # if true, dataloader shuffles before sampling each batch
        drop_last: bool = False, # drop last batches or not
        **kwargs
    ):
        dataset, dl_cls = _dispatch_dataset_and_backend(dataset, backend)
        self.dataloader = dl_cls(
            dataset=dataset, 
            batch_size=batch_size, 
            shuffle=shuffle, 
            drop_last=drop_last,
            **kwargs
        )

    def __len__(self):
        return len(self.dataloader)

    def __next__(self):
        return next(self.dataloader)

    def __iter__(self):
        return self.dataloader.__iter__()

#### A Minimum Example of using Dataloader

We showcase how to use `Dataloader` for training a simple regression model.


In [None]:
from sklearn.datasets import make_regression
import optax
import haiku as hk

In [None]:
X, y = make_regression(n_samples=500, n_features=20)
dataset = ArrayDataset(X, y.reshape(-1, 1))
keys = hk.PRNGSequence(0)

Define `loss`, `step`, `train`:

In [None]:
def loss(w, x, y):
    return jnp.mean(vmap(optax.l2_loss)(x @ w.T, y))

def step(w, x, y):
    lr = 0.1
    grad = jax.grad(loss)(w, x, y)
    w -= lr * grad
    return w

def train(dataloader: DataLoader, key: jax.random.PRNGKey):
    w = jax.random.normal(key, shape=(1, 20))
    n_epochs = 10
    for _ in range(n_epochs):
        for x, y in dataloader:
            w = step(w, x, y)
    return w

def eval(dataloader: DataLoader, w):
    err = []
    for x, y in dataloader:
        err.append(loss(w, x, y))
    return np.mean(err)
    

Train this linear regression model via `DataLoaderJax`:

In [None]:
dataloader = DataLoader(
    dataset, 'jax', batch_size=128, shuffle=True)
w = train(dataloader, next(keys)).block_until_ready()
# assert np.allclose(eval(dataloader, w), 0.)

In [None]:
dataloader = DataLoader(dataset, 'jax', batch_size=200, shuffle=True)
w = train(dataloader, next(keys)).block_until_ready()
# assert np.allclose(eval(dataloader, w), 0.)

Train this linear regression model via `DataLoaderPytorch`:

In [None]:
dataloader = DataLoader(
    dataset, 'pytorch', batch_size=128, shuffle=True)
w = train(dataloader, next(keys)).block_until_ready()
# assert np.allclose(eval(dataloader, w), 0.)