diff --git a/test/distributed/_tensor/debug/test_comm_mode_features.py b/test/distributed/_tensor/debug/test_comm_mode_features.py index fc19cddb58f4a..764f5e3c67b4a 100644 --- a/test/distributed/_tensor/debug/test_comm_mode_features.py +++ b/test/distributed/_tensor/debug/test_comm_mode_features.py @@ -24,7 +24,8 @@ with_comms, ) - +from torch.testing._internal.common_cuda import PLATFORM_SUPPORTS_FUSED_ATTENTION +import unittest c10d_functional = torch.ops.c10d_functional @@ -221,6 +222,7 @@ def test_MLP_module_tracing(self): @skip_unless_torch_gpu @with_comms + @unittest.skipIf(not PLATFORM_SUPPORTS_FUSED_ATTENTION, "Does not support fused scaled dot product attention") def test_transformer_module_tracing(self, is_seq_parallel=False): """ tests module-level tracing for more complicated transformer module and diff --git a/test/distributed/_tools/test_sac_ilp.py b/test/distributed/_tools/test_sac_ilp.py index 282689a2fc4b0..62c6afe76829c 100644 --- a/test/distributed/_tools/test_sac_ilp.py +++ b/test/distributed/_tools/test_sac_ilp.py @@ -24,7 +24,7 @@ skipIfTorchDynamo, TestCase, skipIfRocmArch, - NAVI4_ARCH, + NAVI_ARCH, ) from torch.testing._internal.distributed._tensor.common_dtensor import ( @@ -138,7 +138,7 @@ def _collect_module_info_with_fake_tensor_mode(self) -> ModuleInfo: @skipIfTorchDynamo("https://github.com/pytorch/pytorch/issues/115653") @unittest.skipIf(not TEST_CUDA, "CUDA not available") - @skipIfRocmArch(NAVI4_ARCH) + @skipIfRocmArch(NAVI_ARCH) def test_sac_ilp_case1(self): """ This is a case where the memory budget is either binding or too tight, @@ -181,6 +181,7 @@ def test_sac_ilp_case1(self): @skipIfTorchDynamo("https://github.com/pytorch/pytorch/issues/115653") @unittest.skipIf(not TEST_CUDA, "CUDA not available") + @skipIfRocmArch(NAVI_ARCH) def test_sac_ilp_case2(self): """ This is a case where the memory budget is not binding, meaning that no diff --git a/test/distributed/pipelining/test_schedule_multiproc.py b/test/distributed/pipelining/test_schedule_multiproc.py index f41c06b6b3168..f58686e2f23d1 100644 --- a/test/distributed/pipelining/test_schedule_multiproc.py +++ b/test/distributed/pipelining/test_schedule_multiproc.py @@ -599,7 +599,7 @@ def test_schedule_with_native_zero_bubble(self, ScheduleClass): for name, p in stage_module.named_parameters(): ref_p = ref_submod.get_parameter(name) try: - torch.testing.assert_close(p.grad, ref_p.grad, rtol=1e-5, atol=4e-5) + torch.testing.assert_close(p.grad, ref_p.grad, rtol=1e-5, atol=9e-5) except AssertionError: print( f"Parameter test failed for {submod_name}.{name}: {p.grad} vs {ref_p.grad}"