diff --git a/paddle/CMakeLists.txt b/paddle/CMakeLists.txt index 58a35564f8392..2c1eb7521d896 100644 --- a/paddle/CMakeLists.txt +++ b/paddle/CMakeLists.txt @@ -15,6 +15,7 @@ if(Boost_FOUND) add_subdirectory(memory) add_subdirectory(platform) add_subdirectory(framework) + add_subdirectory(operators) add_subdirectory(pybind) endif() diff --git a/paddle/framework/CMakeLists.txt b/paddle/framework/CMakeLists.txt index 0a5edba6ef3c5..aac49fdb7a04a 100644 --- a/paddle/framework/CMakeLists.txt +++ b/paddle/framework/CMakeLists.txt @@ -11,8 +11,10 @@ proto_library(op_proto SRCS op_proto.proto DEPS attr_type) cc_test(op_proto_test SRCS op_proto_test.cc DEPS op_proto protobuf) proto_library(op_desc SRCS op_desc.proto DEPS attr_type) cc_test(op_desc_test SRCS op_desc_test.cc DEPS op_desc protobuf) +cc_library(operator SRCS operator.cc DEPS op_desc protobuf) +cc_test(operator_test SRCS operator_test.cc DEPS operator op_registry) cc_library(op_registry SRCS op_registry.cc DEPS op_proto op_desc) -cc_test(op_registry_test SRCS op_registry_test.cc DEPS op_registry) +cc_test(op_registry_test SRCS op_registry_test.cc DEPS op_registry operator) py_proto_compile(framework_py_proto SRCS attr_type.proto op_proto.proto op_desc.proto) # Generate an empty __init__.py to make framework_py_proto as a valid python module. add_custom_target(framework_py_proto_init ALL COMMAND ${CMAKE_COMMAND} -E touch __init__.py) diff --git a/paddle/framework/op_registry.cc b/paddle/framework/op_registry.cc index bc6a0dda57d4d..4b35e04e681b4 100644 --- a/paddle/framework/op_registry.cc +++ b/paddle/framework/op_registry.cc @@ -32,5 +32,5 @@ template <> void AttrTypeHelper::SetAttrType>(AttrProto* attr) { attr->set_type(paddle::framework::AttrType::STRINGS); } -} -} \ No newline at end of file +} // namespace framework +} // namespace paddle \ No newline at end of file diff --git a/paddle/framework/op_registry.h b/paddle/framework/op_registry.h index a782834693c39..02c99d50bb50c 100644 --- a/paddle/framework/op_registry.h +++ b/paddle/framework/op_registry.h @@ -1,27 +1,14 @@ #pragma once -#include "paddle/framework/attr_checker.h" - -//#include "paddle/framework/op_base.h" #include +#include "paddle/framework/attr_checker.h" #include "paddle/framework/op_desc.pb.h" #include "paddle/framework/op_proto.pb.h" +#include "paddle/framework/operator.h" namespace paddle { namespace framework { -//==================For test================// -class OpBase { - public: - std::vector inputs_; - std::vector outputs_; - AttributeMap attr_map_; - - virtual std::string Run() const = 0; - virtual ~OpBase() {} -}; -//=========================================// - // helper class to set attribute type struct AttrTypeHelper { template @@ -105,7 +92,7 @@ class OpProtoAndCheckerMaker { }; class OpRegistry { - using OpCreator = std::function; + using OpCreator = std::function; public: template @@ -118,9 +105,10 @@ class OpRegistry { "Fail to initialize %s's OpProto !", op_type); } - static OpBase* CreateOp(const OpDesc& op_desc) { + static OperatorBase* CreateOp(const OpDesc& op_desc) { std::string op_type = op_desc.type(); - OpBase* op = creators().at(op_type)(); + OperatorBase* op = creators().at(op_type)(); + op->desc_ = op_desc; op->inputs_.reserve((size_t)op_desc.inputs_size()); std::copy(op_desc.inputs().begin(), op_desc.inputs().end(), std::back_inserter(op->inputs_)); @@ -128,9 +116,9 @@ class OpRegistry { std::copy(op_desc.outputs().begin(), op_desc.outputs().end(), std::back_inserter(op->outputs_)); for (auto& attr : op_desc.attrs()) { - op->attr_map_[attr.name()] = AttrTypeHelper::GetAttrValue(attr); + op->attrs_[attr.name()] = AttrTypeHelper::GetAttrValue(attr); } - op_checkers().at(op_type).Check(op->attr_map_); + op_checkers().at(op_type).Check(op->attrs_); return op; } diff --git a/paddle/framework/op_registry_test.cc b/paddle/framework/op_registry_test.cc index a92f1feb47660..c4baafc2aebc8 100644 --- a/paddle/framework/op_registry_test.cc +++ b/paddle/framework/op_registry_test.cc @@ -1,14 +1,16 @@ #include "paddle/framework/op_registry.h" #include +#include "paddle/framework/operator.h" +#include "paddle/operators/demo_op.h" + +using namespace paddle::framework; namespace paddle { namespace framework { -class CosineOp : public OpBase { +class CosineOp : public OperatorWithKernel { public: - virtual std::string Run() const { - std::string msg = "CosineOp runs! scale = " + - std::to_string(boost::get(attr_map_.at("scale"))); - return msg; + void Run(const OpRunContext* context) const override { + printf("%s\n", DebugString().c_str()); } }; @@ -28,13 +30,11 @@ class CosineOpProtoAndCheckerMaker : public OpProtoAndCheckerMaker { REGISTER_OP(CosineOp, CosineOpProtoAndCheckerMaker, cos_sim) -class MyTestOp : public OpBase { +class MyTestOp : public OperatorWithKernel { public: - virtual std::string Run() const { - std::string msg = - "MyTestOp runs! test_attr = " + - std::to_string(boost::get(attr_map_.at("test_attr"))); - return msg; + void Run(const OpRunContext* ctx) const override { + printf("%s\n", DebugString().c_str()); + printf("test_attr = %d\n", ctx->op_->GetAttr("test_attr")); } }; @@ -64,19 +64,19 @@ TEST(OpRegistry, CreateOp) { op_desc.add_inputs("aa"); op_desc.add_outputs("bb"); + float scale = 3.3; auto attr = op_desc.mutable_attrs()->Add(); attr->set_name("scale"); attr->set_type(paddle::framework::AttrType::FLOAT); - attr->set_f(3.3); + attr->set_f(scale); - paddle::framework::OpBase* op = + paddle::framework::OperatorBase* op = paddle::framework::OpRegistry::CreateOp(op_desc); - std::string debug_str = op->Run(); - std::string str = "CosineOp runs! scale = " + std::to_string(3.3); - ASSERT_EQ(str.size(), debug_str.size()); - for (size_t i = 0; i < debug_str.length(); ++i) { - ASSERT_EQ(debug_str[i], str[i]); - } + auto scope = std::make_shared(); + auto dev_ctx = DeviceContext(); + op->Run(scope, &dev_ctx); + float scale_get = op->GetAttr("scale"); + ASSERT_EQ(scale_get, scale); } TEST(OpRegistry, IllegalAttr) { @@ -92,7 +92,7 @@ TEST(OpRegistry, IllegalAttr) { bool caught = false; try { - paddle::framework::OpBase* op __attribute__((unused)) = + paddle::framework::OperatorBase* op __attribute__((unused)) = paddle::framework::OpRegistry::CreateOp(op_desc); } catch (paddle::framework::EnforceNotMet err) { caught = true; @@ -111,15 +111,14 @@ TEST(OpRegistry, DefaultValue) { op_desc.add_inputs("aa"); op_desc.add_outputs("bb"); - paddle::framework::OpBase* op = + ASSERT_TRUE(op_desc.IsInitialized()); + + paddle::framework::OperatorBase* op = paddle::framework::OpRegistry::CreateOp(op_desc); - std::string debug_str = op->Run(); - float default_value = 1.0; - std::string str = "CosineOp runs! scale = " + std::to_string(default_value); - ASSERT_EQ(str.size(), debug_str.size()); - for (size_t i = 0; i < debug_str.length(); ++i) { - ASSERT_EQ(debug_str[i], str[i]); - } + auto scope = std::make_shared(); + auto dev_ctx = DeviceContext(); + op->Run(scope, &dev_ctx); + ASSERT_EQ(op->GetAttr("scale"), 1.0); } TEST(OpRegistry, CustomChecker) { @@ -131,7 +130,7 @@ TEST(OpRegistry, CustomChecker) { // attr 'test_attr' is not set bool caught = false; try { - paddle::framework::OpBase* op __attribute__((unused)) = + paddle::framework::OperatorBase* op __attribute__((unused)) = paddle::framework::OpRegistry::CreateOp(op_desc); } catch (paddle::framework::EnforceNotMet err) { caught = true; @@ -150,7 +149,7 @@ TEST(OpRegistry, CustomChecker) { attr->set_i(3); caught = false; try { - paddle::framework::OpBase* op __attribute__((unused)) = + paddle::framework::OperatorBase* op __attribute__((unused)) = paddle::framework::OpRegistry::CreateOp(op_desc); } catch (paddle::framework::EnforceNotMet err) { caught = true; @@ -168,14 +167,13 @@ TEST(OpRegistry, CustomChecker) { attr->set_name("test_attr"); attr->set_type(paddle::framework::AttrType::INT); attr->set_i(4); - paddle::framework::OpBase* op = + paddle::framework::OperatorBase* op = paddle::framework::OpRegistry::CreateOp(op_desc); - std::string debug_str = op->Run(); - std::string str = "MyTestOp runs! test_attr = " + std::to_string(4); - ASSERT_EQ(str.size(), debug_str.size()); - for (size_t i = 0; i < debug_str.length(); ++i) { - ASSERT_EQ(debug_str[i], str[i]); - } + auto dev_ctx = DeviceContext(); + auto scope = std::make_shared(); + op->Run(scope, &dev_ctx); + int test_attr = op->GetAttr("test_attr"); + ASSERT_EQ(test_attr, 4); } int main(int argc, char** argv) { diff --git a/paddle/framework/operator.cc b/paddle/framework/operator.cc new file mode 100644 index 0000000000000..3db3706e47dfa --- /dev/null +++ b/paddle/framework/operator.cc @@ -0,0 +1,51 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/framework/operator.h" + +namespace paddle { +namespace framework { + +std::string OperatorBase::DebugString() const { + std::stringstream ss; + ss << "=================\n"; + ss << "type = " << desc_.type() << "\n"; + ss << "inputs = ["; + for (auto& ipt : inputs_) { + ss << ipt << ", "; + } + ss << "]\n"; + ss << "outputs = ["; + for (auto& opt : outputs_) { + ss << opt << ", "; + } + ss << "]\n"; + ss << "attr_keys = ["; + for (auto& attr : attrs_) { + ss << attr.first << ", "; + } + ss << "]\n"; + return ss.str(); +} + +const Variable* OpRunContext::Input(int index) const { + return scope_->GetVariable(op_->inputs_[index]); +} + +Variable* OpRunContext::Output(int index) const { + return scope_->GetVariable(op_->outputs_[index]); +} + +} // namespace framework +} // namespace paddle \ No newline at end of file diff --git a/paddle/framework/operator.h b/paddle/framework/operator.h new file mode 100644 index 0000000000000..6570d5869814a --- /dev/null +++ b/paddle/framework/operator.h @@ -0,0 +1,107 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include +#include +#include +#include + +#include "paddle/framework/attr_checker.h" +#include "paddle/framework/op_desc.pb.h" +#include "paddle/framework/scope.h" +#include "paddle/utils/Error.h" + +namespace paddle { +namespace framework { + +class OperatorBase; + +class DeviceContext {}; + +/** + * OpRunContext is the only parameter of Operator's Run function. + * Run will get input/output variables, state such as momentum and + * device resource such as CUDA stream, cublas handle, etc. from + * OpRunContext. User should construct it before run the Operator. + */ +class OpRunContext { + public: + OpRunContext(const OperatorBase* op, const std::shared_ptr scope, + const DeviceContext* device_context) + : op_(op), scope_(scope), device_context_(device_context) {} + + const Variable* Input(int index) const; + Variable* Output(int index) const; + + public: + const OperatorBase* op_; + const std::shared_ptr scope_; + const DeviceContext* device_context_; +}; + +/** + * OperatorBase has the basic element that Net will call to do computation. + * Only CreateOperator from OpRegistry will new Operator directly. User + * should always construct a proto message OpDesc and call + * OpRegistry::CreateOp(op_desc) to get an Operator instance. + */ +class OperatorBase { + public: + virtual ~OperatorBase() {} + + template + inline const T& GetAttr(const std::string& name) const { + PADDLE_ENFORCE(attrs_.count(name) != 0, "%s should be in AttributeMap", + name); + return boost::get(attrs_.at(name)); + } + + std::string DebugString() const; + + /// InferShape infer the size of Variables used by this Operator with + /// information inside scope + virtual void InferShape(const std::shared_ptr& scope) const = 0; + + /// Net will call this function to Run an op. + virtual void Run(const std::shared_ptr& scope, + const DeviceContext* dev_ctx) const = 0; + + public: + OpDesc desc_; + std::vector inputs_; + std::vector outputs_; + AttributeMap attrs_; +}; + +class OperatorWithKernel : public OperatorBase { + public: + virtual ~OperatorWithKernel() {} + + virtual void InferShape(const std::shared_ptr& scope) const {} + + void Run(const std::shared_ptr& scope, + const DeviceContext* dev_ctx) const { + OpRunContext op_ctx(this, scope, dev_ctx); + Run(&op_ctx); + } + + /// when implement an Op, your should implement this function. + /// this function should be moved to OpKernel later + virtual void Run(const OpRunContext* context) const = 0; +}; + +} // namespace framework +} // namespace paddle diff --git a/paddle/framework/operator_test.cc b/paddle/framework/operator_test.cc new file mode 100644 index 0000000000000..48808dabb2711 --- /dev/null +++ b/paddle/framework/operator_test.cc @@ -0,0 +1,80 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/framework/operator.h" +#include "gtest/gtest.h" +#include "paddle/framework/op_registry.h" + +namespace paddle { +namespace framework { + +class OperatorTest : public OperatorWithKernel { + public: + void Run(const OpRunContext* ctx) const override { + float scale = ctx->op_->GetAttr("scale"); + PADDLE_ENFORCE(ctx->Input(0) == nullptr, "Input(0) should not initialized"); + PADDLE_ENFORCE(ctx->Output(0) == nullptr, + "Output(1) should not initialized"); + auto output1 = ctx->scope_->CreateVariable("output1"); + PADDLE_ENFORCE(output1 != nullptr, "should create output1 from scope"); + printf("get attr %s = %f\n", "scale", scale); + printf("%s\n", DebugString().c_str()); + } +}; + +class OperatorTestProtoAndCheckerMaker : public OpProtoAndCheckerMaker { + public: + OperatorTestProtoAndCheckerMaker(OpProto* proto, OpAttrChecker* op_checker) + : OpProtoAndCheckerMaker(proto, op_checker) { + AddInput("input", "input of test op"); + AddOutput("output", "output of test op"); + AddAttr("scale", "scale of cosine op") + .SetDefault(1.0) + .LargerThan(0.0); + AddType("test_operator"); + AddComment("This is test op"); + } +}; + +REGISTER_OP(OperatorTest, OperatorTestProtoAndCheckerMaker, test_operator) + +TEST(OperatorBase, DebugString) { + OpDesc op_desc; + op_desc.set_type("test_operator"); + std::vector inputs = {"IN1", "IN2"}; + for (auto& input : inputs) { + op_desc.add_inputs(input); + } + std::vector outputs = {"OUT1", "OUT2"}; + for (auto& output : outputs) { + op_desc.add_outputs(output); + } + auto attr = op_desc.mutable_attrs()->Add(); + attr->set_name("scale"); + attr->set_type(paddle::framework::AttrType::FLOAT); + float scale = 3.14; + attr->set_f(scale); + + DeviceContext device_context; + auto scope = std::make_shared(); + + OperatorBase* op = paddle::framework::OpRegistry::CreateOp(op_desc); + ASSERT_EQ(op->inputs_, inputs); + ASSERT_EQ(op->outputs_, outputs); + ASSERT_EQ(op->GetAttr("scale"), scale); + op->Run(scope, &device_context); +} + +} // namespace framework +} // namespace paddle \ No newline at end of file diff --git a/paddle/operators/.clang-format b/paddle/operators/.clang-format new file mode 100644 index 0000000000000..29282dc87e2c4 --- /dev/null +++ b/paddle/operators/.clang-format @@ -0,0 +1,5 @@ +--- +Language: Cpp +BasedOnStyle: Google +Standard: Cpp11 +... diff --git a/paddle/operators/CMakeLists.txt b/paddle/operators/CMakeLists.txt new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/paddle/operators/demo_op.h b/paddle/operators/demo_op.h new file mode 100644 index 0000000000000..d0b7420b4e25d --- /dev/null +++ b/paddle/operators/demo_op.h @@ -0,0 +1,59 @@ +#pragma once + +#include "paddle/framework/op_registry.h" + +using namespace paddle::framework; + +namespace paddle { +namespace operators { + +class CosineOp : public OperatorWithKernel { + public: + void Run(const OpRunContext *context) const override { + printf("%s\n", DebugString().c_str()); + } +}; + +class CosineOpProtoAndCheckerMaker : public OpProtoAndCheckerMaker { + public: + CosineOpProtoAndCheckerMaker(OpProto *proto, OpAttrChecker *op_checker) + : OpProtoAndCheckerMaker(proto, op_checker) { + AddInput("input", "input of cosine op"); + AddOutput("output", "output of cosine op"); + AddAttr("scale", "scale of cosine op") + .SetDefault(1.0) + .LargerThan(0.0); + AddType("cos"); + AddComment("This is cos op"); + } +}; + +REGISTER_OP(CosineOp, CosineOpProtoAndCheckerMaker, cos_sim) + +class MyTestOp : public OperatorWithKernel { + public: + void Run(const OpRunContext *context) const override { + printf("%s\n", DebugString().c_str()); + } +}; + +class MyTestOpProtoAndCheckerMaker : public OpProtoAndCheckerMaker { + public: + MyTestOpProtoAndCheckerMaker(OpProto *proto, OpAttrChecker *op_checker) + : OpProtoAndCheckerMaker(proto, op_checker) { + AddInput("input", "input of cosine op"); + AddOutput("output", "output of cosine op"); + auto my_checker = [](int i) { + PADDLE_ENFORCE(i % 2 == 0, "'test_attr' must be even!"); + }; + AddAttr("test_attr", "a simple test attribute") + .AddCustomChecker(my_checker); + AddType("my_test_op"); + AddComment("This is my_test op"); + } +}; + +REGISTER_OP(MyTestOp, MyTestOpProtoAndCheckerMaker, my_test_op) + +} // namespace operators +} // namespace operators