From dbcdfaa29401a08bce1c4fd66cb174933104d911 Mon Sep 17 00:00:00 2001 From: rohitgr7 Date: Wed, 8 Dec 2021 00:40:43 +0530 Subject: [PATCH 1/4] update value in warning --- pytorch_lightning/trainer/data_loading.py | 12 ++++++------ tests/trainer/test_data_loading.py | 17 ++++++++++++++++- 2 files changed, 22 insertions(+), 7 deletions(-) diff --git a/pytorch_lightning/trainer/data_loading.py b/pytorch_lightning/trainer/data_loading.py index b4cb610fa138b..fc7ab160b34e7 100644 --- a/pytorch_lightning/trainer/data_loading.py +++ b/pytorch_lightning/trainer/data_loading.py @@ -170,7 +170,7 @@ def _resolve_sampler(self, dataloader: DataLoader, shuffle: bool, mode: Optional if self._requires_distributed_sampler(dataloader): if not isinstance(dataloader.sampler, (SequentialSampler, RandomSampler)): raise MisconfigurationException( - "You seem to have configured a sampler in your DataLoader. This will be replaced " + "You seem to have configured a sampler in your DataLoader. This will be replaced" " by `DistributedSampler` since `replace_sampler_ddp` is True and you are using" " distributed training. Either remove the sampler from your DataLoader or set" " `replace_sampler_ddp=False` if you want to use your custom sampler." @@ -323,7 +323,7 @@ def _reset_eval_dataloader( module = model or self.lightning_module or self.datamodule if len(dataloaders) != 0: for i, dataloader in enumerate(dataloaders): - num_batches = ( + orig_num_batches = ( len(dataloader) if has_len_all_ranks(dataloader, self.training_type_plugin, module) else float("inf") @@ -335,9 +335,9 @@ def _reset_eval_dataloader( # limit num batches either as a percent or num steps if isinstance(limit_eval_batches, int) or limit_eval_batches == 0.0: - num_batches = min(num_batches, int(limit_eval_batches)) - elif num_batches != float("inf"): - num_batches = int(num_batches * limit_eval_batches) + num_batches = min(orig_num_batches, int(limit_eval_batches)) + elif orig_num_batches != float("inf"): + num_batches = int(orig_num_batches * limit_eval_batches) elif limit_eval_batches != 1.0: raise MisconfigurationException( f"When using an IterableDataset for `limit_{mode}_batches`," @@ -349,7 +349,7 @@ def _reset_eval_dataloader( min_pct = 1.0 / len(dataloader) raise MisconfigurationException( f"you requested to check {limit_eval_batches} of the `{mode.dataloader_prefix}_dataloader` but" - f" {limit_eval_batches}*{num_batches} < 1. Please increase the" + f" {limit_eval_batches}*{orig_num_batches} < 1. Please increase the" f" `limit_{mode.dataloader_prefix}_batches` flag. Try at least" f" `limit_{mode.dataloader_prefix}_batches={min_pct}`" ) diff --git a/tests/trainer/test_data_loading.py b/tests/trainer/test_data_loading.py index e191f681d209d..cddda4b4e27f0 100644 --- a/tests/trainer/test_data_loading.py +++ b/tests/trainer/test_data_loading.py @@ -19,6 +19,7 @@ from torch.utils.data import BatchSampler, DataLoader, DistributedSampler, Sampler, SequentialSampler from pytorch_lightning import Trainer +from pytorch_lightning.trainer.states import RunningStage from pytorch_lightning.utilities.data import _update_dataloader from pytorch_lightning.utilities.enums import _StrategyType from pytorch_lightning.utilities.exceptions import MisconfigurationException @@ -265,7 +266,7 @@ class CustomSampler(Sampler): # Should raise an error if existing sampler is being replaced dataloader = CustomDataLoader(dataset, sampler=CustomSampler(dataset)) - with pytest.raises(MisconfigurationException, match="will be replaced by `DistributedSampler`"): + with pytest.raises(MisconfigurationException, match="will be replaced by `DistributedSampler`"): trainer.prepare_dataloader(dataloader, shuffle=True) @@ -334,3 +335,17 @@ def test_pre_made_batches(): loader = DataLoader(RandomDataset(32, 10), batch_size=None) trainer = Trainer(fast_dev_run=1) trainer.predict(LoaderTestModel(), loader) + + +def test_error_raised_with_float_limited_eval_batches(): + """Test that an error is raised if there are not enough batches when passed with float value of + limit_eval_batches.""" + model = BoringModel() + dl_size = len(model.val_dataloader()) + limit_val_batches = 1 / (dl_size + 2) + trainer = Trainer(limit_val_batches=limit_val_batches) + trainer._data_connector.attach_data(model) + with pytest.raises( + MisconfigurationException, match=fr"{limit_val_batches}\*{dl_size} < 1. Please increase the `limit_val_batches`" + ): + trainer._reset_eval_dataloader(RunningStage.VALIDATING, model) From 9a0bd4211a5943150b79592e196f16a7368336a2 Mon Sep 17 00:00:00 2001 From: rohitgr7 Date: Wed, 8 Dec 2021 00:43:30 +0530 Subject: [PATCH 2/4] add space --- pytorch_lightning/trainer/data_loading.py | 2 +- tests/trainer/test_data_loading.py | 3 ++- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/pytorch_lightning/trainer/data_loading.py b/pytorch_lightning/trainer/data_loading.py index fc7ab160b34e7..7db3587cb4fd1 100644 --- a/pytorch_lightning/trainer/data_loading.py +++ b/pytorch_lightning/trainer/data_loading.py @@ -349,7 +349,7 @@ def _reset_eval_dataloader( min_pct = 1.0 / len(dataloader) raise MisconfigurationException( f"you requested to check {limit_eval_batches} of the `{mode.dataloader_prefix}_dataloader` but" - f" {limit_eval_batches}*{orig_num_batches} < 1. Please increase the" + f" {limit_eval_batches} * {orig_num_batches} < 1. Please increase the" f" `limit_{mode.dataloader_prefix}_batches` flag. Try at least" f" `limit_{mode.dataloader_prefix}_batches={min_pct}`" ) diff --git a/tests/trainer/test_data_loading.py b/tests/trainer/test_data_loading.py index cddda4b4e27f0..72cbf96fa807d 100644 --- a/tests/trainer/test_data_loading.py +++ b/tests/trainer/test_data_loading.py @@ -346,6 +346,7 @@ def test_error_raised_with_float_limited_eval_batches(): trainer = Trainer(limit_val_batches=limit_val_batches) trainer._data_connector.attach_data(model) with pytest.raises( - MisconfigurationException, match=fr"{limit_val_batches}\*{dl_size} < 1. Please increase the `limit_val_batches`" + MisconfigurationException, + match=fr"{limit_val_batches} \* {dl_size} < 1. Please increase the `limit_val_batches`", ): trainer._reset_eval_dataloader(RunningStage.VALIDATING, model) From e181e24f6b9133acfe0a73d962af8f64ee331ea0 Mon Sep 17 00:00:00 2001 From: rohitgr7 Date: Wed, 8 Dec 2021 01:04:20 +0530 Subject: [PATCH 3/4] keep as reference --- pytorch_lightning/trainer/data_loading.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/pytorch_lightning/trainer/data_loading.py b/pytorch_lightning/trainer/data_loading.py index 7db3587cb4fd1..35c4c15a535ea 100644 --- a/pytorch_lightning/trainer/data_loading.py +++ b/pytorch_lightning/trainer/data_loading.py @@ -323,7 +323,7 @@ def _reset_eval_dataloader( module = model or self.lightning_module or self.datamodule if len(dataloaders) != 0: for i, dataloader in enumerate(dataloaders): - orig_num_batches = ( + orig_num_batches = num_batches = ( len(dataloader) if has_len_all_ranks(dataloader, self.training_type_plugin, module) else float("inf") @@ -335,9 +335,9 @@ def _reset_eval_dataloader( # limit num batches either as a percent or num steps if isinstance(limit_eval_batches, int) or limit_eval_batches == 0.0: - num_batches = min(orig_num_batches, int(limit_eval_batches)) - elif orig_num_batches != float("inf"): - num_batches = int(orig_num_batches * limit_eval_batches) + num_batches = min(num_batches, int(limit_eval_batches)) + elif num_batches != float("inf"): + num_batches = int(num_batches * limit_eval_batches) elif limit_eval_batches != 1.0: raise MisconfigurationException( f"When using an IterableDataset for `limit_{mode}_batches`," From 9d9d467bff023c282a5e5b164ded0b397f849a03 Mon Sep 17 00:00:00 2001 From: Rohit Gupta Date: Tue, 14 Dec 2021 22:55:05 +0530 Subject: [PATCH 4/4] bad merge --- tests/trainer/test_data_loading.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/trainer/test_data_loading.py b/tests/trainer/test_data_loading.py index 8c0f511e1d5d8..95b9f9061b366 100644 --- a/tests/trainer/test_data_loading.py +++ b/tests/trainer/test_data_loading.py @@ -351,6 +351,7 @@ def test_error_raised_with_float_limited_eval_batches(): MisconfigurationException, match=fr"{limit_val_batches} \* {dl_size} < 1. Please increase the `limit_val_batches`", ): + trainer._reset_eval_dataloader(RunningStage.VALIDATING, model) @pytest.mark.parametrize(