diff --git a/pytorch_lightning/plugins/ddp_plugin.py b/pytorch_lightning/plugins/ddp_plugin.py index 27deeeddfdb45..4c4fdc8f0d368 100644 --- a/pytorch_lightning/plugins/ddp_plugin.py +++ b/pytorch_lightning/plugins/ddp_plugin.py @@ -1,12 +1,16 @@ -from pytorch_lightning.overrides.data_parallel import LightningDistributedDataParallel +from typing import List, Dict, Any + from pytorch_lightning.core.lightning import LightningModule -from typing import List +from pytorch_lightning.overrides.data_parallel import LightningDistributedDataParallel class DDPPlugin(object): """ Plugin to link a custom ddp implementation to any arbitrary accelerator. + This plugin forwards all constructor arguments to `LightningDistributedDataParallel`, + which in turn forwards all args to `DistributedDataParallel`. + Example:: class MyDDP(DDPPlugin): @@ -17,11 +21,16 @@ def configure_ddp(self, model, device_ids): my_ddp = MyDDP() trainer = Trainer(accelerator='ddp_x', plugins=[my_ddp]) - """ - def configure_ddp(self, model: LightningModule, device_ids: List[int]) -> LightningDistributedDataParallel: + def __init__(self, **kwargs): + self._ddp_kwargs: Dict[str, Any] = kwargs + + def configure_ddp( + self, model: LightningModule, device_ids: List[int] + ) -> LightningDistributedDataParallel: """ + Pass through all customizations from constructor to `LightningDistributedDataParallel`. Override to define a custom DDP implementation. .. note:: Only requirement is that your DDP implementation subclasses LightningDistributedDataParallel @@ -43,5 +52,13 @@ def configure_ddp(self, model, device_ids): the model wrapped in LightningDistributedDataParallel """ - model = LightningDistributedDataParallel(model, device_ids=device_ids, find_unused_parameters=True) + # if unset, default `find_unused_parameters` `True` + self._ddp_kwargs["find_unused_parameters"] = self._ddp_kwargs.get( + "find_unused_parameters", True + ) + model = LightningDistributedDataParallel( + model, + device_ids=device_ids, + **self._ddp_kwargs, + ) return model diff --git a/tests/plugins/test_ddp_plugin.py b/tests/plugins/test_ddp_plugin.py index b190f34395522..69cd0e3beb7b4 100644 --- a/tests/plugins/test_ddp_plugin.py +++ b/tests/plugins/test_ddp_plugin.py @@ -1,25 +1,30 @@ -from pytorch_lightning.callbacks import Callback -from tests.base.boring_model import BoringModel -from pytorch_lightning import accelerators, Trainer -from pytorch_lightning.plugins.ddp_plugin import DDPPlugin -import pytest import os from unittest import mock +import pytest +from pytorch_lightning import Trainer, accelerators +from pytorch_lightning.callbacks import Callback +from pytorch_lightning.plugins.ddp_plugin import DDPPlugin +from tests.base.boring_model import BoringModel -@mock.patch.dict(os.environ, { - "CUDA_VISIBLE_DEVICES": "0,1", - "SLURM_NTASKS": "2", - "SLURM_JOB_NAME": "SOME_NAME", - "SLURM_NODEID": "0", - "LOCAL_RANK": "0", - "SLURM_LOCALID": "0" -}) -@mock.patch('torch.cuda.device_count', return_value=2) -@pytest.mark.parametrize(['ddp_backend', 'gpus', 'num_processes'], - [('ddp_cpu', None, None), ('ddp', 2, 0), ('ddp2', 2, 0), ('ddp_spawn', 2, 0)]) -def test_ddp_choice_default_ddp_cpu(tmpdir, ddp_backend, gpus, num_processes): +@mock.patch.dict( + os.environ, + { + "CUDA_VISIBLE_DEVICES": "0,1", + "SLURM_NTASKS": "2", + "SLURM_JOB_NAME": "SOME_NAME", + "SLURM_NODEID": "0", + "LOCAL_RANK": "0", + "SLURM_LOCALID": "0", + }, +) +@mock.patch("torch.cuda.device_count", return_value=2) +@pytest.mark.parametrize( + ["ddp_backend", "gpus", "num_processes"], + [("ddp_cpu", None, None), ("ddp", 2, 0), ("ddp2", 2, 0), ("ddp_spawn", 2, 0)], +) +def test_ddp_choice_default_ddp_cpu(tmpdir, ddp_backend, gpus, num_processes): class CB(Callback): def on_fit_start(self, trainer, pl_module): assert isinstance(trainer.accelerator_backend.ddp_plugin, DDPPlugin) @@ -31,24 +36,29 @@ def on_fit_start(self, trainer, pl_module): gpus=gpus, num_processes=num_processes, distributed_backend=ddp_backend, - callbacks=[CB()] + callbacks=[CB()], ) with pytest.raises(SystemExit): trainer.fit(model) -@mock.patch.dict(os.environ, { - "CUDA_VISIBLE_DEVICES": "0,1", - "SLURM_NTASKS": "2", - "SLURM_JOB_NAME": "SOME_NAME", - "SLURM_NODEID": "0", - "LOCAL_RANK": "0", - "SLURM_LOCALID": "0" -}) -@mock.patch('torch.cuda.device_count', return_value=2) -@pytest.mark.parametrize(['ddp_backend', 'gpus', 'num_processes'], - [('ddp_cpu', None, None), ('ddp', 2, 0), ('ddp2', 2, 0), ('ddp_spawn', 2, 0)]) +@mock.patch.dict( + os.environ, + { + "CUDA_VISIBLE_DEVICES": "0,1", + "SLURM_NTASKS": "2", + "SLURM_JOB_NAME": "SOME_NAME", + "SLURM_NODEID": "0", + "LOCAL_RANK": "0", + "SLURM_LOCALID": "0", + }, +) +@mock.patch("torch.cuda.device_count", return_value=2) +@pytest.mark.parametrize( + ["ddp_backend", "gpus", "num_processes"], + [("ddp_cpu", None, None), ("ddp", 2, 0), ("ddp2", 2, 0), ("ddp_spawn", 2, 0)], +) def test_ddp_choice_custom_ddp_cpu(tmpdir, ddp_backend, gpus, num_processes): class MyDDP(DDPPlugin): pass @@ -65,7 +75,48 @@ def on_fit_start(self, trainer, pl_module): num_processes=num_processes, distributed_backend=ddp_backend, plugins=[MyDDP()], - callbacks=[CB()] + callbacks=[CB()], + ) + + with pytest.raises(SystemExit): + trainer.fit(model) + + +@mock.patch.dict( + os.environ, + { + "CUDA_VISIBLE_DEVICES": "0,1", + "SLURM_NTASKS": "2", + "SLURM_JOB_NAME": "SOME_NAME", + "SLURM_NODEID": "0", + "LOCAL_RANK": "0", + "SLURM_LOCALID": "0", + }, +) +@mock.patch("torch.cuda.device_count", return_value=2) +@pytest.mark.parametrize( + ["ddp_backend", "gpus", "num_processes"], + [("ddp_cpu", None, None), ("ddp", 2, 0), ("ddp2", 2, 0), ("ddp_spawn", 2, 0)], +) +def test_ddp_choice_custom_ddp_cpu_custom_args( + tmpdir, ddp_backend, gpus, num_processes +): + class MyDDP(DDPPlugin): + pass + + class CB(Callback): + def on_fit_start(self, trainer, pl_module): + assert isinstance(trainer.accelerator_backend.ddp_plugin, MyDDP) + raise SystemExit() + + model = BoringModel() + trainer = Trainer( + fast_dev_run=True, + gpus=gpus, + num_processes=num_processes, + distributed_backend=ddp_backend, + plugins=[MyDDP(broadcast_buffers=False, find_unused_parameters=True)], + callbacks=[CB()], ) with pytest.raises(SystemExit):