Skip to content

Commit

Permalink
Disable attaching samplers when using IterableDataset (#11507)
Browse files Browse the repository at this point in the history
  • Loading branch information
rohitgr7 authored and lexierule committed Jan 19, 2022
1 parent e95d8b1 commit f72bf31
Show file tree
Hide file tree
Showing 3 changed files with 25 additions and 4 deletions.
3 changes: 3 additions & 0 deletions CHANGELOG.md
Expand Up @@ -17,6 +17,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Changed `LSFEnvironment` to use `LSB_DJOB_RANKFILE` environment variable instead of `LSB_HOSTS` for determining node rank and main address ([#10825](https://github.com/PyTorchLightning/pytorch-lightning/pull/10825))


- Disbled sampler replacement when using `IterableDataset` ([#11507](https://github.com/PyTorchLightning/pytorch-lightning/pull/11507))


## [1.5.8] - 2022-01-05

### Fixed
Expand Down
10 changes: 7 additions & 3 deletions pytorch_lightning/trainer/data_loading.py
Expand Up @@ -272,9 +272,13 @@ def _get_dataloader_init_kwargs(

# kwargs to re-construct the dataloader
dl_kwargs = {k: v for k, v in attrs.items() if k in non_defaults}
dl_kwargs.update(
TrainerDataLoadingMixin._dataloader_init_kwargs_resolve_sampler(dataloader, sampler, mode=mode)
)
if isinstance(dl_kwargs["dataset"], IterableDataset):
dl_kwargs["batch_sampler"] = None
dl_kwargs["sampler"] = None
else:
dl_kwargs.update(
TrainerDataLoadingMixin._dataloader_init_kwargs_resolve_sampler(dataloader, sampler, mode=mode)
)

required_args = {
p.name
Expand Down
16 changes: 15 additions & 1 deletion tests/trainer/test_data_loading.py
Expand Up @@ -20,11 +20,12 @@
from torch.utils.data.sampler import BatchSampler, Sampler, SequentialSampler

from pytorch_lightning import Trainer
from pytorch_lightning.trainer.data_loading import TrainerDataLoadingMixin
from pytorch_lightning.trainer.states import RunningStage
from pytorch_lightning.trainer.supporters import CombinedLoader
from pytorch_lightning.utilities.enums import DistributedType
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from tests.helpers import BoringModel, RandomDataset
from tests.helpers.boring_model import BoringModel, RandomDataset, RandomIterableDataset
from tests.helpers.runif import RunIf


Expand Down Expand Up @@ -389,3 +390,16 @@ def test_non_sequential_sampler_warning_is_raised_for_eval_dataloader(val_dl):
trainer._data_connector.attach_data(model, val_dataloaders=val_dl)
with pytest.warns(UserWarning, match="recommended .* turn this off for val/test/predict"):
trainer._reset_eval_dataloader(RunningStage.VALIDATING, model)


@pytest.mark.parametrize("mode", [RunningStage.TRAINING, RunningStage.PREDICTING, RunningStage.TESTING])
def test_dataloader_kwargs_replacement_with_iterable_dataset(mode):
"""Test that DataLoader kwargs are not replaced when using Iterable Dataset."""
dataset = RandomIterableDataset(7, 100)
dataloader = DataLoader(dataset, batch_size=32)
dl_kwargs = TrainerDataLoadingMixin._get_dataloader_init_kwargs(dataloader, dataloader.sampler, mode=mode)
assert dl_kwargs["sampler"] is None
assert dl_kwargs["batch_sampler"] is None
assert dl_kwargs["batch_size"] is dataloader.batch_size
assert dl_kwargs["dataset"] is dataloader.dataset
assert dl_kwargs["collate_fn"] is dataloader.collate_fn

0 comments on commit f72bf31

Please sign in to comment.