Skip to content

Commit

Permalink
Pid port + duplicate rank_zero logging (#2231)
Browse files Browse the repository at this point in the history
* init the port using a seed that matches process id for ddp

* init the port using a seed that matches process id for ddp

* init the port using a seed that matches process id for ddp

* init the port using a seed that matches process id for ddp

* init the port using a seed that matches process id for ddp

* init the port using a seed that matches process id for ddp

* init the port using a seed that matches process id for ddp

Co-authored-by: Zhaofeng Wu <zfw7@cs.washington.edu>
  • Loading branch information
williamFalcon and ZhaofengWu committed Jun 18, 2020
1 parent 15cf6a8 commit 476911d
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 5 deletions.
10 changes: 6 additions & 4 deletions pytorch_lightning/trainer/distrib_data_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -372,14 +372,16 @@ def set_nvidia_flags(self, is_slurm_managing_tasks, data_parallel_device_ids):
def __set_random_port(self):
"""
When running DDP NOT managed by SLURM, the ports might collide
:return:
"""
try:
default_port = os.environ['MASTER_PORT']
except Exception:
import random
default_port = random.randint(10000, 19000)
os.environ['MASTER_PORT'] = str(default_port)
# use the process id as a seed to a generator for port only
pid = os.getpid()
rng1 = np.random.RandomState(pid)
default_port = rng1.randint(10000, 19999, 1)[0]

os.environ['MASTER_PORT'] = str(default_port)

def spawn_ddp_children(self, model):
self.__set_random_port()
Expand Down
11 changes: 10 additions & 1 deletion pytorch_lightning/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
from pytorch_lightning.trainer.training_tricks import TrainerTrainingTricksMixin
from pytorch_lightning.trainer.lr_finder import TrainerLRFinderMixin
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities import rank_zero_warn, parsing, rank_zero_info
from pytorch_lightning.utilities import rank_zero_warn, parsing, rank_zero_info, rank_zero_only

try:
from apex import amp
Expand Down Expand Up @@ -322,6 +322,14 @@ def __init__(
# https://github.com/PyTorchLightning/pytorch-lightning/pull/1572/files#r420279383
os.environ["HOROVOD_FUSION_THRESHOLD"] = str(0)

# init the default rank if exists
# we need to call this here or NVIDIA flags and other messaging in init will show on all ranks
# this way we only show it on rank 0
if 'LOCAL_RANK' in os.environ:
rank_zero_only.rank = os.environ['LOCAL_RANK']
if 'SLURM_JOB_ID' in os.environ:
rank_zero_only.rank = os.environ['SLURM_JOB_ID']

# Init callbacks
self.prepare_data_per_node = prepare_data_per_node
self.callbacks = callbacks or []
Expand Down Expand Up @@ -892,6 +900,7 @@ def fit(
mp.spawn(self.ddp_train, nprocs=self.num_processes, args=(model,))

elif self.distributed_backend == 'ddp_spawn':
self.__set_random_port()
model.share_memory()

# spin up peers
Expand Down

0 comments on commit 476911d

Please sign in to comment.