Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Backward on parallel do using nccl #8361

Merged
merged 38 commits into from
Feb 20, 2018

Conversation

tonyyang-svail
Copy link

@tonyyang-svail tonyyang-svail commented Feb 11, 2018

Part of the optimization of parallel_do

This PR contains the following:

  1. add nccl library to the framework
  2. add nccl callback on backward
  3. add nncl flag in parallel do
  4. use asign op to overwrite the reduced gradient
  5. verify the correctness of parallel_do with nccl

@CLAassistant
Copy link

CLAassistant commented Feb 11, 2018

CLA assistant check
All committers have signed the CLA.

@@ -239,7 +304,8 @@ def empty_callback(block, context):
sub_block = program.block(op.block_attr("sub_block"))
grad_sub_block = program.create_block(parent_idx=sub_block.idx)
_append_backward_ops_(sub_block, sub_block.ops, grad_sub_block,
no_grad_dict, grad_to_var)
no_grad_dict, grad_to_var,
_callback_lookup_(op))
Copy link
Collaborator

@JiayiFeng JiayiFeng Feb 11, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How can we apply more than one callbacks here? E.g. we would like to apply the nccl and error clips at the same time.

And I think op is a bad name. The parameter actually means the op owns the current block. op is too broad for it.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@jiangfeng Thank you for your review.

How can we apply more than one callbacks here? E.g. we would like to apply the nccl and error clips at the same time.

I will change the callback to a list of callbacks.

And I think op is a bad name. The parameter actually means the op owns the current block. op is too broad for it.

What do you mean by "The parameter"

}
}
for (auto &s : Outputs(framework::GradVarName(kParameters))) {
if (s == "@EMPTY@") {
Copy link
Author

@tonyyang-svail tonyyang-svail Feb 16, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Backward will change some of the gradients to @EMPTY@ if we don't need to calculate it. For example, we don't need to calculate the gradient of layer.data. In this case, parallel_do should skip them.

@@ -28,8 +28,8 @@ using Tensor = framework::Tensor;

// Base convolution operator definations for other conv
// like operators to reuse the implementation.
inline int OutputSize(int input_size, int filter_size, int dilation,
int padding, int stride) {
inline int ConvOutputSize(int input_size, int filter_size, int dilation,
Copy link
Author

@tonyyang-svail tonyyang-svail Feb 16, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Name OutputSize is too general. And it is under the namescope paddle::namespace, which is too broad.

@@ -60,8 +60,9 @@ void ConvOp::InferShape(framework::InferShapeContext* ctx) const {
"Due to the settings of paddings, filter_dims and "
"dilations, the output size is less than 0, please check "
"again.");
output_shape.push_back(OutputSize(in_dims[i + 2], filter_dims[i + 2],
dilations[i], paddings[i], strides[i]));
output_shape.push_back(ConvOutputSize(in_dims[i + 2], filter_dims[i + 2],
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When I was debugging conv_op.cc, looks like OutputSize has been linked to another function...

@@ -14,10 +14,13 @@ limitations under the License. */

#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/nccl/nccl_gpu_common.h"
#include "paddle/fluid/operators/nccl/nccl_gpu_common.h"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

paddle/fluid/operators/nccl/nccl_gpu_common.h is included twice.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

typo. thanks for pointing it out.

PADDLE_ENFORCE(!gpus.empty(), "Attr(gpus) should not be empty.");
// A parallel do may not use all the gpus. For example, the batch size is 7
// in the last batch while we have 8 gpu. In this case, parallel_do will
// create 7 parallel scopes, so should ncclInitOp create 7 gpu peers
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You mentioned "last batch", is it implying ncclInitOp will be called for every mini-batch?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes.

// in the last batch while we have 8 gpu. In this case, parallel_do will
// create 7 parallel scopes, so should ncclInitOp create 7 gpu peers
auto &parallel_scopes = scope.FindVar(Input(kParallelScopes))
->Get<std::vector<framework::Scope *>>();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does Scope support serialization?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When do we need to serialize scope?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry, never mind, I got confused.

auto &parallel_scopes = scope.FindVar(Input(kParallelScopes))
->Get<std::vector<framework::Scope *>>();
std::vector<int> gpus(parallel_scopes.size());
for (int i = 0; i < static_cast<int>(parallel_scopes.size()); ++i) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why only parallel_scopes.size() is used, could be just pass kNumParallelScopes instead of kParallelScopes?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we don't know kNumParallelScopes at the compilation time

}
}
for (auto &s : Outputs(framework::GradVarName(kParameters))) {
if (s == "@EMPTY@") {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can @EMPTY@ be put into a constant?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sure.

VLOG(3) << "Moving " << s;
CopyOrShare(*sub_scopes[0]->FindVar(s), place, scope.FindVar(s));
}
WaitOnPlaces(places);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not related to this PR, but I am curious why parallel do has to wait for all stream to complete? I thought even the executor does not wait.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I can't think of a case where it is wrong without waiting. But just as we always wait for threads to be joined after we launched them, I feel it's nature for parallel_do to wait for all streams.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That could affect performance. We introduced a synchronization point which we are not sure if we need.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are several places in parallel do where we have to wait. This line won't be a large effect

def _callback_lookup_(op):
"""
Only used in _append_backward_ops_
Build and returns a callback function for certain op. For example
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you add more comment about what is callback function? (e.g, is it something gets called after a OP is completed?)

op_desc = _create_op_desc_(
"ncclInit",
{"parallel_scopes": self.parallel_scopes_name},
{"Communicator": ['nccl_com__do_not_change_']}, {})
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you put nccl_com__do_not_change_ into a constant?

"X": [o_argu],
"Communicator":
['nccl_com__do_not_change_']
}, {"Out": [allreduce_out_name]},
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

allreduce_out_name is assigned to o_argu in the next op, why o_argu can not be the output here so we don't need the next assign op.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the ncclAllreduce requires a buffer memory to hold the result, i.e. it doesn't support in place.

else:
new_callbacks = callbacks + [_callback_lookup_(op)]
_append_backward_ops_(sub_block, sub_block.ops, grad_sub_block,
no_grad_dict, grad_to_var, new_callbacks)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need to do callbacks = new_callbacks, since callbacks is used later as well.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

callbacks used later should not contain _callback_lookup_(op)

helinwang
helinwang previously approved these changes Feb 17, 2018
Copy link
Contributor

@helinwang helinwang left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM!

Copy link
Contributor

@helinwang helinwang left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM!

@helinwang helinwang merged commit 633756a into PaddlePaddle:develop Feb 20, 2018
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

6 participants