Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add operator base #2725

Merged
merged 34 commits into from
Jul 11, 2017
Merged
Show file tree
Hide file tree
Changes from 22 commits
Commits
Show all changes
34 commits
Select commit Hold shift + click to select a range
ff92a58
add operator base
jacquesqiao Jul 4, 2017
3f75b66
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
jacquesqiao Jul 4, 2017
28e2313
import attr_type.proto
jacquesqiao Jul 4, 2017
22c2bec
add test
jacquesqiao Jul 4, 2017
c6407c3
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
jacquesqiao Jul 5, 2017
a591e78
do not return error when run
jacquesqiao Jul 5, 2017
8d8a448
remove Error of InitializeAttrs
jacquesqiao Jul 5, 2017
cb7c234
interface of operator
jacquesqiao Jul 6, 2017
57f1f6a
refactor of operator
jacquesqiao Jul 6, 2017
c248ba0
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
jacquesqiao Jul 6, 2017
57d1fbd
add comment, optimize code style
jacquesqiao Jul 6, 2017
f7d825b
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
jacquesqiao Jul 7, 2017
4f23ec3
net interface of op
jacquesqiao Jul 7, 2017
88f6621
refine operator interface
jacquesqiao Jul 8, 2017
42afab2
optimize operator test
jacquesqiao Jul 9, 2017
0a662ff
optimize code
jacquesqiao Jul 9, 2017
474debd
change test op name to test_operator
jacquesqiao Jul 10, 2017
7b77d76
add optional op name
jacquesqiao Jul 10, 2017
231e056
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
jacquesqiao Jul 10, 2017
2d1d021
rm name, use DeviceContext
jacquesqiao Jul 10, 2017
1e35e83
prepare for opkernel
jacquesqiao Jul 10, 2017
4c68deb
change op in OpContext from ref to const pointer
jacquesqiao Jul 10, 2017
d0fae91
change op member to public
jacquesqiao Jul 10, 2017
40022f7
remove New OpContext
jacquesqiao Jul 10, 2017
e34c579
fix style problem
jacquesqiao Jul 10, 2017
b710d5b
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
jacquesqiao Jul 11, 2017
cd2338f
use fake DeviceContext
jacquesqiao Jul 11, 2017
dfbf13d
use const func call
jacquesqiao Jul 11, 2017
c9aa77c
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
jacquesqiao Jul 11, 2017
6a893c1
add operator with kernel
jacquesqiao Jul 11, 2017
0fc8207
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
jacquesqiao Jul 11, 2017
9324915
merge new attr
jacquesqiao Jul 11, 2017
f0332fb
optimize code
jacquesqiao Jul 11, 2017
c1d8cbb
fix cpp style
jacquesqiao Jul 11, 2017
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions paddle/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ if(Boost_FOUND)
add_subdirectory(memory)
add_subdirectory(platform)
add_subdirectory(framework)
add_subdirectory(operators)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OperatroBase is part of framework not operators

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, there is a demo operator under operators, operator.h is still under paddle/framework

endif()

if(WITH_C_API)
Expand Down
4 changes: 3 additions & 1 deletion paddle/framework/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,9 @@ proto_library(op_proto SRCS op_proto.proto DEPS attr_type)
cc_test(op_proto_test SRCS op_proto_test.cc DEPS op_proto protobuf)
proto_library(op_desc SRCS op_desc.proto DEPS attr_type)
cc_test(op_desc_test SRCS op_desc_test.cc DEPS op_desc protobuf)
cc_test(op_registry_test SRCS op_registry_test.cc DEPS op_proto op_desc)
cc_library(operator SRCS operator.cc DEPS op_desc protobuf)
cc_test(operator_test SRCS operator_test.cc DEPS operator op_proto op_desc attr_type protobuf)
cc_test(op_registry_test SRCS op_registry_test.cc DEPS op_proto op_desc operator)
py_proto_compile(framework_py_proto SRCS attr_type.proto op_proto.proto op_desc.proto)
# Generate an empty __init__.py to make framework_py_proto as a valid python module.
add_custom_target(framework_py_proto_init ALL COMMAND ${CMAKE_COMMAND} -E touch __init__.py)
Expand Down
96 changes: 12 additions & 84 deletions paddle/framework/op_registry.h
Original file line number Diff line number Diff line change
Expand Up @@ -2,25 +2,13 @@

#include "paddle/framework/attr_checker.h"

//#include "paddle/framework/op_base.h"
#include "paddle/framework/op_desc.pb.h"
#include "paddle/framework/op_proto.pb.h"
#include "paddle/framework/operator.h"

