Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

【Fix PIR JIT SaveLoad Unittest No.2-6,8-10,13,15】open many tests for pir jit save/load #64400

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion test/deprecated/legacy_test/test_ops_nms.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
from test_nms_op import nms

import paddle
from paddle.pir_utils import test_with_pir_api
from paddle.pir_utils import test_with_dygraph_pir, test_with_pir_api


def _find(condition):
Expand Down Expand Up @@ -197,6 +197,7 @@ def test_multiclass_nms_static(self):
err_msg=f'paddle out: {out}\n py out: {out_py}\n',
)

@test_with_dygraph_pir
def test_multiclass_nms_dynamic_to_static(self):
for device in self.devices:
for dtype in self.dtypes:
Expand Down
48 changes: 22 additions & 26 deletions test/dygraph_to_static/test_declarative.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,29 +148,27 @@ def test_with_input_spec(self):

# 2. test save load
net.inner_function(x)
# TODO(pir-save-load): Fix this after we support save/load in PIR
if not use_pir_api():
paddle.jit.save(net, self.model_path)
infer_net = paddle.jit.load(self.model_path)
pred = infer_net(x)
np.testing.assert_allclose(out.numpy(), pred.numpy(), rtol=1e-05)

# 3. we can decorate any method
x_2 = paddle.to_tensor(np.ones([4, 20]).astype('float32'))
# uses `to_static(func)` instead of `@to_static`
net.add_func = paddle.jit.to_static(net.add_func)
out = net.add_func(x_2, np.ones([20]).astype('float32'))
self.assertTrue(len(net.add_func.program_cache) == 1)

# 5. test input with list
out = net.func_with_list([x, y], int_val)

# 6. test input with dict
out = net.func_with_dict({'x': x, 'y': y})

# 7. test input with lits contains dict
int_np = np.ones([1]).astype('float32')
out = net.func_with_list_dict([int_np, {'x': x, 'y': y}])
paddle.jit.save(net, self.model_path)
infer_net = paddle.jit.load(self.model_path)
pred = infer_net(x)
np.testing.assert_allclose(out.numpy(), pred.numpy(), rtol=1e-05)

# 3. we can decorate any method
x_2 = paddle.to_tensor(np.ones([4, 20]).astype('float32'))
# uses `to_static(func)` instead of `@to_static`
net.add_func = paddle.jit.to_static(net.add_func)
out = net.add_func(x_2, np.ones([20]).astype('float32'))
self.assertTrue(len(net.add_func.program_cache) == 1)

# 5. test input with list
out = net.func_with_list([x, y], int_val)

# 6. test input with dict
out = net.func_with_dict({'x': x, 'y': y})

# 7. test input with lits contains dict
int_np = np.ones([1]).astype('float32')
out = net.func_with_list_dict([int_np, {'x': x, 'y': y}])

@test_legacy_and_pt_and_pir
def test_with_error(self):
Expand Down Expand Up @@ -504,9 +502,7 @@ def test_set_buffers1(self):
net = paddle.jit.to_static(SetBuffersNet1())
out = net()
self.assertEqual(out.numpy().tolist(), [2])
# TODO(pir-save-load): Fix this after we support save/load in PIR
if not use_pir_api():
paddle.jit.save(net, self.model_path)
paddle.jit.save(net, self.model_path)

@test_ast_only
def test_set_buffers2(self):
Expand Down
4 changes: 2 additions & 2 deletions test/dygraph_to_static/test_for_enumerate.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from dygraph_to_static_utils import (
Dy2StTestBase,
enable_to_static_guard,
test_legacy_and_pir,
test_legacy_and_pt_and_pir,
)

Expand Down Expand Up @@ -554,7 +555,6 @@ def tearDown(self):

@test_legacy_and_pt_and_pir
def test_for_zip_error(self):
# TODO(pir-save-load): enable PIR test after support PIR save load
with self.assertRaises(RuntimeError):
model_path = os.path.join(self.temp_dir.name, 'for_zip_error')
paddle.jit.save(
Expand All @@ -568,8 +568,8 @@ def test_for_zip_error(self):
model_path,
)

@test_legacy_and_pir
def test_for_zip(self):
# TODO(pir-save-load): enable PIR test after support PIR save load
model_path = os.path.join(self.temp_dir.name, 'for_zip')
paddle.jit.save(
paddle.jit.to_static(
Expand Down
7 changes: 0 additions & 7 deletions test/dygraph_to_static/test_grad.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@
)

import paddle
from paddle.framework import use_pir_api


class GradLayer(paddle.nn.Layer):
Expand Down Expand Up @@ -101,9 +100,6 @@ def tearDown(self):

@test_legacy_and_pt_and_pir
def test_save_infer_program(self):
# TODO(pir-save-load): Fix this after we support save/load in PIR
if use_pir_api():
return
static_fn = paddle.jit.to_static(self.func)
input_spec = [
paddle.static.InputSpec(shape=[10, 2, 5], dtype='float32')
Expand Down Expand Up @@ -132,9 +128,6 @@ def test_save_train_program(self):

static_fn.clear_gradients()

# TODO(pir-save-load): Fix this after we support save/load in PIR
if use_pir_api():
return
paddle.jit.save(static_fn, self.train_model_path)
load_func = paddle.jit.load(self.train_model_path)

Expand Down
4 changes: 1 addition & 3 deletions test/dygraph_to_static/test_layer_hook.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,9 +76,7 @@ def train_net(self, to_static=False):
net = paddle.jit.to_static(net)
out = net(self.x)

# TODO(xiongkun) save / load unitest.
if to_static and not paddle.base.framework.use_pir_api():
paddle.jit.save(net, self.path)
paddle.jit.save(net, self.path, input_spec=[self.x])

return float(out)

Expand Down
5 changes: 1 addition & 4 deletions test/dygraph_to_static/test_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@

import paddle
import paddle.nn.functional as F
from paddle.base.framework import use_pir_api
from paddle.jit.dy2static.transformers.loop_transformer import NameVisitor
from paddle.utils import gast

Expand Down Expand Up @@ -466,9 +465,7 @@ def test_start(self):
],
)
temp_dir = tempfile.TemporaryDirectory()
# TODO(pir-save-load): Fix this after we support save/load in PIR
if not use_pir_api():
paddle.jit.save(model, temp_dir.name)
paddle.jit.save(model, temp_dir.name)
temp_dir.cleanup()


