diff --git a/test/distributed/_tools/test_sac_ilp.py b/test/distributed/_tools/test_sac_ilp.py index 6d1b4dfbb8857..777b93ea31640 100644 --- a/test/distributed/_tools/test_sac_ilp.py +++ b/test/distributed/_tools/test_sac_ilp.py @@ -20,9 +20,11 @@ from torch.testing._internal.common_cuda import TEST_CUDA from torch.testing._internal.common_utils import ( run_tests, - skipIfRocm, skipIfTorchDynamo, TestCase, + skipIfRocm, + skipIfRocmArch, + NAVI_ARCH, ) from torch.testing._internal.distributed._tensor.common_dtensor import ( ModelArgs, @@ -178,6 +180,7 @@ def test_sac_ilp_case1(self): @skipIfTorchDynamo("https://github.com/pytorch/pytorch/issues/115653") @unittest.skipIf(not TEST_CUDA, "CUDA not available") + @skipIfRocmArch(NAVI_ARCH) def test_sac_ilp_case2(self): """ This is a case where the memory budget is not binding, meaning that no diff --git a/test/distributed/pipelining/test_schedule_multiproc.py b/test/distributed/pipelining/test_schedule_multiproc.py index 8491881f7fe23..831e6ffe1d595 100644 --- a/test/distributed/pipelining/test_schedule_multiproc.py +++ b/test/distributed/pipelining/test_schedule_multiproc.py @@ -604,7 +604,7 @@ def test_schedule_with_native_zero_bubble(self, ScheduleClass): for name, p in stage_module.named_parameters(): ref_p = ref_submod.get_parameter(name) try: - torch.testing.assert_close(p.grad, ref_p.grad, rtol=1e-5, atol=4e-5) + torch.testing.assert_close(p.grad, ref_p.grad, rtol=1e-5, atol=9e-5) except AssertionError: print( f"Parameter test failed for {submod_name}.{name}: {p.grad} vs {ref_p.grad}" diff --git a/test/distributed/tensor/debug/test_comm_mode_features.py b/test/distributed/tensor/debug/test_comm_mode_features.py index bf1f14a0ec34b..cbd0ff0e8508b 100644 --- a/test/distributed/tensor/debug/test_comm_mode_features.py +++ b/test/distributed/tensor/debug/test_comm_mode_features.py @@ -24,7 +24,8 @@ with_comms, ) - +from torch.testing._internal.common_cuda import PLATFORM_SUPPORTS_FUSED_ATTENTION +import unittest c10d_functional = torch.ops.c10d_functional @@ -221,6 +222,7 @@ def test_MLP_module_tracing(self): @skip_unless_torch_gpu @with_comms + @unittest.skipIf(not PLATFORM_SUPPORTS_FUSED_ATTENTION, "Does not support fused scaled dot product attention") def test_transformer_module_tracing(self, is_seq_parallel=False): """ tests module-level tracing for more complicated transformer module and