# Core API

> Support various dataloader for loading batches.

In [None]:
#| default_exp core

In [None]:
#| include: false
%load_ext autoreload
%autoreload 2
from ipynb_path import *
from nbdev import show_doc
import warnings
warnings.simplefilter(action='ignore', category=FutureWarning)

In [None]:
#| export
from __future__ import print_function, division, annotations
from jax_dataloader.imports import *
from jax_dataloader.utils import *
from jax_dataloader.datasets import *
from jax_dataloader.loaders import *

2023-12-26 15:13:36.437449: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2023-12-26 15:13:36.437528: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2023-12-26 15:13:36.439236: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered


In [None]:
#| export
SUPPORTED_DATASETS = [
    JAXDataset,
    TorchDataset,
    TFDataset,
    HFDataset
]

In [None]:
#| export
@dataclass(frozen=True)
class DataloaderBackends:
    jax = DataLoaderJAX
    pytorch: BaseDataLoader = DataLoaderPytorch
    tensorflow: BaseDataLoader = DataLoaderTensorflow
    merlin: BaseDataLoader = None

    __all__ = dict(
        jax=jax, pytorch=pytorch, tensorflow=tensorflow, merlin=merlin
    )

    def __getitem__(self, key):
        return self.__all__[key]

    @property
    def supported(self) -> List[str]:
        return [
            backend for backend, dl_cls in self.__all__.items() if dl_cls is not None
        ]

In [None]:
#| export
def _get_backends() -> List[str]:
    """Return list of supported dataloader backends"""
    return DataloaderBackends().__all__.keys()


def _dispatch_dataloader(
    backend: str # dataloader backend
) -> BaseDataLoader:
    """Return Dataloader class based on given `backend`"""
    backends = DataloaderBackends()
    if not backend in backends.supported:
        raise ValueError(f"backend=`{backend}` is either an invalid backend or not supported yet. "
            f"Should be one of {backends.supported}.")
    
    dl_cls = backends[backend]
    return dl_cls

In [None]:
#| export
def _check_backend_compatibility(ds, backend: str):
    return DataLoader(ds, backend=backend)

In [None]:
#| export
def get_backend_compatibilities() -> dict[str, list[type]]:

    ds = {
        JAXDataset: ArrayDataset(np.array([1,2,3])),
        TorchDataset: torch_data.Dataset(),
        TFDataset: tf.data.Dataset.from_tensor_slices(np.array([1,2,3])),
        HFDataset: hf_datasets.Dataset.from_dict({'a': [1,2,3]})
    }
    assert len(ds) == len(SUPPORTED_DATASETS)
    backends = {b: [] for b in _get_backends()}
    for b in _get_backends():
        for name, dataset in ds.items():
            try:
                _check_backend_compatibility(dataset, b)
                backends[b].append(name)
            except:
                pass

    return backends

In [None]:
#| export
class DataLoader:
    """Main Dataloader class to load Numpy data batches"""

    def __init__(
        self,
        dataset, # Dataset or Pytorch Dataset or HuggingFace Dataset
        backend: str, # Dataloader backend
        batch_size: int = 1,  # batch size
        shuffle: bool = False,  # if true, dataloader shuffles before sampling each batch
        drop_last: bool = False, # drop last batches or not
        **kwargs
    ):
        dl_cls = _dispatch_dataloader(backend)
        self.dataloader = dl_cls(
            dataset=dataset, 
            batch_size=batch_size, 
            shuffle=shuffle, 
            drop_last=drop_last,
            **kwargs
        )

    def __len__(self):
        return len(self.dataloader)

    def __next__(self):
        return next(self.dataloader)

    def __iter__(self):
        return iter(self.dataloader)

#### A Minimum Example of using Dataloader

We showcase how to use `Dataloader` for training a simple regression model.


In [None]:
from sklearn.datasets import make_regression
import optax
import haiku as hk

In [None]:
X, y = make_regression(n_samples=500, n_features=20)
dataset = ArrayDataset(X, y.reshape(-1, 1))
keys = hk.PRNGSequence(0)

An NVIDIA GPU may be present on this machine, but a CUDA-enabled jaxlib is not installed. Falling back to cpu.


Define `loss`, `step`, `train`:

In [None]:
def loss(w, x, y):
    return jnp.mean(vmap(optax.l2_loss)(x @ w.T, y))

def step(w, x, y):
    lr = 0.1
    grad = jax.grad(loss)(w, x, y)
    w -= lr * grad
    return w

def train(dataloader: DataLoader, key: jax.random.PRNGKey):
    w = jrand.normal(key, shape=(1, 20))
    n_epochs = 10
    for _ in range(n_epochs):
        for x, y in dataloader:
            w = step(w, x, y)
    return w

def eval(dataloader: DataLoader, w):
    err = []
    for x, y in dataloader:
        err.append(loss(w, x, y))
    return np.mean(err)
    

Train this linear regression model via `DataLoaderJAX`:

In [None]:
dataloader = DataLoader(
    dataset, 'jax', batch_size=128, shuffle=True)
w = train(dataloader, next(keys)).block_until_ready()
# assert np.allclose(eval(dataloader, w), 0.)

In [None]:
dataloader = DataLoader(dataset, 'jax', batch_size=200, shuffle=True)
w = train(dataloader, next(keys)).block_until_ready()
# assert np.allclose(eval(dataloader, w), 0.)

Train this linear regression model via `pytorch` backend:

In [None]:
#| torch
dataloader = DataLoader(
    dataset, 'pytorch', batch_size=128, shuffle=True)
w = train(dataloader, next(keys)).block_until_ready()
# assert np.allclose(eval(dataloader, w), 0.)

Train this linear regression model via `jax` backend:

In [None]:
#| tf
dataloader = DataLoader(
    dataset, 'tensorflow', batch_size=128, shuffle=True)
w = train(dataloader, next(keys)).block_until_ready()
# assert np.allclose(eval(dataloader, w), 0.)