diff --git a/python/paddle/framework/io.py b/python/paddle/framework/io.py index 6d4c9268efee9..6a76ffaa9b7b5 100644 --- a/python/paddle/framework/io.py +++ b/python/paddle/framework/io.py @@ -173,7 +173,7 @@ def _load_state_dict_from_save_inference_model(model_path, config): # 2. load layer parameters & buffers with base.dygraph.guard(): persistable_var_dict = _construct_params_and_buffers( - model_path, programs, config.params_filename, append_suffix=False + model_path, programs, config.params_filename ) # 3. construct state_dict diff --git a/python/paddle/jit/pir_translated_layer.py b/python/paddle/jit/pir_translated_layer.py index a5f91add69d5a..bae21ca6d5d84 100644 --- a/python/paddle/jit/pir_translated_layer.py +++ b/python/paddle/jit/pir_translated_layer.py @@ -179,8 +179,12 @@ def _load_pir_persistable_vars(model_path, program_holder, params_filename): load_var_list = [] load_densetensor_list = [] persistable_var = program_holder.persistable_vars - persistable_name = program_holder.persistable_names - for name, var in sorted(zip(persistable_name, persistable_var)): + persistable_var_name = program_holder.persistable_names + origin_persistable_var_name = [ + program_holder._suffix_varname_dict[var_name] + for var_name in persistable_var_name + ] + for name, var in sorted(zip(origin_persistable_var_name, persistable_var)): if var.persistable: # use default shape and dtype new_var = framework.EagerParamBase( @@ -225,15 +229,6 @@ def _load_pir_persistable_vars(model_path, program_holder, params_filename): return load_var_dict -# NOTE(chenzhiyang): to adapt paddle.load to get state_dict -def _remove_varname_suffix(var_dict, program_holder): - no_suffix_var_dict = {} - for var_name in var_dict: - no_suffix_name = program_holder._suffix_varname_dict[var_name] - no_suffix_var_dict[no_suffix_name] = var_dict[var_name] - return no_suffix_var_dict - - def _construct_program_holders(model_path, model_filename=None): # make sure the path has been checked program_holder_dict = {} @@ -283,9 +278,7 @@ def _construct_program_holders(model_path, model_filename=None): return program_holder_dict -def _construct_params_and_buffers( - model_path, programs, params_filename=None, append_suffix=True -): +def _construct_params_and_buffers(model_path, programs, params_filename=None): params_path = os.path.join(model_path, str(params_filename)) if params_filename is not None and not os.path.exists(params_path): @@ -317,10 +310,7 @@ def _construct_params_and_buffers( ) ) - if not append_suffix: - var_dict = _remove_varname_suffix(var_dict, programs['forward']) - - return var_dict + return var_dict def _run_dygraph(instance, input, program_holder): @@ -352,8 +342,11 @@ def _run_dygraph(instance, input, program_holder): input_tensors.append(tensor) persistable_tensors = [] - - for var_name in program_holder.persistable_names: + origin_persistable_var_name = [ + program_holder._suffix_varname_dict[var_name] + for var_name in program_holder.persistable_names + ] + for var_name in origin_persistable_var_name: dy_var_name = instance._persistable_var_name_dict[var_name] if dy_var_name in instance._parameters: persistable_tensors.append(instance._parameters[dy_var_name]) diff --git a/test/deprecated/legacy_test/test_load_state_dict_from_old_format.py b/test/deprecated/legacy_test/test_load_state_dict_from_old_format.py index 5a0127276ba47..54ab61e61de88 100644 --- a/test/deprecated/legacy_test/test_load_state_dict_from_old_format.py +++ b/test/deprecated/legacy_test/test_load_state_dict_from_old_format.py @@ -217,60 +217,66 @@ def check_load_state_dict(self, orig_dict, load_dict): @test_with_pir_api def test_load_default(self): - self.save_dirname = os.path.join( - self.temp_dir.name, "static_mnist.load_state_dict.default" - ) - self.model_filename = None - self.params_filename = None - orig_param_dict = self.train_and_save_model() + with paddle.base.unique_name.guard(): + self.save_dirname = os.path.join( + self.temp_dir.name, "static_mnist.load_state_dict.default" + ) + self.model_filename = None + self.params_filename = None + orig_param_dict = self.train_and_save_model() - new_load_param_dict = paddle.load(self.save_dirname) - self.check_load_state_dict(orig_param_dict, new_load_param_dict) + new_load_param_dict = paddle.load(self.save_dirname) + self.check_load_state_dict(orig_param_dict, new_load_param_dict) @test_with_pir_api def test_load_with_model_filename(self): - self.save_dirname = os.path.join( - self.temp_dir.name, "static_mnist.load_state_dict.model_filename" - ) - self.model_filename = "static_mnist.model" - self.params_filename = None - orig_param_dict = self.train_and_save_model() + with paddle.base.unique_name.guard(): + self.save_dirname = os.path.join( + self.temp_dir.name, + "static_mnist.load_state_dict.model_filename", + ) + self.model_filename = "static_mnist.model" + self.params_filename = None + orig_param_dict = self.train_and_save_model() - new_load_param_dict = paddle.load( - self.save_dirname, model_filename=self.model_filename - ) - self.check_load_state_dict(orig_param_dict, new_load_param_dict) + new_load_param_dict = paddle.load( + self.save_dirname, model_filename=self.model_filename + ) + self.check_load_state_dict(orig_param_dict, new_load_param_dict) @test_with_pir_api def test_load_with_param_filename(self): - self.save_dirname = os.path.join( - self.temp_dir.name, "static_mnist.load_state_dict.param_filename" - ) - self.model_filename = None - self.params_filename = "static_mnist.params" - orig_param_dict = self.train_and_save_model() + with paddle.base.unique_name.guard(): + self.save_dirname = os.path.join( + self.temp_dir.name, + "static_mnist.load_state_dict.param_filename", + ) + self.model_filename = None + self.params_filename = "static_mnist.params" + orig_param_dict = self.train_and_save_model() - new_load_param_dict = paddle.load( - self.save_dirname, params_filename=self.params_filename - ) - self.check_load_state_dict(orig_param_dict, new_load_param_dict) + new_load_param_dict = paddle.load( + self.save_dirname, params_filename=self.params_filename + ) + self.check_load_state_dict(orig_param_dict, new_load_param_dict) @test_with_pir_api def test_load_with_model_and_param_filename(self): - self.save_dirname = os.path.join( - self.temp_dir.name, - "static_mnist.load_state_dict.model_and_param_filename", - ) - self.model_filename = "static_mnist.model" - self.params_filename = "static_mnist.params" - orig_param_dict = self.train_and_save_model() - - new_load_param_dict = paddle.load( - self.save_dirname, - params_filename=self.params_filename, - model_filename=self.model_filename, - ) - self.check_load_state_dict(orig_param_dict, new_load_param_dict) + with paddle.base.unique_name.guard(): + self.save_dirname = os.path.join( + self.temp_dir.name, + "static_mnist.load_state_dict.model_and_param_filename", + ) + self.model_filename = "static_mnist.model" + self.params_filename = "static_mnist.params" + orig_param_dict = self.train_and_save_model() + + new_load_param_dict = paddle.load( + self.save_dirname, + params_filename=self.params_filename, + model_filename=self.model_filename, + ) + self.check_load_state_dict(orig_param_dict, new_load_param_dict) if __name__ == '__main__': diff --git a/test/dygraph_to_static/test_mobile_net.py b/test/dygraph_to_static/test_mobile_net.py index 14bda81c6bea7..b842d83a72bd5 100644 --- a/test/dygraph_to_static/test_mobile_net.py +++ b/test/dygraph_to_static/test_mobile_net.py @@ -30,6 +30,8 @@ from paddle import base from paddle.base.framework import unique_name from paddle.base.param_attr import ParamAttr +from paddle.framework import use_pir_api +from paddle.jit.pir_translated_layer import PIR_INFER_MODEL_SUFFIX from paddle.jit.translated_layer import INFER_MODEL_SUFFIX, INFER_PARAMS_SUFFIX from paddle.nn import BatchNorm, Linear @@ -592,8 +594,7 @@ def train_mobilenet(args, to_static): batch_id += 1 t_last = time.time() if batch_id > args.train_step: - # TODO(@xiongkun): open after save / load supported in pir. - if to_static and not paddle.base.framework.use_pir_api(): + if to_static: paddle.jit.save(net, args.model_save_prefix) else: paddle.save( @@ -609,7 +610,10 @@ def predict_static(args, data): paddle.enable_static() exe = base.Executor(args.place) # load inference model - + if use_pir_api(): + model_filename = args.pir_model_filename + else: + model_filename = args.model_filename [ inference_program, feed_target_names, @@ -617,7 +621,7 @@ def predict_static(args, data): ] = paddle.static.io.load_inference_model( args.model_save_dir, executor=exe, - model_filename=args.model_filename, + model_filename=model_filename, params_filename=args.params_filename, ) @@ -685,6 +689,7 @@ def train(self, model_name, to_static): ) self.args.model_filename = model_name + INFER_MODEL_SUFFIX self.args.params_filename = model_name + INFER_PARAMS_SUFFIX + self.args.pir_model_filename = model_name + PIR_INFER_MODEL_SUFFIX self.args.dy_state_dict_save_path = os.path.join( self.temp_dir.name, model_name + ".dygraph" ) @@ -717,7 +722,6 @@ def assert_same_predict(self, model_name): dy_pre = predict_dygraph(self.args, image) st_pre = predict_static(self.args, image) dy_jit_pre = predict_dygraph_jit(self.args, image) - predictor_pre = predict_analysis_inference(self.args, image) np.testing.assert_allclose( dy_pre, st_pre, @@ -730,29 +734,27 @@ def assert_same_predict(self, model_name): rtol=1e-05, err_msg=f'dy_jit_pre:\n {dy_jit_pre}\n, st_pre: \n{st_pre}.', ) - np.testing.assert_allclose( - predictor_pre, - st_pre, - rtol=1e-05, - atol=1e-05, - err_msg=f'inference_pred_res:\n {predictor_pre}\n, st_pre: \n{st_pre}.', - ) + if not use_pir_api(): + predictor_pre = predict_analysis_inference(self.args, image) + np.testing.assert_allclose( + predictor_pre, + st_pre, + rtol=1e-05, + atol=1e-05, + err_msg=f'inference_pred_res:\n {predictor_pre}\n, st_pre: \n{st_pre}.', + ) @test_legacy_and_pir - def test_mobile_net(self): - # MobileNet-V1 + def test_mobile_net_v1(self): self.assert_same_loss("MobileNetV1") - # MobileNet-V2 - self.assert_same_loss("MobileNetV2") - - # TODO(@xiongkun): open after save / load supported in pir. - if not paddle.base.framework.use_pir_api(): - self.verify_predict() - def verify_predict(self): - # MobileNet-V1 self.assert_same_predict("MobileNetV1") + + @test_legacy_and_pir + def test_mobile_net_v2(self): # MobileNet-V2 + self.assert_same_loss("MobileNetV2") + self.assert_same_predict("MobileNetV2") diff --git a/test/dygraph_to_static/test_resnet.py b/test/dygraph_to_static/test_resnet.py index 665620d1da390..dd5d386b82bc9 100644 --- a/test/dygraph_to_static/test_resnet.py +++ b/test/dygraph_to_static/test_resnet.py @@ -29,6 +29,7 @@ import paddle from paddle.base import core +from paddle.framework import use_pir_api SEED = 2020 IMAGENET1000 = 1281167 @@ -260,6 +261,9 @@ def __init__(self): self.model_filename = ( "resnet" + paddle.jit.translated_layer.INFER_MODEL_SUFFIX ) + self.pir_model_filename = ( + "resnet" + paddle.jit.pir_translated_layer.PIR_INFER_MODEL_SUFFIX + ) self.params_filename = ( "resnet" + paddle.jit.translated_layer.INFER_PARAMS_SUFFIX ) @@ -339,12 +343,7 @@ def train(self, to_static, build_strategy=None): ) if batch_id == 10: if to_static: - # TODO(@xiongkun): open after save / load supported in pir. - if ( - to_static - and not paddle.base.framework.use_pir_api() - ): - paddle.jit.save(resnet, self.model_save_prefix) + paddle.jit.save(resnet, self.model_save_prefix) else: paddle.save( resnet.state_dict(), @@ -374,6 +373,11 @@ def predict_dygraph(self, data): def predict_static(self, data): with static_guard(): exe = paddle.static.Executor(place) + if use_pir_api(): + model_filename = self.pir_model_filename + else: + model_filename = self.model_filename + [ inference_program, feed_target_names, @@ -381,7 +385,7 @@ def predict_static(self, data): ] = paddle.static.load_inference_model( self.model_save_dir, executor=exe, - model_filename=self.model_filename, + model_filename=model_filename, params_filename=self.params_filename, ) @@ -426,7 +430,6 @@ def verify_predict(self): dy_pre = self.resnet_helper.predict_dygraph(image) st_pre = self.resnet_helper.predict_static(image) dy_jit_pre = self.resnet_helper.predict_dygraph_jit(image) - predictor_pre = self.resnet_helper.predict_analysis_inference(image) np.testing.assert_allclose( dy_pre, st_pre, @@ -439,12 +442,14 @@ def verify_predict(self): rtol=1e-05, err_msg=f'dy_jit_pre:\n {dy_jit_pre}\n, st_pre: \n{st_pre}.', ) - np.testing.assert_allclose( - predictor_pre, - st_pre, - rtol=1e-05, - err_msg=f'predictor_pre:\n {predictor_pre}\n, st_pre: \n{st_pre}.', - ) + if not use_pir_api(): + predictor_pre = self.resnet_helper.predict_analysis_inference(image) + np.testing.assert_allclose( + predictor_pre, + st_pre, + rtol=1e-05, + err_msg=f'predictor_pre:\n {predictor_pre}\n, st_pre: \n{st_pre}.', + ) @test_default_and_pir def test_resnet(self): @@ -456,9 +461,7 @@ def test_resnet(self): rtol=1e-05, err_msg=f'static_loss: {static_loss} \n dygraph_loss: {dygraph_loss}', ) - # TODO(@xiongkun): open after save / load supported in pir. - if not paddle.framework.use_pir_api(): - self.verify_predict() + self.verify_predict() @test_default_and_pir def test_resnet_composite(self):