Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 28 additions & 2 deletions tests/pytorch/attention/run_attention_with_cp_pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,28 @@
import time
import traceback

# Rank-0 stdout is the JSON protocol channel to the parent. NCCL/cuBLAS/cuDNN
# (and torch itself) can write banners to fd 1 — e.g. NCCL emits
# "NCCL version ...\n" to stdout when NCCL_DEBUG is VERSION or INFO — which
# would corrupt the JSON stream. Reserve a private dup of the original fd 1 for
# JSON and point fd 1 at stderr, BEFORE importing torch, so even import-time /
# C-level writes land on stderr (still visible in CI logs, drained by the
# parent's stderr thread) instead of the protocol pipe.
_JSON_STDOUT = None


def _redirect_rank0_stdout_to_stderr() -> None:
"""Reserve rank-0's original stdout fd for the JSON channel; send the rest to stderr."""
global _JSON_STDOUT
if os.environ.get("RANK") != "0":
return
_JSON_STDOUT = os.fdopen(os.dup(1), "w", buffering=1)
os.dup2(2, 1)
sys.stdout = sys.stderr
Comment thread
sudhakarsingh27 marked this conversation as resolved.


_redirect_rank0_stdout_to_stderr()
Comment thread
sudhakarsingh27 marked this conversation as resolved.

import torch
import torch.distributed as dist

Expand All @@ -53,8 +75,8 @@ def _recv_request(rank: int) -> dict:

def _send_response(rank: int, payload: dict) -> None:
if rank == 0:
sys.stdout.write(json.dumps(payload) + "\n")
sys.stdout.flush()
_JSON_STDOUT.write(json.dumps(payload) + "\n")
_JSON_STDOUT.flush()


def _silence_non_rank0_stdout(rank: int) -> None:
Expand All @@ -64,6 +86,10 @@ def _silence_non_rank0_stdout(rank: int) -> None:
so Python/library writes on rank>0 would interleave with rank 0's JSON
protocol on the parent's pipe. Closing fd 1 at the OS level on rank>0
catches both Python (``print``) and C-level (NCCL, etc.) writes.

Rank 0 is handled differently (``_redirect_rank0_stdout_to_stderr``): its
fd 1 also goes to stderr, but a dup of the original is kept for the JSON
channel. rank>0 has nothing to preserve, so it just goes to /dev/null.
"""
if rank == 0:
return
Expand Down
14 changes: 7 additions & 7 deletions tests/pytorch/attention/test_attention_with_cp.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,11 +214,11 @@ def _submit_once(self, kwargs: dict, timeout: float) -> None:
self._kill()
raise AssertionError(msg)

# Worker redirects non-rank-0 stdout to /dev/null at fd level, so
# rank 0's JSON line is the only thing that arrives on this pipe.
# select() on a pipe fd is Linux/macOS only on Windows the select
# module only accepts sockets. CP attention tests run on Linux GPU
# hosts so this is fine; flag if portability is ever needed.
# The worker reserves rank-0's stdout fd for this JSON channel and sends
# every other rank's stdout (and rank 0's own library/Python writes) to
# /dev/null or stderr, so the only thing on this pipe is the response
# line. select() on a pipe fd is Linux/macOS only — fine for the GPU
# hosts these tests run on.
ready, _, _ = select.select([self.proc.stdout], [], [], timeout)
if not ready:
msg = self._diag(
Expand All @@ -234,8 +234,8 @@ def _submit_once(self, kwargs: dict, timeout: float) -> None:
self._kill()
raise AssertionError(msg)

# A stray non-JSON line from rank 0 would desynchronize the protocol;
# turn it into a clear test failure rather than a raw JSONDecodeError.
# A non-JSON line means stdout isolation failed somewhere; surface it
# clearly rather than as a raw JSONDecodeError.
try:
resp = json.loads(line)
except json.JSONDecodeError as e:
Expand Down
Loading