-
Notifications
You must be signed in to change notification settings - Fork 5.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add compile time infershape #4569
Add compile time infershape #4569
Conversation
… add_compile_time_infershape
paddle/framework/block_desc.cc
Outdated
@@ -34,6 +34,11 @@ VarDescBind *BlockDescBind::Var(const std::string &name) const { | |||
return it->second.get(); | |||
} | |||
|
|||
bool BlockDescBind::HasVar(const std::string &name) const { | |||
auto it = vars_.find(name); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
return vars_.count(name) != 0;
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
return vars.count(name);
is more simple : )
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
find
is better to describe the semantic of has
.
Since a map can only have at most one key, count will essentially stop after one element has been found. However, in view of more general containers such as multimaps and multisets, find is strictly better if you only care whether some element with this key exists, since it can really stop once the first matching element has been found.
For map
, the implement of count
is just:
size_type count(const key_type& __x) const
{ return _M_t.find(__x) == _M_t.end() ? 0 : 1;
}
https://stackoverflow.com/questions/25490357/checking-for-existence-in-stdmap-count-vs-find
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I see. find
could be more efficient than count
. So here maybe return vars_.find(name) != vars_end();
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yes, have change to this
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
paddle/framework/operator.h
Outdated
|
||
bool HasInput(const std::string& name) const override { | ||
const std::vector<std::string>& input_names = op_.Input(name); | ||
PADDLE_ENFORCE_EQ(input_names.size(), 1UL, "Inputs(%s) length is not 1", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
the
Inputs(%s) length is not 1
lacks information, tell why should be 1
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I second @Superjom that this error message could be more informative. Why must the input length be 1? What is the "input" here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
paddle/framework/operator.h
Outdated
|
||
bool HasOutput(const std::string& name) const override { | ||
const std::vector<std::string>& output_names = op_.Output(name); | ||
PADDLE_ENFORCE_EQ(output_names.size(), 1UL, "Outputs(%s) length is not 1", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
same as comment above
paddle/framework/operator.h
Outdated
|
||
DDim GetInputDim(const std::string& name) const override { | ||
std::vector<DDim> ddims = GetInputsDim(name); | ||
PADDLE_ENFORCE_EQ(ddims.size(), 1UL, "Inputs(%s) length is not 1", name); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
xxx is not 1 -> xxx should be 1 so that xxx xxx.
paddle/framework/operator.h
Outdated
|
||
DDim GetOutputDim(const std::string& name) const override { | ||
std::vector<DDim> ddims = GetOutputsDim(name); | ||
PADDLE_ENFORCE_EQ(ddims.size(), 1UL, "Outputs(%s) length is not 1", name); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
same as above
paddle/framework/block_desc.cc
Outdated
@@ -34,6 +34,11 @@ VarDescBind *BlockDescBind::Var(const std::string &name) const { | |||
return it->second.get(); | |||
} | |||
|
|||
bool BlockDescBind::HasVar(const std::string &name) const { | |||
auto it = vars_.find(name); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
return vars.count(name);
is more simple : )
paddle/framework/operator.h
Outdated
return block_.HasVar(output_names[0]); | ||
} | ||
|
||
bool HasInputs(const std::string& name) const override { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
HasInputs()
seems to be a more general case of HasInput()
.Why not Remove HasInput()
and only leave this one?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Most of the Time, One input will only have one value,
X = [x1]
HasInput
has a check that for X, there should be and only be one value.
for HasInputs
, it just requires the length is not zero because the length is not certain.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't understand what does the example X=[x1]
mean?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Also, what's the difference between HasInput
and HasInputs
? From the English language, it seems that they would have the same functionality.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In our current implementation, all inputs outputs are repeated filed in protobuf, so the real data sturcture for input/output is:
VarName = [v1, v2, v3 ...]
Most of the time, One VarName only have one corresponding variable, but in some case, the input length is not certain, for example: mul_op
will have data like:
X = [x1, x2, x3, ...]
Y = [y1, y2, y3, ...]
Out = [x1+y1, x2+y2, x3+y3, ...]
But for most Operator, the just need to get the first variable with VarName, so currently we add a helper function
def input(varname):
# check varname only have one variable
···
return true; | ||
} | ||
|
||
DDim GetInputDim(const std::string& name) const override { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
GetInputDim
& GetInputsDim
, it's sooo easy to misuse them for the almost same names. If it's really essential to retain all of them, please give GetInputsDim
a more discriminating name. Such as GetMultInputDim
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I would suggest that we merge all Inputs/Input methods into one. For example, HasInput and HasInputs could be one (pseudo code only):
bool HasInput(name) {
if (FindVarDesc(name).duplicable == true) {
CHECK_LT(0, vars_[name].size());
} else {
CHECK_EQ(1, vars_[name].size());
}
}
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In our current design, the VarDesc did not have the flag duplicable, so we can not use it now~ can optimize the whole interface in next PR~
… add_compile_time_infershape
paddle/framework/block_desc.cc
Outdated
@@ -34,6 +34,11 @@ VarDescBind *BlockDescBind::Var(const std::string &name) const { | |||
return it->second.get(); | |||
} | |||
|
|||
bool BlockDescBind::HasVar(const std::string &name) const { | |||
auto it = vars_.find(name); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I see. find
could be more efficient than count
. So here maybe return vars_.find(name) != vars_end();
?
@@ -317,26 +318,108 @@ class ExecutionContext : public InferShapeContext { | |||
const platform::DeviceContext& device_context_; | |||
}; | |||
|
|||
class CompileTimeInferShapeContext : public InferShapeContextBase { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
How about add a TODO comment here:
// TODO(longfei): Once after both CompileTimeInferShapeContext and
// RuntimeInferShapeContext get merged, we can rename InferShapeContextBase into
// InferShapeContext so to replace the current InferShapeContext.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ok
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done, added to shape_inference.h
paddle/framework/operator.h
Outdated
|
||
bool HasInput(const std::string& name) const override { | ||
const std::vector<std::string>& input_names = op_.Input(name); | ||
PADDLE_ENFORCE_EQ(input_names.size(), 1UL, "Inputs(%s) length is not 1", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I second @Superjom that this error message could be more informative. Why must the input length be 1? What is the "input" here?
paddle/framework/operator.h
Outdated
return block_.HasVar(output_names[0]); | ||
} | ||
|
||
bool HasInputs(const std::string& name) const override { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't understand what does the example X=[x1]
mean?
paddle/framework/operator.h
Outdated
return block_.HasVar(output_names[0]); | ||
} | ||
|
||
bool HasInputs(const std::string& name) const override { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Also, what's the difference between HasInput
and HasInputs
? From the English language, it seems that they would have the same functionality.
An additional suggestion -- move InferShapeContext, InferShapeContextBase, CompileTimeInferShapeContext out from operator.h into infershape.h. Otherwise, I am afraid that operator.h would be too lengthy. |
Yes, this is one thing need todo after the interface is done. |
Have a discussion with @wangkuiyi. In current Interface, std::vector<Type> Input(const std::string& name); use should know what they want to use, if they know that input Type x = Input("X")[0]; to get what he want. I think this should be done in next pr to modify the intreface define and all the related operators. |
fix: #4183