Add timeout to distributed torch tests to fail on hang#596
Conversation
| raise subprocess.CalledProcessError( | ||
| cmd, | ||
| kwargs.get("args", None), | ||
| output=stdout, | ||
| stderr=stderr | ||
| ) |
There was a problem hiding this comment.
The CalledProcessError constructor signature is CalledProcessError(returncode, cmd, output=None, stderr=None), but this call passes the wrong values in the wrong order: the cmd list goes where returncode (an int) is expected, and kwargs.get("args", None) is always None here because args was never inserted into kwargs — cmd is a positional parameter of this function, not a kwarg. The exception will still raise, but exc.returncode will be a list and exc.cmd will be None, breaking any caller that introspects the exception (and obscuring the failure in pytest tracebacks). Suggested fix:
| raise subprocess.CalledProcessError( | |
| cmd, | |
| kwargs.get("args", None), | |
| output=stdout, | |
| stderr=stderr | |
| ) | |
| # Handle check=True | |
| if check and p.returncode != 0: | |
| raise subprocess.CalledProcessError( | |
| p.returncode, | |
| cmd, | |
| output=stdout, | |
| stderr=stderr, | |
| ) |
| except subprocess.TimeoutExpired: | ||
| p.terminate() | ||
| try: | ||
| # Give the process time to terminate gracefully | ||
| if capture_output: | ||
| stdout, stderr = p.communicate(timeout=timeout) | ||
| else: | ||
| p.wait(timeout=timeout) | ||
| except subprocess.TimeoutExpired: | ||
| os.killpg(p.pid, signal.SIGKILL) | ||
| if capture_output: | ||
| stdout, stderr = p.communicate() |
There was a problem hiding this comment.
The PR description states the goal is to "kill all child processes on timeout", but p.terminate() only sends SIGTERM to p.pid (the session leader — i.e. torchrun), not to the whole process group. The actual hung workers are children of torchrun, and torchrun generally won't forward SIGTERM to them promptly. So the "graceful" stage almost never reclaims the workers; only the killpg(SIGKILL) fallback does, and that's gated behind a second full timeout wait. With timeout=120, the worst-case wall-clock before the hang is actually cleared is ~240s.
Two suggestions:
- Make the graceful stage match the stated intent by signaling the whole group:
os.killpg(p.pid, signal.SIGTERM)instead ofp.terminate(). - Use a short, bounded grace window (e.g. 10–30s) for the second
wait/communicaterather than reusingtimeout, so SIGKILL fires promptly when graceful shutdown is ignored.
Secondary concern: swallowing TimeoutExpired and returning a CompletedProcess with a negative returncode makes hang-killed runs indistinguishable from ordinary failures in the test logs. Since the whole point of this wrapper is to surface hangs, consider either re-raising TimeoutExpired (matching subprocess.run(timeout=...) semantics) or at minimum logging a clear "timed out, killed" line so CI failures are diagnosable.
|
Review summary Scope: 4 files, +79/-8 — adds Verdict: changes are well-scoped to the stated goal, but the wrapper itself has two bugs worth fixing before merge (see inline comments on
Copyright headers: OK — |
alextmagro
left a comment
There was a problem hiding this comment.
LGTM -- One comment, would it make sense to use the utils.py already present in pytorch/ or do you want to keep distributed/utils.py separate?
Because the tests are not loaded as a module import of ../utils is problematic and requires playing with sys.path. That is the main reason of creating separate utils.py |
| result = run_subprocess(test_cmd, 120 if IS_HIP_EXTENSION else None, env=os.environ, | ||
| check=True) |
There was a problem hiding this comment.
Instead of using run_proctree_with_timeout, could we just use coreutils' timeout command?
This could perhaps even become a pytest fixture:
# in conftest.py:
_DEFAULT_TIMEOUT = 120
_DEFAULT_KILL_AFTER = 30
def pytest_configure(config):
config.addinivalue_line(
"markers",
"subprocess_timeout(seconds, kill_after=30): "
"wrap subprocess.run calls with coreutils timeout to detect hangs",
)
@pytest.fixture(autouse=True)
def subprocess_timeout(request, monkeypatch):
marker = request.node.get_closest_marker("subprocess_timeout")
if marker is None:
return
seconds = str(marker.args[0]) if marker.args else str(_DEFAULT_TIMEOUT)
kill_after = str(marker.kwargs.get("kill_after", _DEFAULT_KILL_AFTER))
original_run = subprocess.run
def _run_with_timeout(cmd, *args, **kwargs):
cmd = ["timeout", f"--kill-after={kill_after}", seconds] + list(cmd)
result = original_run(cmd, *args, **kwargs)
if result.returncode == 124:
pytest.fail(f"Subprocess timed out after {seconds}s (hang detected)")
return result
monkeypatch.setattr(subprocess, "run", _run_with_timeout)
# for the test:
@pytest.mark.subprocess_timeout(20, kill_after=2)
def test_xyz():This would have a few advantages:
- No sys.path modification, shorter code
- No code changes inside test functions
- signal/kill whole process group natively
There was a problem hiding this comment.
It is brilliant idea. It can be done even more simple by modifying command line right in tests. I've modified the PR. The first CI run is at https://github.com/ROCm/TransformerEngine/actions/runs/26611894136/job/78425262492
There was a problem hiding this comment.
This looks good.
There are no timeout-related messages in the distributed-test logs at https://github.com/ROCm/TransformerEngine/actions/runs/26611894136/job/78425262492 as far as I can tell, and the distributed tests passed, so the mechanism was not exercised in that run. Not sure if you want to test this further within this PR, but either way, good to go from my side.
Description
Some distributed pytorch tests demonstrate random hang. Sometimes it triggers timeout in RCCL but in most cases it results in infinite wait.
The change in this PR adds test level timeout and torchrun wrapper to let it kill all child processes on timeout
Fixes # (issue)
Type of change
Changes
Please list the changes introduced in this PR:
Checklist: