diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index cb0e9b858cd3f..093e9d14cd985 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -180,6 +180,7 @@ def __init__(self, if self.logger is None: self.logger = TestTubeLogger( save_dir=self.default_save_path, + version=self.slurm_job_id, name='lightning_logs' ) @@ -240,6 +241,15 @@ def __init__(self, self.amp_level = amp_level self.__init_amp(use_amp) + @property + def slurm_job_id(self): + try: + job_id = os.environ['SLURM_JOB_ID'] + job_id = int(job_id) + except Exception as e: + job_id = None + return job_id + def __configure_weights_path(self, checkpoint_callback, weights_save_path): """ Weight path set in this priority: @@ -882,12 +892,25 @@ def __init_tcp_connection(self): :param tries: :return: """ - # sets the appropriate port + + # use slurm job id for the port number + # guarantees unique ports across jobs from same grid search + try: + # use the last 4 numbers in the job id as the id + default_port = os.environ['SLURM_JOB_ID'] + default_port = default_port[-4:] + + # all ports should be in the 10k+ range + default_port = int(default_port) + 10000 + + except Exception as e: + default_port = 12910 + + # if user gave a port number, use that one instead try: port = os.environ['MASTER_PORT'] except Exception: - port = 12910 - os.environ['MASTER_PORT'] = str(port) + os.environ['MASTER_PORT'] = str(default_port) # figure out the root node addr try: