Skip to content

Commit

Permalink
[bugfix] Add mechanism to prevent deadlock for DDP on Exception Trigg…
Browse files Browse the repository at this point in the history
…er (#8167)

* add mechanism to prevent deadlock

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* resolve flake8 + update changelog

* update on comments

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* update

* remove space

* resolve bugs

* overwrite config

* update on comments

* update on comments

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* update

* update

* update test with comments

* Update pytorch_lightning/plugins/training_type/parallel.py

Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com>

* update on comments

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com>
Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com>
  • Loading branch information
3 people committed Jun 30, 2021
1 parent 9ead05e commit f3adf73
Show file tree
Hide file tree
Showing 5 changed files with 102 additions and 5 deletions.
53 changes: 50 additions & 3 deletions pytorch_lightning/plugins/training_type/ddp.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,12 @@
# limitations under the License.
import logging
import os
import shutil
import signal
import subprocess
import sys
import tempfile
import time
from time import sleep
from typing import Any, Dict, List, Optional, Union

Expand All @@ -36,7 +40,7 @@
rank_zero_warn,
)
from pytorch_lightning.utilities.distributed import rank_zero_info, rank_zero_only, ReduceOp, sync_ddp_if_available
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.exceptions import DeadlockDetectedException, MisconfigurationException
from pytorch_lightning.utilities.seed import reset_seed

if _HYDRA_AVAILABLE:
Expand Down Expand Up @@ -82,6 +86,8 @@ def __init__(
self._ddp_comm_state = ddp_comm_state
self._ddp_comm_hook = ddp_comm_hook
self._ddp_comm_wrapper = ddp_comm_wrapper
self._pids: Optional[List[int]] = None
self._sync_dir: Optional[str] = None
self.set_world_ranks()

@property
Expand Down Expand Up @@ -112,7 +118,6 @@ def setup_environment(self):
self.setup_distributed()

def _call_children_scripts(self):

# bookkeeping of spawned processes
assert self.local_rank == 0
self._check_can_spawn_children()
Expand All @@ -126,6 +131,9 @@ def _call_children_scripts(self):
os.environ["NODE_RANK"] = str(self.cluster_environment.node_rank())
os.environ["LOCAL_RANK"] = str(self.cluster_environment.local_rank())

# create a temporary directory used to synchronize processes on deadlock.
os.environ["PL_DDP_SYNC_TMPDIR"] = self._sync_dir = tempfile.mkdtemp()

# when user is using hydra find the absolute path
path_lib = os.path.abspath if not _HYDRA_AVAILABLE else to_absolute_path

Expand Down Expand Up @@ -281,7 +289,8 @@ def pre_dispatch(self):

self.configure_ddp()

self.barrier()
# share ddp pids to all processes
self._share_information_to_prevent_deadlock()

def post_dispatch(self) -> None:
self.cluster_environment.teardown()
Expand Down Expand Up @@ -344,3 +353,41 @@ def register_plugins(cls, plugin_registry: Dict) -> None:
description="DDP Plugin with `find_unused_parameters` as False",
find_unused_parameters=False
)

def _share_information_to_prevent_deadlock(self):
self._share_pids()

# remove `PL_DDP_SYNC_TMPDIR` from os.environ
self._sync_dir = os.environ.pop("PL_DDP_SYNC_TMPDIR", None)

def _share_pids(self):
"""
Make all DDP processes aware of all processes pids.
"""
self.barrier()
pids = self.all_gather(torch.tensor(os.getpid(), device=self.root_device))
pids = pids.cpu().numpy().tolist()
self._pids = pids if isinstance(pids, list) else [pids]

def reconciliate_processes(self, trace: str):
if self.world_size < 2:
return

sync_dir = self._sync_dir

# save a file locally.
torch.save(True, os.path.join(sync_dir, f"{self.global_rank}.pl"))

# sleep for a short time
time.sleep(3)

# return if all processes wrote a file in the `sync_dir`.
# todo (tchaton) Add support for non-shared file-system which will fail.
if len(os.listdir(sync_dir)) == self.world_size:
return

