In [None]:
#| default_exp loaders.torch

In [None]:
#| include: false
%load_ext autoreload
%autoreload 2
from ipynb_path import *
import warnings
warnings.simplefilter(action='ignore', category=FutureWarning)

In [None]:
#| export
from __future__ import print_function, division, annotations
from jax_dataloader.imports import *
from jax_dataloader.loaders import BaseDataLoader
from jax_dataloader.datasets import Dataset, ArrayDataset, JAXDataset
from jax_dataloader.utils import check_pytorch_installed
from jax_dataloader.tests import *
from jax.tree_util import tree_map


In [None]:
#| hide
from fastcore.test import *

## `Pytorch`-backed Dataloader

Use `Pytorch` to load batches. It requires [pytorch](https://pytorch.org/get-started/) to be installed.

In [None]:
#| export
# adapted from https://jax.readthedocs.io/en/latest/notebooks/Neural_Network_and_Data_Loading.html
def _numpy_collate(batch):
  return tree_map(np.asarray, torch_data.default_collate(batch))

In [None]:
#| export
@dispatch
def to_torch_dataset(dataset: JAXDataset) -> torch_data.Dataset:
    class DatasetPytorch(torch_data.Dataset):
        def __init__(self, dataset: Dataset): self.dataset = dataset
        def __len__(self): return len(self.dataset)
        def __getitem__(self, idx): return self.dataset[idx]
    
    return DatasetPytorch(dataset)

@dispatch
def to_torch_dataset(dataset: TorchDataset):
    return dataset

@dispatch
def to_torch_dataset(dataset: HFDataset):
    return dataset.with_format("jax")

In [None]:
#| export
class DataLoaderPytorch(BaseDataLoader):
    """Pytorch Dataloader"""
    
    @typecheck
    def __init__(
        self, 
        dataset: Union[JAXDataset, TorchDataset, HFDataset],
        batch_size: int = 1,  # Batch size
        shuffle: bool = False,  # If true, dataloader shuffles before sampling each batch
        drop_last: bool = False, # Drop last batch or not
        **kwargs
    ):
        super().__init__(dataset, batch_size, shuffle, drop_last)
        check_pytorch_installed()
        from torch.utils.data import BatchSampler, RandomSampler, SequentialSampler

        if 'sampler' in kwargs:
            warnings.warn("`sampler` is currently not supported. We will ignore it and use `shuffle` instead.")
            del kwargs['sampler']

        dataset = to_torch_dataset(dataset)
        sampler = RandomSampler(dataset) if shuffle else SequentialSampler(dataset)
        batch_sampler = BatchSampler(sampler, batch_size=batch_size, drop_last=drop_last)

        self.dataloader = torch_data.DataLoader(
            dataset, 
            batch_sampler=batch_sampler,
            # batch_size=batch_size, 
            # shuffle=shuffle, 
            # drop_last=drop_last,
            collate_fn=_numpy_collate,
            **kwargs
        )

    def __len__(self):
        return len(self.dataloader)

    def __next__(self):
        return next(self.dataloader)

    def __iter__(self):
        return self.dataloader.__iter__()

In [None]:
samples = 1280
batch_size = 12
feats = np.arange(samples).repeat(10).reshape(samples, 10)
labels = np.arange(samples).reshape(samples, 1)

ds_torch = torch_data.TensorDataset(torch.from_numpy(feats), torch.from_numpy(labels))
ds_array = ArrayDataset(feats, labels)

In [None]:
dl_1 = DataLoaderPytorch(ds_torch, batch_size=batch_size, shuffle=True)

for _ in range(10):
    for (x, y) in dl_1: 
        assert isinstance(x, np.ndarray)

dl_2 = DataLoaderPytorch(ds_array, batch_size=batch_size, shuffle=True)
for (x, y) in dl_2: 
    assert isinstance(x, np.ndarray)

In [None]:
#| hide
#| torch
test_dataloader(DataLoaderPytorch, samples=20, batch_size=12)
test_dataloader(DataLoaderPytorch, samples=20, batch_size=10)
test_dataloader(DataLoaderPytorch, samples=11, batch_size=10)