Skip to content

fix: distributed training hanging with submit_job#141

Merged
HennerM merged 1 commit into
mainfrom
fix/distributed-training-submit
Apr 3, 2026
Merged

fix: distributed training hanging with submit_job#141
HennerM merged 1 commit into
mainfrom
fix/distributed-training-submit

Conversation

@HennerM
Copy link
Copy Markdown
Collaborator

@HennerM HennerM commented Apr 2, 2026

Distributed training was not working from noether-train-submit-job CLI.

There were two main issues:

  1. The sbatch command we are using for submit-job was running python directly, instead of invocating srun, this means that we only run one process instead of one process per task as intended. srun makes sure of that and spawns the right number of processes.
  2. A more subtle problem was around assigning of the used device per process that takes part in the distributed communication. Previously we relied on CUDA_VISIBLE_DEVICES being set and if it's set restrict it to the one device that corresponds to the local rank. The intention with this is that once we only have one visible device per process, it is obvious to torch and nccl which one should be used to form the process group. However, changing the environment variable is only really safe before initalising CUDA and NCCL, which can't always be guaranteed. Instead now we are just using the cuda.set_device function to set the default device based on the local rank, and also pass the chosen device to init_process_group.

One problem remains but this might be related to the installed CUDA version, which is that if gpus_per_task is set instead of gpus_per_node, normally we should just get one GPU assigned; this doesn't work though. Whenever the SLURM restriction of GPU per process is used, with the current installed version of CUDA, we get an error in ncclCommInitRankConfig:

ncclUnhandledCudaError: Call to CUDA function failed.
Last error:
Cuda failure 'invalid argument

which seems to be coming from https://github.com/pytorch/pytorch/blob/v2.11.0/torch/csrc/distributed/c10d/NCCLUtils.cpp#L70. This is not a big problem though because we can always set the GPUs per node as well, only if really necessary we can come back to this.

A few small changes I introduced with this fix:

  • I added the option to set the distributed timeout per env variable and set the default to be 2 minutes instead of the 10 minute default PyTorch is using. This is helpful for debugging this kind of issues.
  • I added an exit handler to call destroy_process_group

@HennerM HennerM force-pushed the fix/distributed-training-submit branch 2 times, most recently from fcc2095 to 4cd9a0f Compare April 2, 2026 12:11
@github-actions
Copy link
Copy Markdown
Contributor

github-actions Bot commented Apr 2, 2026

Coverage

Tests Skipped Failures Errors Time
1204 21 💤 0 ❌ 0 🔥 25.758s ⏱️

Copy link
Copy Markdown
Member

@Ndles Ndles left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks for digging into it! I think if the CUDA version becomes a problem - we need to raise this to Fabian and Furkan.

Comment thread src/noether/core/distributed/run/managed.py
@HennerM HennerM force-pushed the fix/distributed-training-submit branch from 4cd9a0f to 7dbcf75 Compare April 2, 2026 13:51
@HennerM HennerM force-pushed the fix/distributed-training-submit branch from 7dbcf75 to b8397af Compare April 2, 2026 20:56
@HennerM HennerM merged commit 0fbd090 into main Apr 3, 2026
9 checks passed
@HennerM HennerM deleted the fix/distributed-training-submit branch April 3, 2026 07:02
@github-actions github-actions Bot locked and limited conversation to collaborators Apr 3, 2026
Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants