Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Revise python save load api using new load/save op #7995

Merged
merged 10 commits into from
Feb 1, 2018
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion python/paddle/v2/fluid/framework.py
Original file line number Diff line number Diff line change
Expand Up @@ -489,7 +489,7 @@ def find_name(var_list, name):
no_kernel_op_set = {
'feed', 'fetch', 'save', 'load', 'recurrent',
'rnn_memory_helper_grad', 'conditional_block', 'while', 'send',
'recv', 'parallel_do'
'recv', 'parallel_do', 'save_combine', 'load_combine'
}
if type not in no_kernel_op_set:
self.desc.infer_var_type(self.block.desc)
Expand Down
174 changes: 125 additions & 49 deletions python/paddle/v2/fluid/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,12 @@ def _clone_var_in_block_(block, var):
persistable=True)


def save_vars(executor, dirname, main_program=None, vars=None, predicate=None):
def save_vars(executor,
dirname,
main_program=None,
vars=None,
predicate=None,
save_file_name=None):
"""
Save variables to directory by executor.

Expand All @@ -69,9 +74,12 @@ def save_vars(executor, dirname, main_program=None, vars=None, predicate=None):
:param main_program: program. If vars is None, then filter all variables in this
program which fit `predicate`. Default default_main_program.
:param predicate: The Predicate describes a callable that returns a variable
as a bool. If it returns true, the variables will be saved.
:param vars: variables need to be saved. If specify vars, program & predicate
as a bool. If it returns true, the corresponding input variable will be saved.
:param vars: variables need to be saved. If vars is specified, program & predicate
will be ignored
:param save_file_name: The name of a single file that all vars are saved to.
If it is None, save variables to separate files.

:return: None
"""
if vars is None:
Expand All @@ -83,21 +91,40 @@ def save_vars(executor, dirname, main_program=None, vars=None, predicate=None):
save_vars(
executor,
dirname=dirname,
vars=filter(predicate, main_program.list_vars()))
vars=filter(predicate, main_program.list_vars()),
save_file_name=save_file_name)
else:
save_program = Program()
save_block = save_program.global_block()
for each_var in vars:
new_var = _clone_var_in_block_(save_block, each_var)

if save_file_name is None:
for each_var in vars:
new_var = _clone_var_in_block_(save_block, each_var)
save_block.append_op(
type='save',
inputs={'X': [new_var]},
outputs={},
attrs={'file_path': os.path.join(dirname, new_var.name)})
else:
save_var_map = {}
for each_var in vars:
new_var = _clone_var_in_block_(save_block, each_var)
save_var_map[new_var.name] = new_var

save_var_list = []
for name in sorted(save_var_map.keys()):
save_var_list.append(save_var_map[name])

save_block.append_op(
type='save',
inputs={'X': [new_var]},
type='save_combine',
inputs={'X': save_var_list},
outputs={},
attrs={'file_path': os.path.join(dirname, new_var.name)})
attrs={'file_path': os.path.join(dirname, save_file_name)})

executor.run(save_program)


def save_params(executor, dirname, main_program=None):
def save_params(executor, dirname, main_program=None, save_file_name=None):
"""
Save all parameters to directory with executor.
"""
Expand All @@ -106,10 +133,12 @@ def save_params(executor, dirname, main_program=None):
dirname=dirname,
main_program=main_program,
vars=None,
predicate=is_parameter)
predicate=is_parameter,
save_file_name=save_file_name)


def save_persistables(executor, dirname, main_program=None):
def save_persistables(executor, dirname, main_program=None,
save_file_name=None):
"""
Save all persistables to directory with executor.
"""
Expand All @@ -118,21 +147,30 @@ def save_persistables(executor, dirname, main_program=None):
dirname=dirname,
main_program=main_program,
vars=None,
predicate=is_persistable)
predicate=is_persistable,
save_file_name=save_file_name)


def load_vars(executor, dirname, main_program=None, vars=None, predicate=None):
def load_vars(executor,
dirname,
main_program=None,
vars=None,
predicate=None,
load_file_name=None):
"""
Load variables from directory by executor.

:param executor: executor that save variable
:param executor: executor that load variable
:param dirname: directory path
:param main_program: program. If vars is None, then filter all variables in this
program which fit `predicate`. Default default_main_program().
:param predicate: The Predicate describes a callable that returns a variable
as a bool. If it returns true, the variables will be loaded.
:param vars: variables need to be loaded. If specify vars, program &
as a bool. If it returns true, the corresponding input variable will be loaded.
:param vars: variables need to be loaded. If vars is specified, program &
predicate will be ignored
:param load_file_name: The name of the single file that all vars are loaded from.
If it is None, load variables from separate files.

:return: None
"""
if vars is None:
Expand All @@ -144,42 +182,64 @@ def load_vars(executor, dirname, main_program=None, vars=None, predicate=None):
load_vars(
executor,
dirname=dirname,
vars=filter(predicate, main_program.list_vars()))
vars=filter(predicate, main_program.list_vars()),
load_file_name=load_file_name)
else:
load_prog = Program()
load_block = load_prog.global_block()
for each_var in vars:
assert isinstance(each_var, Variable)
new_var = _clone_var_in_block_(load_block, each_var)