Expand Down
50 changes: 28 additions & 22 deletions test/dygraph_to_static/test_save_inference_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,8 @@
from dygraph_to_static_utils import (
Dy2StTestBase,
test_ast_only,
test_legacy_and_pir,
test_legacy_and_pt_and_pir,
test_legacy_only,
)

import paddle
Expand All @@ -32,6 +32,7 @@
from paddle.jit.dy2static.pir_partial_program import (
partial_program_from as pir_partial_program_from,
)
from paddle.jit.pir_translated_layer import PIR_INFER_MODEL_SUFFIX
from paddle.jit.translated_layer import INFER_MODEL_SUFFIX, INFER_PARAMS_SUFFIX

SEED = 2020
Expand Down Expand Up @@ -89,6 +90,7 @@ def tearDown(self):
self.temp_dir.cleanup()

@test_ast_only
@test_legacy_and_pir
def test_save_inference_model(self):
fc_size = 20
x_data = np.random.random((fc_size, fc_size)).astype('float32')
Expand All @@ -112,29 +114,29 @@ def test_save_inference_model(self):
infer_model_dir = os.path.join(
self.temp_dir.name, "test_dy2stat_inference_in_guard"
)
# TODO(pir-save-load): Fix this after we support save/load in PIR
if not use_pir_api():
paddle.jit.save(
layer=layer,
path=infer_model_prefix,
input_spec=[x],
output_spec=[pred],
)
# Check the correctness of the inference
dygraph_out, _ = layer(x)
self.check_save_inference_model(
layer, [x_data], dygraph_out.numpy()
)
self.check_save_inference_model(
layer, [x_data], dygraph_out.numpy(), fetch=[loss]
)
self.check_save_inference_model(
layer, [x_data], dygraph_out.numpy(), feed=[x]
)

paddle.jit.save(
layer=layer,
path=infer_model_prefix,
input_spec=[x],
output_spec=[1] if use_pir_api() else [pred],
)
# Check the correctness of the inference
dygraph_out, _ = layer(x)
self.check_save_inference_model(layer, [x_data], dygraph_out.numpy())
self.check_save_inference_model(
layer,
[x_data],
dygraph_out.numpy(),
fetch=[0] if use_pir_api() else [loss],
)
self.check_save_inference_model(
layer, [x_data], dygraph_out.numpy(), feed=[x]
)

# TODO(MarioLulab): Disable PT test until we support PIR PyLayer
@test_ast_only
@test_legacy_only
@test_legacy_and_pir
def test_save_pylayer_model(self):
fc_size = 20
x_data = np.random.random((fc_size, fc_size)).astype('float32')
Expand Down Expand Up @@ -187,8 +189,12 @@ def check_save_inference_model(
infer_model_dir = os.path.join(
self.temp_dir.name, "test_dy2stat_inference"
)
model_filename = "model" + INFER_MODEL_SUFFIX
if use_pir_api():
model_filename = "model" + PIR_INFER_MODEL_SUFFIX
else:
model_filename = "model" + INFER_MODEL_SUFFIX
params_filename = "model" + INFER_PARAMS_SUFFIX

paddle.jit.save(
layer=model,
path=infer_model_prefix,
Expand Down
15 changes: 6 additions & 9 deletions test/dygraph_to_static/test_slice.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@
)

import paddle
from paddle.framework import use_pir_api
from paddle.static import InputSpec

SEED = 2020
Expand Down Expand Up @@ -199,14 +198,12 @@ def test_set_value_with_save(self):
LayerWithSetValue(input_dim=10, hidden=1)
)
x = paddle.full(shape=[5, 10], fill_value=5.0, dtype="float32")
# TODO(pir-save-load): Fix this after we support save/load in PIR
if not use_pir_api():
paddle.jit.save(
layer=model,
path=self.model_path,
input_spec=[x],
output_spec=None,
)
paddle.jit.save(
layer=model,
path=self.model_path,
input_spec=[x],
output_spec=None,
)


class TestSliceSupplementSpecialCase(Dy2StTestBase):
Expand Down
3 changes: 2 additions & 1 deletion test/dygraph_to_static/test_typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from typing import Dict, List, Tuple

import numpy as np
from dygraph_to_static_utils import Dy2StTestBase
from dygraph_to_static_utils import Dy2StTestBase, test_legacy_and_pir

import paddle

Expand Down Expand Up @@ -93,6 +93,7 @@ def run_dy(self):
out, _ = self.net(self.x)
return out

@test_legacy_and_pir
def test_type(self):
self.net = self.build_net()
out = self.run_dy()
Expand Down