Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature/shrink memory op #5419

Merged
merged 8 commits into from
Nov 8, 2017

Conversation

reyoung
Copy link
Collaborator

@reyoung reyoung commented Nov 7, 2017

No description provided.

Copy link

@tonyyang-svail tonyyang-svail left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM.

PADDLE_ENFORCE(context->HasInput("X"));
PADDLE_ENFORCE(context->HasInput("I"));
PADDLE_ENFORCE(context->HasInput("RankTable"));
context->SetOutputDim("Out", context->GetInputDim("X"));
Copy link

@tonyyang-svail tonyyang-svail Nov 8, 2017

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ShrinkStateOpInferShape looks difficult to implement.

During the compilation time, we don't have enough information to infer shape Out, because we don't know what will be filled in the RankTable. During the runtime, we could use RankTable, but it will make the code different between these two phrases infer shape.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The RankTable is only related with the value of Out's first dimension. Other dimensions of Out is just as same as X. Inferring the other dimensions and leave the first dimension to the runtime, that is exactly InferShape should do.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see. Good point.

auto *out_var = scope.FindVar(Output("Out"));
PADDLE_ENFORCE(out_var != nullptr, "Output Out must be set");
auto &out_tensor = *out_var->GetMutable<framework::LoDTensor>();
out_tensor.ShareDataWith(x_tensor.Slice(0, dst_num_rows));

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note: the shape of out_tensor has been changed by ShareDataWith. This violates the general design that a tensor's shape should only be modified by InferShape. But as pointed out below, it is relatively hard to infer the shape of out_tensor, so we consider this behavior as an exception.

tonyyang-svail
tonyyang-svail previously approved these changes Nov 8, 2017
Used for shrink memories state in DyRNN. The height of state could
be shrinked after running a step block.

{
auto &rank_items = rank_table.items();
for (auto &rank_item : rank_items) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Using binary search can make it faster when rank_table is big.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

namespace paddle {
namespace operators {

class ShrinkStateOp : public ArrayOp {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think ShrinkStateOp is a good name. We don't really 'shrink' the memory block, we only do a slice on the original one. And state is also confusing. Maybe we can call it RearrangeRnnMemoryOp

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should indicate RNN in the name since the Op will only be used in RNN. We could even put DyRNN in the name.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

PADDLE_ENFORCE(context->HasInput("X"));
PADDLE_ENFORCE(context->HasInput("I"));
PADDLE_ENFORCE(context->HasInput("RankTable"));
context->SetOutputDim("Out", context->GetInputDim("X"));
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The RankTable is only related with the value of Out's first dimension. Other dimensions of Out is just as same as X. Inferring the other dimensions and leave the first dimension to the runtime, that is exactly InferShape should do.

const platform::DeviceContext &dev_ctx) const override {
auto *dout_var = scope.FindVar(Input(framework::GradVarName("Out")));
auto dx_name = Output(framework::GradVarName("X"));
auto *dx_var = scope.FindVar(dx_name);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why not

auto *dx_var = scope.FindVar(Output(framework::GradVarName("X")));

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

auto height = dout_tensor.dims()[0];
dx_tensor.Slice(0, static_cast<int>(height))
.CopyFrom(dout_tensor, dout_tensor.place(), dev_ctx);
if (height < dout_tensor.dims()[0]) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How could height < dout_tensor.dims()[0] ? In line 110: auto height = dout_tensor.dims()[0];

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

@@ -115,20 +85,21 @@ class WriteToArrayInferVarType : public framework::VarTypeInference {
public:
void operator()(const framework::OpDescBind &op_desc,
framework::BlockDescBind *block) const override {
VLOG(10) << "I am here?";
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can this be more meaningful?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

@reyoung reyoung merged commit 2a76b42 into PaddlePaddle:develop Nov 8, 2017
@reyoung reyoung deleted the feature/shrink_memory_op branch December 26, 2017 09:31
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

3 participants