for pid in self._pids:
if pid != os.getpid():
os.kill(pid, signal.SIGKILL)
shutil.rmtree(sync_dir)
raise DeadlockDetectedException(f"DeadLock detected from rank: {self.global_rank} \n {trace}")
5 changes: 5 additions & 0 deletions pytorch_lightning/plugins/training_type/parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,11 @@ def distributed_sampler_kwargs(self):
distributed_sampler_kwargs = dict(num_replicas=len(self.parallel_devices), rank=self.global_rank)
return distributed_sampler_kwargs

def reconciliate_processes(self, trace: str):
"""
Function to re-conciliate processes on failure
"""

def all_gather(self, tensor: torch.Tensor, group: Optional[Any] = None, sync_grads: bool = False) -> torch.Tensor:
"""Perform a all_gather on all processes """
return all_gather_ddp_if_available(tensor, group=group, sync_grads=sync_grads)
Expand Down
5 changes: 5 additions & 0 deletions pytorch_lightning/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.
"""Trainer to automate the training."""
import logging
import traceback
import warnings
from datetime import timedelta
from itertools import count
Expand Down Expand Up @@ -61,6 +62,7 @@
from pytorch_lightning.tuner.tuning import Tuner
from pytorch_lightning.utilities import DeviceType, parsing, rank_zero_warn
from pytorch_lightning.utilities.debugging import InternalDebugger
from pytorch_lightning.utilities.distributed import distributed_available
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.memory import recursive_detach
from pytorch_lightning.utilities.model_helpers import is_overridden
Expand Down Expand Up @@ -902,6 +904,9 @@ def run_train(self) -> None:
self.state.stage = None
except BaseException:
self.state.status = TrainerStatus.INTERRUPTED
if distributed_available() and self.world_size > 1:
# try syncing remaing processes, kill otherwise
self.training_type_plugin.reconciliate_processes(traceback.format_exc())
# give accelerators a chance to finish
self.accelerator.on_train_end()
# reset bookkeeping
Expand Down
10 changes: 9 additions & 1 deletion pytorch_lightning/utilities/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,4 +14,12 @@


class MisconfigurationException(Exception):
pass
"""
Exception used to inform users of mis-use with PyTorch Lightning
"""


class DeadlockDetectedException(Exception):
"""
Exception used when a deadlock has been detected and processes are being killed
"""
34 changes: 33 additions & 1 deletion tests/trainer/test_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@
from pytorch_lightning.profiler import AdvancedProfiler, PassThroughProfiler, PyTorchProfiler, SimpleProfiler
from pytorch_lightning.trainer.states import TrainerFn
from pytorch_lightning.utilities.cloud_io import load as pl_load
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.exceptions import DeadlockDetectedException, MisconfigurationException
from pytorch_lightning.utilities.seed import seed_everything
from tests.base import EvalModelTemplate
from tests.helpers import BoringModel, RandomDataset
Expand Down Expand Up @@ -2079,3 +2079,35 @@ def test_module_current_fx_attributes_reset(tmpdir):
assert (
model._current_dataloader_idx is None
), f"_current_dataloader_idx not reset after test: {model._current_dataloader_idx}"


@RunIf(min_gpus=2, special=True)
def test_ddp_terminate_when_deadlock_is_detected(tmpdir):
""" Test that DDP kills the remaining processes when only one rank is throwing an exception. """

class CustomException(Exception):
pass

class TestModel(BoringModel):

def training_step(self, batch, batch_idx):
if batch_idx == 1 and self.trainer.is_global_zero:
# rank 0: raises an exception
# rank 1: continues training but will hang on the next barrier in the training loop
raise CustomException
return super().training_step(batch, batch_idx)

model = TestModel()

trainer = Trainer(
default_root_dir=tmpdir,
max_epochs=1,
limit_train_batches=5,
num_sanity_val_steps=0,
gpus=2,
accelerator="ddp",
)

# simulate random failure in training_step on rank 0
with pytest.raises(DeadlockDetectedException, match="CustomException"):
trainer.fit(model)

0 comments on commit f3adf73

Please sign in to comment.