Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Dev refine backward no grad var handling #7140

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions doc/design/backward.md
Original file line number Diff line number Diff line change
Expand Up @@ -106,9 +106,11 @@ See function `_addup_repetitive_outputs_` in `backward.py` for implementation de

In our framework, variables can be marked as *no_gradient*, it means that the gradient of this variable is unnecessary and can be considered as zero in model training. Apparently, when all the outputs of some `grad_op` are marked as *no_gradient*, the `grad_op` itself can be skipped in backward pass.

But these unnecessary gradients still need to be creating and initialized by something, otherwise following `grad_op`s who take these gradients as inputs take the risk of using uninitialized memory. In our code, we employ `fill_zeros_like_op` to initialize them as all zeros.
Another situation is all the gradient inputs of some `grad_op` are marked as *no_gradient*, which means all of them can be considered as zeros. For `grad_op`s are in essence the propagation of gradients, all the outputs are definitely zeros when all gradient inputs are zeros. Therefore the `grad_op` can also be skipped.

This features are implemented in function `_remove_no_grad_branch_`. It checks new created `grad_op`s one-by-one, removes whose outputs are all in `no_grad_set` or inserts `fill_zeros_like_op` when its necessary. We can get the `no_grad_set` from the `_append_backward_ops_` argument `no_grad_dict` or generate it on the fly by scanning all variables' `no_gradient` attribute(True or False).
It should be noted that all these zero gradients still need to be creating and initialized by something, otherwise following `grad_op`s who take these gradients as inputs take the risk of using uninitialized memory. In our code, we employ `fill_zeros_like_op` to initialize them as all zeros.

This features are implemented in function `_remove_no_grad_branch_`. It checks new created `grad_op`s one-by-one, removes who can be skipped and inserts `fill_zeros_like_op` when its necessary. We can get the `no_grad_set` from the `_append_backward_ops_` argument `no_grad_dict` or generate it on the fly by scanning all variables' `no_gradient` attribute(True or False).

### Creating Backward Variables

Expand Down
23 changes: 19 additions & 4 deletions python/paddle/v2/fluid/backward.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,8 @@ def _all_in_set_(cands, s):
"""
Test if all elements of 'cands' are in set 's'
"""
if len(cands) == 0:
return False
for c in cands:
if not c in s:
return False
Expand Down Expand Up @@ -136,12 +138,23 @@ def _remove_no_grad_branch_(op_descs, no_grad_set):
Remove unnecessary grad ops
A grad op can be removed in two cases:
1. all outputs of the grad op are in 'no_grad_set'
2. (TODO) all grad inputs of the grad op are in 'no_grad_set'
2. all grad inputs of the grad op are in 'no_grad_set'
"""

def _op_can_be_removed_(op_desc, no_grad_set):
out_arg_names = op_desc.output_arg_names()
if len(out_arg_names) == 0 or _all_in_set_(out_arg_names, no_grad_set):
return True
if _all_in_set_(
filter(lambda name: name.find(core.grad_var_suffix()) != -1,
op_desc.input_arg_names()), no_grad_set):
no_grad_set.union(out_arg_names)
return True
return False

# Remove ops whose outputs are all in no_grad_dict
op_descs = filter(
lambda op_desc: not _all_in_set_(op_desc.output_arg_names(), no_grad_set),
op_descs)
lambda op_desc: not _op_can_be_removed_(op_desc, no_grad_set), op_descs)
# Insert fill_zeros_like_op
to_insert = []
for idx, op_desc in enumerate(op_descs):
Expand Down Expand Up @@ -284,7 +297,9 @@ def append_backward(loss, parameter_list=None, no_grad_set=None):
block_no_grad_set.add(_append_grad_suffix_(var.name))
no_grad_dict[block.idx] = block_no_grad_set
elif isinstance(no_grad_set, set):
no_grad_dict = {0: no_grad_set}
no_grad_dict = {
0: set([_append_grad_suffix_(name) for name in no_grad_set])
}
else:
raise ValueError("'no_grad_set' should be a set or None.")

Expand Down