# Dataloader

> Support various dataloader for loading batches.

In [None]:
#| default_exp core

In [None]:
#| include: false
%load_ext autoreload
%autoreload 2
from nbdev import show_doc
import warnings
warnings.simplefilter(action='ignore', category=FutureWarning)

In [None]:
#| export
from __future__ import annotations
import numpy as np
from typing import Callable, Union, List, Tuple, Dict, Any, Optional, TypeVar, Generic, Iterable, Sequence, Iterator
import jax
from jax import vmap, grad, jit, numpy as jnp
from jax.random import PRNGKey
from abc import ABC
from dataclasses import dataclass
import collections

In [None]:
#| export
try: import torch.utils.data as torch_data
except ModuleNotFoundError: torch_data = None

try: import haiku as hk 
except ModuleNotFoundError: hk = None

In [None]:
#| exporti
#| hide
@dataclass
class Config:
    rng_reserve_size: int
    global_seed: int

    @classmethod
    def default(cls) -> Config:
        return cls(rng_reserve_size=1, global_seed=42)

In [None]:
#| exporti
main_config = Config.default()

In [None]:
#| exporti
def get_config() -> Config:
    return main_config

In [None]:
#| export
class PRNGSequence(Iterator[PRNGKey]):
    """An Interator of Jax PRNGKey (minimal version of haiku.PRNGSequence)."""

    def __init__(self, seed: int):
        self._key = jax.random.PRNGKey(seed)
        self._subkeys = collections.deque()

    def reserve(self, num):
        """Splits additional ``num`` keys for later use."""
        if num > 0:
            new_keys = tuple(jax.random.split(self._key, num + 1))
            self._key = new_keys[0]
            self._subkeys.extend(new_keys[1:])
            
    def __next__(self):
        if not self._subkeys:
            self.reserve(get_config().rng_reserve_size)
        return self._subkeys.popleft()

In [None]:
#| export
class Dataset:
    """A simple pytorch-like Numpy Dataset."""
    
    def __init__(self, *arrays: jnp.DeviceArray):
        assert all(arrays[0].shape[0] == arr.shape[0] for arr in arrays), \
            "All arrays must have the same dimension."
        self.arrays = arrays

    def __len__(self):
        return self.arrays[0].shape[0]

    def __getitem__(self, idx):
        return tuple(arr[idx] for arr in self.arrays)

In [None]:
keys = PRNGSequence(seed=0)
X = jax.random.normal(next(keys), shape=(1000, 10))
y = jax.random.normal(next(keys), shape=(1000, ))
ds = Dataset(X, y)

assert len(ds) == 1000

Indexing `Dataset` using `ds[idx]`

In [None]:
x1, y1 = ds[1]
assert jnp.array_equal(x1, X[1])
assert jnp.array_equal(y1, y[1])

In [None]:
#| export
class BaseDataLoader(ABC):
    """Dataloader Interface"""
    def __init__(
        self, 
        dataset,
        batch_size: int = 1,  # batch size
        shuffle: bool = False,  # if true, dataloader shuffles before sampling each batch
        drop_last: bool = False,
        **kwargs
    ):
        pass 
    
    def __len__(self):
        raise NotImplementedError
    
    def __next__(self):
        raise NotImplementedError
    
    def __iter__(self):
        raise NotImplementedError

## Jax Dataloader

