In [None]:
#| default_exp tests

In [None]:
#| include: false
%load_ext autoreload
%autoreload 2
from ipynb_path import *
import warnings
warnings.simplefilter(action='ignore', category=FutureWarning)

In [None]:
#| export
from __future__ import print_function, division, annotations
from jax_dataloader.imports import *
from jax_dataloader.datasets import ArrayDataset

In [None]:
def get_batch(batch):
    if isinstance(batch, dict):
        return batch['feats'], batch['labels']
    else:
        return batch

In [None]:
#| exporti
def test_no_shuffle(cls, ds, batch_size: int, feats, labels):
    dl = cls(ds, batch_size=batch_size, shuffle=False)
    for _ in range(2):
        X_list, Y_list = [], []
        for batch in dl:
            x, y = get_batch(batch)
            X_list.append(x)
            Y_list.append(y)
        _X, _Y = map(jnp.concatenate, (X_list, Y_list))
        assert jnp.array_equal(_X, feats)
        assert jnp.array_equal(_Y, labels)

In [None]:
#| exporti
def test_no_shuffle_drop_last(cls, ds, batch_size: int, feats, labels):
    dl = cls(ds, batch_size=batch_size, shuffle=False, drop_last=True)
    for _ in range(2):
        X_list, Y_list = [], []
        for batch in dl:
            x, y = get_batch(batch)
            X_list.append(x)
            Y_list.append(y)
        _X, _Y = map(jnp.concatenate, (X_list, Y_list))
        last_idx = len(X_list) * batch_size
        assert jnp.array_equal(_X, feats[: last_idx])
        assert jnp.array_equal(_Y, labels[: last_idx])

In [None]:
#| exporti
def test_shuffle(cls, ds, batch_size: int, feats, labels):
    dl = cls(ds, batch_size=batch_size, shuffle=True, drop_last=False)
    last_X, last_Y = jnp.array([]), jnp.array([])
    for _ in range(2):
        X_list, Y_list = [], []
        for batch in dl:
            x, y = get_batch(batch)
            assert jnp.array_equal(x[:, :1], y)
            X_list.append(x)
            Y_list.append(y)
        _X, _Y = map(jnp.concatenate, (X_list, Y_list))
        assert not jnp.array_equal(_X, feats)
        assert not jnp.array_equal(_Y, labels)
        assert jnp.sum(_X) == jnp.sum(feats), \
            f"jnp.sum(_X)={jnp.sum(_X)}, jnp.sum(feats)={jnp.sum(feats)}"
        assert not jnp.array_equal(_X, last_X)
        assert not jnp.array_equal(_Y, last_Y)
        last_X, last_Y = _X, _Y

In [None]:
#| exporti
def test_shuffle_drop_last(cls, ds, batch_size: int, feats, labels):
    dl = cls(ds, batch_size=batch_size, shuffle=True, drop_last=True)
    for _ in range(2):
        X_list, Y_list = [], []
        for batch in dl:
            x, y = get_batch(batch)
            assert jnp.array_equal(x[:, :1], y)
            X_list.append(x)
            Y_list.append(y)
        _X, _Y = map(jnp.concatenate, (X_list, Y_list))
        assert not jnp.array_equal(_X, feats)
        assert not jnp.array_equal(_Y, labels)
        assert len(_X) == len(X_list) * batch_size

In [None]:
#| export
def test_dataloader(cls, ds_type='jax', samples=1000, batch_size=12):
    feats = np.arange(samples).repeat(10).reshape(samples, 10)
    labels = np.arange(samples).reshape(samples, 1)

    if ds_type == 'jax':
        ds = ArrayDataset(feats, labels)
    elif ds_type == 'torch':
        ds = torch.utils.data.TensorDataset(
            torch.from_numpy(feats), torch.from_numpy(labels))
    elif ds_type == 'tf':
        ds = tf.data.Dataset.from_tensor_slices((feats, labels))
    elif ds_type == "hf":
        ds = hf_datasets.Dataset.from_dict({"feats": feats, "labels": labels})
    else:
        raise ValueError(f"Unknown ds_type: {ds_type}")
    
    test_no_shuffle(cls, ds, batch_size, feats, labels)
    test_no_shuffle_drop_last(cls, ds, batch_size, feats, labels)
    test_shuffle(cls, ds, batch_size, feats, labels)
    test_shuffle_drop_last(cls, ds, batch_size, feats, labels)

In [None]:
from jax_dataloader.loaders import DataLoaderJAX

In [None]:
test_dataloader(DataLoaderJAX, ds_type='jax')

In [None]:
# def test_dataloader(dataloader_cls, samples=1000, batch_size=12):
#     feats = jnp.arange(samples).repeat(10).reshape(samples, 10)
#     labels = jnp.arange(samples).reshape(samples, 1)
#     ds = ArrayDataset(feats, labels)
#     # N % batchsize != 0
#     dl = dataloader_cls(ds, batch_size=batch_size, shuffle=False)
#     for _ in range(2):
#         X_list, Y_list = [], []
#         for x, y in dl:
#             X_list.append(x)
#             Y_list.append(y)
#         _X, _Y = map(jnp.concatenate, (X_list, Y_list))
#         assert jnp.array_equal(_X, feats)
#         assert jnp.array_equal(_Y, labels)

#     dl = dataloader_cls(ds, batch_size=batch_size, shuffle=False, drop_last=True)
#     for _ in range(2):
#         X_list, Y_list = [], []
#         for x, y in dl:
#             X_list.append(x)
#             Y_list.append(y)
#         _X, _Y = map(jnp.concatenate, (X_list, Y_list))
#         last_idx = len(X_list) * batch_size
#         assert jnp.array_equal(_X, feats[: last_idx])
#         assert jnp.array_equal(_Y, labels[: last_idx])


#     dl_shuffle = dataloader_cls(ds, batch_size=batch_size, shuffle=True, drop_last=False)
#     last_X, last_Y = jnp.array([]), jnp.array([])
#     for _ in range(2):
#         X_list, Y_list = [], []
#         for x, y in dl_shuffle:
#             assert jnp.array_equal(x[:, :1], y)
#             X_list.append(x)
#             Y_list.append(y)
#         _X, _Y = map(jnp.concatenate, (X_list, Y_list))
#         assert not jnp.array_equal(_X, feats)
#         assert not jnp.array_equal(_Y, labels)
#         assert jnp.sum(_X) == jnp.sum(feats), \
#             f"jnp.sum(_X)={jnp.sum(_X)}, jnp.sum(feats)={jnp.sum(feats)}"
#         assert not jnp.array_equal(_X, last_X)
#         assert not jnp.array_equal(_Y, last_Y)
#         last_X, last_Y = _X, _Y


#     dl_shuffle = dataloader_cls(ds, batch_size=batch_size, shuffle=True, drop_last=True)
#     for _ in range(2):
#         X_list, Y_list = [], []
#         for x, y in dl_shuffle:
#             assert jnp.array_equal(x[:, :1], y)
#             X_list.append(x)
#             Y_list.append(y)
#         _X, _Y = map(jnp.concatenate, (X_list, Y_list))
#         assert not jnp.array_equal(_X, feats)
#         assert not jnp.array_equal(_Y, labels)
#         assert len(_X) == len(X_list) * batch_size