Skip to content

Commit

Permalink
moves configure ddp to each backend (#3924)
Browse files Browse the repository at this point in the history
* moves configure ddp to each backend

* moves configure ddp to each backend

* moves configure ddp to each backend

* added torch manual seed in test_mean_error

* test for complicated batch structure

* test for complicated batch structure

* test for complicated batch structure

Co-authored-by: ananyahjha93 <ananya@pytorchlightning.ai>
  • Loading branch information
williamFalcon and ananyahjha93 committed Oct 7, 2020
1 parent d65b037 commit 9c415d2
Show file tree
Hide file tree
Showing 12 changed files with 107 additions and 75 deletions.
1 change: 1 addition & 0 deletions docs/source/hooks.rst
Expand Up @@ -25,6 +25,7 @@ Training set-up
- :meth:`~pytorch_lightning.core.datamodule.LightningDataModule.prepare_data`
- :meth:`~pytorch_lightning.core.datamodule.LightningDataModule.setup`
- :meth:`~pytorch_lightning.trainer.optimizers.TrainerOptimizersMixin.init_optimizers`
- :meth:`~pytorch_lightning.core.lightning.LightningModule.configure_apex`
- :meth:`~pytorch_lightning.core.datamodule.LightningDataModule.train_dataloader`
- :meth:`~pytorch_lightning.core.datamodule.LightningDataModule.test_dataloader`
- :meth:`~pytorch_lightning.core.datamodule.LightningDataModule.val_dataloader`
Expand Down
25 changes: 0 additions & 25 deletions docs/source/introduction_guide.rst
Expand Up @@ -916,31 +916,6 @@ With your own
# do a custom way of backward
loss.backward(retain_graph=True)

Or if you wanted to initialize ddp in a different way than the default one

.. testcode::

def configure_ddp(self, model, device_ids):
# Lightning DDP simply routes to test_step, val_step, etc...
model = LightningDistributedDataParallel(
model,
device_ids=device_ids,
find_unused_parameters=True
)
return model

you could do your own:

.. testcode::

class LitMNIST(LightningModule):

def configure_ddp(self, model, device_ids):

model = Horovod(model)
# model = Ray(model)
return model

Every single part of training is configurable this way.
For a full list look at :ref:`LightningModule <lightning_module>`.

Expand Down
14 changes: 13 additions & 1 deletion pytorch_lightning/accelerators/ddp2_backend.py
Expand Up @@ -23,6 +23,9 @@
from pytorch_lightning.accelerators.base_backend import Accelerator
from pytorch_lightning.utilities import AMPType
from pytorch_lightning.utilities.distributed import rank_zero_only
from pytorch_lightning.overrides.data_parallel import LightningDistributedDataParallel
from torch.nn.parallel import DistributedDataParallel
from typing import List

try:
from hydra.utils import to_absolute_path, get_original_cwd
Expand Down Expand Up @@ -174,7 +177,7 @@ def ddp_train(self, process_idx, mp_queue, model):
device_ids = self.get_device_ids()

# allow user to configure ddp
model = model.configure_ddp(model, device_ids)
model = self.configure_ddp(model, device_ids)

# set up training routine
self.trainer.train_loop.setup_training(model)
Expand All @@ -186,6 +189,14 @@ def ddp_train(self, process_idx, mp_queue, model):
torch.cuda.empty_cache()
return results

def configure_ddp(
self, model: "LightningModule", device_ids: List[int]
) -> DistributedDataParallel:
model = LightningDistributedDataParallel(
model, device_ids=device_ids, find_unused_parameters=True
)
return model

def configure_sync_batchnorm(self, model: "LightningModule") -> "LightningModule":
"""
Add global batchnorm for a model spread across multiple GPUs and nodes.
Expand All @@ -200,4 +211,5 @@ def configure_sync_batchnorm(self, model: "LightningModule") -> "LightningModule
LightningModule with batchnorm layers synchronized between process groups
"""
model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model, process_group=None)

return model
14 changes: 13 additions & 1 deletion pytorch_lightning/accelerators/ddp_backend.py
Expand Up @@ -30,6 +30,9 @@
from pytorch_lightning.utilities.seed import seed_everything
from pytorch_lightning.distributed.dist import LightningDistributed
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.overrides.data_parallel import LightningDistributedDataParallel
from torch.nn.parallel import DistributedDataParallel
from typing import List


try:
Expand Down Expand Up @@ -265,7 +268,7 @@ def ddp_train(self, process_idx, model):
device_ids = self.get_device_ids()

# allow user to configure ddp
model = model.configure_ddp(model, device_ids)
model = self.configure_ddp(model, device_ids)

# set up training routine
self.barrier('ddp_setup')
Expand All @@ -279,6 +282,14 @@ def ddp_train(self, process_idx, model):

return results

def configure_ddp(
self, model: "LightningModule", device_ids: List[int]
) -> DistributedDataParallel:
model = LightningDistributedDataParallel(
model, device_ids=device_ids, find_unused_parameters=True
)
return model

