From e543ef898f9a9b8949059f0e5a9baa09132094b3 Mon Sep 17 00:00:00 2001 From: Sudhakar Singh Date: Mon, 1 Jun 2026 17:08:56 -0700 Subject: [PATCH] [PyTorch] Isolate CP pool worker stdout from NCCL/library banners MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The CP attention test pool worker (PR #2993) uses rank-0 stdout as a line-delimited JSON protocol channel to the parent pytest process. When NCCL_DEBUG is set to VERSION or INFO, NCCL writes an "NCCL version ...\n" banner to stdout (fd 1); that banner reaches the parent ahead of the first JSON response, json.loads raises, and because the pool fixture is session-scoped and killed on first failure, every subsequent CP test in the file fails. CI runners that export NCCL_DEBUG hit this on all ~200 non-skipped cases. The prior mitigation only redirected non-rank-0 stdout to /dev/null, so rank 0's own banner still corrupted the stream. Fix it at the source: in the worker, before importing torch, dup rank-0's original fd 1 into a private stream reserved for JSON, then point fd 1 at stderr. Any banner from NCCL/cuBLAS/cuDNN/torch (Python or C level, import-time or runtime) now lands on stderr — still drained into CI logs by the parent's stderr thread — instead of the protocol pipe. Combined with the existing non-rank-0 /dev/null redirect, the pipe carries only rank-0 JSON, so the parent's single-line read needs no change. Validated on 8xH100 (TE built from this commit). With NCCL_DEBUG=VERSION and =INFO, flash (p2p/all_gather/a2a) and fused cases pass and zero non-protocol lines reach the pipe; without the fix the same cases fail with "pool worker JSON protocol broke". Control with NCCL_DEBUG unset also passes. Signed-off-by: Sudhakar Singh --- .../attention/run_attention_with_cp_pool.py | 30 +++++++++++++++++-- .../attention/test_attention_with_cp.py | 14 ++++----- 2 files changed, 35 insertions(+), 9 deletions(-) diff --git a/tests/pytorch/attention/run_attention_with_cp_pool.py b/tests/pytorch/attention/run_attention_with_cp_pool.py index 3e5f64a429..67d10ebda5 100644 --- a/tests/pytorch/attention/run_attention_with_cp_pool.py +++ b/tests/pytorch/attention/run_attention_with_cp_pool.py @@ -32,6 +32,28 @@ import time import traceback +# Rank-0 stdout is the JSON protocol channel to the parent. NCCL/cuBLAS/cuDNN +# (and torch itself) can write banners to fd 1 — e.g. NCCL emits +# "NCCL version ...\n" to stdout when NCCL_DEBUG is VERSION or INFO — which +# would corrupt the JSON stream. Reserve a private dup of the original fd 1 for +# JSON and point fd 1 at stderr, BEFORE importing torch, so even import-time / +# C-level writes land on stderr (still visible in CI logs, drained by the +# parent's stderr thread) instead of the protocol pipe. +_JSON_STDOUT = None + + +def _redirect_rank0_stdout_to_stderr() -> None: + """Reserve rank-0's original stdout fd for the JSON channel; send the rest to stderr.""" + global _JSON_STDOUT + if os.environ.get("RANK") != "0": + return + _JSON_STDOUT = os.fdopen(os.dup(1), "w", buffering=1) + os.dup2(2, 1) + sys.stdout = sys.stderr + + +_redirect_rank0_stdout_to_stderr() + import torch import torch.distributed as dist @@ -53,8 +75,8 @@ def _recv_request(rank: int) -> dict: def _send_response(rank: int, payload: dict) -> None: if rank == 0: - sys.stdout.write(json.dumps(payload) + "\n") - sys.stdout.flush() + _JSON_STDOUT.write(json.dumps(payload) + "\n") + _JSON_STDOUT.flush() def _silence_non_rank0_stdout(rank: int) -> None: @@ -64,6 +86,10 @@ def _silence_non_rank0_stdout(rank: int) -> None: so Python/library writes on rank>0 would interleave with rank 0's JSON protocol on the parent's pipe. Closing fd 1 at the OS level on rank>0 catches both Python (``print``) and C-level (NCCL, etc.) writes. + + Rank 0 is handled differently (``_redirect_rank0_stdout_to_stderr``): its + fd 1 also goes to stderr, but a dup of the original is kept for the JSON + channel. rank>0 has nothing to preserve, so it just goes to /dev/null. """ if rank == 0: return diff --git a/tests/pytorch/attention/test_attention_with_cp.py b/tests/pytorch/attention/test_attention_with_cp.py index a03f51f6c9..59b0e0bdbf 100644 --- a/tests/pytorch/attention/test_attention_with_cp.py +++ b/tests/pytorch/attention/test_attention_with_cp.py @@ -214,11 +214,11 @@ def _submit_once(self, kwargs: dict, timeout: float) -> None: self._kill() raise AssertionError(msg) - # Worker redirects non-rank-0 stdout to /dev/null at fd level, so - # rank 0's JSON line is the only thing that arrives on this pipe. - # select() on a pipe fd is Linux/macOS only — on Windows the select - # module only accepts sockets. CP attention tests run on Linux GPU - # hosts so this is fine; flag if portability is ever needed. + # The worker reserves rank-0's stdout fd for this JSON channel and sends + # every other rank's stdout (and rank 0's own library/Python writes) to + # /dev/null or stderr, so the only thing on this pipe is the response + # line. select() on a pipe fd is Linux/macOS only — fine for the GPU + # hosts these tests run on. ready, _, _ = select.select([self.proc.stdout], [], [], timeout) if not ready: msg = self._diag( @@ -234,8 +234,8 @@ def _submit_once(self, kwargs: dict, timeout: float) -> None: self._kill() raise AssertionError(msg) - # A stray non-JSON line from rank 0 would desynchronize the protocol; - # turn it into a clear test failure rather than a raw JSONDecodeError. + # A non-JSON line means stdout isolation failed somewhere; surface it + # clearly rather than as a raw JSONDecodeError. try: resp = json.loads(line) except json.JSONDecodeError as e: