diff --git a/test/distributed/test_c10d_nccl.py b/test/distributed/test_c10d_nccl.py index 158fa46633c3e..fbd4a7b1e6d22 100644 --- a/test/distributed/test_c10d_nccl.py +++ b/test/distributed/test_c10d_nccl.py @@ -39,7 +39,7 @@ from torch import nn from torch._C._distributed_c10d import ErrorType, OpType, WorkResult from torch.nn.parallel import DistributedDataParallel -from torch.testing._internal.common_cuda import TEST_MULTIGPU +from torch.testing._internal.common_cuda import TEST_MULTIGPU, _get_torch_rocm_version from torch.testing._internal.common_distributed import ( get_timeout, init_multigpu_helper, @@ -4634,9 +4634,17 @@ def test_trace_while_active(self, timing_enabled, only_active): else: self.assertEqual(t[-1]["profiling_name"], "nccl:all_reduce") self.assertEqual(t[-1]["collective_seq_id"], 2) - self.assertEqual( - t[-1]["state"], self.started_or_scheduled(timing_enabled) - ) + + #ROCm runtime used to call uSleep(20 µs)inside the default‑signal busy-wait loop. + #Now, this sleep is removed which lets the host thread spin continuously + #Therefore, the state can either be scheduled or started before test dumps the trace. + if torch.version.hip and _get_torch_rocm_version() >= (6,4) and timing_enabled: + assert( + t[-1]["state"] in ("scheduled", "started")) + else: + self.assertEqual( + t[-1]["state"], self.started_or_scheduled(timing_enabled) + ) self.parent.send("next") self.assertEqual("next", self.parent.recv())