diff --git a/test/legacy_test/test_fused_multi_transformer_op.py b/test/legacy_test/test_fused_multi_transformer_op.py index f7659464126ee..5ae564dcdb0ff 100644 --- a/test/legacy_test/test_fused_multi_transformer_op.py +++ b/test/legacy_test/test_fused_multi_transformer_op.py @@ -1412,7 +1412,9 @@ def test_fused_multi_transformer_op(self): ) -class TestFusedMultiAttentionAPIError(unittest.TestCase): +# Starts the name of this test with 'Z' to make this test +# run after others. If not, it will make other tests fail. +class ZTestFusedMultiAttentionAPIError(unittest.TestCase): def test_errors(self): def test_invalid_input_dim(): array = np.array([1.9], dtype=np.float32) @@ -1425,7 +1427,7 @@ def test_invalid_input_dim(): self.assertRaises(ValueError, test_invalid_input_dim) -class TestFusedMultiTransformerAPIError(unittest.TestCase): +class ZTestFusedMultiTransformerAPIError(unittest.TestCase): def test_errors(self): def test_invalid_input_dim(): array = np.array([], dtype=np.float32)