Skip to content

Add timeout to distributed torch tests to fail on hang#596

Merged
ipanfilo merged 4 commits into
devfrom
ipanfilo/pytorch_tests_timeout
May 30, 2026
Merged

Add timeout to distributed torch tests to fail on hang#596
ipanfilo merged 4 commits into
devfrom
ipanfilo/pytorch_tests_timeout

Conversation

@ipanfilo
Copy link
Copy Markdown
Collaborator

Description

Some distributed pytorch tests demonstrate random hang. Sometimes it triggers timeout in RCCL but in most cases it results in infinite wait.
The change in this PR adds test level timeout and torchrun wrapper to let it kill all child processes on timeout

Fixes # (issue)

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

Please list the changes introduced in this PR:

  • Add subprocess.run wrapper that supports timeout and graceful shutdown

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

Comment thread tests/pytorch/distributed/utils.py Outdated
Comment on lines +48 to +53
raise subprocess.CalledProcessError(
cmd,
kwargs.get("args", None),
output=stdout,
stderr=stderr
)
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The CalledProcessError constructor signature is CalledProcessError(returncode, cmd, output=None, stderr=None), but this call passes the wrong values in the wrong order: the cmd list goes where returncode (an int) is expected, and kwargs.get("args", None) is always None here because args was never inserted into kwargscmd is a positional parameter of this function, not a kwarg. The exception will still raise, but exc.returncode will be a list and exc.cmd will be None, breaking any caller that introspects the exception (and obscuring the failure in pytest tracebacks). Suggested fix:

Suggested change
raise subprocess.CalledProcessError(
cmd,
kwargs.get("args", None),
output=stdout,
stderr=stderr
)
# Handle check=True
if check and p.returncode != 0:
raise subprocess.CalledProcessError(
p.returncode,
cmd,
output=stdout,
stderr=stderr,
)

Comment thread tests/pytorch/distributed/utils.py Outdated
Comment on lines +33 to +44
except subprocess.TimeoutExpired:
p.terminate()
try:
# Give the process time to terminate gracefully
if capture_output:
stdout, stderr = p.communicate(timeout=timeout)
else:
p.wait(timeout=timeout)
except subprocess.TimeoutExpired:
os.killpg(p.pid, signal.SIGKILL)
if capture_output:
stdout, stderr = p.communicate()
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The PR description states the goal is to "kill all child processes on timeout", but p.terminate() only sends SIGTERM to p.pid (the session leader — i.e. torchrun), not to the whole process group. The actual hung workers are children of torchrun, and torchrun generally won't forward SIGTERM to them promptly. So the "graceful" stage almost never reclaims the workers; only the killpg(SIGKILL) fallback does, and that's gated behind a second full timeout wait. With timeout=120, the worst-case wall-clock before the hang is actually cleared is ~240s.

Two suggestions:

  1. Make the graceful stage match the stated intent by signaling the whole group: os.killpg(p.pid, signal.SIGTERM) instead of p.terminate().
  2. Use a short, bounded grace window (e.g. 10–30s) for the second wait/communicate rather than reusing timeout, so SIGKILL fires promptly when graceful shutdown is ignored.

Secondary concern: swallowing TimeoutExpired and returning a CompletedProcess with a negative returncode makes hang-killed runs indistinguishable from ordinary failures in the test logs. Since the whole point of this wrapper is to surface hangs, consider either re-raising TimeoutExpired (matching subprocess.run(timeout=...) semantics) or at minimum logging a clear "timed out, killed" line so CI failures are diagnosable.

@github-actions
Copy link
Copy Markdown

Review summary

Scope: 4 files, +79/-8 — adds tests/pytorch/distributed/utils.py with a run_proctree_with_timeout wrapper and switches three distributed tests (test_comm_gemm_overlap.py, test_torch_fsdp2.py, test_torch_fsdp2_fp8.py) to use it. CUDA path is preserved in two of them via 120 if IS_HIP_EXTENSION else None; test_torch_fsdp2_fp8.py (AMD-only file) applies the timeout unconditionally.

