Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 12 additions & 4 deletions test/distributed/test_c10d_nccl.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@
from torch import nn
from torch._C._distributed_c10d import ErrorType, OpType, WorkResult
from torch.nn.parallel import DistributedDataParallel
from torch.testing._internal.common_cuda import TEST_MULTIGPU
from torch.testing._internal.common_cuda import TEST_MULTIGPU, _get_torch_rocm_version
from torch.testing._internal.common_distributed import (
get_timeout,
init_multigpu_helper,
Expand Down Expand Up @@ -4634,9 +4634,17 @@ def test_trace_while_active(self, timing_enabled, only_active):
else:
self.assertEqual(t[-1]["profiling_name"], "nccl:all_reduce")
self.assertEqual(t[-1]["collective_seq_id"], 2)
self.assertEqual(
t[-1]["state"], self.started_or_scheduled(timing_enabled)
)

#ROCm runtime used to call uSleep(20 µs)inside the default‑signal busy-wait loop.
#Now, this sleep is removed which lets the host thread spin continuously
#Therefore, the state can either be scheduled or started before test dumps the trace.
if torch.version.hip and _get_torch_rocm_version() >= (6,4) and timing_enabled:
assert(
t[-1]["state"] in ("scheduled", "started"))
else:
self.assertEqual(
t[-1]["state"], self.started_or_scheduled(timing_enabled)
)

self.parent.send("next")
self.assertEqual("next", self.parent.recv())
Expand Down