diff --git a/test/distributed/tensor/parallel/test_tp_examples.py b/test/distributed/tensor/parallel/test_tp_examples.py index 9b412f88440f1..194735f912368 100644 --- a/test/distributed/tensor/parallel/test_tp_examples.py +++ b/test/distributed/tensor/parallel/test_tp_examples.py @@ -27,6 +27,8 @@ RowwiseParallel, ) from torch.distributed.tensor.parallel.input_reshard import input_reshard +from torch.testing._internal.common_cuda import PLATFORM_SUPPORTS_FUSED_ATTENTION +from torch.testing._internal.common_device_type import skipIf from torch.testing._internal.common_utils import ( instantiate_parametrized_tests, parametrize, @@ -412,6 +414,7 @@ def test_transformer_training(self, is_seq_parallel, dtype: torch.dtype): + f"{str(dtype).split('.')[-1]}_" + f"thaw_{'__'.join(sorted({n.rpartition('.')[0].replace('.', '_') for n in thaw})) if thaw else 'all'}", ) + @skipIf(not PLATFORM_SUPPORTS_FUSED_ATTENTION, "Does not support fused scaled dot product attention") def test_transformer_req_grad(self, thaw_params, is_seq_parallel, dtype, exp_cnts): # Sample a subset of `requires_grad` patterns