Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Revise python save load api using new load/save op #7995

Merged
merged 10 commits into from
Feb 1, 2018

Conversation

kexinzhao
Copy link
Contributor

@kexinzhao kexinzhao commented Jan 31, 2018

fix #7959

op's info is not well synchronized on the python side. So I have to use the OpDesc info. input_arg_names() is defined in protobuf.cc as a binding to InputArguments() method in the c++ OpDesc class.

@kexinzhao kexinzhao added the 预测 原名Inference,包含Capi预测问题等 label Jan 31, 2018
@kexinzhao kexinzhao added this to Documentation in Inference Framework Jan 31, 2018
@kexinzhao kexinzhao moved this from Documentation to DOING in Inference Framework Jan 31, 2018
main_program=None,
vars=None,
predicate=None,
save_file_name='__parameters__'):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I suggest setting the default value of save_filename to None. If None, then all variables will be saved into separate files as before. I am not sure if it is suitable to change the storing format of training results. So maybe it is better to enable two formats?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree. Will do.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

main_program=None,
vars=None,
predicate=None,
load_file_name='__parameters__'):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The same as save_vars

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

vars=None,
predicate=_is_presistable_and_exist_)
parameter_list = get_parameters(inference_program)
save_vars(executor, dirname, inference_program, parameter_list)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

#7874 is merged. Directly change of the save_inference_model will fail the CI of the develop branch. Please update the develop branch first.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Got it. Thanks!


for var in program.list_vars():
if is_persistable(var) and var.name in input_args:
parameter_list.append(var)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is actually not a parameter list but a persistable variable list. Normally, the program should not contain unreferenced variables, so if var.name in input_args should be removed. When loading, if a persistable variable is absent, there should be some error message.

Copy link
Contributor Author

@kexinzhao kexinzhao Jan 31, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the explanation. Then I guess we don't need to define this function. There is already a save_persistable method, so I will use that one instead

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But I think we still need to exclude 'feed' and 'fetch' variables right (because they have been added to the program desc)? They are also persistable and we don't want to store them.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I tried to remove var.name in input_args, then 'feed' and 'fetch' variables are also included, which incurs an error "The type of var fetch is unsupported" since the type of feed/fetch is vector<lodTensor> and is not supported by load / save op.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this implementation can potentially solve the problem described in PR #8020

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I will try including these changes to actually verify.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But I think we still need to exclude 'feed' and 'fetch' variables right

@kexinzhao Can we change the function is_persistable() or rewrite load/save_persistables() to exclude feed and fetch variables?

I think this implementation can potentially solve the problem described in PR #8020

@sidgoyal78 I think the problem in #8020 is that, means of bn is not parameter but persistable variable. In fact, we should save all persistable variables in save_inference_model, not only parameters. I think about this for a long time, and there are some issues for this: #7931 #7163

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah. 👍

Copy link
Contributor Author

@kexinzhao kexinzhao Feb 1, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we change the function is_persistable() or rewrite load/save_persistables() to exclude feed and fetch variables?

I think of three ways to redefine load/save_persistables():

  1. One is like this:
def load_persitables(xxx):
     parameter_list = get_parameters(program)
     load_var(xxx, parameter_list)

which basically moved the usage of get_parameters inside save_load_persistables.
You can also get rid of the get_parameters method, move its code directly into save/load_persistable, but we prefer not to do this because of duplicate code in two functions.

  1. If we don't want to use the code in get_parameters() method, then basically we want to modify the is_persistable(var) predicate function so that it can exclude 'feed' and 'fetch' vars.
    Note that since this is a predicate, we only have the variable as input.
    Although as show in framework.py that class Variable has a data member op, this op will only be set to the operator that output this variable. Meaning that if a variable is not an output of any operator (e.g., feed and weights parameters), this var.op == None.

So we cannot use code like below to exclude 'feed' and 'fetch'

def is_persistable(var):
    if var.op.desc.type() == 'feed' or var.op.desc.type() == 'fetch':
        return false
    return var.persistable
  1. Just like we define the feed/fetch operator type to be fixed as 'feed'/'fetch', we can also fix the name of the feed/fetch variable to be 'feed'/'fetch' (or some better names). This means that we can get rid of the API in the current Inference Design that optionally allows user to provide its own 'feed_holder_name' and 'fetch_holder_name'. For this design, we can simply modify the is_persistable as follows:
def is_persistable(var):
    if var.desc.name() == 'feed' or var.desc.name() == 'fetch':
        return false
    return var.persistable

If we want to go with this option, we can firstly do a quick fix in this pr using the code above. Then fix the feed/fetch var name, modify API accordingly, set some global const kFeedVarName in C++ and pybind it to python, etc in the future PR.

@Xreki @luotao1 @sidgoyal78, which option do your prefer or do you have other suggestions?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd like the 2nd method. But I am not sure whether it is suitable to use var.op.desc.type(). In the C++ definition of VarDesc, there is not a member to record the belonged op. Also, a variable may be shared among multiple op.

For the 3rd method. I think it is not suitable to use the name, but may be we can use the type, which should be FEED_MINIBATCH.

@Xreki Xreki requested a review from luotao1 January 31, 2018 06:31
load_var_map = {}
for each_var in vars:
assert isinstance(each_var, Variable)
new_var = _clone_var_in_block_(load_block, each_var)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Move the common codes line 202 - 204 out of the if statement?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done


for var in program.list_vars():
if is_persistable(var) and var.name in input_args:
parameter_list.append(var)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But I think we still need to exclude 'feed' and 'fetch' variables right

@kexinzhao Can we change the function is_persistable() or rewrite load/save_persistables() to exclude feed and fetch variables?

I think this implementation can potentially solve the problem described in PR #8020

@sidgoyal78 I think the problem in #8020 is that, means of bn is not parameter but persistable variable. In fact, we should save all persistable variables in save_inference_model, not only parameters. I think about this for a long time, and there are some issues for this: #7931 #7163


load_vars(
parameter_list = get_parameters(inference_program)
save_vars(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can try to call save_persistables here.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

@@ -342,7 +410,13 @@ def load_inference_model(dirname, executor):
program_desc_str = f.read()

program = Program.parse_from_string(program_desc_str)
load_persistables_if_exist(executor, dirname, program)
parameter_list = get_parameters(program)
load_vars(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can also try to call load_persistables here.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

@@ -46,6 +46,9 @@ def is_parameter(var):


def is_persistable(var):
if var.desc.type() == core.VarDesc.VarType.FEED_MINIBATCH or \
var.desc.type() == core.VarDesc.VarType.FETCH_LIST:
return False
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@Xreki I have changed the code and go with option 3 using your suggestion. For option 2, there is problem. Because in the python side of the code, the operator op field of var will only be associated with the operator that have this variable as its output. So for feed variable, since it is not the output of any operator. Its op data member will be None.

Copy link
Contributor

@Xreki Xreki left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM. But I am thinking about the what the interface load_inference_model should be. I mean, the argument list.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
预测 原名Inference,包含Capi预测问题等
Projects
No open projects
Inference Framework
Basic Usage (DONE)
Development

Successfully merging this pull request may close these issues.

Modify save/load api for the new combined load save operators
3 participants