Skip to content

Commit

Permalink
【Fix PIR JIT SaveLoad Unittest No.2-6,8-10,13,15】open many tests for …
Browse files Browse the repository at this point in the history
…pir jit save/load (#64400)

* modify test_modbile_net.py

* delete print

* modify ci

* add test_custom_relu_model.py

* add test_custom_relu_model.py

* modify se_resnet

* add change

* add test_input_sepc

* add pir test

* recover
  • Loading branch information
xiaoguoguo626807 committed May 20, 2024
1 parent 3e275d5 commit b167952
Show file tree
Hide file tree
Showing 9 changed files with 64 additions and 75 deletions.
3 changes: 2 additions & 1 deletion test/deprecated/legacy_test/test_ops_nms.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
from test_nms_op import nms

import paddle
from paddle.pir_utils import test_with_pir_api
from paddle.pir_utils import test_with_dygraph_pir, test_with_pir_api


def _find(condition):
Expand Down Expand Up @@ -197,6 +197,7 @@ def test_multiclass_nms_static(self):
err_msg=f'paddle out: {out}\n py out: {out_py}\n',
)

@test_with_dygraph_pir
def test_multiclass_nms_dynamic_to_static(self):
for device in self.devices:
for dtype in self.dtypes:
Expand Down
48 changes: 22 additions & 26 deletions test/dygraph_to_static/test_declarative.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,29 +148,27 @@ def test_with_input_spec(self):

# 2. test save load
net.inner_function(x)
# TODO(pir-save-load): Fix this after we support save/load in PIR
if not use_pir_api():
paddle.jit.save(net, self.model_path)
infer_net = paddle.jit.load(self.model_path)
pred = infer_net(x)
np.testing.assert_allclose(out.numpy(), pred.numpy(), rtol=1e-05)

# 3. we can decorate any method
x_2 = paddle.to_tensor(np.ones([4, 20]).astype('float32'))
# uses `to_static(func)` instead of `@to_static`
net.add_func = paddle.jit.to_static(net.add_func)
out = net.add_func(x_2, np.ones([20]).astype('float32'))
self.assertTrue(len(net.add_func.program_cache) == 1)

# 5. test input with list
out = net.func_with_list([x, y], int_val)

# 6. test input with dict
out = net.func_with_dict({'x': x, 'y': y})

# 7. test input with lits contains dict
int_np = np.ones([1]).astype('float32')
out = net.func_with_list_dict([int_np, {'x': x, 'y': y}])
paddle.jit.save(net, self.model_path)
infer_net = paddle.jit.load(self.model_path)
pred = infer_net(x)
np.testing.assert_allclose(out.numpy(), pred.numpy(), rtol=1e-05)

# 3. we can decorate any method
x_2 = paddle.to_tensor(np.ones([4, 20]).astype('float32'))
# uses `to_static(func)` instead of `@to_static`
net.add_func = paddle.jit.to_static(net.add_func)
out = net.add_func(x_2, np.ones([20]).astype('float32'))
self.assertTrue(len(net.add_func.program_cache) == 1)

# 5. test input with list
out = net.func_with_list([x, y], int_val)

# 6. test input with dict
out = net.func_with_dict({'x': x, 'y': y})

# 7. test input with lits contains dict
int_np = np.ones([1]).astype('float32')
out = net.func_with_list_dict([int_np, {'x': x, 'y': y}])

@test_legacy_and_pt_and_pir
def test_with_error(self):
Expand Down Expand Up @@ -504,9 +502,7 @@ def test_set_buffers1(self):
net = paddle.jit.to_static(SetBuffersNet1())
out = net()
self.assertEqual(out.numpy().tolist(), [2])
# TODO(pir-save-load): Fix this after we support save/load in PIR
if not use_pir_api():
paddle.jit.save(net, self.model_path)
paddle.jit.save(net, self.model_path)