Verdict: changes are well-scoped to the stated goal, but the wrapper itself has two bugs worth fixing before merge (see inline comments on utils.py):

  • CalledProcessError is raised with constructor arguments swapped, so the exception's .returncode and .cmd carry the wrong values when check=True fails.
  • p.terminate() only signals the session leader (torchrun), not the whole process group, so the "graceful" stage doesn't accomplish the PR's stated "kill all child processes on timeout" goal — only the killpg(SIGKILL) fallback does, after a second full timeout wait. Also flagged: swallowing TimeoutExpired makes hang-killed runs visually indistinguishable from ordinary test failures.

Copyright headers: OK — utils.py (new ROCm-only) carries AMD 2026 + license; test_torch_fsdp2.py correctly adds AMD 2026 alongside the preserved NVIDIA copyright; the other two modified files have current-year AMD headers from prior PRs.

Copy link
Copy Markdown
Contributor

@alextmagro alextmagro left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM -- One comment, would it make sense to use the utils.py already present in pytorch/ or do you want to keep distributed/utils.py separate?

@ipanfilo
Copy link
Copy Markdown
Collaborator Author

LGTM -- One comment, would it make sense to use the utils.py already present in pytorch/ or do you want to keep distributed/utils.py separate?

Because the tests are not loaded as a module import of ../utils is problematic and requires playing with sys.path. That is the main reason of creating separate utils.py

Comment on lines +43 to +44
result = run_subprocess(test_cmd, 120 if IS_HIP_EXTENSION else None, env=os.environ,
check=True)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Instead of using run_proctree_with_timeout, could we just use coreutils' timeout command?

This could perhaps even become a pytest fixture:

# in conftest.py:
_DEFAULT_TIMEOUT = 120
_DEFAULT_KILL_AFTER = 30


def pytest_configure(config):
    config.addinivalue_line(
        "markers",
        "subprocess_timeout(seconds, kill_after=30): "
        "wrap subprocess.run calls with coreutils timeout to detect hangs",
    )


@pytest.fixture(autouse=True)
def subprocess_timeout(request, monkeypatch):
    marker = request.node.get_closest_marker("subprocess_timeout")
    if marker is None:
        return

    seconds = str(marker.args[0]) if marker.args else str(_DEFAULT_TIMEOUT)
    kill_after = str(marker.kwargs.get("kill_after", _DEFAULT_KILL_AFTER))

    original_run = subprocess.run

    def _run_with_timeout(cmd, *args, **kwargs):
        cmd = ["timeout", f"--kill-after={kill_after}", seconds] + list(cmd)
        result = original_run(cmd, *args, **kwargs)
        if result.returncode == 124:
            pytest.fail(f"Subprocess timed out after {seconds}s (hang detected)")
        return result

    monkeypatch.setattr(subprocess, "run", _run_with_timeout)

# for the test:
@pytest.mark.subprocess_timeout(20, kill_after=2)
def test_xyz():

This would have a few advantages:

  • No sys.path modification, shorter code
  • No code changes inside test functions
  • signal/kill whole process group natively

Copy link
Copy Markdown
Collaborator Author

@ipanfilo ipanfilo May 29, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is brilliant idea. It can be done even more simple by modifying command line right in tests. I've modified the PR. The first CI run is at https://github.com/ROCm/TransformerEngine/actions/runs/26611894136/job/78425262492

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks good.

There are no timeout-related messages in the distributed-test logs at https://github.com/ROCm/TransformerEngine/actions/runs/26611894136/job/78425262492 as far as I can tell, and the distributed tests passed, so the mechanism was not exercised in that run. Not sure if you want to test this further within this PR, but either way, good to go from my side.

@ipanfilo ipanfilo requested a review from matthiasdiener May 29, 2026 05:22
@ipanfilo ipanfilo merged commit dfbec23 into dev May 30, 2026
22 of 25 checks passed
@ipanfilo ipanfilo deleted the ipanfilo/pytorch_tests_timeout branch May 30, 2026 01:27
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ci-level 3 CI test level 3

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants