diff --git a/pytorch_lightning/trainer/data_loading.py b/pytorch_lightning/trainer/data_loading.py index e036e7e90254d..4460e235b11bd 100644 --- a/pytorch_lightning/trainer/data_loading.py +++ b/pytorch_lightning/trainer/data_loading.py @@ -171,7 +171,7 @@ def _resolve_sampler(self, dataloader: DataLoader, shuffle: bool, mode: Optional if self._requires_distributed_sampler(dataloader): if not isinstance(dataloader.sampler, (SequentialSampler, RandomSampler)): raise MisconfigurationException( - "You seem to have configured a sampler in your DataLoader. This will be replaced " + "You seem to have configured a sampler in your DataLoader. This will be replaced" " by `DistributedSampler` since `replace_sampler_ddp` is True and you are using" " distributed training. Either remove the sampler from your DataLoader or set" " `replace_sampler_ddp=False` if you want to use your custom sampler." @@ -322,7 +322,7 @@ def _reset_eval_dataloader( module = model or self.lightning_module or self.datamodule if len(dataloaders) != 0: for i, dataloader in enumerate(dataloaders): - num_batches = ( + orig_num_batches = num_batches = ( len(dataloader) if has_len_all_ranks(dataloader, self.training_type_plugin, module) else float("inf") @@ -348,7 +348,7 @@ def _reset_eval_dataloader( min_pct = 1.0 / len(dataloader) raise MisconfigurationException( f"you requested to check {limit_eval_batches} of the `{mode.dataloader_prefix}_dataloader` but" - f" {limit_eval_batches}*{num_batches} < 1. Please increase the" + f" {limit_eval_batches} * {orig_num_batches} < 1. Please increase the" f" `limit_{mode.dataloader_prefix}_batches` flag. Try at least" f" `limit_{mode.dataloader_prefix}_batches={min_pct}`" ) diff --git a/tests/trainer/test_data_loading.py b/tests/trainer/test_data_loading.py index 74c074aef2110..95b9f9061b366 100644 --- a/tests/trainer/test_data_loading.py +++ b/tests/trainer/test_data_loading.py @@ -268,7 +268,7 @@ class CustomSampler(Sampler): # Should raise an error if existing sampler is being replaced dataloader = CustomDataLoader(dataset, sampler=CustomSampler(dataset)) - with pytest.raises(MisconfigurationException, match="will be replaced by `DistributedSampler`"): + with pytest.raises(MisconfigurationException, match="will be replaced by `DistributedSampler`"): trainer.prepare_dataloader(dataloader, shuffle=True) @@ -339,6 +339,21 @@ def test_pre_made_batches(): trainer.predict(LoaderTestModel(), loader) +def test_error_raised_with_float_limited_eval_batches(): + """Test that an error is raised if there are not enough batches when passed with float value of + limit_eval_batches.""" + model = BoringModel() + dl_size = len(model.val_dataloader()) + limit_val_batches = 1 / (dl_size + 2) + trainer = Trainer(limit_val_batches=limit_val_batches) + trainer._data_connector.attach_data(model) + with pytest.raises( + MisconfigurationException, + match=fr"{limit_val_batches} \* {dl_size} < 1. Please increase the `limit_val_batches`", + ): + trainer._reset_eval_dataloader(RunningStage.VALIDATING, model) + + @pytest.mark.parametrize( "val_dl", [