Skip to content

Commit

Permalink
Update DataLoader.persistent_workers warnings in ddp_spawn (#6762)
Browse files Browse the repository at this point in the history
Co-authored-by: Carlos Mocholi <carlossmocholi@gmail.com>
  • Loading branch information
s-rog and carmocca committed Apr 9, 2021
1 parent 5e4dfd7 commit e35192d
Show file tree
Hide file tree
Showing 3 changed files with 53 additions and 10 deletions.
3 changes: 3 additions & 0 deletions CHANGELOG.md
Expand Up @@ -116,6 +116,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Changed `PyTorchProfiler` to use `torch.autograd.profiler.record_function` to record functions ([#6349](https://github.com/PyTorchLightning/pytorch-lightning/pull/6349))


- Changed warnings and recommendations for dataloaders in `ddp_spawn` ([#6762](https://github.com/PyTorchLightning/pytorch-lightning/pull/6762/))


### Deprecated

- Deprecated `TrainerTrainingTricksMixin` in favor of a separate utilities module for NaN/Inf detection for gradients and parameters ([#6834](https://github.com/PyTorchLightning/pytorch-lightning/pull/6834/))
Expand Down
38 changes: 28 additions & 10 deletions pytorch_lightning/trainer/data_loading.py
Expand Up @@ -61,18 +61,36 @@ def _worker_check(self, dataloader: DataLoader, name: str) -> None:
using_spawn = self.accelerator_connector.distributed_backend == "ddp_spawn"
if is_dataloader and not on_windows:
if dataloader.num_workers > 0 and using_spawn:
rank_zero_warn(
'Dataloader(num_workers>0) and ddp_spawn do not mix well!'
' Your performance might suffer dramatically.'
' Please consider setting accelerator=ddp to use num_workers > 0'
' (this is a bottleneck of Python .spawn() and PyTorch'
)
# checks for the attr persistent_workers available in pytorch >= 1.7
if hasattr(dataloader, "persistent_workers"):
if not dataloader.persistent_workers:
rank_zero_warn(
'num_workers>0, persistent_workers=False, and accelerator=ddp_spawn'
' may result in data loading bottlenecks.'
' Consider setting persistent_workers=True'
' (this is a limitation of Python .spawn() and PyTorch)'
)
else:
rank_zero_warn(
'num_workers>0 and accelerator=ddp_spawn do not mix well'
' and may result in data loading bottlenecks.'
' Consider setting accelerator=ddp to use num_workers>0'
' (this is a limitation of Python .spawn() and PyTorch)'
)

elif dataloader.num_workers == 0 and using_spawn:
rank_zero_warn(
'You are using `accelerator=ddp_spawn` with num_workers=0.'
' For much faster performance, switch to `accelerator=ddp` and set `num_workers>0`'
)
# checks for the attr persistent_workers available in pytorch >= 1.7
if hasattr(dataloader, "persistent_workers"):
if not dataloader.persistent_workers:
rank_zero_warn(
'accelerator=ddp_spawn and num_workers=0 may result in data loading bottlenecks.'
' Consider setting num_workers>0 and persistent_workers=True'
)
else:
rank_zero_warn(
'accelerator=ddp_spawn and num_workers=0 may result in data loading bottlenecks.'
' Consider setting accelerator=ddp and set num_workers>0'
)

elif dataloader.num_workers <= 2 and multiprocessing.cpu_count() > 2 and not using_spawn:
num_cpus = multiprocessing.cpu_count()
Expand Down
22 changes: 22 additions & 0 deletions tests/trainer/test_data_loading.py
Expand Up @@ -102,3 +102,25 @@ def check_replace_distrubuted_sampler(tmpdir, save_preds_on_dl_idx, accelerator,
@pytest.mark.parametrize("mode", [1, 2])
def test_replace_distrubuted_sampler_custom_dataloader_custom_batch_sampler(tmpdir, mode):
check_replace_distrubuted_sampler(tmpdir, True, "ddp", 2, 2, mode)


@pytest.mark.parametrize("num_workers", [0, 1])
def test_dataloader_warnings(num_workers):

class TestModel(BoringModel):

def on_train_start(self, *_) -> None:
raise SystemExit()

dl = DataLoader(RandomDataset(32, 64), num_workers=num_workers)
if hasattr(dl, "persistent_workers"):
if num_workers == 0:
warn_str = "Consider setting num_workers>0 and persistent_workers=True"
else:
warn_str = "Consider setting persistent_workers=True"
else:
warn_str = "Consider setting accelerator=ddp"

trainer = Trainer(accelerator="ddp_spawn")
with pytest.warns(UserWarning, match=warn_str), pytest.raises(SystemExit):
trainer.fit(TestModel(), dl)

0 comments on commit e35192d

Please sign in to comment.