def configure_sync_batchnorm(self, model: "LightningModule") -> "LightningModule":
"""
Add global batchnorm for a model spread across multiple GPUs and nodes.
Expand All @@ -293,4 +304,5 @@ def configure_sync_batchnorm(self, model: "LightningModule") -> "LightningModule
LightningModule with batchnorm layers synchronized between process groups
"""
model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model, process_group=None)

return model
14 changes: 13 additions & 1 deletion pytorch_lightning/accelerators/ddp_cpu_slurm_backend.py
Expand Up @@ -22,6 +22,9 @@
from pytorch_lightning.utilities.distributed import rank_zero_only
from pytorch_lightning.utilities.seed import seed_everything
from pytorch_lightning.distributed.dist import LightningDistributed
from pytorch_lightning.overrides.data_parallel import LightningDistributedDataParallel
from torch.nn.parallel import DistributedDataParallel
from typing import List


try:
Expand Down Expand Up @@ -159,7 +162,7 @@ def ddp_train(self, process_idx, model):
device_ids = self.get_device_ids()

# allow user to configure ddp
model = model.configure_ddp(model, device_ids)
model = self.configure_ddp(model, device_ids)

# set up training routine
self.trainer.train_loop.setup_training(model)
Expand All @@ -172,6 +175,14 @@ def ddp_train(self, process_idx, model):

return results

def configure_ddp(
self, model: "LightningModule", device_ids: List[int]
) -> DistributedDataParallel:
model = LightningDistributedDataParallel(
model, device_ids=device_ids, find_unused_parameters=True
)
return model

def configure_sync_batchnorm(self, model: "LightningModule") -> "LightningModule":
"""
Add global batchnorm for a model spread across multiple GPUs and nodes.
Expand All @@ -186,4 +197,5 @@ def configure_sync_batchnorm(self, model: "LightningModule") -> "LightningModule
LightningModule with batchnorm layers synchronized between process groups
"""
model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model, process_group=None)

return model
14 changes: 13 additions & 1 deletion pytorch_lightning/accelerators/ddp_cpu_spawn_backend.py
Expand Up @@ -26,6 +26,9 @@
from pytorch_lightning.utilities.distributed import rank_zero_only, rank_zero_warn
from pytorch_lightning.utilities.distributed import find_free_network_port
from pytorch_lightning.distributed.dist import LightningDistributed
from pytorch_lightning.overrides.data_parallel import LightningDistributedDataParallel
from torch.nn.parallel import DistributedDataParallel
from typing import List

try:
from hydra.core.hydra_config import HydraConfig
Expand Down Expand Up @@ -127,7 +130,7 @@ def ddp_train(self, process_idx, mp_queue, model):
device_ids = self.get_device_ids()

# allow user to configure ddp
model = model.configure_ddp(model, device_ids)
model = self.configure_ddp(model, device_ids)

# set up training routine
self.trainer.train_loop.setup_training(model)
Expand Down Expand Up @@ -205,6 +208,14 @@ def transfer_distrib_spawn_state_on_fit_end(self, model, mp_queue, results):
mp_queue.put(best_model_path)
mp_queue.put(results)

def configure_ddp(
self, model: "LightningModule", device_ids: List[int]
) -> DistributedDataParallel:
model = LightningDistributedDataParallel(
model, device_ids=device_ids, find_unused_parameters=True
)
return model

def configure_sync_batchnorm(self, model: "LightningModule") -> "LightningModule":
"""
Add global batchnorm for a model spread across multiple GPUs and nodes.
Expand All @@ -219,4 +230,5 @@ def configure_sync_batchnorm(self, model: "LightningModule") -> "LightningModule
LightningModule with batchnorm layers synchronized between process groups
"""
model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model, process_group=None)

return model
15 changes: 13 additions & 2 deletions pytorch_lightning/accelerators/ddp_cpu_torchelastic_backend.py
Expand Up @@ -22,7 +22,9 @@
from pytorch_lightning.utilities.distributed import rank_zero_only
from pytorch_lightning.utilities.seed import seed_everything
from pytorch_lightning.distributed.dist import LightningDistributed

from pytorch_lightning.overrides.data_parallel import LightningDistributedDataParallel
from torch.nn.parallel import DistributedDataParallel
from typing import List

try:
from hydra.utils import to_absolute_path, get_original_cwd
Expand Down Expand Up @@ -159,7 +161,7 @@ def ddp_train(self, process_idx, model):
device_ids = self.get_device_ids()

# allow user to configure ddp
model = model.configure_ddp(model, device_ids)
model = self.configure_ddp(model, device_ids)

# set up training routine
self.trainer.train_loop.setup_training(model)
Expand All @@ -172,6 +174,14 @@ def ddp_train(self, process_idx, model):

return results

def configure_ddp(
self, model: "LightningModule", device_ids: List[int]
) -> DistributedDataParallel:
model = LightningDistributedDataParallel(
model, device_ids=device_ids, find_unused_parameters=True
)
return model

def configure_sync_batchnorm(self, model: "LightningModule") -> "LightningModule":
"""
Add global batchnorm for a model spread across multiple GPUs and nodes.
Expand All @@ -186,4 +196,5 @@ def configure_sync_batchnorm(self, model: "LightningModule") -> "LightningModule
LightningModule with batchnorm layers synchronized between process groups
"""
model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model, process_group=None)

