Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add compile time infershape #4569

Merged

Conversation

jacquesqiao
Copy link
Member

@jacquesqiao jacquesqiao commented Oct 3, 2017

fix: #4183

@jacquesqiao jacquesqiao changed the title [WIP]Add compile time infershape Add compile time infershape Oct 5, 2017
@@ -34,6 +34,11 @@ VarDescBind *BlockDescBind::Var(const std::string &name) const {
return it->second.get();
}

bool BlockDescBind::HasVar(const std::string &name) const {
auto it = vars_.find(name);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

return vars_.count(name) != 0;

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

return vars.count(name); is more simple : )

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

find is better to describe the semantic of has.

Since a map can only have at most one key, count will essentially stop after one element has been found. However, in view of more general containers such as multimaps and multisets, find is strictly better if you only care whether some element with this key exists, since it can really stop once the first matching element has been found.

For map, the implement of count is just:

size_type count(const key_type& __x) const                                                       
    { return _M_t.find(__x) == _M_t.end() ? 0 : 1; 
}  

https://stackoverflow.com/questions/25490357/checking-for-existence-in-stdmap-count-vs-find

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see. find could be more efficient than count. So here maybe return vars_.find(name) != vars_end();?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, have change to this

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done


bool HasInput(const std::string& name) const override {
const std::vector<std::string>& input_names = op_.Input(name);
PADDLE_ENFORCE_EQ(input_names.size(), 1UL, "Inputs(%s) length is not 1",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the

Inputs(%s) length is not 1

lacks information, tell why should be 1

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I second @Superjom that this error message could be more informative. Why must the input length be 1? What is the "input" here?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done


bool HasOutput(const std::string& name) const override {
const std::vector<std::string>& output_names = op_.Output(name);
PADDLE_ENFORCE_EQ(output_names.size(), 1UL, "Outputs(%s) length is not 1",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same as comment above


DDim GetInputDim(const std::string& name) const override {
std::vector<DDim> ddims = GetInputsDim(name);
PADDLE_ENFORCE_EQ(ddims.size(), 1UL, "Inputs(%s) length is not 1", name);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

xxx is not 1 -> xxx should be 1 so that xxx xxx.


DDim GetOutputDim(const std::string& name) const override {
std::vector<DDim> ddims = GetOutputsDim(name);
PADDLE_ENFORCE_EQ(ddims.size(), 1UL, "Outputs(%s) length is not 1", name);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same as above

@@ -34,6 +34,11 @@ VarDescBind *BlockDescBind::Var(const std::string &name) const {
return it->second.get();
}

bool BlockDescBind::HasVar(const std::string &name) const {
auto it = vars_.find(name);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

return vars.count(name); is more simple : )

return block_.HasVar(output_names[0]);
}

bool HasInputs(const std::string& name) const override {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

HasInputs() seems to be a more general case of HasInput().Why not Remove HasInput() and only leave this one?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Most of the Time, One input will only have one value,

X = [x1]

HasInput has a check that for X, there should be and only be one value.

for HasInputs, it just requires the length is not zero because the length is not certain.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't understand what does the example X=[x1] mean?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also, what's the difference between HasInput and HasInputs? From the English language, it seems that they would have the same functionality.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In our current implementation, all inputs outputs are repeated filed in protobuf, so the real data sturcture for input/output is:

VarName = [v1, v2, v3 ...]

Most of the time, One VarName only have one corresponding variable, but in some case, the input length is not certain, for example: mul_op will have data like:

X = [x1, x2, x3, ...]
Y = [y1, y2, y3, ...]
Out = [x1+y1, x2+y2, x3+y3, ...]

But for most Operator, the just need to get the first variable with VarName, so currently we add a helper function

def input(varname):
   # check varname only have one variable
···

return true;
}

DDim GetInputDim(const std::string& name) const override {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

GetInputDim & GetInputsDim, it's sooo easy to misuse them for the almost same names. If it's really essential to retain all of them, please give GetInputsDim a more discriminating name. Such as GetMultInputDim.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would suggest that we merge all Inputs/Input methods into one. For example, HasInput and HasInputs could be one (pseudo code only):

bool HasInput(name) {
  if (FindVarDesc(name).duplicable == true) {
    CHECK_LT(0, vars_[name].size());
  } else {
    CHECK_EQ(1, vars_[name].size());
  }
}

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In our current design, the VarDesc did not have the flag duplicable, so we can not use it now~ can optimize the whole interface in next PR~

@@ -34,6 +34,11 @@ VarDescBind *BlockDescBind::Var(const std::string &name) const {
return it->second.get();
}

bool BlockDescBind::HasVar(const std::string &name) const {
auto it = vars_.find(name);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see. find could be more efficient than count. So here maybe return vars_.find(name) != vars_end();?

@@ -317,26 +318,108 @@ class ExecutionContext : public InferShapeContext {
const platform::DeviceContext& device_context_;
};

class CompileTimeInferShapeContext : public InferShapeContextBase {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How about add a TODO comment here:

// TODO(longfei): Once after both CompileTimeInferShapeContext and 
// RuntimeInferShapeContext get merged, we can rename InferShapeContextBase into 
// InferShapeContext so to replace the current InferShapeContext.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done, added to shape_inference.h


bool HasInput(const std::string& name) const override {
const std::vector<std::string>& input_names = op_.Input(name);
PADDLE_ENFORCE_EQ(input_names.size(), 1UL, "Inputs(%s) length is not 1",
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I second @Superjom that this error message could be more informative. Why must the input length be 1? What is the "input" here?

return block_.HasVar(output_names[0]);
}

bool HasInputs(const std::string& name) const override {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't understand what does the example X=[x1] mean?

return block_.HasVar(output_names[0]);
}

bool HasInputs(const std::string& name) const override {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also, what's the difference between HasInput and HasInputs? From the English language, it seems that they would have the same functionality.

@wangkuiyi
Copy link
Collaborator

An additional suggestion -- move InferShapeContext, InferShapeContextBase, CompileTimeInferShapeContext out from operator.h into infershape.h. Otherwise, I am afraid that operator.h would be too lengthy.

@jacquesqiao
Copy link
Member Author

An additional suggestion -- move InferShapeContext, InferShapeContextBase, CompileTimeInferShapeContext out from operator.h into infershape.h. Otherwise, I am afraid that operator.h would be too lengthy.

Yes, this is one thing need todo after the interface is done.

@jacquesqiao
Copy link
Member Author

jacquesqiao commented Oct 6, 2017

Have a discussion with @wangkuiyi.

In current Interface, Input and Inputs is used to distinguish input/output that has only one value or has multiple values. It's confusing, so we decided that there should only be one interface

std::vector<Type> Input(const std::string& name);

use should know what they want to use, if they know that input X have only one value, he can do like:

Type x = Input("X")[0];

to get what he want.

I think this should be done in next pr to modify the intreface define and all the related operators.

@wangkuiyi wangkuiyi merged commit f8b5d54 into PaddlePaddle:develop Oct 6, 2017
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Implement Compile time & Runtime Infershape
4 participants