Skip to content

Commit

Permalink
[pir] fix frexp datatype (#61087)
Browse files Browse the repository at this point in the history
* Update math.py

* add test in pir
  • Loading branch information
DrRyanHuang committed Jan 25, 2024
1 parent cf649d7 commit 4f53595
Show file tree
Hide file tree
Showing 3 changed files with 16 additions and 9 deletions.
7 changes: 6 additions & 1 deletion python/paddle/tensor/math.py
Original file line number Diff line number Diff line change
Expand Up @@ -6547,7 +6547,12 @@ def frexp(x, name=None):
Tensor(shape=[1, 4], dtype=float32, place=Place(cpu), stop_gradient=True,
[[1., 2., 2., 3.]])
"""
if x.dtype not in [paddle.float32, paddle.float64]:
if x.dtype not in [
paddle.float32,
paddle.float64,
DataType.FLOAT32,
DataType.FLOAT64,
]:
raise TypeError(
f"The data type of input must be one of ['float32', 'float64'], but got {x.dtype}"
)
Expand Down
2 changes: 2 additions & 0 deletions test/legacy_test/test_frexp_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

import paddle
import paddle.base
from paddle.pir_utils import test_with_pir_api


class TestFrexpAPI(unittest.TestCase):
Expand All @@ -35,6 +36,7 @@ def set_input(self):
self.x_np = np.random.uniform(-3, 3, [10, 12]).astype('float32')

# 静态图单测
@test_with_pir_api
def test_static_api(self):
# 开启静态图模式
paddle.enable_static()
Expand Down
16 changes: 8 additions & 8 deletions test/legacy_test/test_smooth_l1_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ def test_smooth_l1_loss_mean(self):
expected = smooth_l1_loss_np(input_np, label_np, reduction='mean')

@test_with_pir_api
def test_dynamic_or_pir_mode():
def test_static():
prog = paddle.static.Program()
startup_prog = paddle.static.Program()
with paddle.static.program_guard(prog, startup_prog):
Expand Down Expand Up @@ -90,7 +90,7 @@ def test_dynamic_or_pir_mode():
dy_ret_value = dy_ret.numpy()
self.assertIsNotNone(dy_ret_value)

test_dynamic_or_pir_mode()
test_static()
np.testing.assert_allclose(dy_ret_value, expected, rtol=1e-05)

def test_smooth_l1_loss_sum(self):
Expand All @@ -105,7 +105,7 @@ def test_smooth_l1_loss_sum(self):
expected = smooth_l1_loss_np(input_np, label_np, reduction='sum')

@test_with_pir_api
def test_dynamic_or_pir_mode():
def test_static():
prog = paddle.static.Program()
startup_prog = paddle.static.Program()
with paddle.static.program_guard(prog, startup_prog):
Expand Down Expand Up @@ -138,7 +138,7 @@ def test_dynamic_or_pir_mode():
dy_ret_value = dy_ret.numpy()
self.assertIsNotNone(dy_ret_value)

test_dynamic_or_pir_mode()
test_static()
np.testing.assert_allclose(dy_ret_value, expected, rtol=1e-05)

def test_smooth_l1_loss_none(self):
Expand All @@ -153,7 +153,7 @@ def test_smooth_l1_loss_none(self):
expected = smooth_l1_loss_np(input_np, label_np, reduction='none')

@test_with_pir_api
def test_dynamic_or_pir_mode():
def test_static():
prog = paddle.static.Program()
startup_prog = paddle.static.Program()
with paddle.static.program_guard(prog, startup_prog):
Expand Down Expand Up @@ -186,7 +186,7 @@ def test_dynamic_or_pir_mode():
dy_ret_value = dy_ret.numpy()
self.assertIsNotNone(dy_ret_value)

test_dynamic_or_pir_mode()
test_static()
np.testing.assert_allclose(dy_ret_value, expected, rtol=1e-05)

def test_smooth_l1_loss_delta(self):
Expand All @@ -202,7 +202,7 @@ def test_smooth_l1_loss_delta(self):
expected = smooth_l1_loss_np(input_np, label_np, delta=delta)

@test_with_pir_api
def test_dynamic_or_pir_mode():
def test_static():
prog = paddle.static.Program()
startup_prog = paddle.static.Program()
with paddle.static.program_guard(prog, startup_prog):
Expand Down Expand Up @@ -235,7 +235,7 @@ def test_dynamic_or_pir_mode():
dy_ret_value = dy_ret.numpy()
self.assertIsNotNone(dy_ret_value)

test_dynamic_or_pir_mode()
test_static()
np.testing.assert_allclose(dy_ret_value, expected, rtol=1e-05)


Expand Down

0 comments on commit 4f53595

Please sign in to comment.