Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

【Fix PIR JIT SaveLoad Unittest No.20-21】modify test_mobile_net.py, test_resnet.py #64315

Merged
merged 7 commits into from
May 16, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion python/paddle/framework/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,7 +173,7 @@ def _load_state_dict_from_save_inference_model(model_path, config):
# 2. load layer parameters & buffers
with base.dygraph.guard():
persistable_var_dict = _construct_params_and_buffers(
model_path, programs, config.params_filename, append_suffix=False
model_path, programs, config.params_filename
)

# 3. construct state_dict
Expand Down
33 changes: 13 additions & 20 deletions python/paddle/jit/pir_translated_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,8 +179,12 @@ def _load_pir_persistable_vars(model_path, program_holder, params_filename):
load_var_list = []
load_densetensor_list = []
persistable_var = program_holder.persistable_vars
persistable_name = program_holder.persistable_names
for name, var in sorted(zip(persistable_name, persistable_var)):
persistable_var_name = program_holder.persistable_names
origin_persistable_var_name = [
program_holder._suffix_varname_dict[var_name]
for var_name in persistable_var_name
]
for name, var in sorted(zip(origin_persistable_var_name, persistable_var)):
if var.persistable:
# use default shape and dtype
new_var = framework.EagerParamBase(
Expand Down Expand Up @@ -225,15 +229,6 @@ def _load_pir_persistable_vars(model_path, program_holder, params_filename):
return load_var_dict


# NOTE(chenzhiyang): to adapt paddle.load to get state_dict
def _remove_varname_suffix(var_dict, program_holder):
no_suffix_var_dict = {}
for var_name in var_dict:
no_suffix_name = program_holder._suffix_varname_dict[var_name]
no_suffix_var_dict[no_suffix_name] = var_dict[var_name]
return no_suffix_var_dict


def _construct_program_holders(model_path, model_filename=None):
# make sure the path has been checked
program_holder_dict = {}
Expand Down Expand Up @@ -283,9 +278,7 @@ def _construct_program_holders(model_path, model_filename=None):
return program_holder_dict


def _construct_params_and_buffers(
model_path, programs, params_filename=None, append_suffix=True
):
def _construct_params_and_buffers(model_path, programs, params_filename=None):
params_path = os.path.join(model_path, str(params_filename))

if params_filename is not None and not os.path.exists(params_path):
Expand Down Expand Up @@ -317,10 +310,7 @@ def _construct_params_and_buffers(
)
)

if not append_suffix:
var_dict = _remove_varname_suffix(var_dict, programs['forward'])

return var_dict
return var_dict


def _run_dygraph(instance, input, program_holder):
Expand Down Expand Up @@ -352,8 +342,11 @@ def _run_dygraph(instance, input, program_holder):
input_tensors.append(tensor)

persistable_tensors = []

for var_name in program_holder.persistable_names:
origin_persistable_var_name = [
program_holder._suffix_varname_dict[var_name]
for var_name in program_holder.persistable_names
]
for var_name in origin_persistable_var_name:
dy_var_name = instance._persistable_var_name_dict[var_name]
if dy_var_name in instance._parameters:
persistable_tensors.append(instance._parameters[dy_var_name])
Expand Down
90 changes: 48 additions & 42 deletions test/deprecated/legacy_test/test_load_state_dict_from_old_format.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,60 +217,66 @@ def check_load_state_dict(self, orig_dict, load_dict):

@test_with_pir_api
def test_load_default(self):
self.save_dirname = os.path.join(
self.temp_dir.name, "static_mnist.load_state_dict.default"
)
self.model_filename = None
self.params_filename = None
orig_param_dict = self.train_and_save_model()
with paddle.base.unique_name.guard():
self.save_dirname = os.path.join(
self.temp_dir.name, "static_mnist.load_state_dict.default"
)
self.model_filename = None
self.params_filename = None
orig_param_dict = self.train_and_save_model()

new_load_param_dict = paddle.load(self.save_dirname)
self.check_load_state_dict(orig_param_dict, new_load_param_dict)
new_load_param_dict = paddle.load(self.save_dirname)
self.check_load_state_dict(orig_param_dict, new_load_param_dict)

@test_with_pir_api
def test_load_with_model_filename(self):
self.save_dirname = os.path.join(
self.temp_dir.name, "static_mnist.load_state_dict.model_filename"
)
self.model_filename = "static_mnist.model"
self.params_filename = None
orig_param_dict = self.train_and_save_model()
with paddle.base.unique_name.guard():
self.save_dirname = os.path.join(
self.temp_dir.name,
"static_mnist.load_state_dict.model_filename",
)
self.model_filename = "static_mnist.model"
self.params_filename = None
orig_param_dict = self.train_and_save_model()

