diff --git a/merlin/dataloader/array.py b/merlin/dataloader/array.py index 2f57186d..0e4369df 100644 --- a/merlin/dataloader/array.py +++ b/merlin/dataloader/array.py @@ -214,14 +214,7 @@ def array_lib(self): return np def __len__(self): - """Number of batches in the Sequence. - - Note: This also resets the loader state. - Required because of the calls to `__getitem__` - from keras prior to the start of the main loop - through the loader. - """ - ArrayLoaderBase.stop(self) + """Number of batches in the dataloader.""" return ArrayLoaderBase.__len__(self) def __getitem__(self, index): diff --git a/merlin/dataloader/tensorflow.py b/merlin/dataloader/tensorflow.py index 30626891..4e1a6e32 100644 --- a/merlin/dataloader/tensorflow.py +++ b/merlin/dataloader/tensorflow.py @@ -56,6 +56,17 @@ def __init__( convert_col, _to_dlpack_fn=_to_dlpack_fn, _from_dlpack_fn=_from_dlpack_fn, _unsafe=True ) + def __len__(self): + """Number of batches in the Sequence. + + Note: This also resets the loader state. + Required because of the calls to `__getitem__` + from keras prior to the start of the main loop + through the loader. + """ + ArrayLoader.stop(self) + return ArrayLoader.__len__(self) + def __getitem__(self, index): """Gets batch at position `index`. diff --git a/tests/unit/dataloader/test_torch_dataloader.py b/tests/unit/dataloader/test_torch_dataloader.py index 246b62c5..75ed74a6 100644 --- a/tests/unit/dataloader/test_torch_dataloader.py +++ b/tests/unit/dataloader/test_torch_dataloader.py @@ -33,6 +33,8 @@ # If pytorch isn't installed skip these tests. Note that the # torch_dataloader import needs to happen after this line torch = pytest.importorskip("torch") # noqa +from torch.utils.data import DataLoader, IterableDataset # noqa + import merlin.dataloader.torch as torch_dataloader # noqa from merlin.dataloader.torch import ( # noqa isort:skip @@ -40,6 +42,29 @@ ) +def test_iterable_dataset(): + df = pd.DataFrame({"feature": [1, 2, 3], "target": [0, 1, 0]}) + dataset = Dataset(df) + dataset.schema["target"] = dataset.schema["target"].with_tags(Tags.TARGET) + iterable_dataset = torch_dataloader.Loader(dataset, batch_size=1) + assert isinstance(iterable_dataset, IterableDataset) + + +def test_calling_len_during_iteration(): + df = pd.DataFrame({"feature": [1, 2, 3], "target": [0, 1, 0]}) + dataset = Dataset(df) + dataset.schema["target"] = dataset.schema["target"].with_tags(Tags.TARGET) + iterable_dataset = torch_dataloader.Loader(dataset, batch_size=1) + torch_data_loader = DataLoader(iterable_dataset) + batches = 0 + for i, batch in enumerate(torch_data_loader): + len(torch_data_loader) + batches += 1 + if i > 5: + break + assert batches == 3 + + @pytest.mark.parametrize("shape", [(), (1,), (2,), (3, 4)]) @pytest.mark.parametrize("num_cols", [1, 2]) def test_fixed_column(shape, num_cols):