From f7200d9eb67e9f86366ac1a7201ce1ab013e46ea Mon Sep 17 00:00:00 2001 From: rohitgr7 Date: Fri, 14 Jan 2022 20:46:57 +0530 Subject: [PATCH] fix --- CHANGELOG.md | 1 - docs/source/common/test_set.rst | 4 ++-- .../trainer/connectors/data_connector.py | 6 ++++-- tests/trainer/connectors/test_data_connector.py | 12 +++++------- 4 files changed, 11 insertions(+), 12 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index c4f0dfb9286e1e..364db0637012f6 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -75,7 +75,6 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Added a warning when using `DistributedSampler` during evaluation ([#11479](https://github.com/PyTorchLightning/pytorch-lightning/pull/11479)) - ### Changed - Raised exception in `init_dist_connection()` when torch distibuted is not available ([#10418](https://github.com/PyTorchLightning/pytorch-lightning/issues/10418)) diff --git a/docs/source/common/test_set.rst b/docs/source/common/test_set.rst index 81d1149c911158..d753c6926a005a 100644 --- a/docs/source/common/test_set.rst +++ b/docs/source/common/test_set.rst @@ -146,10 +146,10 @@ Apart from this ``.validate`` has same API as ``.test``, but would rely respecti .. automethod:: pytorch_lightning.trainer.Trainer.validate :noindex: - + .. warning:: - It is recommended to test on single device since Distributed Training such as DDP internally + It is recommended to validate on single device since Distributed Training such as DDP internally uses :class:`~torch.utils.data.distributed.DistributedSampler` which replicates some samples to make sure all devices have same batch size in case of uneven inputs. This is helpful to make sure benchmarking for research papers is done the right way. diff --git a/pytorch_lightning/trainer/connectors/data_connector.py b/pytorch_lightning/trainer/connectors/data_connector.py index 8ffaff1c685d5c..7b8e4ce40b3cc1 100644 --- a/pytorch_lightning/trainer/connectors/data_connector.py +++ b/pytorch_lightning/trainer/connectors/data_connector.py @@ -396,10 +396,12 @@ def _resolve_sampler(self, dataloader: DataLoader, shuffle: bool, mode: Optional trainer_fn = self.trainer.state.fn if isinstance(sampler, DistributedSampler) and trainer_fn in (TrainerFn.VALIDATING, TrainerFn.TESTING): rank_zero_warn( - 'Using `DistributedSampler` with the dataloaders. It is recommended to' - ' use single device strategy with evaluation.' + "Using `DistributedSampler` with the dataloaders. It is recommended to" + " use single device strategy with evaluation." ) + return sampler + return dataloader.sampler @staticmethod diff --git a/tests/trainer/connectors/test_data_connector.py b/tests/trainer/connectors/test_data_connector.py index 5f347361118235..73834022984cd2 100644 --- a/tests/trainer/connectors/test_data_connector.py +++ b/tests/trainer/connectors/test_data_connector.py @@ -14,12 +14,11 @@ from unittest.mock import Mock import pytest -from pytorch_lightning.accelerators import accelerator -from pytorch_lightning.trainer.states import TrainerFn from torch.utils.data import DataLoader -from pytorch_lightning import Trainer, strategies +from pytorch_lightning import Trainer from pytorch_lightning.trainer.connectors.data_connector import _DataLoaderSource +from pytorch_lightning.trainer.states import TrainerFn from tests.helpers import BoringDataModule, BoringModel @@ -74,16 +73,15 @@ def test_eval_distributed_sampler_warning(tmpdir): """Test that a warning is raised with `DistributedSampler` is used with evaluation.""" model = BoringModel() - trainer = Trainer(strategy='ddp', devices=2, accelerator='cpu') + trainer = Trainer(strategy="ddp", devices=2, accelerator="cpu") trainer._data_connector.attach_data(model) # validation - with pytest.warns(UserWarning, match = 'It is recommended to use single device strategy'): + with pytest.warns(UserWarning, match="It is recommended to use single device strategy"): trainer.state.fn = TrainerFn.VALIDATING trainer.reset_val_dataloader(model) # testing - with pytest.warns(UserWarning, match = 'It is recommended to use single device strategy'): + with pytest.warns(UserWarning, match="It is recommended to use single device strategy"): trainer.state.fn = TrainerFn.TESTING trainer.reset_test_dataloader(model) -