namespace paddle {
namespace framework {

//==================For test================//
class OpBase {
public:
std::vector<std::string> inputs_;
std::vector<std::string> outputs_;
AttributeMap attr_map_;

virtual std::string Run() const = 0;
virtual ~OpBase() {}
};
//=========================================//

// helper class to set attribute type
struct AttrTypeHelper {
template <typename T>
Expand Down Expand Up @@ -134,7 +122,7 @@ class OpProtoAndCheckerMaker {
};

class OpRegistry {
typedef std::function<OpBase*()> OpCreator;
typedef std::function<OperatorBase*()> OpCreator;

public:
template <typename OpType, typename ProtoMakerType>
Expand All @@ -143,28 +131,22 @@ class OpRegistry {
OpProto& op_proto = protos_[op_type];
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The creators_ map's key is not op_type, but op_type+device_type+(data_type).
Here, we may consider data_type later(only float is enough for now)

OpAttrChecker& op_checker = op_checkers_[op_type];
ProtoMakerType(&op_proto, &op_checker);
PADDLE_ENFORCE(op_proto.IsInitialized() == true,
PADDLE_ENFORCE(op_proto.IsInitialized(),
"Fail to initialize %s's OpProto !", op_type);
}

static OpBase* CreateOp(const OpDesc& op_desc) {
static OperatorBase* CreateOp(const OpDesc& op_desc) {
std::string op_type = op_desc.type();
OpBase* op = (creators_.at(op_type))();
(op->inputs_).resize(op_desc.inputs_size());
for (int i = 0; i < op_desc.inputs_size(); ++i) {
(op->inputs_)[i] = op_desc.inputs(i);
}
(op->outputs_).resize(op_desc.outputs_size());
for (int i = 0; i < op_desc.outputs_size(); ++i) {
(op->outputs_)[i] = op_desc.outputs(i);
}
OperatorBase* op = (creators_.at(op_type))();
AttributeMap attrs;
for (int i = 0; i < op_desc.attrs_size(); ++i) {
const AttrDesc& ith_attr = op_desc.attrs(i);
std::string name = ith_attr.name();
(op->attr_map_)[name] = AttrTypeHelper::GetAttrValue(ith_attr);
(attrs)[name] = AttrTypeHelper::GetAttrValue(ith_attr);
}
const OpAttrChecker& op_checker = op_checkers_.at(op_type);
op_checker.Check(op->attr_map_);
const OpAttrChecker& op_checker = OpRegistry::op_checkers_.at(op_type);
op_checker.Check(attrs);
op->Init(op_desc, attrs);
return op;
}

Expand All @@ -174,7 +156,8 @@ class OpRegistry {
static std::unordered_map<std::string, OpAttrChecker> op_checkers_;
};

std::unordered_map<std::string, std::function<OpBase*()>> OpRegistry::creators_;
std::unordered_map<std::string, std::function<OperatorBase*()>>
OpRegistry::creators_;
std::unordered_map<std::string, OpProto> OpRegistry::protos_;
std::unordered_map<std::string, OpAttrChecker> OpRegistry::op_checkers_;

Expand All @@ -194,60 +177,5 @@ class OpRegisterHelper {
const OpRegisterHelper<__op_class, __op_maker_class> \
__op_class##Register::reg(#__op_type);

// Demos

class CosineOp : public OpBase {
public:
virtual std::string Run() const {
std::string msg = "CosineOp runs! scale = " +
std::to_string(boost::get<float>(attr_map_.at("scale")));
return msg;
}
};

class CosineOpProtoAndCheckerMaker : public OpProtoAndCheckerMaker {
public:
CosineOpProtoAndCheckerMaker(OpProto* proto, OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("input", "input of cosine op");
AddOutput("output", "output of cosine op");
AddAttr<float>("scale", "scale of cosine op")
.SetDefault(1.0)
.LargerThan(0.0);
AddType("cos");
AddComment("This is cos op");
}
};

REGISTER_OP(CosineOp, CosineOpProtoAndCheckerMaker, cos_sim)

class MyTestOp : public OpBase {
public:
virtual std::string Run() const {
std::string msg =
"MyTestOp runs! test_attr = " +
std::to_string(boost::get<int>(attr_map_.at("test_attr")));
return msg;
}
};

class MyTestOpProtoAndCheckerMaker : public OpProtoAndCheckerMaker {
public:
MyTestOpProtoAndCheckerMaker(OpProto* proto, OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("input", "input of cosine op");
AddOutput("output", "output of cosine op");
auto my_checker = [](int i) {
PADDLE_ENFORCE(i % 2 == 0, "'test_attr' must be even!");
};
AddAttr<int>("test_attr", "a simple test attribute")
.AddCustomChecker(my_checker);
AddType("my_test_op");
AddComment("This is my_test op");
}
};

REGISTER_OP(MyTestOp, MyTestOpProtoAndCheckerMaker, my_test_op)

} // namespace framework
} // namespace paddle
45 changes: 19 additions & 26 deletions paddle/framework/op_registry_test.cc
Original file line number Diff line number Diff line change
@@ -1,25 +1,24 @@
#include "paddle/framework/op_registry.h"
#include <gtest/gtest.h>
#include "paddle/operators/demo_op.h"

TEST(OpRegistry, CreateOp) {
paddle::framework::OpDesc op_desc;
op_desc.set_type("cos_sim");
op_desc.add_inputs("aa");
op_desc.add_outputs("bb");

float scale = 3.3;
auto attr = op_desc.mutable_attrs()->Add();
attr->set_name("scale");
attr->set_type(paddle::framework::AttrType::FLOAT);
attr->set_f(3.3);
attr->set_f(scale);

paddle::framework::OpBase* op =
paddle::framework::OperatorBase* op =
paddle::framework::OpRegistry::CreateOp(op_desc);
std::string debug_str = op->Run();
std::string str = "CosineOp runs! scale = " + std::to_string(3.3);
ASSERT_EQ(str.size(), debug_str.size());
for (size_t i = 0; i < debug_str.length(); ++i) {
ASSERT_EQ(debug_str[i], str[i]);
}
op->Run(nullptr);
float scale_get = op->GetAttr<float>("scale");
ASSERT_EQ(scale_get, scale);
}

TEST(OpRegistry, IllegalAttr) {
Expand All @@ -35,7 +34,7 @@ TEST(OpRegistry, IllegalAttr) {

bool caught = false;
try {
paddle::framework::OpBase* op __attribute__((unused)) =
paddle::framework::OperatorBase* op __attribute__((unused)) =
paddle::framework::OpRegistry::CreateOp(op_desc);
} catch (paddle::framework::EnforceNotMet err) {
caught = true;
Expand All @@ -54,15 +53,12 @@ TEST(OpRegistry, DefaultValue) {
op_desc.add_inputs("aa");
op_desc.add_outputs("bb");

paddle::framework::OpBase* op =
ASSERT_TRUE(op_desc.IsInitialized());

paddle::framework::OperatorBase* op =
paddle::framework::OpRegistry::CreateOp(op_desc);
std::string debug_str = op->Run();
float default_value = 1.0;
std::string str = "CosineOp runs! scale = " + std::to_string(default_value);
ASSERT_EQ(str.size(), debug_str.size());
for (size_t i = 0; i < debug_str.length(); ++i) {
ASSERT_EQ(debug_str[i], str[i]);
}
op->Run(nullptr);
ASSERT_EQ(op->GetAttr<float>("scale"), 1.0);
}

TEST(OpRegistry, CustomChecker) {
Expand All @@ -74,7 +70,7 @@ TEST(OpRegistry, CustomChecker) {
// attr 'test_attr' is not set
bool caught = false;
try {
paddle::framework::OpBase* op __attribute__((unused)) =
paddle::framework::OperatorBase* op __attribute__((unused)) =
paddle::framework::OpRegistry::CreateOp(op_desc);
} catch (paddle::framework::EnforceNotMet err) {
caught = true;
Expand All @@ -93,7 +89,7 @@ TEST(OpRegistry, CustomChecker) {
attr->set_i(3);
caught = false;
try {
paddle::framework::OpBase* op __attribute__((unused)) =
paddle::framework::OperatorBase* op __attribute__((unused)) =
paddle::framework::OpRegistry::CreateOp(op_desc);
} catch (paddle::framework::EnforceNotMet err) {
caught = true;
Expand All @@ -111,12 +107,9 @@ TEST(OpRegistry, CustomChecker) {
attr->set_name("test_attr");
attr->set_type(paddle::framework::AttrType::INT);
attr->set_i(4);
paddle::framework::OpBase* op =
paddle::framework::OperatorBase* op =
paddle::framework::OpRegistry::CreateOp(op_desc);
std::string debug_str = op->Run();
std::string str = "MyTestOp runs! test_attr = " + std::to_string(4);
ASSERT_EQ(str.size(), debug_str.size());
for (size_t i = 0; i < debug_str.length(); ++i) {
ASSERT_EQ(debug_str[i], str[i]);
}
op->Run(nullptr);
int test_attr = op->GetAttr<int>("test_attr");
ASSERT_EQ(test_attr, 4);
}
66 changes: 66 additions & 0 deletions paddle/framework/operator.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#include "paddle/framework/operator.h"

namespace paddle {
namespace framework {

void OperatorBase::Init(const OpDesc& op_desc, AttributeMap& attrs) {
desc_ = op_desc;
inputs_.reserve(desc_.inputs_size());
for (auto& input : desc_.inputs()) {
inputs_.push_back(input);
}
outputs_.reserve(desc_.outputs_size());
for (auto& output : desc_.outputs()) {
outputs_.push_back(output);
}
attrs_.insert(attrs.begin(), attrs.end());
}

void OperatorBase::InferShape(Scope* scope) const {}

const std::string OperatorBase::DebugString() const {
Copy link
Contributor

@Superjomn Superjomn Jul 10, 2017

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why should this be a const?

the returned string is copied, and free to mutate?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

const removed

std::stringstream ss;
ss << "=================\n";
ss << "type = " << desc_.type() << "\n";
ss << "inputs = [";
for (auto& ipt : inputs_) {
ss << ipt << ", ";
}
ss << "]\n";
ss << "outputs = [";
for (auto& opt : outputs_) {
ss << opt << ", ";
}
ss << "]\n";
ss << "attr_keys = [";
for (auto& attr : attrs_) {
ss << attr.first << ", ";
}
ss << "]\n";
return ss.str();
}

const Variable* OpContext::Input(int index) const {
return scope->GetVariable(op->inputs()[index]);
}

Variable* OpContext::Output(int index) const {
return scope->GetVariable(op->outputs()[index]);
}

} // namespace framework
} // namespace paddle
Loading