-
Notifications
You must be signed in to change notification settings - Fork 5.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
add operator base #2725
Merged
Merged
add operator base #2725
Changes from 29 commits
Commits
Show all changes
34 commits
Select commit
Hold shift + click to select a range
ff92a58
add operator base
jacquesqiao 3f75b66
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
jacquesqiao 28e2313
import attr_type.proto
jacquesqiao 22c2bec
add test
jacquesqiao c6407c3
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
jacquesqiao a591e78
do not return error when run
jacquesqiao 8d8a448
remove Error of InitializeAttrs
jacquesqiao cb7c234
interface of operator
jacquesqiao 57f1f6a
refactor of operator
jacquesqiao c248ba0
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
jacquesqiao 57d1fbd
add comment, optimize code style
jacquesqiao f7d825b
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
jacquesqiao 4f23ec3
net interface of op
jacquesqiao 88f6621
refine operator interface
jacquesqiao 42afab2
optimize operator test
jacquesqiao 0a662ff
optimize code
jacquesqiao 474debd
change test op name to test_operator
jacquesqiao 7b77d76
add optional op name
jacquesqiao 231e056
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
jacquesqiao 2d1d021
rm name, use DeviceContext
jacquesqiao 1e35e83
prepare for opkernel
jacquesqiao 4c68deb
change op in OpContext from ref to const pointer
jacquesqiao d0fae91
change op member to public
jacquesqiao 40022f7
remove New OpContext
jacquesqiao e34c579
fix style problem
jacquesqiao b710d5b
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
jacquesqiao cd2338f
use fake DeviceContext
jacquesqiao dfbf13d
use const func call
jacquesqiao c9aa77c
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
jacquesqiao 6a893c1
add operator with kernel
jacquesqiao 0fc8207
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
jacquesqiao 9324915
merge new attr
jacquesqiao f0332fb
optimize code
jacquesqiao c1d8cbb
fix cpp style
jacquesqiao File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -2,25 +2,13 @@ | |
|
||
#include "paddle/framework/attr_checker.h" | ||
|
||
//#include "paddle/framework/op_base.h" | ||
#include "paddle/framework/op_desc.pb.h" | ||
#include "paddle/framework/op_proto.pb.h" | ||
#include "paddle/framework/operator.h" | ||
|
||
namespace paddle { | ||
namespace framework { | ||
|
||
//==================For test================// | ||
class OpBase { | ||
public: | ||
std::vector<std::string> inputs_; | ||
std::vector<std::string> outputs_; | ||
AttributeMap attr_map_; | ||
|
||
virtual std::string Run() const = 0; | ||
virtual ~OpBase() {} | ||
}; | ||
//=========================================// | ||
|
||
// helper class to set attribute type | ||
struct AttrTypeHelper { | ||
template <typename T> | ||
|
@@ -134,7 +122,7 @@ class OpProtoAndCheckerMaker { | |
}; | ||
|
||
class OpRegistry { | ||
typedef std::function<OpBase*()> OpCreator; | ||
typedef std::function<OperatorBase*()> OpCreator; | ||
|
||
public: | ||
template <typename OpType, typename ProtoMakerType> | ||
|
@@ -143,28 +131,29 @@ class OpRegistry { | |
OpProto& op_proto = protos_[op_type]; | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The creators_ map's key is not op_type, but op_type+device_type+(data_type). |
||
OpAttrChecker& op_checker = op_checkers_[op_type]; | ||
ProtoMakerType(&op_proto, &op_checker); | ||
PADDLE_ENFORCE(op_proto.IsInitialized() == true, | ||
PADDLE_ENFORCE(op_proto.IsInitialized(), | ||
"Fail to initialize %s's OpProto !", op_type); | ||
} | ||
|
||
static OpBase* CreateOp(const OpDesc& op_desc) { | ||
static OperatorBase* CreateOp(const OpDesc& op_desc) { | ||
std::string op_type = op_desc.type(); | ||
OpBase* op = (creators_.at(op_type))(); | ||
(op->inputs_).resize(op_desc.inputs_size()); | ||
for (int i = 0; i < op_desc.inputs_size(); ++i) { | ||
(op->inputs_)[i] = op_desc.inputs(i); | ||
} | ||
(op->outputs_).resize(op_desc.outputs_size()); | ||
for (int i = 0; i < op_desc.outputs_size(); ++i) { | ||
(op->outputs_)[i] = op_desc.outputs(i); | ||
} | ||
OperatorBase* op = (creators_.at(op_type))(); | ||
// init attrs | ||
for (int i = 0; i < op_desc.attrs_size(); ++i) { | ||
const AttrDesc& ith_attr = op_desc.attrs(i); | ||
std::string name = ith_attr.name(); | ||
(op->attr_map_)[name] = AttrTypeHelper::GetAttrValue(ith_attr); | ||
(op->attrs_)[name] = AttrTypeHelper::GetAttrValue(ith_attr); | ||
} | ||
const OpAttrChecker& op_checker = OpRegistry::op_checkers_.at(op_type); | ||
// check attrs | ||
op_checker.Check(op->attrs_); | ||
op->desc_ = op_desc; | ||
for (auto& input : op_desc.inputs()) { | ||
op->inputs_.push_back(input); | ||
} | ||
for (auto& output : op_desc.outputs()) { | ||
op->outputs_.push_back(output); | ||
} | ||
const OpAttrChecker& op_checker = op_checkers_.at(op_type); | ||
op_checker.Check(op->attr_map_); | ||
return op; | ||
} | ||
|
||
|
@@ -174,7 +163,8 @@ class OpRegistry { | |
static std::unordered_map<std::string, OpAttrChecker> op_checkers_; | ||
}; | ||
|
||
std::unordered_map<std::string, std::function<OpBase*()>> OpRegistry::creators_; | ||
std::unordered_map<std::string, std::function<OperatorBase*()>> | ||
OpRegistry::creators_; | ||
std::unordered_map<std::string, OpProto> OpRegistry::protos_; | ||
std::unordered_map<std::string, OpAttrChecker> OpRegistry::op_checkers_; | ||
|
||
|
@@ -194,60 +184,5 @@ class OpRegisterHelper { | |
const OpRegisterHelper<__op_class, __op_maker_class> \ | ||
__op_class##Register::reg(#__op_type); | ||
|
||
// Demos | ||
|
||
class CosineOp : public OpBase { | ||
public: | ||
virtual std::string Run() const { | ||
std::string msg = "CosineOp runs! scale = " + | ||
std::to_string(boost::get<float>(attr_map_.at("scale"))); | ||
return msg; | ||
} | ||
}; | ||
|
||
class CosineOpProtoAndCheckerMaker : public OpProtoAndCheckerMaker { | ||
public: | ||
CosineOpProtoAndCheckerMaker(OpProto* proto, OpAttrChecker* op_checker) | ||
: OpProtoAndCheckerMaker(proto, op_checker) { | ||
AddInput("input", "input of cosine op"); | ||
AddOutput("output", "output of cosine op"); | ||
AddAttr<float>("scale", "scale of cosine op") | ||
.SetDefault(1.0) | ||
.LargerThan(0.0); | ||
AddType("cos"); | ||
AddComment("This is cos op"); | ||
} | ||
}; | ||
|
||
REGISTER_OP(CosineOp, CosineOpProtoAndCheckerMaker, cos_sim) | ||
|
||
class MyTestOp : public OpBase { | ||
public: | ||
virtual std::string Run() const { | ||
std::string msg = | ||
"MyTestOp runs! test_attr = " + | ||
std::to_string(boost::get<int>(attr_map_.at("test_attr"))); | ||
return msg; | ||
} | ||
}; | ||
|
||
class MyTestOpProtoAndCheckerMaker : public OpProtoAndCheckerMaker { | ||
public: | ||
MyTestOpProtoAndCheckerMaker(OpProto* proto, OpAttrChecker* op_checker) | ||
: OpProtoAndCheckerMaker(proto, op_checker) { | ||
AddInput("input", "input of cosine op"); | ||
AddOutput("output", "output of cosine op"); | ||
auto my_checker = [](int i) { | ||
PADDLE_ENFORCE(i % 2 == 0, "'test_attr' must be even!"); | ||
}; | ||
AddAttr<int>("test_attr", "a simple test attribute") | ||
.AddCustomChecker(my_checker); | ||
AddType("my_test_op"); | ||
AddComment("This is my_test op"); | ||
} | ||
}; | ||
|
||
REGISTER_OP(MyTestOp, MyTestOpProtoAndCheckerMaker, my_test_op) | ||
|
||
} // namespace framework | ||
} // namespace paddle |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,53 @@ | ||
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. | ||
|
||
Licensed under the Apache License, Version 2.0 (the "License"); | ||
you may not use this file except in compliance with the License. | ||
You may obtain a copy of the License at | ||
|
||
http://www.apache.org/licenses/LICENSE-2.0 | ||
|
||
Unless required by applicable law or agreed to in writing, software | ||
distributed under the License is distributed on an "AS IS" BASIS, | ||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
See the License for the specific language governing permissions and | ||
limitations under the License. */ | ||
|
||
#include "paddle/framework/operator.h" | ||
|
||
namespace paddle { | ||
namespace framework { | ||
|
||
void OperatorBase::InferShape(const std::shared_ptr<Scope>& scope) const {} | ||
|
||
std::string OperatorBase::DebugString() const { | ||
std::stringstream ss; | ||
ss << "=================\n"; | ||
ss << "type = " << desc_.type() << "\n"; | ||
ss << "inputs = ["; | ||
for (auto& ipt : inputs_) { | ||
ss << ipt << ", "; | ||
} | ||
ss << "]\n"; | ||
ss << "outputs = ["; | ||
for (auto& opt : outputs_) { | ||
ss << opt << ", "; | ||
} | ||
ss << "]\n"; | ||
ss << "attr_keys = ["; | ||
for (auto& attr : attrs_) { | ||
ss << attr.first << ", "; | ||
} | ||
ss << "]\n"; | ||
return ss.str(); | ||
} | ||
|
||
const Variable* OpRunContext::Input(int index) const { | ||
return scope_->GetVariable(op_->inputs_[index]); | ||
} | ||
|
||
Variable* OpRunContext::Output(int index) const { | ||
return scope_->GetVariable(op_->outputs_[index]); | ||
} | ||
|
||
} // namespace framework | ||
} // namespace paddle |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
OperatroBase is part of
framework
notoperators
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yes, there is a demo operator under operators, operator.h is still under paddle/framework