@test_ast_only
def test_set_buffers2(self):
Expand Down
4 changes: 2 additions & 2 deletions test/dygraph_to_static/test_for_enumerate.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from dygraph_to_static_utils import (
Dy2StTestBase,
enable_to_static_guard,
test_legacy_and_pir,
test_legacy_and_pt_and_pir,
)

Expand Down Expand Up @@ -554,7 +555,6 @@ def tearDown(self):

@test_legacy_and_pt_and_pir
def test_for_zip_error(self):
# TODO(pir-save-load): enable PIR test after support PIR save load
with self.assertRaises(RuntimeError):
model_path = os.path.join(self.temp_dir.name, 'for_zip_error')
paddle.jit.save(
Expand All @@ -568,8 +568,8 @@ def test_for_zip_error(self):
model_path,
)

@test_legacy_and_pir
def test_for_zip(self):
# TODO(pir-save-load): enable PIR test after support PIR save load
model_path = os.path.join(self.temp_dir.name, 'for_zip')
paddle.jit.save(
paddle.jit.to_static(
Expand Down
7 changes: 0 additions & 7 deletions test/dygraph_to_static/test_grad.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@
)

import paddle
from paddle.framework import use_pir_api


class GradLayer(paddle.nn.Layer):
Expand Down Expand Up @@ -101,9 +100,6 @@ def tearDown(self):

@test_legacy_and_pt_and_pir
def test_save_infer_program(self):
# TODO(pir-save-load): Fix this after we support save/load in PIR
if use_pir_api():
return
static_fn = paddle.jit.to_static(self.func)
input_spec = [
paddle.static.InputSpec(shape=[10, 2, 5], dtype='float32')
Expand Down Expand Up @@ -132,9 +128,6 @@ def test_save_train_program(self):

static_fn.clear_gradients()

# TODO(pir-save-load): Fix this after we support save/load in PIR
if use_pir_api():
return
paddle.jit.save(static_fn, self.train_model_path)
load_func = paddle.jit.load(self.train_model_path)

Expand Down
4 changes: 1 addition & 3 deletions test/dygraph_to_static/test_layer_hook.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,9 +76,7 @@ def train_net(self, to_static=False):
net = paddle.jit.to_static(net)
out = net(self.x)

# TODO(xiongkun) save / load unitest.
if to_static and not paddle.base.framework.use_pir_api():
paddle.jit.save(net, self.path)
paddle.jit.save(net, self.path, input_spec=[self.x])

return float(out)

Expand Down
5 changes: 1 addition & 4 deletions test/dygraph_to_static/test_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@

import paddle
import paddle.nn.functional as F
from paddle.base.framework import use_pir_api
from paddle.jit.dy2static.transformers.loop_transformer import NameVisitor
from paddle.utils import gast

Expand Down Expand Up @@ -466,9 +465,7 @@ def test_start(self):
],
)
temp_dir = tempfile.TemporaryDirectory()
# TODO(pir-save-load): Fix this after we support save/load in PIR
if not use_pir_api():
paddle.jit.save(model, temp_dir.name)
paddle.jit.save(model, temp_dir.name)
temp_dir.cleanup()


Expand Down
50 changes: 28 additions & 22 deletions test/dygraph_to_static/test_save_inference_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,8 @@
from dygraph_to_static_utils import (
Dy2StTestBase,
test_ast_only,
test_legacy_and_pir,
test_legacy_and_pt_and_pir,
test_legacy_only,
)

import paddle
Expand All @@ -32,6 +32,7 @@
from paddle.jit.dy2static.pir_partial_program import (
partial_program_from as pir_partial_program_from,
)
from paddle.jit.pir_translated_layer import PIR_INFER_MODEL_SUFFIX
from paddle.jit.translated_layer import INFER_MODEL_SUFFIX, INFER_PARAMS_SUFFIX

SEED = 2020
Expand Down Expand Up @@ -89,6 +90,7 @@ def tearDown(self):
self.temp_dir.cleanup()

