Skip to content

Commit

Permalink
fix the bug of test_fused_multi_transformer_op on cuda12 (#55431)
Browse files Browse the repository at this point in the history
  • Loading branch information
pkuzyc committed Jul 21, 2023
1 parent 7ed818f commit 3132054
Showing 1 changed file with 4 additions and 2 deletions.
6 changes: 4 additions & 2 deletions test/legacy_test/test_fused_multi_transformer_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -1412,7 +1412,9 @@ def test_fused_multi_transformer_op(self):
)


class TestFusedMultiAttentionAPIError(unittest.TestCase):
# Starts the name of this test with 'Z' to make this test
# run after others. If not, it will make other tests fail.
class ZTestFusedMultiAttentionAPIError(unittest.TestCase):
def test_errors(self):
def test_invalid_input_dim():
array = np.array([1.9], dtype=np.float32)
Expand All @@ -1425,7 +1427,7 @@ def test_invalid_input_dim():
self.assertRaises(ValueError, test_invalid_input_dim)


class TestFusedMultiTransformerAPIError(unittest.TestCase):
class ZTestFusedMultiTransformerAPIError(unittest.TestCase):
def test_errors(self):
def test_invalid_input_dim():
array = np.array([], dtype=np.float32)
Expand Down

0 comments on commit 3132054

Please sign in to comment.