From 4f535952e7c9719e4b804c33886504d0ca5d0ee2 Mon Sep 17 00:00:00 2001 From: Ryan <44900829+DrRyanHuang@users.noreply.github.com> Date: Thu, 25 Jan 2024 10:21:45 +0800 Subject: [PATCH] [pir] fix frexp datatype (#61087) * Update math.py * add test in pir --- python/paddle/tensor/math.py | 7 ++++++- test/legacy_test/test_frexp_api.py | 2 ++ test/legacy_test/test_smooth_l1_loss.py | 16 ++++++++-------- 3 files changed, 16 insertions(+), 9 deletions(-) diff --git a/python/paddle/tensor/math.py b/python/paddle/tensor/math.py index d5401446e0628..1a88acd43d0a2 100644 --- a/python/paddle/tensor/math.py +++ b/python/paddle/tensor/math.py @@ -6547,7 +6547,12 @@ def frexp(x, name=None): Tensor(shape=[1, 4], dtype=float32, place=Place(cpu), stop_gradient=True, [[1., 2., 2., 3.]]) """ - if x.dtype not in [paddle.float32, paddle.float64]: + if x.dtype not in [ + paddle.float32, + paddle.float64, + DataType.FLOAT32, + DataType.FLOAT64, + ]: raise TypeError( f"The data type of input must be one of ['float32', 'float64'], but got {x.dtype}" ) diff --git a/test/legacy_test/test_frexp_api.py b/test/legacy_test/test_frexp_api.py index 520b6a90785a4..a156eb4e898f1 100644 --- a/test/legacy_test/test_frexp_api.py +++ b/test/legacy_test/test_frexp_api.py @@ -17,6 +17,7 @@ import paddle import paddle.base +from paddle.pir_utils import test_with_pir_api class TestFrexpAPI(unittest.TestCase): @@ -35,6 +36,7 @@ def set_input(self): self.x_np = np.random.uniform(-3, 3, [10, 12]).astype('float32') # 静态图单测 + @test_with_pir_api def test_static_api(self): # 开启静态图模式 paddle.enable_static() diff --git a/test/legacy_test/test_smooth_l1_loss.py b/test/legacy_test/test_smooth_l1_loss.py index d9c1b3d4fcb13..90ae5593ec4b7 100644 --- a/test/legacy_test/test_smooth_l1_loss.py +++ b/test/legacy_test/test_smooth_l1_loss.py @@ -57,7 +57,7 @@ def test_smooth_l1_loss_mean(self): expected = smooth_l1_loss_np(input_np, label_np, reduction='mean') @test_with_pir_api - def test_dynamic_or_pir_mode(): + def test_static(): prog = paddle.static.Program() startup_prog = paddle.static.Program() with paddle.static.program_guard(prog, startup_prog): @@ -90,7 +90,7 @@ def test_dynamic_or_pir_mode(): dy_ret_value = dy_ret.numpy() self.assertIsNotNone(dy_ret_value) - test_dynamic_or_pir_mode() + test_static() np.testing.assert_allclose(dy_ret_value, expected, rtol=1e-05) def test_smooth_l1_loss_sum(self): @@ -105,7 +105,7 @@ def test_smooth_l1_loss_sum(self): expected = smooth_l1_loss_np(input_np, label_np, reduction='sum') @test_with_pir_api - def test_dynamic_or_pir_mode(): + def test_static(): prog = paddle.static.Program() startup_prog = paddle.static.Program() with paddle.static.program_guard(prog, startup_prog): @@ -138,7 +138,7 @@ def test_dynamic_or_pir_mode(): dy_ret_value = dy_ret.numpy() self.assertIsNotNone(dy_ret_value) - test_dynamic_or_pir_mode() + test_static() np.testing.assert_allclose(dy_ret_value, expected, rtol=1e-05) def test_smooth_l1_loss_none(self): @@ -153,7 +153,7 @@ def test_smooth_l1_loss_none(self): expected = smooth_l1_loss_np(input_np, label_np, reduction='none') @test_with_pir_api - def test_dynamic_or_pir_mode(): + def test_static(): prog = paddle.static.Program() startup_prog = paddle.static.Program() with paddle.static.program_guard(prog, startup_prog): @@ -186,7 +186,7 @@ def test_dynamic_or_pir_mode(): dy_ret_value = dy_ret.numpy() self.assertIsNotNone(dy_ret_value) - test_dynamic_or_pir_mode() + test_static() np.testing.assert_allclose(dy_ret_value, expected, rtol=1e-05) def test_smooth_l1_loss_delta(self): @@ -202,7 +202,7 @@ def test_smooth_l1_loss_delta(self): expected = smooth_l1_loss_np(input_np, label_np, delta=delta) @test_with_pir_api - def test_dynamic_or_pir_mode(): + def test_static(): prog = paddle.static.Program() startup_prog = paddle.static.Program() with paddle.static.program_guard(prog, startup_prog): @@ -235,7 +235,7 @@ def test_dynamic_or_pir_mode(): dy_ret_value = dy_ret.numpy() self.assertIsNotNone(dy_ret_value) - test_dynamic_or_pir_mode() + test_static() np.testing.assert_allclose(dy_ret_value, expected, rtol=1e-05)