## Loaders

In [None]:
#| default_exp loaders

In [None]:
#| include: false
%load_ext autoreload
%autoreload 2
from ipynb_path import *
from nbdev import show_doc
import warnings
warnings.simplefilter(action='ignore', category=FutureWarning)

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [None]:
#| export
from __future__ import print_function, division, annotations
from jax_dataloader.imports import *
from jax_dataloader.utils import *
from jax_dataloader.datasets import *

In [None]:
#| hide
def test_dataloader(dataloader_cls, samples=1000, batch_size=12):
    feats = jnp.arange(samples).repeat(10).reshape(samples, 10)
    labels = jnp.arange(samples).reshape(samples, 1)
    ds = ArrayDataset(feats, labels)
    # N % batchsize != 0
    dl = dataloader_cls(ds, batch_size=batch_size, shuffle=False)
    for _ in range(2):
        X_list, Y_list = [], []
        for x, y in dl:
            X_list.append(x)
            Y_list.append(y)
        _X, _Y = map(jnp.concatenate, (X_list, Y_list))
        assert jnp.array_equal(_X, feats)
        assert jnp.array_equal(_Y, labels)

    dl = dataloader_cls(ds, batch_size=batch_size, shuffle=False, drop_last=True)
    for _ in range(2):
        X_list, Y_list = [], []
        for x, y in dl:
            X_list.append(x)
            Y_list.append(y)
        _X, _Y = map(jnp.concatenate, (X_list, Y_list))
        last_idx = len(X_list) * batch_size
        assert jnp.array_equal(_X, feats[: last_idx])
        assert jnp.array_equal(_Y, labels[: last_idx])


    dl_shuffle = dataloader_cls(ds, batch_size=batch_size, shuffle=True, drop_last=False)
    for _ in range(2):
        X_list, Y_list = [], []
        for x, y in dl_shuffle:
            assert jnp.array_equal(x[:, :1], y)
            X_list.append(x)
            Y_list.append(y)
        _X, _Y = map(jnp.concatenate, (X_list, Y_list))
        assert not jnp.array_equal(_X, feats)
        assert not jnp.array_equal(_Y, labels)
        assert jnp.sum(_X) == jnp.sum(feats), \
            f"jnp.sum(_X)={jnp.sum(_X)}, jnp.sum(feats)={jnp.sum(feats)}"


    dl_shuffle = dataloader_cls(ds, batch_size=batch_size, shuffle=True, drop_last=True)
    for _ in range(2):
        X_list, Y_list = [], []
        for x, y in dl_shuffle:
            assert jnp.array_equal(x[:, :1], y)
            X_list.append(x)
            Y_list.append(y)
        _X, _Y = map(jnp.concatenate, (X_list, Y_list))
        assert not jnp.array_equal(_X, feats)
        assert not jnp.array_equal(_Y, labels)
        assert len(_X) == len(X_list) * batch_size


In [None]:
#| export
class BaseDataLoader:
    """Dataloader Interface"""
    
    def __init__(
        self, 
        dataset: Dataset, 
        batch_size: int = 1,  # batch size
        shuffle: bool = False,  # if true, dataloader shuffles before sampling each batch
        drop_last: bool = False,
        **kwargs
    ):
        pass

    def __len__(self):
        raise NotImplementedError
    
    def __next__(self):
        raise NotImplementedError
    
    def __iter__(self):
        raise NotImplementedError

## JAX Dataloader