return model
15 changes: 13 additions & 2 deletions pytorch_lightning/accelerators/ddp_slurm_backend.py
Expand Up @@ -22,7 +22,9 @@
from pytorch_lightning.utilities.distributed import rank_zero_only
from pytorch_lightning.utilities.seed import seed_everything
from pytorch_lightning.distributed.dist import LightningDistributed

from pytorch_lightning.overrides.data_parallel import LightningDistributedDataParallel
from torch.nn.parallel import DistributedDataParallel
from typing import List

try:
from hydra.utils import to_absolute_path, get_original_cwd
Expand Down Expand Up @@ -165,7 +167,7 @@ def ddp_train(self, process_idx, model):
device_ids = self.get_device_ids()

# allow user to configure ddp
model = model.configure_ddp(model, device_ids)
model = self.configure_ddp(model, device_ids)

# set up training routine
self.trainer.train_loop.setup_training(model)
Expand All @@ -178,6 +180,14 @@ def ddp_train(self, process_idx, model):

return results

def configure_ddp(
self, model: "LightningModule", device_ids: List[int]
) -> DistributedDataParallel:
model = LightningDistributedDataParallel(
model, device_ids=device_ids, find_unused_parameters=True
)
return model

def configure_sync_batchnorm(self, model: "LightningModule") -> "LightningModule":
"""
Add global batchnorm for a model spread across multiple GPUs and nodes.
Expand All @@ -192,4 +202,5 @@ def configure_sync_batchnorm(self, model: "LightningModule") -> "LightningModule
LightningModule with batchnorm layers synchronized between process groups
"""
model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model, process_group=None)

return model
15 changes: 13 additions & 2 deletions pytorch_lightning/accelerators/ddp_spawn_backend.py
Expand Up @@ -27,7 +27,9 @@
from pytorch_lightning.utilities.seed import seed_everything
from pytorch_lightning.distributed.dist import LightningDistributed
from pytorch_lightning.utilities.distributed import find_free_network_port

from pytorch_lightning.overrides.data_parallel import LightningDistributedDataParallel
from torch.nn.parallel import DistributedDataParallel
from typing import List

try:
from hydra.core.hydra_config import HydraConfig
Expand Down Expand Up @@ -140,7 +142,7 @@ def ddp_train(self, process_idx, mp_queue, model, is_master=False, proc_offset=0
device_ids = self.get_device_ids()

# allow user to configure ddp
model = model.configure_ddp(model, device_ids)
model = self.configure_ddp(model, device_ids)

# set up training routine
self.trainer.train_loop.setup_training(model)
Expand Down Expand Up @@ -233,6 +235,14 @@ def transfer_distrib_spawn_state_on_fit_end(self, model, mp_queue, results):
atomic_save(model.state_dict(), last_path)
mp_queue.put(last_path)

def configure_ddp(
self, model: "LightningModule", device_ids: List[int]
) -> DistributedDataParallel:
model = LightningDistributedDataParallel(
model, device_ids=device_ids, find_unused_parameters=True
)
return model

def configure_sync_batchnorm(self, model: "LightningModule") -> "LightningModule":
"""
Add global batchnorm for a model spread across multiple GPUs and nodes.
Expand All @@ -247,4 +257,5 @@ def configure_sync_batchnorm(self, model: "LightningModule") -> "LightningModule
LightningModule with batchnorm layers synchronized between process groups
"""
model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model, process_group=None)

return model
14 changes: 13 additions & 1 deletion pytorch_lightning/accelerators/ddp_torchelastic_backend.py
Expand Up @@ -22,6 +22,9 @@
from pytorch_lightning.utilities.distributed import rank_zero_only
from pytorch_lightning.utilities.seed import seed_everything
from pytorch_lightning.distributed.dist import LightningDistributed
from pytorch_lightning.overrides.data_parallel import LightningDistributedDataParallel
from torch.nn.parallel import DistributedDataParallel
from typing import List


try:
Expand Down Expand Up @@ -161,7 +164,7 @@ def ddp_train(self, process_idx, model):
device_ids = self.get_device_ids()

# allow user to configure ddp
model = model.configure_ddp(model, device_ids)
model = self.configure_ddp(model, device_ids)

# set up training routine
self.trainer.train_loop.setup_training(model)
Expand All @@ -174,6 +177,14 @@ def ddp_train(self, process_idx, model):

return results

def configure_ddp(
self, model: "LightningModule", device_ids: List[int]
) -> DistributedDataParallel:
model = LightningDistributedDataParallel(
model, device_ids=device_ids, find_unused_parameters=True
)
return model

def configure_sync_batchnorm(self, model: "LightningModule") -> "LightningModule":
"""
Add global batchnorm for a model spread across multiple GPUs and nodes.
Expand All @@ -188,4 +199,5 @@ def configure_sync_batchnorm(self, model: "LightningModule") -> "LightningModule
LightningModule with batchnorm layers synchronized between process groups
"""
model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model, process_group=None)

return model

0 comments on commit 9c415d2

Please sign in to comment.