new_load_param_dict = paddle.load(
self.save_dirname, model_filename=self.model_filename
)
self.check_load_state_dict(orig_param_dict, new_load_param_dict)
new_load_param_dict = paddle.load(
self.save_dirname, model_filename=self.model_filename
)
self.check_load_state_dict(orig_param_dict, new_load_param_dict)

@test_with_pir_api
def test_load_with_param_filename(self):
self.save_dirname = os.path.join(
self.temp_dir.name, "static_mnist.load_state_dict.param_filename"
)
self.model_filename = None
self.params_filename = "static_mnist.params"
orig_param_dict = self.train_and_save_model()
with paddle.base.unique_name.guard():
self.save_dirname = os.path.join(
self.temp_dir.name,
"static_mnist.load_state_dict.param_filename",
)
self.model_filename = None
self.params_filename = "static_mnist.params"
orig_param_dict = self.train_and_save_model()

new_load_param_dict = paddle.load(
self.save_dirname, params_filename=self.params_filename
)
self.check_load_state_dict(orig_param_dict, new_load_param_dict)
new_load_param_dict = paddle.load(
self.save_dirname, params_filename=self.params_filename
)
self.check_load_state_dict(orig_param_dict, new_load_param_dict)

@test_with_pir_api
def test_load_with_model_and_param_filename(self):
self.save_dirname = os.path.join(
self.temp_dir.name,
"static_mnist.load_state_dict.model_and_param_filename",
)
self.model_filename = "static_mnist.model"
self.params_filename = "static_mnist.params"
orig_param_dict = self.train_and_save_model()

new_load_param_dict = paddle.load(
self.save_dirname,
params_filename=self.params_filename,
model_filename=self.model_filename,
)
self.check_load_state_dict(orig_param_dict, new_load_param_dict)
with paddle.base.unique_name.guard():
self.save_dirname = os.path.join(
self.temp_dir.name,
"static_mnist.load_state_dict.model_and_param_filename",
)
self.model_filename = "static_mnist.model"
self.params_filename = "static_mnist.params"
orig_param_dict = self.train_and_save_model()

new_load_param_dict = paddle.load(
self.save_dirname,
params_filename=self.params_filename,
model_filename=self.model_filename,
)
self.check_load_state_dict(orig_param_dict, new_load_param_dict)


if __name__ == '__main__':
Expand Down
46 changes: 24 additions & 22 deletions test/dygraph_to_static/test_mobile_net.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,8 @@
from paddle import base
from paddle.base.framework import unique_name
from paddle.base.param_attr import ParamAttr
from paddle.framework import use_pir_api
from paddle.jit.pir_translated_layer import PIR_INFER_MODEL_SUFFIX
from paddle.jit.translated_layer import INFER_MODEL_SUFFIX, INFER_PARAMS_SUFFIX
from paddle.nn import BatchNorm, Linear

Expand Down Expand Up @@ -592,8 +594,7 @@ def train_mobilenet(args, to_static):
batch_id += 1
t_last = time.time()
if batch_id > args.train_step:
# TODO(@xiongkun): open after save / load supported in pir.
if to_static and not paddle.base.framework.use_pir_api():
if to_static:
paddle.jit.save(net, args.model_save_prefix)
else:
paddle.save(
Expand All @@ -609,15 +610,18 @@ def predict_static(args, data):
paddle.enable_static()
exe = base.Executor(args.place)
# load inference model

if use_pir_api():
model_filename = args.pir_model_filename
else:
model_filename = args.model_filename
[
inference_program,
feed_target_names,
fetch_targets,
] = paddle.static.io.load_inference_model(
args.model_save_dir,
executor=exe,
model_filename=args.model_filename,
model_filename=model_filename,
params_filename=args.params_filename,
)

Expand Down Expand Up @@ -685,6 +689,7 @@ def train(self, model_name, to_static):
)
self.args.model_filename = model_name + INFER_MODEL_SUFFIX
self.args.params_filename = model_name + INFER_PARAMS_SUFFIX
self.args.pir_model_filename = model_name + PIR_INFER_MODEL_SUFFIX
self.args.dy_state_dict_save_path = os.path.join(
self.temp_dir.name, model_name + ".dygraph"
)
Expand Down Expand Up @@ -717,7 +722,6 @@ def assert_same_predict(self, model_name):
dy_pre = predict_dygraph(self.args, image)
st_pre = predict_static(self.args, image)
dy_jit_pre = predict_dygraph_jit(self.args, image)
predictor_pre = predict_analysis_inference(self.args, image)
np.testing.assert_allclose(
dy_pre,
st_pre,
Expand All @@ -730,29 +734,27 @@ def assert_same_predict(self, model_name):
rtol=1e-05,
err_msg=f'dy_jit_pre:\n {dy_jit_pre}\n, st_pre: \n{st_pre}.',
)
np.testing.assert_allclose(
predictor_pre,
st_pre,
rtol=1e-05,
atol=1e-05,
err_msg=f'inference_pred_res:\n {predictor_pre}\n, st_pre: \n{st_pre}.',
)
if not use_pir_api():
predictor_pre = predict_analysis_inference(self.args, image)
np.testing.assert_allclose(
predictor_pre,
st_pre,
rtol=1e-05,
atol=1e-05,
err_msg=f'inference_pred_res:\n {predictor_pre}\n, st_pre: \n{st_pre}.',
)

