-
Notifications
You must be signed in to change notification settings - Fork 5.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Prune Design Doc #4732
Prune Design Doc #4732
Conversation
## Motivation | ||
|
||
We want to support running inference, training and checkpointing in one `ProgramDesc`. We implement | ||
`void Prune(const ProgramDesc* input, ProgramDesc* output)` function, which takes a `ProgramDesc` |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
1, need target
in the interface
2, if we will not change input, than we should use reference
void Prune(
const ProgramDesc& input,
const std::vector<OpDesc*>& targets,
ProgramDesc* output)
{...}
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
- I added
is_target
field inOpDesc
- Sure.
doc/design/prune.md
Outdated
repeated Var inputs = 1; | ||
repeated Var outputs = 2; | ||
repeated Attr attrs = 4; | ||
required bool is_target = 5 [ default = false ]; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
do we still use is_target
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I want the ProgramDesc
to be self-complete, in case we need to do Prune
on the master in the future.
@@ -0,0 +1,110 @@ | |||
# Prune | |||
|
|||
## Motivation |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
where will Pruning happen, on the client or on the master?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In the current design, prune is based only on ProgramDesc
. So both designs are fine.
doc/design/prune.md
Outdated
repeated Var inputs = 1; | ||
repeated Var outputs = 2; | ||
repeated Attr attrs = 4; | ||
required bool is_target = 5 [ default = false ]; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
When will is_target
be set?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
When a client calls sess.run(target=loss)
, PaddlePaddle will make a copy of ProgramDesc
and set is_target
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So the sess.run(target=loss)
is supposed to be in a for
loop and it will take the same target repeated. We can optimize it by caching targets and their corresponding pruned ProgramDesc.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
right.
doc/design/prune.md
Outdated
// erase its output to the dependency graph | ||
for (auto& var : op_desc.outputs()) { | ||
for (auto& argu : var.arguments()) { | ||
dependent_vars.erase(argu); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we have to erase them?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good point. In order to be on the safe side, I will NOT erase them. Even though in all the cases I have thought of, erasing works fine.
doc/design/prune.md
Outdated
// add pruned ops to output | ||
for (size_t i = 0; i < should_run.size(); ++i) { | ||
if (should_run[i]) { | ||
output->AppendOp(input->block[0].ops(i)); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Make sure output
is empty before appending any op into it.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sure. The check has been added to the implementation.
doc/design/prune.md
Outdated
Then the whole algorithm can be implemented as the following | ||
|
||
```c++ | ||
void Prune(const ProgramDesc& input, ProgramDesc* output, int id) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What is the purpose of having id
as an argument for Prune
? How should we use the Prune
function with the id
argument?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
id
is the block_id. Just like executor::run
takes a block_id, it denotes which block, usually the root block, to be pruned.
I will change the name to block_id
to make it less confusing.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I thought Prune
will always take the entire given ProgramDesc
, and outputs a minimal ProgramDesc
. Otherwise if the user pass in argument block_id = 2
, what should Prune
do to other blocks (e.g., block 0, block 1): keep it or discard it / prune it or not?
Adding one more argument adds more complexity, do we really need block_id
for Prune
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I want Prune
to be called recursively. Say it encounters a RNNOp
in block 0, then we can simply call prune(rnn_sub_block_id)
to do the pruning there. Although the support for RNN and IfElse block is still under development, I believe we should leave the block_id
as an argument here.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
One alternative to consider is something like this:
// private
void pruneImpl(const ProgramDesc& input, ProgramDesc* output, int id) {
if (rnn) {
pruneImpl(input, output, id+1);
return;
}
}
// public
void Prune(const ProgramDesc& input, ProgramDesc* output) {
pruneImpl(input, output, 0);
}
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I can't agree more. It is awesome.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not relevant to this PR, I think the same design concept could be applied to executor
as well.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good idea! If you want to send a PR changing the executor design doc, I will instantly approve :p
LGTM, @jacquesqiao has some comment, maybe he can approve. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM!
Implementation merged at #4738