Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 1 addition & 8 deletions merlin/dataloader/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,14 +214,7 @@ def array_lib(self):
return np

def __len__(self):
"""Number of batches in the Sequence.

Note: This also resets the loader state.
Required because of the calls to `__getitem__`
from keras prior to the start of the main loop
through the loader.
"""
ArrayLoaderBase.stop(self)
"""Number of batches in the dataloader."""
return ArrayLoaderBase.__len__(self)

def __getitem__(self, index):
Expand Down
11 changes: 11 additions & 0 deletions merlin/dataloader/tensorflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,17 @@ def __init__(
convert_col, _to_dlpack_fn=_to_dlpack_fn, _from_dlpack_fn=_from_dlpack_fn, _unsafe=True
)

def __len__(self):
"""Number of batches in the Sequence.

Note: This also resets the loader state.
Required because of the calls to `__getitem__`
from keras prior to the start of the main loop
through the loader.
"""
ArrayLoader.stop(self)
return ArrayLoader.__len__(self)

def __getitem__(self, index):
"""Gets batch at position `index`.

Expand Down
25 changes: 25 additions & 0 deletions tests/unit/dataloader/test_torch_dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,13 +33,38 @@
# If pytorch isn't installed skip these tests. Note that the
# torch_dataloader import needs to happen after this line
torch = pytest.importorskip("torch") # noqa
from torch.utils.data import DataLoader, IterableDataset # noqa

import merlin.dataloader.torch as torch_dataloader # noqa

from merlin.dataloader.torch import ( # noqa isort:skip
Loader as torch_loader,
)


def test_iterable_dataset():
df = pd.DataFrame({"feature": [1, 2, 3], "target": [0, 1, 0]})
dataset = Dataset(df)
dataset.schema["target"] = dataset.schema["target"].with_tags(Tags.TARGET)
iterable_dataset = torch_dataloader.Loader(dataset, batch_size=1)
assert isinstance(iterable_dataset, IterableDataset)


def test_calling_len_during_iteration():
df = pd.DataFrame({"feature": [1, 2, 3], "target": [0, 1, 0]})
dataset = Dataset(df)
dataset.schema["target"] = dataset.schema["target"].with_tags(Tags.TARGET)
iterable_dataset = torch_dataloader.Loader(dataset, batch_size=1)
torch_data_loader = DataLoader(iterable_dataset)
batches = 0
for i, batch in enumerate(torch_data_loader):
len(torch_data_loader)
batches += 1
if i > 5:
break
assert batches == 3


@pytest.mark.parametrize("shape", [(), (1,), (2,), (3, 4)])
@pytest.mark.parametrize("num_cols", [1, 2])
def test_fixed_column(shape, num_cols):
Expand Down