@test_legacy_and_pir
def test_mobile_net(self):
# MobileNet-V1
def test_mobile_net_v1(self):
self.assert_same_loss("MobileNetV1")
# MobileNet-V2
self.assert_same_loss("MobileNetV2")

# TODO(@xiongkun): open after save / load supported in pir.
if not paddle.base.framework.use_pir_api():
self.verify_predict()

def verify_predict(self):
# MobileNet-V1
self.assert_same_predict("MobileNetV1")

@test_legacy_and_pir
def test_mobile_net_v2(self):
# MobileNet-V2
self.assert_same_loss("MobileNetV2")

self.assert_same_predict("MobileNetV2")


Expand Down
37 changes: 20 additions & 17 deletions test/dygraph_to_static/test_resnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@

import paddle
from paddle.base import core
from paddle.framework import use_pir_api

SEED = 2020
IMAGENET1000 = 1281167
Expand Down Expand Up @@ -260,6 +261,9 @@ def __init__(self):
self.model_filename = (
"resnet" + paddle.jit.translated_layer.INFER_MODEL_SUFFIX
)
self.pir_model_filename = (
"resnet" + paddle.jit.pir_translated_layer.PIR_INFER_MODEL_SUFFIX
)
self.params_filename = (
"resnet" + paddle.jit.translated_layer.INFER_PARAMS_SUFFIX
)
Expand Down Expand Up @@ -339,12 +343,7 @@ def train(self, to_static, build_strategy=None):
)
if batch_id == 10:
if to_static:
# TODO(@xiongkun): open after save / load supported in pir.
if (
to_static
and not paddle.base.framework.use_pir_api()
):
paddle.jit.save(resnet, self.model_save_prefix)
paddle.jit.save(resnet, self.model_save_prefix)
else:
paddle.save(
resnet.state_dict(),
Expand Down Expand Up @@ -374,14 +373,19 @@ def predict_dygraph(self, data):
def predict_static(self, data):
with static_guard():
exe = paddle.static.Executor(place)
if use_pir_api():
model_filename = self.pir_model_filename
else:
model_filename = self.model_filename

[
inference_program,
feed_target_names,
fetch_targets,
] = paddle.static.load_inference_model(
self.model_save_dir,
executor=exe,
model_filename=self.model_filename,
model_filename=model_filename,
params_filename=self.params_filename,
)

Expand Down Expand Up @@ -426,7 +430,6 @@ def verify_predict(self):
dy_pre = self.resnet_helper.predict_dygraph(image)
st_pre = self.resnet_helper.predict_static(image)
dy_jit_pre = self.resnet_helper.predict_dygraph_jit(image)
predictor_pre = self.resnet_helper.predict_analysis_inference(image)
np.testing.assert_allclose(
dy_pre,
st_pre,
Expand All @@ -439,12 +442,14 @@ def verify_predict(self):
rtol=1e-05,
err_msg=f'dy_jit_pre:\n {dy_jit_pre}\n, st_pre: \n{st_pre}.',
)
np.testing.assert_allclose(
predictor_pre,
st_pre,
rtol=1e-05,
err_msg=f'predictor_pre:\n {predictor_pre}\n, st_pre: \n{st_pre}.',
)
if not use_pir_api():
predictor_pre = self.resnet_helper.predict_analysis_inference(image)
np.testing.assert_allclose(
predictor_pre,
st_pre,
rtol=1e-05,
err_msg=f'predictor_pre:\n {predictor_pre}\n, st_pre: \n{st_pre}.',
)

@test_default_and_pir
def test_resnet(self):
Expand All @@ -456,9 +461,7 @@ def test_resnet(self):
rtol=1e-05,
err_msg=f'static_loss: {static_loss} \n dygraph_loss: {dygraph_loss}',
)
# TODO(@xiongkun): open after save / load supported in pir.
if not paddle.framework.use_pir_api():
self.verify_predict()
self.verify_predict()

@test_default_and_pir
def test_resnet_composite(self):
Expand Down