Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 26 additions & 3 deletions pytorch_lightning/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,6 +180,7 @@ def __init__(self,
if self.logger is None:
self.logger = TestTubeLogger(
save_dir=self.default_save_path,
version=self.slurm_job_id,
name='lightning_logs'
)

Expand Down Expand Up @@ -240,6 +241,15 @@ def __init__(self,
self.amp_level = amp_level
self.__init_amp(use_amp)

@property
def slurm_job_id(self):
try:
job_id = os.environ['SLURM_JOB_ID']
job_id = int(job_id)
except Exception as e:
job_id = None
return job_id

def __configure_weights_path(self, checkpoint_callback, weights_save_path):
"""
Weight path set in this priority:
Expand Down Expand Up @@ -882,12 +892,25 @@ def __init_tcp_connection(self):
:param tries:
:return:
"""
# sets the appropriate port

# use slurm job id for the port number
# guarantees unique ports across jobs from same grid search
try:
# use the last 4 numbers in the job id as the id
default_port = os.environ['SLURM_JOB_ID']
default_port = default_port[-4:]

# all ports should be in the 10k+ range
default_port = int(default_port) + 10000

except Exception as e:
default_port = 12910

# if user gave a port number, use that one instead
try:
port = os.environ['MASTER_PORT']
except Exception:
port = 12910
os.environ['MASTER_PORT'] = str(port)
os.environ['MASTER_PORT'] = str(default_port)

# figure out the root node addr
try:
Expand Down