-
Notifications
You must be signed in to change notification settings - Fork 5.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
add rnn op interfaces #2775
add rnn op interfaces #2775
Changes from 4 commits
c418dac
6042795
13d8ca9
a645ae6
8640f96
d4cde51
6e99289
63b5841
08f69f6
007ca1e
2538b2f
5eb87f0
4dcb02e
ca53f3a
671cc26
1e48cc8
e0cbcd0
f7916a6
089c448
bffd11e
c7947de
94766b6
6dca711
eabf1bf
d210b0b
6674fee
778ebb4
c60ed35
8642b27
b0938ed
3921fbb
244fe51
020c189
8e70b37
4150fa7
1584414
ce802c0
a883b4c
b98cae4
a81be58
acde9b7
638384e
82464f5
bbcc149
c92ce74
5c5d890
522445b
01f20be
08003de
a6483e8
7b1d123
bcd03bf
de319bb
0a4a502
e64b5d3
e700bf6
f525390
3a27b02
aede869
45682d2
497c7ff
fc5acee
14dd843
3c15641
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,141 @@ | ||
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. | ||
|
||
Licensed under the Apache License, Version 2.0 (the "License"); | ||
you may not use this file except in compliance with the License. | ||
You may obtain a copy of the License at | ||
|
||
http://www.apache.org/licenses/LICENSE-2.0 | ||
|
||
Unless required by applicable law or agreed to in writing, software | ||
distributed under the License is distributed on an "AS IS" BASIS, | ||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
See the License for the specific language governing permissions and | ||
limitations under the License. */ | ||
|
||
#pragma once | ||
|
||
#include "paddle/framework/enforce.h" | ||
#include "paddle/framework/scope.h" | ||
#include "paddle/framework/variable.h" | ||
|
||
namespace paddle { | ||
namespace framework { | ||
|
||
// fake interfaces that has not be implemented by other modules. | ||
struct OpRunContext { | ||
Scope* scope; | ||
}; | ||
|
||
// TODO replace this with Net's proto. | ||
struct NetDesc { | ||
std::string name; | ||
} | ||
|
||
class OperatorBase { | ||
public: | ||
virtual ~OperatorBase() {} | ||
virtual void Run(OpRunContext* context) const = 0; | ||
virtual void InferShape(const Scope* scope) const = 0; | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. what does InferShape do? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think the purpose of InferShape is to inference the size of inputs/outputs from some of them that we already know the size. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. InferShape will set the output variable dim according to the input variable dim. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. RNNOp.InferShape will just call its step net's InferShape, and will
It is offered as a public method because we want to keep checking dynamically during user adding operators. |
||
|
||
protected: | ||
std::vector<std::string> inputs_; | ||
std::vector<std::string> outputs_; | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. add attributes |
||
} | ||
|
||
class RecurrentGroupForwardOp { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. RecurrentGroupForwardOp => RecurrentOp There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. good, short enough. the backward op's name?
|
||
public: | ||
RecurrentGroupForwardOp(NetDesc& net_desc) | ||
: name_(net_desc.name), | ||
net_name_(net_desc.name + "__net__"), | ||
step_scopes_name_(net_desc.name + "__step_scopes_") {} | ||
|
||
virtual void InferShape(const Scope* scope) = 0; | ||
/* | ||
* Forward run the RNN. | ||
* | ||
* NOTE the context's scope is not given until `Run` called, so step scopes' | ||
* father should be set/updated in this method. | ||
*/ | ||
virtual void Run(OpRunContext* contex) const { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. should be in .cpp There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes, will move to .cpp later. We are working on a simple implementation to verify the whole process and will give a version soon. |
||
auto scope = contex.scope; | ||
|
||
Variable* net = scope->GetVariable(net_name_); | ||
if (net == nullptr) { | ||
BuildStepNet(scope); | ||
net = scope->GetVariable(net_name_); | ||
} | ||
PADDLE_ENFORCE(net); | ||
|
||
// expand lazily. | ||
CreateScopes(scope); | ||
ScatterLinks(scope); | ||
PrepareMemories(scope); | ||
Variable* step_scopes = scope->GetVariable(step_scopes_name_); | ||
PADDLE_ENFORCE(step_scopes); | ||
|
||
// forward | ||
for (Scope* step_scope : step_scopes->GetMutable<std::vector<Scope*>>()) { | ||
net->Run(step_scope); | ||
} | ||
|
||
// prepare outputs | ||
GatherOutLinks(scope); | ||
} | ||
|
||
protected: | ||
/* | ||
* Prepare inputs for each stepnet. | ||
*/ | ||
void ScatterInLinks(Scope* scope); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. ScatterInLinks => SegmentInputs. Let us use accurate English wording. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. done |
||
|
||
/* | ||
* Process outputs of stepnets and merge to variables. | ||
*/ | ||
void GatherOutLinks(Scope* scope); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. GatherOutLinks => ConcatenateOutputs |
||
|
||
/* | ||
* Build a `Net` which is shared across all steps. | ||
*/ | ||
void BuildStepNet(Scope* scope); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. BuildStepNet => CreateStepNet |
||
|
||
/* | ||
* Create a scope for each step, the context's scope is shared across all | ||
* the step scopes as the father scope. The step scopes will be stored in | ||
* the father scope as a variable. | ||
*/ | ||
void CreateScopes(Scope* scope); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
|
||
/* | ||
* Prepare steps' states and relations. | ||
*/ | ||
void PrepareMemories(Scope* scope); | ||
|
||
protected: | ||
/* | ||
* these are defined in BaseOperator | ||
* | ||
* std::vector<std::string> inputs_; | ||
* std::vector<std::string> outputs_; | ||
*/ | ||
|
||
// Memory of a RNN (same as the role of `Momory` in PaddlePaddle) | ||
struct MemoryAttr { | ||
// name of current state variable | ||
std::string var; | ||
// name of previous step's state variable | ||
std::string pre_var; | ||
// name of the variable to init a state, which is store in context's | ||
// scope. | ||
std::string boot_var; | ||
}; | ||
|
||
std::vector<MemoryAttr> memories_; | ||
std::string name_; | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Remove this name. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Done |
||
|
||
const std::string net_name_; | ||
const std::string step_scopes_name_; | ||
}; | ||
|
||
class RecurrentGroupBackwardOp; | ||
} // namespace framework | ||
} // namespace paddle |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I looks to me that the constructor needs a parameter
paddle::framework::proto::OperatorDesc
so could it possble to callInferShape
, which saves sizes of inputs/outputs into the desc. Only if so, we could have all necessary information for callingOperatorBase::Run
:So the information in
proto::OperatorDesc
propagates along the path:@Superjom @reyoung @jacquesqiao
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In the new design of Operator, OpDesc will store in Op, and InferShape can get the information from scope, but it seems that it need not store the shape into the desc
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@jacquesqiao You are right.
The first clue about input/output sizes is in training data instances, and we get the instance when we do training, i.e,. call operator's
Run
.Should we just remove
InferShape
and let each operator defines its own shape inference methods, i.e., one method for an output, so to shorten code in its Run method like this: