# JAX Dataloader

In [None]:
#| default_exp loaders.jax

In [None]:
#| include: false
%load_ext autoreload
%autoreload 2
from ipynb_path import *
import warnings
warnings.simplefilter(action='ignore', category=FutureWarning)

In [None]:
#| export
from __future__ import print_function, division, annotations
from jax_dataloader.imports import *
from jax_dataloader.datasets import ArrayDataset, JAXDataset
from jax_dataloader.loaders import BaseDataLoader
from jax_dataloader.utils import get_config, asnumpy
from jax_dataloader.tests import *
import jax_dataloader as jdl
from threading import Thread, Event
from queue import Queue

2024-02-01 22:18:26.142014: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2024-02-01 22:18:26.142138: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2024-02-01 22:18:26.151662: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered


In [None]:
#| export
def chunk(seq: Sequence, size: int) -> List[Sequence]:
    return [seq[pos:pos + size] for pos in range(0, len(seq), size)]  


In [None]:
#| export
def EpochIterator(
    data,
    batch_size: int,
    indices: Sequence[int]
):
    for i in range(0, len(indices), batch_size):
        idx = indices[i:i+batch_size]
        yield data[idx]

In [None]:
#| export
class MultiprocessIterator(Thread):
    """[WIP] Multiprocessing Epoch Iterator"""
    
    def __init__(self, data, batch_size: int, indices=None):
        super().__init__()
        self.data = data
        indices = np.arange(len(data)) if indices is None else indices
        batches = chunk(indices, batch_size)
        self.iter_idx = iter(batches)
        self.output_queue = Queue() # TODO: maxsize
        self.terminate_event = Event()
        self.start()

    def run(self):
        try:
            while True:
                result = self.get_data()
                self.output_queue.put(result)
        except StopIteration:
            self.output_queue.put(None)

    def __next__(self):
        result = self.output_queue.get()
        if result is None:
            self.close()
            raise StopIteration()
        return result
    
    def __iter__(self):
        return self
    
    def __del__(self):
        self.close()

    def close(self):
        self.terminate_event.set()

    def get_data(self):
        batch_idx = next(self.iter_idx)
        batch = self.data[batch_idx]
        return batch


In [None]:
#| export
@dispatch
def to_jax_dataset(dataset: JAXDataset):
    if isinstance(dataset, ArrayDataset):
        dataset.asnumpy()
    return dataset

@dispatch
def to_jax_dataset(dataset: HFDataset):
    return dataset.with_format('numpy')

In [None]:
#| export
class DataLoaderJAX(BaseDataLoader):

    @typecheck
    def __init__(
        self, 
        dataset: Union[JAXDataset, HFDataset], 
        batch_size: int = 1,  # batch size
        shuffle: bool = False,  # if true, dataloader shuffles before sampling each batch
        num_workers: int = 0,  # how many subprocesses to use for data loading. Ignored.
        drop_last: bool = False,
        **kwargs
    ):
        self.key = jrand.PRNGKey(get_config().global_seed)
        self.dataset = to_jax_dataset(dataset)
        
        self.indices = np.arange(len(dataset))
        self.batch_size = batch_size
        self.shuffle = shuffle
        self.drop_last = drop_last
    
    def __iter__(self):
        # shuffle (permutation) indices every epoch        
        indices = jrand.permutation(self.next_key(), self.indices).__array__() if self.shuffle else self.indices
        
        if self.drop_last:
            indices = indices[:len(self.indices) - len(self.indices) % self.batch_size]
        return EpochIterator(self.dataset, self.batch_size, indices)
    
    def next_key(self):
        self.key, subkey = jrand.split(self.key)
        return subkey
    
    def __len__(self):
        return len(self.indices) // self.batch_size + int(not self.drop_last)

In [None]:
samples = 1280
batch_size = 12
feats = np.arange(samples).repeat(10).reshape(samples, 10)
labels = np.arange(samples).reshape(samples, 1)
ds = ArrayDataset(feats, labels)
dl = DataLoaderJAX(ds, batch_size=batch_size, shuffle=True)
assert len(dl) == 1280 // 12 + 1
assert len(dl.indices) == 1280

In [None]:
#| hide
test_dataloader(DataLoaderJAX, samples=20, batch_size=12)
test_dataloader(DataLoaderJAX, samples=20, batch_size=10)
test_dataloader(DataLoaderJAX, samples=11, batch_size=10)
test_dataloader(DataLoaderJAX, samples=40, batch_size=12)

In [None]:
#| hide
test_dataloader(DataLoaderJAX, ds_type='hf', samples=40, batch_size=12)

In [None]:
%%timeit -n 5 -r 3
test_dataloader(DataLoaderJAX, samples=1280, batch_size=10)

281 ms ± 27.8 ms per loop (mean ± std. dev. of 3 runs, 5 loops each)