In [None]:
#| export
class DataLoaderJax(BaseDataLoader):
    """Dataloder in Vanilla Jax"""

    def __init__(
        self, 
        dataset,
        batch_size: int = 1,  # batch size
        shuffle: bool = False,  # if true, dataloader shuffles before sampling each batch
        drop_last: bool = False, # drop last batches or not
        **kwargs
    ):
        self.dataset = dataset
        self.batch_size = batch_size
        self.shuffle = shuffle
        self.drop_last = drop_last

        self.keys = PRNGSequence(seed=get_config().global_seed) \
            if hk is None else hk.PRNGSequence(get_config().global_seed)
        self.data_len = len(dataset)  # Length of the dataset
        self.indices = jnp.arange(self.data_len) # available indices in the dataset
        self.pose = 0  # record the current position in the dataset
        self._shuffle()

    def _shuffle(self):
        if self.shuffle:
            self.indices = jax.random.permutation(next(self.keys), self.indices)
        
    def _stop_iteration(self):
        self.pose = 0
        self._shuffle()
        raise StopIteration

    def __len__(self):
        if self.drop_last:
            batches = len(self.dataset) // self.batch_size  # get the floor of division
        else:
            batches = -(len(self.dataset) // -self.batch_size)  # get the ceil of division
        return batches

    def __next__(self):
        if self.pose + self.batch_size <= self.data_len:
            batch_indices = self.indices[self.pose: self.pose + self.batch_size]
            batch_data = self.dataset[batch_indices]
            self.pose += self.batch_size
            return batch_data
        elif self.pose < self.data_len and not self.drop_last:
            batch_indices = self.indices[self.pose:]
            batch_data = self.dataset[batch_indices]
            self.pose += self.batch_size
            return batch_data
        else:
            self._stop_iteration()

    def __iter__(self):
        return self

In [None]:
#| hide
test_dataloader(DataLoaderJax, samples=20, batch_size=12)
test_dataloader(DataLoaderJax, samples=20, batch_size=10)
test_dataloader(DataLoaderJax, samples=11, batch_size=10)

## Pytorch Dataloader

Use `Pytorch` to load batches. It requires [pytorch](https://pytorch.org/get-started/) to be installed.

In [None]:
#| exporti
# adapted from https://jax.readthedocs.io/en/latest/notebooks/Neural_Network_and_Data_Loading.html
def _numpy_collate(batch):
    if isinstance(batch[0], (np.ndarray, jax.Array)):
        return np.stack(batch)
    elif isinstance(batch[0], (tuple, list)):
        transposed = zip(*batch)
        return [_numpy_collate(samples) for samples in transposed]
    elif isinstance(batch[0], dict):
        return {key: _numpy_collate([d[key] for d in batch]) for key in batch[0]}
    else:
        return np.array(batch)

def _convert_dataset_pytorch(dataset: Dataset):
    class DatasetPytorch(torch_data.Dataset):
        def __init__(self, dataset: Dataset): self.dataset = dataset
        def __len__(self): return len(self.dataset)
        def __getitem__(self, idx): return self.dataset[idx]
    
    return DatasetPytorch(dataset)

In [None]:
#| export
class DataLoaderPytorch(BaseDataLoader):
    """Pytorch Dataloader"""
    def __init__(
        self, 
        dataset,
        batch_size: int = 1,  # Batch size
        shuffle: bool = False,  # If true, dataloader shuffles before sampling each batch
        drop_last: bool = False, # Drop last batch or not
        **kwargs
    ):
        super().__init__(dataset, batch_size, shuffle, drop_last)
        check_pytorch_installed()
        from torch.utils.data import BatchSampler, RandomSampler, SequentialSampler
        
        if 'sampler' in kwargs:
            warnings.warn("`sampler` is currently not supported. We will ignore it and use `shuffle` instead.")
            del kwargs['sampler']

        sampler = RandomSampler(dataset) if shuffle else SequentialSampler(dataset)
        batch_sampler = BatchSampler(sampler, batch_size=batch_size, drop_last=drop_last)

        self.dataloader = torch_data.DataLoader(
            dataset, 
            batch_sampler=batch_sampler,
            # batch_size=batch_size, 
            # shuffle=shuffle, 
            # drop_last=drop_last,
            collate_fn=_numpy_collate,
            **kwargs
        ) 

    def __len__(self):
        return len(self.dataloader)

    def __next__(self):
        return next(self.dataloader)

    def __iter__(self):
        return self.dataloader.__iter__()

In [None]:
#| hide
test_dataloader(DataLoaderPytorch, samples=20, batch_size=12)
test_dataloader(DataLoaderPytorch, samples=20, batch_size=10)
test_dataloader(DataLoaderPytorch, samples=11, batch_size=10)