From 2fa5ea0e37154266c6f5b0ce49b17e0bb154ec26 Mon Sep 17 00:00:00 2001 From: Dmitry Nikolaev Date: Wed, 16 Jul 2025 23:28:09 +0000 Subject: [PATCH] skip test_transformer_req_grad on Navi32/Navi4x --- test/distributed/tensor/parallel/test_tp_examples.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/test/distributed/tensor/parallel/test_tp_examples.py b/test/distributed/tensor/parallel/test_tp_examples.py index 9b412f88440f1..194735f912368 100644 --- a/test/distributed/tensor/parallel/test_tp_examples.py +++ b/test/distributed/tensor/parallel/test_tp_examples.py @@ -27,6 +27,8 @@ RowwiseParallel, ) from torch.distributed.tensor.parallel.input_reshard import input_reshard +from torch.testing._internal.common_cuda import PLATFORM_SUPPORTS_FUSED_ATTENTION +from torch.testing._internal.common_device_type import skipIf from torch.testing._internal.common_utils import ( instantiate_parametrized_tests, parametrize, @@ -412,6 +414,7 @@ def test_transformer_training(self, is_seq_parallel, dtype: torch.dtype): + f"{str(dtype).split('.')[-1]}_" + f"thaw_{'__'.join(sorted({n.rpartition('.')[0].replace('.', '_') for n in thaw})) if thaw else 'all'}", ) + @skipIf(not PLATFORM_SUPPORTS_FUSED_ATTENTION, "Does not support fused scaled dot product attention") def test_transformer_req_grad(self, thaw_params, is_seq_parallel, dtype, exp_cnts): # Sample a subset of `requires_grad` patterns