if load_file_name is None:
for each_var in vars:
assert isinstance(each_var, Variable)
new_var = _clone_var_in_block_(load_block, each_var)
load_block.append_op(
type='load',
inputs={},
outputs={'Out': [new_var]},
attrs={'file_path': os.path.join(dirname, new_var.name)})
else:
load_var_map = {}
for each_var in vars:
assert isinstance(each_var, Variable)
new_var = _clone_var_in_block_(load_block, each_var)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Move the common codes line 202 - 204 out of the if statement?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

load_var_map[new_var.name] = new_var

load_var_list = []
for name in sorted(load_var_map.keys()):
load_var_list.append(load_var_map[name])

load_block.append_op(
type='load',
type='load_combine',
inputs={},
outputs={"Out": [new_var]},
attrs={'file_path': os.path.join(dirname, new_var.name)})
outputs={"Out": load_var_list},
attrs={'file_path': os.path.join(dirname, load_file_name)})

executor.run(load_prog)


def load_params(executor, dirname, main_program=None):
def load_params(executor, dirname, main_program=None, load_file_name=None):
"""
load all parameters from directory by executor.
"""
load_vars(
executor,
dirname=dirname,
main_program=main_program,
predicate=is_parameter)
predicate=is_parameter,
load_file_name=load_file_name)


def load_persistables(executor, dirname, main_program=None):
def load_persistables(executor, dirname, main_program=None,
load_file_name=None):
"""
load all persistables from directory by executor.
"""
load_vars(
executor,
dirname=dirname,
main_program=main_program,
predicate=is_persistable)
predicate=is_persistable,
load_file_name=load_file_name)


def get_inference_program(target_vars, main_program=None):
Expand Down Expand Up @@ -234,11 +294,27 @@ def append_fetch_ops(inference_program,
attrs={'col': i})


def get_parameters(program):
parameter_list = []
input_args = set()
for block in program.blocks:
for op in block.ops:
if op.desc.type() != 'feed':
input_args.update(op.desc.input_arg_names())

for var in program.list_vars():
if is_persistable(var) and var.name in input_args:
parameter_list.append(var)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is actually not a parameter list but a persistable variable list. Normally, the program should not contain unreferenced variables, so if var.name in input_args should be removed. When loading, if a persistable variable is absent, there should be some error message.

Copy link
Contributor Author

@kexinzhao kexinzhao Jan 31, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the explanation. Then I guess we don't need to define this function. There is already a save_persistable method, so I will use that one instead

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But I think we still need to exclude 'feed' and 'fetch' variables right (because they have been added to the program desc)? They are also persistable and we don't want to store them.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I tried to remove var.name in input_args, then 'feed' and 'fetch' variables are also included, which incurs an error "The type of var fetch is unsupported" since the type of feed/fetch is vector<lodTensor> and is not supported by load / save op.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this implementation can potentially solve the problem described in PR #8020

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I will try including these changes to actually verify.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But I think we still need to exclude 'feed' and 'fetch' variables right

@kexinzhao Can we change the function is_persistable() or rewrite load/save_persistables() to exclude feed and fetch variables?

I think this implementation can potentially solve the problem described in PR #8020

@sidgoyal78 I think the problem in #8020 is that, means of bn is not parameter but persistable variable. In fact, we should save all persistable variables in save_inference_model, not only parameters. I think about this for a long time, and there are some issues for this: #7931 #7163

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah. 👍

Copy link
Contributor Author

@kexinzhao kexinzhao Feb 1, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we change the function is_persistable() or rewrite load/save_persistables() to exclude feed and fetch variables?

I think of three ways to redefine load/save_persistables():

  1. One is like this:
def load_persitables(xxx):
     parameter_list = get_parameters(program)
     load_var(xxx, parameter_list)

