From db5f9339de4d4fad2d449afc48bc355b9f459e02 Mon Sep 17 00:00:00 2001 From: Dmitry Nikolaev <139769634+dnikolaev-amd@users.noreply.github.com> Date: Mon, 21 Jul 2025 19:01:08 +0200 Subject: [PATCH] [rocm7.0_internal_testing] skip test_transformer_req_grad on Navi32/Navi4x (#2385) Skip `distributed/tensor/parallel/test_tp_examples.py::DistTensorParallelExampleTest::test_transformer_req_grad_* ` on Navi32/Navi4x Fixes SWDEV-510742 --- test/distributed/tensor/parallel/test_tp_examples.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/test/distributed/tensor/parallel/test_tp_examples.py b/test/distributed/tensor/parallel/test_tp_examples.py index 9b412f88440f1..194735f912368 100644 --- a/test/distributed/tensor/parallel/test_tp_examples.py +++ b/test/distributed/tensor/parallel/test_tp_examples.py @@ -27,6 +27,8 @@ RowwiseParallel, ) from torch.distributed.tensor.parallel.input_reshard import input_reshard +from torch.testing._internal.common_cuda import PLATFORM_SUPPORTS_FUSED_ATTENTION +from torch.testing._internal.common_device_type import skipIf from torch.testing._internal.common_utils import ( instantiate_parametrized_tests, parametrize, @@ -412,6 +414,7 @@ def test_transformer_training(self, is_seq_parallel, dtype: torch.dtype): + f"{str(dtype).split('.')[-1]}_" + f"thaw_{'__'.join(sorted({n.rpartition('.')[0].replace('.', '_') for n in thaw})) if thaw else 'all'}", ) + @skipIf(not PLATFORM_SUPPORTS_FUSED_ATTENTION, "Does not support fused scaled dot product attention") def test_transformer_req_grad(self, thaw_params, is_seq_parallel, dtype, exp_cnts): # Sample a subset of `requires_grad` patterns