In [None]:
#| export
class DataLoaderJax(BaseDataLoader):
    """Dataloder in Vanilla Jax"""

    def __init__(
        self, 
        dataset: Dataset,
        batch_size: int = 1,  # batch size
        shuffle: bool = False,  # if true, dataloader shuffles before sampling each batch
        drop_last: bool = False, # drop last batches or not
        **kwargs
    ):
        self.dataset = dataset
        self.batch_size = batch_size
        self.shuffle = shuffle
        self.drop_last = drop_last

        self.keys = PRNGSequence(seed=get_config().global_seed) \
            if hk is None else hk.PRNGSequence(get_config().global_seed)
        self.data_len = len(dataset)  # Length of the dataset
        self.indices = jnp.arange(self.data_len) # available indices in the dataset
        self.pose = 0  # record the current position in the dataset
        self._shuffle()

    def _shuffle(self):
        if self.shuffle:
            self.indices = jax.random.permutation(next(self.keys), self.indices)
        
    def _stop_iteration(self):
        self.pose = 0
        self._shuffle()
        raise StopIteration

    def __len__(self):
        if self.drop_last:
            batches = len(self.dataset) // self.batch_size  # get the floor of division
        else:
            batches = -(len(self.dataset) // -self.batch_size)  # get the ceil of division
        return batches

    def __next__(self):
        if self.pose + self.batch_size <= self.data_len:
            batch_indices = self.indices[self.pose: self.pose + self.batch_size]
            batch_data = self.dataset[batch_indices, ...]
            self.pose += self.batch_size
            return batch_data
        elif self.pose < self.data_len and not self.drop_last:
            batch_indices = self.indices[self.pose:]
            batch_data = self.dataset[batch_indices, ...]
            self.pose += self.batch_size
            return batch_data
        else:
            self._stop_iteration()

    def __iter__(self):
        return self

In [None]:
def test_dataloader(dataloader_cls, samples=1000, batch_size=12):
    feats = jnp.arange(samples).repeat(10).reshape(samples, 10)
    labels = jnp.arange(samples).reshape(samples, 1)
    ds = Dataset(feats, labels)
    # N % batchsize != 0
    dl = dataloader_cls(ds, batch_size=batch_size, shuffle=False)
    for _ in range(2):
        X_list, Y_list = [], []
        for x, y in dl:
            X_list.append(x)
            Y_list.append(y)
        _X, _Y = map(jnp.concatenate, (X_list, Y_list))
        assert jnp.array_equal(_X, feats)
        assert jnp.array_equal(_Y, labels)

    dl = dataloader_cls(ds, batch_size=batch_size, shuffle=False, drop_last=True)
    for _ in range(2):
        X_list, Y_list = [], []
        for x, y in dl:
            X_list.append(x)
            Y_list.append(y)
        _X, _Y = map(jnp.concatenate, (X_list, Y_list))
        last_idx = len(X_list) * batch_size
        assert jnp.array_equal(_X, feats[: last_idx])
        assert jnp.array_equal(_Y, labels[: last_idx])


    dl_shuffle = dataloader_cls(ds, batch_size=batch_size, shuffle=True, drop_last=False)
    for _ in range(2):
        X_list, Y_list = [], []
        for x, y in dl_shuffle:
            assert jnp.array_equal(x[:, :1], y)
            X_list.append(x)
            Y_list.append(y)
        _X, _Y = map(jnp.concatenate, (X_list, Y_list))
        assert not jnp.array_equal(_X, feats)
        assert not jnp.array_equal(_Y, labels)
        assert jnp.sum(_X) == jnp.sum(feats), \
            f"jnp.sum(_X)={jnp.sum(_X)}, jnp.sum(feats)={jnp.sum(feats)}"


    dl_shuffle = dataloader_cls(ds, batch_size=batch_size, shuffle=True, drop_last=True)
    for _ in range(2):
        X_list, Y_list = [], []
        for x, y in dl_shuffle:
            assert jnp.array_equal(x[:, :1], y)
            X_list.append(x)
            Y_list.append(y)
        _X, _Y = map(jnp.concatenate, (X_list, Y_list))
        assert not jnp.array_equal(_X, feats)
        assert not jnp.array_equal(_Y, labels)
        assert len(_X) == len(X_list) * batch_size


In [None]:
#| hide
test_dataloader(DataLoaderJax, samples=1000, batch_size=12)
test_dataloader(DataLoaderJax, samples=1000, batch_size=10)
test_dataloader(DataLoaderJax, samples=1001, batch_size=10)

## Pytorch Dataloader

Use `Pytorch` to load batches. It requires [pytorch](https://pytorch.org/get-started/) to be installed.

In [None]:
#| exporti
# copy from https://jax.readthedocs.io/en/latest/notebooks/Neural_Network_and_Data_Loading.html#data-loading-with-pytorch
def _numpy_collate(batch):
    if isinstance(batch[0], np.ndarray):
        return np.stack(batch)
    elif isinstance(batch[0], (tuple, list)):
        transposed = zip(*batch)
        return [_numpy_collate(samples) for samples in transposed]
    else:
        return np.array(batch)

def _convert_dataset_pytorch(dataset: Dataset):
    class DatasetPytorch(torch_data.Dataset):
        def __init__(self, dataset: Dataset): self.dataset = dataset
        def __len__(self): return len(self.dataset)
        def __getitem__(self, idx): return self.dataset[idx]
    
    return DatasetPytorch(dataset)

In [None]:
#| export
class DataLoaderPytorch(BaseDataLoader):
    """Pytorch Dataloader"""
    def __init__(
        self, 
        dataset: Dataset,
        batch_size: int = 1,  # batch size
        shuffle: bool = False,  # if true, dataloader shuffles before sampling each batch
        drop_last: bool = False, # drop last batch or not
        **kwargs
    ):
        if torch_data is None:
            raise ModuleNotFoundError("`pytorch` library needs to be installed. Try `pip install torch`."
            "Please refer to pytorch documentation for details: https://pytorch.org/get-started/.")
        
        dataset = _convert_dataset_pytorch(dataset)
        self.dataloader = torch_data.DataLoader(
            dataset, 
            batch_size=batch_size, 
            shuffle=shuffle, 
            drop_last=drop_last,
            collate_fn=_numpy_collate,
            **kwargs
        ) 

    def __len__(self):
        return len(self.dataloader)

    def __next__(self):
        return next(self.dataloader)

    def __iter__(self):
        return self.dataloader.__iter__()

In [None]:
#| hide
test_dataloader(DataLoaderPytorch, samples=1000, batch_size=12)
test_dataloader(DataLoaderPytorch, samples=1000, batch_size=10)
test_dataloader(DataLoaderPytorch, samples=1001, batch_size=10)

## Main Dataloader Class

In [None]:
#| export
@dataclass(frozen=True)
class DataloaderBackends:
    jax: BaseDataLoader = DataLoaderJax
    pytorch: BaseDataLoader = DataLoaderPytorch
    tensorflow: BaseDataLoader = None
    merlin: BaseDataLoader = None

    __all__ = dict(
        jax=jax, pytorch=pytorch, tensorflow=tensorflow, merlin=merlin
    )

    def __getitem__(self, key):
        return self.__all__[key]

    @property
    def supported(self) -> List[str]:
        return [
            backend for backend, dl_cls in self.__all__.items() if dl_cls is not None
        ]

In [None]:
#| exporti
def _dispatch_dataloader(
    backend: str # dataloader backend
) -> BaseDataLoader:
    """Return Dataloader class based on given `backend`"""
    backends = DataloaderBackends()
    if not backend in backends.supported:
        raise ValueError(f"backend=`{backend}` is either an invalid backend or not supported yet. "
            f"Should be one of {backends.supported}.")
    
    dl_cls = backends[backend]
    return dl_cls


In [None]:
show_doc(_dispatch_dataloader)

---

### _dispatch_dataloader

>      _dispatch_dataloader (backend:str)

Return Dataloader class based on given `backend`

|    | **Type** | **Details** |
| -- | -------- | ----------- |
| backend | str | dataloader backend |
| **Returns** | **BaseDataLoader** |  |

In [None]:
#| export
class DataLoader:
    """Main Dataloader class to load Numpy data batches"""
    def __init__(
        self,
        dataset: Dataset,
        backend: str, # Dataloader backend
        batch_size: int = 1,  # batch size
        shuffle: bool = False,  # if true, dataloader shuffles before sampling each batch
        drop_last: bool = False, # drop last batches or not
        **kwargs
    ):
        dataloader_cls = _dispatch_dataloader(backend)
        self.dataloader = dataloader_cls(
            dataset=dataset, 
            batch_size=batch_size, 
            shuffle=shuffle, 
            drop_last=drop_last,
            **kwargs
        )

    def __len__(self):
        return len(self.dataloader)

    def __next__(self):
        return next(self.dataloader)

    def __iter__(self):
        return self.dataloader.__iter__()

#### A Minimum Example of using Dataloader

We showcase how to use `Dataloader` for training a simple regression model.


In [None]:
from sklearn.datasets import make_regression
import optax
import haiku as hk

In [None]:
X, y = make_regression(n_samples=10000, n_features=20)
dataset = Dataset(X, y.reshape(-1, 1))
keys = hk.PRNGSequence(0)

Define `loss`, `step`, `train`:

In [None]:
def loss(w, x, y):
    return jnp.mean(vmap(optax.l2_loss)(x @ w.T, y))

def step(w, x, y):
    lr = 0.1
    grad = jax.grad(loss)(w, x, y)
    w -= lr * grad
    return w

def train(dataloader: DataLoader, key: jax.random.PRNGKey):
    w = jax.random.normal(key, shape=(1, 20))
    n_epochs = 10
    for _ in range(n_epochs):
        for x, y in dataloader:
            w = step(w, x, y)
    return w

def eval(dataloader: DataLoader, w):
    err = []
    for x, y in dataloader:
        err.append(loss(w, x, y))
    return np.mean(err)
    

Train this linear regression model via `DataLoaderJax`:

In [None]:
dataloader = DataLoader(
    dataset, 'jax', batch_size=128, shuffle=True)
w = train(dataloader, next(keys)).block_until_ready()
assert np.allclose(eval(dataloader, w), 0.)


In [None]:
dataloader = DataLoader(dataset, 'jax', batch_size=200, shuffle=True)
w = train(dataloader, next(keys)).block_until_ready()
assert np.allclose(eval(dataloader, w), 0.)

Train this linear regression model via `DataLoaderPytorch`:

In [None]:
dataloader = DataLoader(
    dataset, 'pytorch', batch_size=128, shuffle=True)
w = train(dataloader, next(keys)).block_until_ready()
assert np.allclose(eval(dataloader, w), 0.)