Skip to content

Commit

Permalink
Relax restrictions on wrapping a custom batch sampler in predict (#19678
Browse files Browse the repository at this point in the history
)
  • Loading branch information
awaelchli committed Mar 27, 2024
1 parent 94167d6 commit 438f29f
Show file tree
Hide file tree
Showing 3 changed files with 20 additions and 4 deletions.
2 changes: 1 addition & 1 deletion src/lightning/pytorch/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

- The `prepare_data()` hook in `LightningModule` and `LightningDataModule` is now subject to a barrier without timeout to avoid long-running tasks to be interrupted ([#19448](https://github.com/Lightning-AI/lightning/pull/19448))

-
- Relaxed the requirement for custom batch samplers to expose `drop_last` for prediction ([#19678](https://github.com/Lightning-AI/pytorch-lightning/pull/19678))

-

Expand Down
9 changes: 9 additions & 0 deletions src/lightning/pytorch/utilities/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
has_iterable_dataset,
sized_len,
)
from lightning.fabric.utilities.warnings import PossibleUserWarning
from lightning.pytorch.overrides.distributed import _IndexBatchSamplerWrapper
from lightning.pytorch.trainer.states import RunningStage
from lightning.pytorch.utilities.exceptions import MisconfigurationException
Expand Down Expand Up @@ -301,6 +302,14 @@ def _dataloader_init_kwargs_resolve_sampler(
" or set `Trainer(use_distributed_sampler=False)`. If you choose the latter, you will be"
" responsible for handling the distributed sampling within your batch sampler."
) from ex
elif is_predicting:
rank_zero_warn(
f"You are using a custom batch sampler `{batch_sampler_cls.__qualname__}` for prediction."
" Lightning would normally set `drop_last=False` to ensure all samples are returned, but for"
" custom samplers it can't guarantee this. Make sure your sampler is configured correctly to return"
" all indices.",
category=PossibleUserWarning,
)
else:
# The sampler is not a PyTorch `BatchSampler`, we don't know how to inject a custom sampler or
# how to adjust the `drop_last` value
Expand Down
13 changes: 10 additions & 3 deletions tests/tests_pytorch/utilities/test_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import pytest
import torch
from lightning.fabric.utilities.data import _replace_dunder_methods
from lightning.fabric.utilities.warnings import PossibleUserWarning
from lightning.pytorch import Trainer
from lightning.pytorch.demos.boring_classes import RandomDataset, RandomIterableDataset
from lightning.pytorch.overrides.distributed import _IndexBatchSamplerWrapper
Expand Down Expand Up @@ -230,7 +231,8 @@ def __len__(self) -> int:
assert batch_sampler.drop_last == (not predicting)


def test_custom_batch_sampler():
@pytest.mark.parametrize("predicting", [True, False])
def test_custom_batch_sampler(predicting):
"""Test that a custom (non-PyTorch) batch sampler requires the user to set `use_distributed_sampler=False`."""

class CustomBatchSampler: # not inheriting from `BatchSampler`
Expand All @@ -240,8 +242,13 @@ def __iter__(self):

batch_sampler = CustomBatchSampler()
dataloader = DataLoader(range(100), batch_sampler=batch_sampler)
with pytest.raises(TypeError, match=r"can't inject a \(distributed\) sampler into your batch sampler"):
_ = _update_dataloader(dataloader, sampler=Mock())

if predicting:
with pytest.warns(PossibleUserWarning, match=r"Make sure your sampler is configured correctly to return all"):
_ = _update_dataloader(dataloader, sampler=Mock(), mode=RunningStage.PREDICTING)
else:
with pytest.raises(TypeError, match=r"can't inject a \(distributed\) sampler into your batch sampler"):
_ = _update_dataloader(dataloader, sampler=Mock(), mode=None)


def test_custom_batch_sampler_no_drop_last():
Expand Down

0 comments on commit 438f29f

Please sign in to comment.