Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions pytorch_lightning/trainer/data_loading.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,7 @@ def _resolve_sampler(self, dataloader: DataLoader, shuffle: bool, mode: Optional
if self._requires_distributed_sampler(dataloader):
if not isinstance(dataloader.sampler, (SequentialSampler, RandomSampler)):
raise MisconfigurationException(
"You seem to have configured a sampler in your DataLoader. This will be replaced "
"You seem to have configured a sampler in your DataLoader. This will be replaced"
" by `DistributedSampler` since `replace_sampler_ddp` is True and you are using"
" distributed training. Either remove the sampler from your DataLoader or set"
" `replace_sampler_ddp=False` if you want to use your custom sampler."
Expand Down Expand Up @@ -322,7 +322,7 @@ def _reset_eval_dataloader(
module = model or self.lightning_module or self.datamodule
if len(dataloaders) != 0:
for i, dataloader in enumerate(dataloaders):
num_batches = (
orig_num_batches = num_batches = (
len(dataloader)
if has_len_all_ranks(dataloader, self.training_type_plugin, module)
else float("inf")
Expand All @@ -348,7 +348,7 @@ def _reset_eval_dataloader(
min_pct = 1.0 / len(dataloader)
raise MisconfigurationException(
f"you requested to check {limit_eval_batches} of the `{mode.dataloader_prefix}_dataloader` but"
f" {limit_eval_batches}*{num_batches} < 1. Please increase the"
f" {limit_eval_batches} * {orig_num_batches} < 1. Please increase the"
f" `limit_{mode.dataloader_prefix}_batches` flag. Try at least"
f" `limit_{mode.dataloader_prefix}_batches={min_pct}`"
)
Expand Down
17 changes: 16 additions & 1 deletion tests/trainer/test_data_loading.py
Original file line number Diff line number Diff line change
Expand Up @@ -268,7 +268,7 @@ class CustomSampler(Sampler):

# Should raise an error if existing sampler is being replaced
dataloader = CustomDataLoader(dataset, sampler=CustomSampler(dataset))
with pytest.raises(MisconfigurationException, match="will be replaced by `DistributedSampler`"):
with pytest.raises(MisconfigurationException, match="will be replaced by `DistributedSampler`"):
trainer.prepare_dataloader(dataloader, shuffle=True)


Expand Down Expand Up @@ -339,6 +339,21 @@ def test_pre_made_batches():
trainer.predict(LoaderTestModel(), loader)


def test_error_raised_with_float_limited_eval_batches():
"""Test that an error is raised if there are not enough batches when passed with float value of
limit_eval_batches."""
model = BoringModel()
dl_size = len(model.val_dataloader())
limit_val_batches = 1 / (dl_size + 2)
trainer = Trainer(limit_val_batches=limit_val_batches)
trainer._data_connector.attach_data(model)
with pytest.raises(
MisconfigurationException,
match=fr"{limit_val_batches} \* {dl_size} < 1. Please increase the `limit_val_batches`",
):
trainer._reset_eval_dataloader(RunningStage.VALIDATING, model)


@pytest.mark.parametrize(
"val_dl",
[
Expand Down