# JAX Dataloader

In [None]:
#| default_exp jax.loaders

In [None]:
#| include: false
%load_ext autoreload
%autoreload 2
from ipynb_path import *
import warnings
warnings.simplefilter(action='ignore', category=FutureWarning)

In [None]:
#| export
from __future__ import print_function, division, annotations
from jax_dataloader.imports import *
from jax_dataloader.datasets import ArrayDataset
from jax_dataloader.loaders import BaseDataLoader
from jax_dataloader.utils import get_config
from jax_dataloader.tests import *
from threading import Thread, Event
from queue import Queue

In [None]:
#| export
def chunk(seq: Sequence, size: int):
    for pos in range(0, len(seq), size):
        yield seq[pos:pos + size]

In [None]:
#| export
class EpochIterator(Thread):
    def __init__(self, data, batch_size: int, indices=None):
        super().__init__()
        self.data = data
        indices = np.arange(len(data)) if indices is None else indices
        batches = chunk(indices, batch_size)
        self.iter_idx = iter(batches)
        # self.current_slot = 0
        self.output_queue = Queue() # TODO: maxsize
        self.terminate_event = Event()
        self.start()

    def run(self):
        try:
            while True:
                result = self.get_data()
                self.output_queue.put(result)
                # self.current_slot += 1
        except StopIteration:
            self.output_queue.put(None)

    def __next__(self):
        result = self.output_queue.get()
        if result is None:
            self.close()
            raise StopIteration()
        return result
    
    def __iter__(self):
        return self
    
    def __del__(self):
        self.close()

    def close(self):
        self.terminate_event.set()

    def get_data(self):
        batch_idx = next(self.iter_idx)
        batch = self.data[batch_idx]
        return batch

In [None]:
class DataLoaderJAX(BaseDataLoader):
    def __init__(
        self, 
        dataset, 
        batch_size: int = 1,  # batch size
        shuffle: bool = False,  # if true, dataloader shuffles before sampling each batch
        num_workers: int = 0,  # how many subprocesses to use for data loading. Ignored.
        drop_last: bool = False,
        **kwargs
    ):
        self.key = jrand.PRNGKey(get_config().global_seed)
        self.dataset = dataset
        self.indices = np.arange(len(dataset))
        self.batch_size = batch_size
        self.shuffle = shuffle
        self.drop_last = drop_last
    
    def __iter__(self):
        if self.shuffle:
            self.indices = jrand.permutation(self.next_key(), self.indices)

        if self.drop_last:
            self.indices = self.indices[:len(self.indices) - len(self.indices) % self.batch_size]
        return EpochIterator(self.dataset, self.batch_size, self.indices)

    def next_key(self):
        self.key, subkey = jrand.split(self.key)
        return subkey

In [None]:
samples = 1280
feats = jnp.arange(samples).repeat(10).reshape(samples, 10)
labels = jnp.arange(samples).reshape(samples, 1)
ds = ArrayDataset(feats, labels)

In [None]:
dl = DataLoaderJAX(ds, batch_size=12, shuffle=True)
for _ in range(8):
    for i, batch in enumerate(dl):
        i, batch[0].shape, batch[1].shape
        # print(batch[0].shape, batch[1].shape)


In [None]:
from keras.trainers.epoch_iterator import EpochIterator

In [None]:
dl = EpochIterator(feats, labels, batch_size=12, shuffle=True)
for _ in range(8):
    for i, (batch) in dl.enumerate_epoch('np'):
        print(batch)
        break


[(array([[ 811,  811,  811,  811,  811,  811,  811,  811,  811,  811],
       [1009, 1009, 1009, 1009, 1009, 1009, 1009, 1009, 1009, 1009],
       [1179, 1179, 1179, 1179, 1179, 1179, 1179, 1179, 1179, 1179],
       [ 542,  542,  542,  542,  542,  542,  542,  542,  542,  542],
       [1124, 1124, 1124, 1124, 1124, 1124, 1124, 1124, 1124, 1124],
       [ 647,  647,  647,  647,  647,  647,  647,  647,  647,  647],
       [1248, 1248, 1248, 1248, 1248, 1248, 1248, 1248, 1248, 1248],
       [ 153,  153,  153,  153,  153,  153,  153,  153,  153,  153],
       [  23,   23,   23,   23,   23,   23,   23,   23,   23,   23],
       [1099, 1099, 1099, 1099, 1099, 1099, 1099, 1099, 1099, 1099],
       [1027, 1027, 1027, 1027, 1027, 1027, 1027, 1027, 1027, 1027],
       [ 656,  656,  656,  656,  656,  656,  656,  656,  656,  656]],
      dtype=int32), array([[ 811],
       [1009],
       [1179],
       [ 542],
       [1124],
       [ 647],
       [1248],
       [ 153],
       [  23],
       [1099],

In [None]:
test_dataloader(DataLoaderJAX, samples=20, batch_size=12)
test_dataloader(DataLoaderJAX, samples=20, batch_size=10)
test_dataloader(DataLoaderJAX, samples=11, batch_size=10)
test_dataloader(DataLoaderJAX, samples=40, batch_size=12)

In [None]:
%%timeit -n 5 -r 3
test_dataloader(DataLoaderJAX, samples=1280, batch_size=10)

1.64 s ± 35.8 ms per loop (mean ± std. dev. of 3 runs, 5 loops each)


In [None]:
i

31

In [None]:
1000 // 32

31