-
Notifications
You must be signed in to change notification settings - Fork 5.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
add operator base #2725
add operator base #2725
Changes from 22 commits
ff92a58
3f75b66
28e2313
22c2bec
c6407c3
a591e78
8d8a448
cb7c234
57f1f6a
c248ba0
57d1fbd
f7d825b
4f23ec3
88f6621
42afab2
0a662ff
474debd
7b77d76
231e056
2d1d021
1e35e83
4c68deb
d0fae91
40022f7
e34c579
b710d5b
cd2338f
dfbf13d
c9aa77c
6a893c1
0fc8207
9324915
f0332fb
c1d8cbb
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -2,25 +2,13 @@ | |
|
||
#include "paddle/framework/attr_checker.h" | ||
|
||
//#include "paddle/framework/op_base.h" | ||
#include "paddle/framework/op_desc.pb.h" | ||
#include "paddle/framework/op_proto.pb.h" | ||
#include "paddle/framework/operator.h" | ||
|
||
namespace paddle { | ||
namespace framework { | ||
|
||
//==================For test================// | ||
class OpBase { | ||
public: | ||
std::vector<std::string> inputs_; | ||
std::vector<std::string> outputs_; | ||
AttributeMap attr_map_; | ||
|
||
virtual std::string Run() const = 0; | ||
virtual ~OpBase() {} | ||
}; | ||
//=========================================// | ||
|
||
// helper class to set attribute type | ||
struct AttrTypeHelper { | ||
template <typename T> | ||
|
@@ -134,7 +122,7 @@ class OpProtoAndCheckerMaker { | |
}; | ||
|
||
class OpRegistry { | ||
typedef std::function<OpBase*()> OpCreator; | ||
typedef std::function<OperatorBase*()> OpCreator; | ||
|
||
public: | ||
template <typename OpType, typename ProtoMakerType> | ||
|
@@ -143,28 +131,22 @@ class OpRegistry { | |
OpProto& op_proto = protos_[op_type]; | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The creators_ map's key is not op_type, but op_type+device_type+(data_type). |
||
OpAttrChecker& op_checker = op_checkers_[op_type]; | ||
ProtoMakerType(&op_proto, &op_checker); | ||
PADDLE_ENFORCE(op_proto.IsInitialized() == true, | ||
PADDLE_ENFORCE(op_proto.IsInitialized(), | ||
"Fail to initialize %s's OpProto !", op_type); | ||
} | ||
|
||
static OpBase* CreateOp(const OpDesc& op_desc) { | ||
static OperatorBase* CreateOp(const OpDesc& op_desc) { | ||
std::string op_type = op_desc.type(); | ||
OpBase* op = (creators_.at(op_type))(); | ||
(op->inputs_).resize(op_desc.inputs_size()); | ||
for (int i = 0; i < op_desc.inputs_size(); ++i) { | ||
(op->inputs_)[i] = op_desc.inputs(i); | ||
} | ||
(op->outputs_).resize(op_desc.outputs_size()); | ||
for (int i = 0; i < op_desc.outputs_size(); ++i) { | ||
(op->outputs_)[i] = op_desc.outputs(i); | ||
} | ||
OperatorBase* op = (creators_.at(op_type))(); | ||
AttributeMap attrs; | ||
for (int i = 0; i < op_desc.attrs_size(); ++i) { | ||
const AttrDesc& ith_attr = op_desc.attrs(i); | ||
std::string name = ith_attr.name(); | ||
(op->attr_map_)[name] = AttrTypeHelper::GetAttrValue(ith_attr); | ||
(attrs)[name] = AttrTypeHelper::GetAttrValue(ith_attr); | ||
} | ||
const OpAttrChecker& op_checker = op_checkers_.at(op_type); | ||
op_checker.Check(op->attr_map_); | ||
const OpAttrChecker& op_checker = OpRegistry::op_checkers_.at(op_type); | ||
op_checker.Check(attrs); | ||
op->Init(op_desc, attrs); | ||
return op; | ||
} | ||
|
||
|
@@ -174,7 +156,8 @@ class OpRegistry { | |
static std::unordered_map<std::string, OpAttrChecker> op_checkers_; | ||
}; | ||
|
||
std::unordered_map<std::string, std::function<OpBase*()>> OpRegistry::creators_; | ||
std::unordered_map<std::string, std::function<OperatorBase*()>> | ||
OpRegistry::creators_; | ||
std::unordered_map<std::string, OpProto> OpRegistry::protos_; | ||
std::unordered_map<std::string, OpAttrChecker> OpRegistry::op_checkers_; | ||
|
||
|
@@ -194,60 +177,5 @@ class OpRegisterHelper { | |
const OpRegisterHelper<__op_class, __op_maker_class> \ | ||
__op_class##Register::reg(#__op_type); | ||
|
||
// Demos | ||
|
||
class CosineOp : public OpBase { | ||
public: | ||
virtual std::string Run() const { | ||
std::string msg = "CosineOp runs! scale = " + | ||
std::to_string(boost::get<float>(attr_map_.at("scale"))); | ||
return msg; | ||
} | ||
}; | ||
|
||
class CosineOpProtoAndCheckerMaker : public OpProtoAndCheckerMaker { | ||
public: | ||
CosineOpProtoAndCheckerMaker(OpProto* proto, OpAttrChecker* op_checker) | ||
: OpProtoAndCheckerMaker(proto, op_checker) { | ||
AddInput("input", "input of cosine op"); | ||
AddOutput("output", "output of cosine op"); | ||
AddAttr<float>("scale", "scale of cosine op") | ||
.SetDefault(1.0) | ||
.LargerThan(0.0); | ||
AddType("cos"); | ||
AddComment("This is cos op"); | ||
} | ||
}; | ||
|
||
REGISTER_OP(CosineOp, CosineOpProtoAndCheckerMaker, cos_sim) | ||
|
||
class MyTestOp : public OpBase { | ||
public: | ||
virtual std::string Run() const { | ||
std::string msg = | ||
"MyTestOp runs! test_attr = " + | ||
std::to_string(boost::get<int>(attr_map_.at("test_attr"))); | ||
return msg; | ||
} | ||
}; | ||
|
||
class MyTestOpProtoAndCheckerMaker : public OpProtoAndCheckerMaker { | ||
public: | ||
MyTestOpProtoAndCheckerMaker(OpProto* proto, OpAttrChecker* op_checker) | ||
: OpProtoAndCheckerMaker(proto, op_checker) { | ||
AddInput("input", "input of cosine op"); | ||
AddOutput("output", "output of cosine op"); | ||
auto my_checker = [](int i) { | ||
PADDLE_ENFORCE(i % 2 == 0, "'test_attr' must be even!"); | ||
}; | ||
AddAttr<int>("test_attr", "a simple test attribute") | ||
.AddCustomChecker(my_checker); | ||
AddType("my_test_op"); | ||
AddComment("This is my_test op"); | ||
} | ||
}; | ||
|
||
REGISTER_OP(MyTestOp, MyTestOpProtoAndCheckerMaker, my_test_op) | ||
|
||
} // namespace framework | ||
} // namespace paddle |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,66 @@ | ||
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. | ||
|
||
Licensed under the Apache License, Version 2.0 (the "License"); | ||
you may not use this file except in compliance with the License. | ||
You may obtain a copy of the License at | ||
|
||
http://www.apache.org/licenses/LICENSE-2.0 | ||
|
||
Unless required by applicable law or agreed to in writing, software | ||
distributed under the License is distributed on an "AS IS" BASIS, | ||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
See the License for the specific language governing permissions and | ||
limitations under the License. */ | ||
|
||
#include "paddle/framework/operator.h" | ||
|
||
namespace paddle { | ||
namespace framework { | ||
|
||
void OperatorBase::Init(const OpDesc& op_desc, AttributeMap& attrs) { | ||
desc_ = op_desc; | ||
inputs_.reserve(desc_.inputs_size()); | ||
for (auto& input : desc_.inputs()) { | ||
inputs_.push_back(input); | ||
} | ||
outputs_.reserve(desc_.outputs_size()); | ||
for (auto& output : desc_.outputs()) { | ||
outputs_.push_back(output); | ||
} | ||
attrs_.insert(attrs.begin(), attrs.end()); | ||
} | ||
|
||
void OperatorBase::InferShape(Scope* scope) const {} | ||
|
||
const std::string OperatorBase::DebugString() const { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. why should this be a the returned string is copied, and free to mutate? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. const removed |
||
std::stringstream ss; | ||
ss << "=================\n"; | ||
ss << "type = " << desc_.type() << "\n"; | ||
ss << "inputs = ["; | ||
for (auto& ipt : inputs_) { | ||
ss << ipt << ", "; | ||
} | ||
ss << "]\n"; | ||
ss << "outputs = ["; | ||
for (auto& opt : outputs_) { | ||
ss << opt << ", "; | ||
} | ||
ss << "]\n"; | ||
ss << "attr_keys = ["; | ||
for (auto& attr : attrs_) { | ||
ss << attr.first << ", "; | ||
} | ||
ss << "]\n"; | ||
return ss.str(); | ||
} | ||
|
||
const Variable* OpContext::Input(int index) const { | ||
return scope->GetVariable(op->inputs()[index]); | ||
} | ||
|
||
Variable* OpContext::Output(int index) const { | ||
return scope->GetVariable(op->outputs()[index]); | ||
} | ||
|
||
} // namespace framework | ||
} // namespace paddle |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
OperatroBase is part of
framework
notoperators
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yes, there is a demo operator under operators, operator.h is still under paddle/framework