@test_ast_only
@test_legacy_and_pir
def test_save_inference_model(self):
fc_size = 20
x_data = np.random.random((fc_size, fc_size)).astype('float32')
Expand All @@ -112,29 +114,29 @@ def test_save_inference_model(self):
infer_model_dir = os.path.join(
self.temp_dir.name, "test_dy2stat_inference_in_guard"
)
# TODO(pir-save-load): Fix this after we support save/load in PIR
if not use_pir_api():
paddle.jit.save(
layer=layer,
path=infer_model_prefix,
input_spec=[x],
output_spec=[pred],
)
# Check the correctness of the inference
dygraph_out, _ = layer(x)
self.check_save_inference_model(
layer, [x_data], dygraph_out.numpy()
)
self.check_save_inference_model(
layer, [x_data], dygraph_out.numpy(), fetch=[loss]
)
self.check_save_inference_model(
layer, [x_data], dygraph_out.numpy(), feed=[x]
)

paddle.jit.save(
layer=layer,
path=infer_model_prefix,
input_spec=[x],
output_spec=[1] if use_pir_api() else [pred],
)
# Check the correctness of the inference
dygraph_out, _ = layer(x)
self.check_save_inference_model(layer, [x_data], dygraph_out.numpy())
self.check_save_inference_model(
layer,
[x_data],
dygraph_out.numpy(),
fetch=[0] if use_pir_api() else [loss],
)
self.check_save_inference_model(
layer, [x_data], dygraph_out.numpy(), feed=[x]
)

# TODO(MarioLulab): Disable PT test until we support PIR PyLayer
@test_ast_only
@test_legacy_only
@test_legacy_and_pir
def test_save_pylayer_model(self):
fc_size = 20
x_data = np.random.random((fc_size, fc_size)).astype('float32')
Expand Down Expand Up @@ -187,8 +189,12 @@ def check_save_inference_model(
infer_model_dir = os.path.join(
self.temp_dir.name, "test_dy2stat_inference"
)
model_filename = "model" + INFER_MODEL_SUFFIX
if use_pir_api():
model_filename = "model" + PIR_INFER_MODEL_SUFFIX
else:
model_filename = "model" + INFER_MODEL_SUFFIX
params_filename = "model" + INFER_PARAMS_SUFFIX

paddle.jit.save(
layer=model,
path=infer_model_prefix,
Expand Down
15 changes: 6 additions & 9 deletions test/dygraph_to_static/test_slice.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@
)

import paddle
from paddle.framework import use_pir_api
from paddle.static import InputSpec

SEED = 2020
Expand Down Expand Up @@ -199,14 +198,12 @@ def test_set_value_with_save(self):
LayerWithSetValue(input_dim=10, hidden=1)
)
x = paddle.full(shape=[5, 10], fill_value=5.0, dtype="float32")
# TODO(pir-save-load): Fix this after we support save/load in PIR
if not use_pir_api():
paddle.jit.save(
layer=model,
path=self.model_path,
input_spec=[x],
output_spec=None,
)
paddle.jit.save(
layer=model,
path=self.model_path,
input_spec=[x],
output_spec=None,
)


class TestSliceSupplementSpecialCase(Dy2StTestBase):
Expand Down
3 changes: 2 additions & 1 deletion test/dygraph_to_static/test_typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from typing import Dict, List, Tuple

import numpy as np
from dygraph_to_static_utils import Dy2StTestBase
from dygraph_to_static_utils import Dy2StTestBase, test_legacy_and_pir

import paddle

Expand Down Expand Up @@ -93,6 +93,7 @@ def run_dy(self):
out, _ = self.net(self.x)
return out

@test_legacy_and_pir
def test_type(self):
self.net = self.build_net()
out = self.run_dy()
Expand Down

0 comments on commit b167952

Please sign in to comment.