Skip to content

Commit

Permalink
Merge pull request PaddlePaddle#4 from feifei-111/test_cache_program
Browse files Browse the repository at this point in the history
update test_cache_program
  • Loading branch information
2742195759 committed Jun 15, 2023
2 parents 7bc67ed + 360cbb6 commit 9d03ed0
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 1 deletion.
4 changes: 3 additions & 1 deletion test/dygraph_to_static/test_cache_program.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from collections import Counter

import numpy as np
from dygraph_to_static_util import dy2static_unittest
from test_fetch_feed import Linear, Pool2D

import paddle
Expand All @@ -24,6 +25,7 @@
from paddle.jit.dy2static import convert_to_static


@dy2static_unittest
class TestCacheProgram(unittest.TestCase):
def setUp(self):
self.batch_num = 5
Expand All @@ -36,7 +38,7 @@ def test_cache(self):
with fluid.dygraph.guard(fluid.CPUPlace()):
static_net = self.dygraph_class()
for batch_id in range(self.batch_num):
out = static_net(self.data)
out = static_net(paddle.to_tensor(self.data))
# Check outputs
prev_out = cur_out
cur_out = out
Expand Down
3 changes: 3 additions & 0 deletions test/dygraph_to_static/test_fallback.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import unittest

import numpy as np
from dygraph_to_static_util import ast_only_test

import paddle

Expand Down Expand Up @@ -84,6 +85,7 @@ def test_case_net_fallback(self):
u_net(self.x).numpy(),
)

@ast_only_test
def test_case_net_error(self):
s_net = SuppportNet()
u_net = UnsuppportNet()
Expand All @@ -110,6 +112,7 @@ def test_case_training(self):
np.testing.assert_allclose(u_net(self.x).numpy(), [1, 1])
assert u_net.training is False, "Training must be false."

@ast_only_test
def test_case_save_error(self):
"""
test the save will raise error.
Expand Down

0 comments on commit 9d03ed0

Please sign in to comment.