# Grain Dataloader

In [None]:
#| default_exp loaders.grain

In [None]:
#| include: false
%load_ext autoreload
%autoreload 2
from ipynb_path import *
import warnings
warnings.simplefilter(action='ignore', category=FutureWarning)

In [None]:
#| export
from __future__ import print_function, division, annotations
from jax_dataloader.imports import *
from jax_dataloader.datasets import ArrayDataset, JAXDataset
from jax_dataloader.loaders import BaseDataLoader
from jax_dataloader.utils import get_config
from jax_dataloader.tests import *
import jax_dataloader as jdl

In [None]:
#| export
class DataLoaderGrain(BaseDataLoader):

    # @typecheck
    def __init__(
        self, 
        dataset: Union[JAXDataset, TorchDataset, HFDataset],
        batch_size: int = 1,  # Batch size
        shuffle: bool = False,  # If true, dataloader shuffles before sampling each batch
        num_workers: int = 0, # Number of workers to use
        drop_last: bool = False, # Drop last batch or not
        **kwargs
    ):

        sampler = grain.IndexSampler(
            num_records=len(dataset),
            shuffle=shuffle,
            seed=get_config().global_seed,
            shard_options=grain.NoSharding()
        )
        operations = (grain.Batch(batch_size, drop_remainder=drop_last),)
        self.dataloader = grain.DataLoader(
            data_source=dataset,
            sampler=sampler,
            operations=operations,
            worker_count=num_workers
        )

    def __next__(self):
        return next(self.dataloader)

    def __iter__(self):
        return self.dataloader.__iter__()

In [None]:
#| hide
# test_dataloader(DataLoaderGrain, samples=20, batch_size=12, test_len=False)
# test_dataloader(DataLoaderGrain, samples=20, batch_size=10, test_len=False)
# test_dataloader(DataLoaderGrain, samples=11, batch_size=10, test_len=False)
# test_dataloader(DataLoaderGrain, samples=40, batch_size=12, test_len=False)