# JAX Dataloader

In [None]:
#| default_exp loaders.jax

In [None]:
#| include: false
%load_ext autoreload
%autoreload 2
from ipynb_path import *
import warnings
warnings.simplefilter(action='ignore', category=FutureWarning)

In [None]:
#| export
from __future__ import print_function, division, annotations
from jax_dataloader.imports import *
from jax_dataloader.datasets import ArrayDataset, JAXDataset
from jax_dataloader.loaders import BaseDataLoader
from jax_dataloader.utils import get_config, asnumpy, Generator
from jax_dataloader.tests import *
import jax_dataloader as jdl
from threading import Thread, Event
from queue import Queue

In [None]:
#| export
def EpochIterator(
    data,
    batch_size: int,
    indices: Sequence[int]
):
    for i in range(0, len(indices), batch_size):
        idx = indices[i:i+batch_size]
        yield data[idx]

In [None]:
#| export
@dispatch
def to_jax_dataset(dataset: JAXDataset):
    if isinstance(dataset, ArrayDataset):
        dataset.asnumpy()
    return dataset

@dispatch
def to_jax_dataset(dataset: HFDataset):
    return dataset.with_format('numpy')

In [None]:
#| export
class DataLoaderJAX(BaseDataLoader):

    @typecheck
    def __init__(
        self, 
        dataset: Union[JAXDataset, HFDataset], 
        batch_size: int = 1,  # batch size
        shuffle: bool = False,  # if true, dataloader shuffles before sampling each batch
        num_workers: int = 0,  # how many subprocesses to use for data loading. Ignored.
        drop_last: bool = False, # if true, drop the last incomplete batch
        generator: Optional[Generator | jax.Array | torch.Generator] = None, # random seed generator
        **kwargs
    ):
        self.dataset = to_jax_dataset(dataset)
        
        self.indices = np.arange(len(dataset))
        self.batch_size = batch_size
        self.shuffle = shuffle
        self.drop_last = drop_last

        # init rng key via generator
        if not isinstance(generator, Generator):
            generator = Generator(generator)
        if generator is None:
            # explicitly set the manual seed of the generator 
            generator = Generator().manual_seed(get_config().global_seed)
        self.key = generator.jax_generator()
    
    def __iter__(self):
        # shuffle (permutation) indices every epoch        
        indices = jrand.permutation(self.next_key(), self.indices).__array__() if self.shuffle else self.indices
        
        if self.drop_last:
            indices = indices[:len(self.indices) - len(self.indices) % self.batch_size]
        return EpochIterator(self.dataset, self.batch_size, indices)
    
    def next_key(self):
        self.key, subkey = jrand.split(self.key)
        return subkey
    
    def __len__(self):
        complete_batches, remainder = divmod(len(self.indices), self.batch_size)
        return complete_batches if self.drop_last else complete_batches + bool(remainder)

# %%

In [None]:
samples = 1280
batch_size = 12
feats = np.arange(samples).repeat(10).reshape(samples, 10)
labels = np.arange(samples).reshape(samples, 1)
ds = ArrayDataset(feats, labels)
dl = DataLoaderJAX(ds, batch_size=batch_size, shuffle=True)
assert len(dl) == 1280 // 12 + 1
assert len(dl.indices) == 1280

In [None]:
samples = 128
batch_size = 128
feats = np.arange(samples).repeat(10).reshape(samples, 10)
labels = np.arange(samples).reshape(samples, 1)
ds = ArrayDataset(feats, labels)
dl = DataLoaderJAX(ds, batch_size=batch_size, shuffle=True, drop_last=True)
assert len(dl) == 1
dl = DataLoaderJAX(ds, batch_size=batch_size, shuffle=True, drop_last=False)
assert len(dl) == 1

In [None]:
#| hide
test_dataloader(DataLoaderJAX, samples=10, batch_size=10)
test_dataloader(DataLoaderJAX, samples=20, batch_size=12)
test_dataloader(DataLoaderJAX, samples=20, batch_size=10)
test_dataloader(DataLoaderJAX, samples=11, batch_size=10)
test_dataloader(DataLoaderJAX, samples=40, batch_size=12)

In [None]:
#| hide
test_dataloader(DataLoaderJAX, ds_type='hf', samples=40, batch_size=12)

In [None]:
%%timeit -n 5 -r 3
test_dataloader(DataLoaderJAX, samples=1280, batch_size=10)

281 ms ± 27.8 ms per loop (mean ± std. dev. of 3 runs, 5 loops each)