which basically moved the usage of get_parameters inside save_load_persistables.
You can also get rid of the get_parameters method, move its code directly into save/load_persistable, but we prefer not to do this because of duplicate code in two functions.

  1. If we don't want to use the code in get_parameters() method, then basically we want to modify the is_persistable(var) predicate function so that it can exclude 'feed' and 'fetch' vars.
    Note that since this is a predicate, we only have the variable as input.
    Although as show in framework.py that class Variable has a data member op, this op will only be set to the operator that output this variable. Meaning that if a variable is not an output of any operator (e.g., feed and weights parameters), this var.op == None.

So we cannot use code like below to exclude 'feed' and 'fetch'

def is_persistable(var):
    if var.op.desc.type() == 'feed' or var.op.desc.type() == 'fetch':
        return false
    return var.persistable
  1. Just like we define the feed/fetch operator type to be fixed as 'feed'/'fetch', we can also fix the name of the feed/fetch variable to be 'feed'/'fetch' (or some better names). This means that we can get rid of the API in the current Inference Design that optionally allows user to provide its own 'feed_holder_name' and 'fetch_holder_name'. For this design, we can simply modify the is_persistable as follows:
def is_persistable(var):
    if var.desc.name() == 'feed' or var.desc.name() == 'fetch':
        return false
    return var.persistable

If we want to go with this option, we can firstly do a quick fix in this pr using the code above. Then fix the feed/fetch var name, modify API accordingly, set some global const kFeedVarName in C++ and pybind it to python, etc in the future PR.

@Xreki @luotao1 @sidgoyal78, which option do your prefer or do you have other suggestions?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd like the 2nd method. But I am not sure whether it is suitable to use var.op.desc.type(). In the C++ definition of VarDesc, there is not a member to record the belonged op. Also, a variable may be shared among multiple op.

For the 3rd method. I think it is not suitable to use the name, but may be we can use the type, which should be FEED_MINIBATCH.


return parameter_list


def save_inference_model(dirname,
feeded_var_names,
target_vars,
executor,
main_program=None):
main_program=None,
save_file_name=None):
"""
Build a model especially for inference,
and save it to directory by the executor.
Expand All @@ -249,6 +325,8 @@ def save_inference_model(dirname,
:param executor: executor that save inference model
:param main_program: original program, which will be pruned to build the inference model.
Default default_main_program().
:param save_file_name: The name of a single file that all parameters are saved to.
If it is None, save parameters to separate files.

:return: None
"""
Expand Down Expand Up @@ -283,25 +361,13 @@ def save_inference_model(dirname,
with open(model_file_name, "wb") as f:
f.write(inference_program.desc.serialize_to_string())

save_params(executor, dirname, main_program)


def load_persistables_if_exist(executor, dirname, main_program=None):
filenames = next(os.walk(dirname))[2]
filenames = set(filenames)

def _is_presistable_and_exist_(var):
if not is_persistable(var):
return False
else:
return var.name in filenames

load_vars(
parameter_list = get_parameters(inference_program)
save_vars(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can try to call save_persistables here.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

executor,
dirname,
main_program=main_program,
vars=None,
predicate=_is_presistable_and_exist_)
inference_program,
parameter_list,
save_file_name=save_file_name)


def get_feed_targets_names(program):
Expand All @@ -322,13 +388,15 @@ def get_fetch_targets_names(program):
return fetch_targets_names


def load_inference_model(dirname, executor):
def load_inference_model(dirname, executor, load_file_name=None):
"""
Load inference model from a directory

:param dirname: directory path
:param executor: executor that load inference model

:param load_file_name: The name of the single file that all parameters are loaded from.
If it is None, load parameters from separate files.

:return: [program, feed_target_names, fetch_targets]
program: program especially for inference.
feed_target_names: Names of variables that need to feed data
Expand All @@ -342,7 +410,13 @@ def load_inference_model(dirname, executor):
program_desc_str = f.read()

program = Program.parse_from_string(program_desc_str)
load_persistables_if_exist(executor, dirname, program)
parameter_list = get_parameters(program)
load_vars(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can also try to call load_persistables here.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

executor,
dirname,
program,
parameter_list,
load_file_name=load_file_name)

feed_target_names = get_feed_targets_names(program)
fetch_target_names = get_fetch_targets_names(program)
Expand All @@ -359,6 +433,7 @@ def get_parameter_value(para, executor):

:param executor: executor for retrieving the value
:param para: the given parameter

:return: the LoDTensor for the parameter
"""
assert is_parameter(para)
Expand All @@ -377,6 +452,7 @@ def get_parameter_value_by_name(name, executor, program=None):
:param name: the name of the parameter
:param program: the program where the variable is found
Default default_main_program().

:return: the LoDTensor for the variable
"""
if program is None:
Expand Down