Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 22 additions & 5 deletions pytorch_lightning/plugins/ddp_plugin.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,16 @@
from pytorch_lightning.overrides.data_parallel import LightningDistributedDataParallel
from typing import List, Dict, Any

from pytorch_lightning.core.lightning import LightningModule
from typing import List
from pytorch_lightning.overrides.data_parallel import LightningDistributedDataParallel


class DDPPlugin(object):
"""
Plugin to link a custom ddp implementation to any arbitrary accelerator.

This plugin forwards all constructor arguments to `LightningDistributedDataParallel`,
which in turn forwards all args to `DistributedDataParallel`.

Example::

class MyDDP(DDPPlugin):
Expand All @@ -17,11 +21,16 @@ def configure_ddp(self, model, device_ids):

my_ddp = MyDDP()
trainer = Trainer(accelerator='ddp_x', plugins=[my_ddp])

"""

def configure_ddp(self, model: LightningModule, device_ids: List[int]) -> LightningDistributedDataParallel:
def __init__(self, **kwargs):
self._ddp_kwargs: Dict[str, Any] = kwargs

def configure_ddp(
self, model: LightningModule, device_ids: List[int]
) -> LightningDistributedDataParallel:
"""
Pass through all customizations from constructor to `LightningDistributedDataParallel`.
Override to define a custom DDP implementation.

.. note:: Only requirement is that your DDP implementation subclasses LightningDistributedDataParallel
Expand All @@ -43,5 +52,13 @@ def configure_ddp(self, model, device_ids):
the model wrapped in LightningDistributedDataParallel

"""
model = LightningDistributedDataParallel(model, device_ids=device_ids, find_unused_parameters=True)
# if unset, default `find_unused_parameters` `True`
self._ddp_kwargs["find_unused_parameters"] = self._ddp_kwargs.get(
"find_unused_parameters", True
)
model = LightningDistributedDataParallel(
model,
device_ids=device_ids,
**self._ddp_kwargs,
)
return model
111 changes: 81 additions & 30 deletions tests/plugins/test_ddp_plugin.py
Original file line number Diff line number Diff line change
@@ -1,25 +1,30 @@
from pytorch_lightning.callbacks import Callback
from tests.base.boring_model import BoringModel
from pytorch_lightning import accelerators, Trainer
from pytorch_lightning.plugins.ddp_plugin import DDPPlugin
import pytest
import os
from unittest import mock

import pytest
from pytorch_lightning import Trainer, accelerators
from pytorch_lightning.callbacks import Callback
from pytorch_lightning.plugins.ddp_plugin import DDPPlugin
from tests.base.boring_model import BoringModel

@mock.patch.dict(os.environ, {
"CUDA_VISIBLE_DEVICES": "0,1",
"SLURM_NTASKS": "2",
"SLURM_JOB_NAME": "SOME_NAME",
"SLURM_NODEID": "0",
"LOCAL_RANK": "0",
"SLURM_LOCALID": "0"
})
@mock.patch('torch.cuda.device_count', return_value=2)
@pytest.mark.parametrize(['ddp_backend', 'gpus', 'num_processes'],
[('ddp_cpu', None, None), ('ddp', 2, 0), ('ddp2', 2, 0), ('ddp_spawn', 2, 0)])
def test_ddp_choice_default_ddp_cpu(tmpdir, ddp_backend, gpus, num_processes):

@mock.patch.dict(
os.environ,
{
"CUDA_VISIBLE_DEVICES": "0,1",
"SLURM_NTASKS": "2",
"SLURM_JOB_NAME": "SOME_NAME",
"SLURM_NODEID": "0",
"LOCAL_RANK": "0",
"SLURM_LOCALID": "0",
},
)
@mock.patch("torch.cuda.device_count", return_value=2)
@pytest.mark.parametrize(
["ddp_backend", "gpus", "num_processes"],
[("ddp_cpu", None, None), ("ddp", 2, 0), ("ddp2", 2, 0), ("ddp_spawn", 2, 0)],
)
def test_ddp_choice_default_ddp_cpu(tmpdir, ddp_backend, gpus, num_processes):
class CB(Callback):
def on_fit_start(self, trainer, pl_module):
assert isinstance(trainer.accelerator_backend.ddp_plugin, DDPPlugin)
Expand All @@ -31,24 +36,29 @@ def on_fit_start(self, trainer, pl_module):
gpus=gpus,
num_processes=num_processes,
distributed_backend=ddp_backend,
callbacks=[CB()]
callbacks=[CB()],
)

with pytest.raises(SystemExit):
trainer.fit(model)


@mock.patch.dict(os.environ, {
"CUDA_VISIBLE_DEVICES": "0,1",
"SLURM_NTASKS": "2",
"SLURM_JOB_NAME": "SOME_NAME",
"SLURM_NODEID": "0",
"LOCAL_RANK": "0",
"SLURM_LOCALID": "0"
})
@mock.patch('torch.cuda.device_count', return_value=2)
@pytest.mark.parametrize(['ddp_backend', 'gpus', 'num_processes'],
[('ddp_cpu', None, None), ('ddp', 2, 0), ('ddp2', 2, 0), ('ddp_spawn', 2, 0)])
@mock.patch.dict(
os.environ,
{
"CUDA_VISIBLE_DEVICES": "0,1",
"SLURM_NTASKS": "2",
"SLURM_JOB_NAME": "SOME_NAME",
"SLURM_NODEID": "0",
"LOCAL_RANK": "0",
"SLURM_LOCALID": "0",
},
)
@mock.patch("torch.cuda.device_count", return_value=2)
@pytest.mark.parametrize(
["ddp_backend", "gpus", "num_processes"],
[("ddp_cpu", None, None), ("ddp", 2, 0), ("ddp2", 2, 0), ("ddp_spawn", 2, 0)],
)
def test_ddp_choice_custom_ddp_cpu(tmpdir, ddp_backend, gpus, num_processes):
class MyDDP(DDPPlugin):
pass
Expand All @@ -65,7 +75,48 @@ def on_fit_start(self, trainer, pl_module):
num_processes=num_processes,
distributed_backend=ddp_backend,
plugins=[MyDDP()],
callbacks=[CB()]
callbacks=[CB()],
)

with pytest.raises(SystemExit):
trainer.fit(model)


@mock.patch.dict(
os.environ,
{
"CUDA_VISIBLE_DEVICES": "0,1",
"SLURM_NTASKS": "2",
"SLURM_JOB_NAME": "SOME_NAME",
"SLURM_NODEID": "0",
"LOCAL_RANK": "0",
"SLURM_LOCALID": "0",
},
)
@mock.patch("torch.cuda.device_count", return_value=2)
@pytest.mark.parametrize(
["ddp_backend", "gpus", "num_processes"],
[("ddp_cpu", None, None), ("ddp", 2, 0), ("ddp2", 2, 0), ("ddp_spawn", 2, 0)],
)
def test_ddp_choice_custom_ddp_cpu_custom_args(
tmpdir, ddp_backend, gpus, num_processes
):
class MyDDP(DDPPlugin):
pass

class CB(Callback):
def on_fit_start(self, trainer, pl_module):
assert isinstance(trainer.accelerator_backend.ddp_plugin, MyDDP)
raise SystemExit()

model = BoringModel()
trainer = Trainer(
fast_dev_run=True,
gpus=gpus,
num_processes=num_processes,
distributed_backend=ddp_backend,
plugins=[MyDDP(broadcast_buffers=False, find_unused_parameters=True)],
callbacks=[CB()],
)

with pytest.raises(SystemExit):
Expand Down