From ff92a58beac3023d6d9c7d99fdf6d595fea24b64 Mon Sep 17 00:00:00 2001 From: qiaolongfei Date: Tue, 4 Jul 2017 11:41:14 +0800 Subject: [PATCH 01/26] add operator base --- paddle/framework/CMakeLists.txt | 4 ++ paddle/framework/op_desc.proto | 29 ++++++++++++++ paddle/framework/operator.cc | 38 ++++++++++++++++++ paddle/framework/operator.h | 65 +++++++++++++++++++++++++++++++ paddle/framework/operator_test.cc | 19 +++++++++ 5 files changed, 155 insertions(+) create mode 100644 paddle/framework/op_desc.proto create mode 100644 paddle/framework/operator.cc create mode 100644 paddle/framework/operator.h create mode 100644 paddle/framework/operator_test.cc diff --git a/paddle/framework/CMakeLists.txt b/paddle/framework/CMakeLists.txt index 6aa6b9bc2db6a..dc2e958849a84 100644 --- a/paddle/framework/CMakeLists.txt +++ b/paddle/framework/CMakeLists.txt @@ -5,3 +5,7 @@ nv_test(dim_test SRCS dim_test.cu DEPS ddim) cc_test(variable_test SRCS variable_test.cc) cc_test(scope_test SRCS scope_test.cc) cc_test(enforce_test SRCS enforce_test.cc) + +proto_library(op_desc SRCS op_desc.proto) +cc_library(operator SRCS operator.cc DEPS op_desc protobuf) +cc_test(operator_test SRCS operator_test.cc DEPS operator) diff --git a/paddle/framework/op_desc.proto b/paddle/framework/op_desc.proto new file mode 100644 index 0000000000000..1ef0e531f20e6 --- /dev/null +++ b/paddle/framework/op_desc.proto @@ -0,0 +1,29 @@ +syntax="proto2"; +package paddle.framework; + +enum AttrType { + INT = 1; + FLOAT = 2; + STRING = 3; + INTS = 4; + FLOATS = 5; + STRINGS = 6; +} + +message AttrDesc { + required AttrType type = 1; + optional int32 i = 2; + optional float f = 3; + optional string s = 4; + repeated int32 ints = 5; + repeated float floats = 6; + repeated string strings = 7; + required string name = 8; +}; + +message OpDesc { + repeated string inputs = 1; + repeated string outputs = 2; + required string type = 3; + repeated AttrDesc attrs = 4; +}; \ No newline at end of file diff --git a/paddle/framework/operator.cc b/paddle/framework/operator.cc new file mode 100644 index 0000000000000..fc9c14f55b777 --- /dev/null +++ b/paddle/framework/operator.cc @@ -0,0 +1,38 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/framework/operator.h" + +namespace paddle { +namespace framework { + +Operator::Operator(const OpDesc &desc): op_desc_(desc) { + inputs_.reserve(desc.inputs_size()); + for(const std::string& input: op_desc_.inputs()) { + inputs_.push_back(input); + } + outputs_.reserve(op_desc_.outputs_size()); + for(const std::string& output: op_desc_.outputs()) { + outputs_.push_back(output); + } + std::vector attrs; + attrs.reserve(desc.attrs_size()); + for(const AttrDesc& attr: op_desc_.attrs()) { + attrs.push_back(attr); + } + InitializeAttrs(attrs); +} + +} // namespace framework +} // namespace paddle \ No newline at end of file diff --git a/paddle/framework/operator.h b/paddle/framework/operator.h new file mode 100644 index 0000000000000..a7b6a3e9f3506 --- /dev/null +++ b/paddle/framework/operator.h @@ -0,0 +1,65 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include +#include +#include +#include "paddle/utils/Error.h" + +#include + +namespace paddle { +namespace framework { + +class OpContext {}; + +/** + * @brief Operator is used to do some computation. + * + * We use a OpDesc proto Message to describe and create a operator. + * Operator will get the Variables and computing resource from OpContext when Run. + */ +class Operator { + public: + explicit Operator(const OpDesc& desc); + virtual ~Operator() {} + + Error InitializeAttrs(const std::vector& attrs); + + /** + * InferShape is used to infer the shape of tensors related to this Operator. + */ + virtual Error InferShape() = 0; + + /** + * Run take a OpContext as parameter. + * + * 1. it will get input/output variable from OpContext.scope + * 2. It will get computing resource such as cpu/gpu from OpContext. + */ + virtual Error Run(OpContext *context) const = 0; + const std::string DebugString() const { + return op_desc_.ShortDebugString(); + } + +protected: + OpDesc op_desc_; + std::vector inputs_; + std::vector outputs_; +}; + +} // namespace framework +} // namespace paddle \ No newline at end of file diff --git a/paddle/framework/operator_test.cc b/paddle/framework/operator_test.cc new file mode 100644 index 0000000000000..390cf6b31baf8 --- /dev/null +++ b/paddle/framework/operator_test.cc @@ -0,0 +1,19 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/framework/operator.h" +#include "gtest/gtest.h" + +TEST(Operator, Create) { +} From 28e231359867870d1911974612e45ce09b326970 Mon Sep 17 00:00:00 2001 From: qiaolongfei Date: Tue, 4 Jul 2017 11:45:39 +0800 Subject: [PATCH 02/26] import attr_type.proto --- paddle/framework/op_desc.proto | 9 +-------- 1 file changed, 1 insertion(+), 8 deletions(-) diff --git a/paddle/framework/op_desc.proto b/paddle/framework/op_desc.proto index 1ef0e531f20e6..76b2e130c1129 100644 --- a/paddle/framework/op_desc.proto +++ b/paddle/framework/op_desc.proto @@ -1,14 +1,7 @@ syntax="proto2"; package paddle.framework; -enum AttrType { - INT = 1; - FLOAT = 2; - STRING = 3; - INTS = 4; - FLOATS = 5; - STRINGS = 6; -} +import "attr_type.proto"; message AttrDesc { required AttrType type = 1; From 22c2becd0b0dc8ffe4af2ff194d61982e751c947 Mon Sep 17 00:00:00 2001 From: qiaolongfei Date: Tue, 4 Jul 2017 15:59:38 +0800 Subject: [PATCH 03/26] add test --- paddle/framework/op_desc_test.cc | 31 +++++++++++++++++++++++++++++++ 1 file changed, 31 insertions(+) create mode 100644 paddle/framework/op_desc_test.cc diff --git a/paddle/framework/op_desc_test.cc b/paddle/framework/op_desc_test.cc new file mode 100644 index 0000000000000..6a9530f80451a --- /dev/null +++ b/paddle/framework/op_desc_test.cc @@ -0,0 +1,31 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/framework/operator.h" +#include "gtest/gtest.h" + +TEST(Operator, Create) { + using paddle::framework::Operator; + paddle::framework::OpDesc op_desc; + op_desc.set_type("ADD"); + op_desc.add_inputs("X"); + op_desc.add_inputs("Y"); + op_desc.add_outputs("Z"); + + auto attr = op_desc.mutable_attrs()->Add(); + attr->set_type(paddle::framework::AttrType::FLOAT); + attr->set_f(3.14); + + auto op = new Operator(op_desc); +} \ No newline at end of file From a591e78d8b981bf661248679a55e4811dda6f557 Mon Sep 17 00:00:00 2001 From: qiaolongfei Date: Wed, 5 Jul 2017 16:09:24 +0800 Subject: [PATCH 04/26] do not return error when run --- paddle/framework/operator.h | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/paddle/framework/operator.h b/paddle/framework/operator.h index a7b6a3e9f3506..c1be94a833cf8 100644 --- a/paddle/framework/operator.h +++ b/paddle/framework/operator.h @@ -42,7 +42,7 @@ class Operator { /** * InferShape is used to infer the shape of tensors related to this Operator. */ - virtual Error InferShape() = 0; + virtual void InferShape() = 0; /** * Run take a OpContext as parameter. @@ -50,7 +50,7 @@ class Operator { * 1. it will get input/output variable from OpContext.scope * 2. It will get computing resource such as cpu/gpu from OpContext. */ - virtual Error Run(OpContext *context) const = 0; + virtual void Run(OpContext *context) const = 0; const std::string DebugString() const { return op_desc_.ShortDebugString(); } From 8d8a448e0f61f0237b8ffd6d6171b56b986f06aa Mon Sep 17 00:00:00 2001 From: qiaolongfei Date: Wed, 5 Jul 2017 17:17:53 +0800 Subject: [PATCH 05/26] remove Error of InitializeAttrs --- paddle/framework/operator.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/paddle/framework/operator.h b/paddle/framework/operator.h index c1be94a833cf8..0ea0ee2753a66 100644 --- a/paddle/framework/operator.h +++ b/paddle/framework/operator.h @@ -37,7 +37,7 @@ class Operator { explicit Operator(const OpDesc& desc); virtual ~Operator() {} - Error InitializeAttrs(const std::vector& attrs); + void InitializeAttrs(); /** * InferShape is used to infer the shape of tensors related to this Operator. From cb7c23442082704b36281878222cc352ee385ea5 Mon Sep 17 00:00:00 2001 From: qiaolongfei Date: Thu, 6 Jul 2017 15:48:00 +0800 Subject: [PATCH 06/26] interface of operator --- paddle/framework/CMakeLists.txt | 1 + paddle/framework/operator.cc | 25 +++++++++------- paddle/framework/operator.h | 48 +++++++++++++++---------------- paddle/framework/operator_test.cc | 19 ------------ 4 files changed, 39 insertions(+), 54 deletions(-) delete mode 100644 paddle/framework/operator_test.cc diff --git a/paddle/framework/CMakeLists.txt b/paddle/framework/CMakeLists.txt index dcd70d285174a..596dbf4601227 100644 --- a/paddle/framework/CMakeLists.txt +++ b/paddle/framework/CMakeLists.txt @@ -12,3 +12,4 @@ cc_test(op_proto_test SRCS op_proto_test.cc DEPS op_proto protobuf) proto_library(op_desc SRCS op_desc.proto DEPS attr_type) cc_test(op_desc_test SRCS op_desc_test.cc DEPS op_desc protobuf) +cc_library(operator SRCS operator.cc) diff --git a/paddle/framework/operator.cc b/paddle/framework/operator.cc index fc9c14f55b777..9db8b1e8c52a8 100644 --- a/paddle/framework/operator.cc +++ b/paddle/framework/operator.cc @@ -17,22 +17,27 @@ limitations under the License. */ namespace paddle { namespace framework { -Operator::Operator(const OpDesc &desc): op_desc_(desc) { +OperatorBase::OperatorBase(const OpDesc &desc): desc_(desc) { inputs_.reserve(desc.inputs_size()); - for(const std::string& input: op_desc_.inputs()) { + for(const std::string& input: desc_.inputs()) { inputs_.push_back(input); } - outputs_.reserve(op_desc_.outputs_size()); - for(const std::string& output: op_desc_.outputs()) { + outputs_.reserve(desc_.outputs_size()); + for(const std::string& output: desc_.outputs()) { outputs_.push_back(output); } - std::vector attrs; - attrs.reserve(desc.attrs_size()); - for(const AttrDesc& attr: op_desc_.attrs()) { - attrs.push_back(attr); - } - InitializeAttrs(attrs); } +std::string OperatorBase::DebugString() { + return desc_.DebugString(); +} + +Variable* OperatorBase::input(Scope *scope, int index) { + return scope->CreateVariable(inputs_[index]); +} + +Variable* OperatorBase::output(Scope *scope, int index) { + return scope->CreateVariable(outputs_[index]); +} } // namespace framework } // namespace paddle \ No newline at end of file diff --git a/paddle/framework/operator.h b/paddle/framework/operator.h index 0ea0ee2753a66..2a56c8d113062 100644 --- a/paddle/framework/operator.h +++ b/paddle/framework/operator.h @@ -18,47 +18,45 @@ limitations under the License. */ #include #include #include "paddle/utils/Error.h" - -#include +#include "paddle/framework/scope.h" +#include "paddle/framework/ddim.h" +#include "paddle/framework/op_desc.pb.h" namespace paddle { namespace framework { -class OpContext {}; +class DeviceContext {}; +class CpuContext : public DeviceContext {}; +class GpuContext : public DeviceContext {}; /** * @brief Operator is used to do some computation. * * We use a OpDesc proto Message to describe and create a operator. - * Operator will get the Variables and computing resource from OpContext when Run. + * Operator will get the Variables from scope and computing resource from DeviceContext. */ -class Operator { +class OperatorBase { public: - explicit Operator(const OpDesc& desc); - virtual ~Operator() {} + explicit OperatorBase(const OpDesc& desc); + virtual ~OperatorBase() {} + + /// initialize Attributes of this OP from proto message desc.attrs() + /// you should derive this function to init the attr you need in OP. + virtual void InitializeAttributes() = 0; - void InitializeAttrs(); + virtual void InferShape(const Scope* scope) const = 0; - /** - * InferShape is used to infer the shape of tensors related to this Operator. - */ - virtual void InferShape() = 0; + /// when implement an Op, your should implement this function. + virtual void Run(Scope* scope, DeviceContext* device_context) const = 0; - /** - * Run take a OpContext as parameter. - * - * 1. it will get input/output variable from OpContext.scope - * 2. It will get computing resource such as cpu/gpu from OpContext. - */ - virtual void Run(OpContext *context) const = 0; - const std::string DebugString() const { - return op_desc_.ShortDebugString(); - } + std::string DebugString(); + Variable* input(Scope* scope, int index); + Variable* output(Scope* scope, int index); protected: - OpDesc op_desc_; - std::vector inputs_; - std::vector outputs_; + const OpDesc desc_; + std::vector inputs_; + std::vector outputs_; }; } // namespace framework diff --git a/paddle/framework/operator_test.cc b/paddle/framework/operator_test.cc deleted file mode 100644 index 390cf6b31baf8..0000000000000 --- a/paddle/framework/operator_test.cc +++ /dev/null @@ -1,19 +0,0 @@ -/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#include "paddle/framework/operator.h" -#include "gtest/gtest.h" - -TEST(Operator, Create) { -} From 57f1f6ae9e227ab96c78f19ab3e3aeeefa9f783c Mon Sep 17 00:00:00 2001 From: qiaolongfei Date: Thu, 6 Jul 2017 20:31:44 +0800 Subject: [PATCH 07/26] refactor of operator --- paddle/framework/CMakeLists.txt | 3 ++- paddle/framework/operator.cc | 35 ++++++++++++------------ paddle/framework/operator.h | 39 +++++++++++++-------------- paddle/framework/operator_test.cc | 45 +++++++++++++++++++++++++++++++ 4 files changed, 83 insertions(+), 39 deletions(-) create mode 100644 paddle/framework/operator_test.cc diff --git a/paddle/framework/CMakeLists.txt b/paddle/framework/CMakeLists.txt index 596dbf4601227..6e154a3b89946 100644 --- a/paddle/framework/CMakeLists.txt +++ b/paddle/framework/CMakeLists.txt @@ -12,4 +12,5 @@ cc_test(op_proto_test SRCS op_proto_test.cc DEPS op_proto protobuf) proto_library(op_desc SRCS op_desc.proto DEPS attr_type) cc_test(op_desc_test SRCS op_desc_test.cc DEPS op_desc protobuf) -cc_library(operator SRCS operator.cc) +cc_library(operator SRCS operator.cc DEPS op_desc protobuf) +cc_test(operator_test SRCS operator_test.cc DEPS operator op_desc protobuf) diff --git a/paddle/framework/operator.cc b/paddle/framework/operator.cc index 9db8b1e8c52a8..d7dcdc5b4bf76 100644 --- a/paddle/framework/operator.cc +++ b/paddle/framework/operator.cc @@ -13,31 +13,30 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "paddle/framework/operator.h" +#include namespace paddle { namespace framework { -OperatorBase::OperatorBase(const OpDesc &desc): desc_(desc) { - inputs_.reserve(desc.inputs_size()); - for(const std::string& input: desc_.inputs()) { - inputs_.push_back(input); +std::string OperatorBase::DebugString() const { + std::stringstream ss; + ss << "inputs = ["; + for(auto& ipt : inputs) { + ss << ipt << ", "; } - outputs_.reserve(desc_.outputs_size()); - for(const std::string& output: desc_.outputs()) { - outputs_.push_back(output); + ss << "]\n"; + ss << "outputs = ["; + for(auto& opt : outputs) { + ss << opt << ", "; } + ss << "]\n"; + ss << "attr_keys = ["; + for(auto& attr : attrs) { + ss << attr.first << ", "; + } + ss << "]\n"; + return ss.str(); } -std::string OperatorBase::DebugString() { - return desc_.DebugString(); -} - -Variable* OperatorBase::input(Scope *scope, int index) { - return scope->CreateVariable(inputs_[index]); -} - -Variable* OperatorBase::output(Scope *scope, int index) { - return scope->CreateVariable(outputs_[index]); -} } // namespace framework } // namespace paddle \ No newline at end of file diff --git a/paddle/framework/operator.h b/paddle/framework/operator.h index 2a56c8d113062..6de72b6a2cfe3 100644 --- a/paddle/framework/operator.h +++ b/paddle/framework/operator.h @@ -19,8 +19,9 @@ limitations under the License. */ #include #include "paddle/utils/Error.h" #include "paddle/framework/scope.h" -#include "paddle/framework/ddim.h" #include "paddle/framework/op_desc.pb.h" +#include + namespace paddle { namespace framework { @@ -29,34 +30,32 @@ class DeviceContext {}; class CpuContext : public DeviceContext {}; class GpuContext : public DeviceContext {}; +class OpRunContext { + public: + Scope* scope; + DeviceContext* device_context; +}; + +using Attribute = boost::variant, std::vector, std::vector>; +using AttributeMap = std::unordered_map; + /** - * @brief Operator is used to do some computation. - * - * We use a OpDesc proto Message to describe and create a operator. - * Operator will get the Variables from scope and computing resource from DeviceContext. + * all the init will be done by CreateOperator. */ class OperatorBase { public: - explicit OperatorBase(const OpDesc& desc); virtual ~OperatorBase() {} - - /// initialize Attributes of this OP from proto message desc.attrs() - /// you should derive this function to init the attr you need in OP. - virtual void InitializeAttributes() = 0; + /// when implement an Op, your should implement this function. + virtual void Run(OpRunContext* context) const = 0; virtual void InferShape(const Scope* scope) const = 0; - /// when implement an Op, your should implement this function. - virtual void Run(Scope* scope, DeviceContext* device_context) const = 0; + std::string DebugString() const; - std::string DebugString(); - Variable* input(Scope* scope, int index); - Variable* output(Scope* scope, int index); - -protected: - const OpDesc desc_; - std::vector inputs_; - std::vector outputs_; + public: + std::vector inputs; + std::vector outputs; + AttributeMap attrs; }; } // namespace framework diff --git a/paddle/framework/operator_test.cc b/paddle/framework/operator_test.cc new file mode 100644 index 0000000000000..322a9ce897219 --- /dev/null +++ b/paddle/framework/operator_test.cc @@ -0,0 +1,45 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/framework/operator.h" +#include "paddle/framework/scope.h" +#include "gtest/gtest.h" + +namespace paddle { +namespace framework { + +class OperatorTest: public OperatorBase { + void Run(OpRunContext* context) const override {} + void InferShape(const Scope* scope) const override {} +}; + +TEST(OperatorBase, DebugString) { + Scope* scope = new Scope(); + DeviceContext* device_context = new DeviceContext(); + OpRunContext* op_context = new OpRunContext(); + op_context->scope = scope; + op_context->device_context = device_context; + + auto op = new OperatorTest(); + op->inputs.push_back("X"); + op->inputs.push_back("Y"); + op->outputs.push_back("O"); + op->attrs["scale"] = 0; + + printf("%s\n", op->DebugString().c_str()); + +} + +} // namespace framework +} // namespace paddle \ No newline at end of file From 57d1fbd6bebc675eb842d4541d8b53d29f0f2731 Mon Sep 17 00:00:00 2001 From: qiaolongfei Date: Thu, 6 Jul 2017 23:46:03 +0800 Subject: [PATCH 08/26] add comment, optimize code style --- paddle/framework/operator.cc | 10 ++++----- paddle/framework/operator.h | 35 +++++++++++++++++++++---------- paddle/framework/operator_test.cc | 17 +++++++-------- 3 files changed, 37 insertions(+), 25 deletions(-) diff --git a/paddle/framework/operator.cc b/paddle/framework/operator.cc index d7dcdc5b4bf76..a09787c1c6405 100644 --- a/paddle/framework/operator.cc +++ b/paddle/framework/operator.cc @@ -21,22 +21,22 @@ namespace framework { std::string OperatorBase::DebugString() const { std::stringstream ss; ss << "inputs = ["; - for(auto& ipt : inputs) { + for (auto& ipt : inputs_) { ss << ipt << ", "; } ss << "]\n"; ss << "outputs = ["; - for(auto& opt : outputs) { + for (auto& opt : outputs_) { ss << opt << ", "; } ss << "]\n"; ss << "attr_keys = ["; - for(auto& attr : attrs) { + for (auto& attr : attrs_) { ss << attr.first << ", "; } ss << "]\n"; return ss.str(); } -} // namespace framework -} // namespace paddle \ No newline at end of file +} // namespace framework +} // namespace paddle \ No newline at end of file diff --git a/paddle/framework/operator.h b/paddle/framework/operator.h index 6de72b6a2cfe3..5c70c12be77c6 100644 --- a/paddle/framework/operator.h +++ b/paddle/framework/operator.h @@ -14,14 +14,13 @@ limitations under the License. */ #pragma once +#include #include #include #include -#include "paddle/utils/Error.h" -#include "paddle/framework/scope.h" #include "paddle/framework/op_desc.pb.h" -#include - +#include "paddle/framework/scope.h" +#include "paddle/utils/Error.h" namespace paddle { namespace framework { @@ -30,17 +29,27 @@ class DeviceContext {}; class CpuContext : public DeviceContext {}; class GpuContext : public DeviceContext {}; +/** + * OpRunContext is the only parameter of Operator's Run function. + * Run will get input/output variables, state such as momentum and + * device resource such as CUDA stream, cublas handle, etc. from + * OpRunContext. User should construct it before run the Operator. + */ class OpRunContext { public: Scope* scope; DeviceContext* device_context; }; -using Attribute = boost::variant, std::vector, std::vector>; +using Attribute = + boost::variant, + std::vector, std::vector>; using AttributeMap = std::unordered_map; /** - * all the init will be done by CreateOperator. + * OperatorBase has the basic element that Net will call to do compute. + * It have no construct function because CreateOperator(const& op_desc) + * will parse op_desc and set the input/output/attr properly. */ class OperatorBase { public: @@ -48,15 +57,19 @@ class OperatorBase { /// when implement an Op, your should implement this function. virtual void Run(OpRunContext* context) const = 0; + /// InferShape infer the size of Variables used by this Operator with + /// information + /// inside scope virtual void InferShape(const Scope* scope) const = 0; std::string DebugString() const; public: - std::vector inputs; - std::vector outputs; - AttributeMap attrs; + std::string type_; + std::vector inputs_; + std::vector outputs_; + AttributeMap attrs_; }; -} // namespace framework -} // namespace paddle \ No newline at end of file +} // namespace framework +} // namespace paddle diff --git a/paddle/framework/operator_test.cc b/paddle/framework/operator_test.cc index 322a9ce897219..cc0dca2dceae2 100644 --- a/paddle/framework/operator_test.cc +++ b/paddle/framework/operator_test.cc @@ -13,13 +13,13 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "paddle/framework/operator.h" -#include "paddle/framework/scope.h" #include "gtest/gtest.h" +#include "paddle/framework/scope.h" namespace paddle { namespace framework { -class OperatorTest: public OperatorBase { +class OperatorTest : public OperatorBase { void Run(OpRunContext* context) const override {} void InferShape(const Scope* scope) const override {} }; @@ -32,14 +32,13 @@ TEST(OperatorBase, DebugString) { op_context->device_context = device_context; auto op = new OperatorTest(); - op->inputs.push_back("X"); - op->inputs.push_back("Y"); - op->outputs.push_back("O"); - op->attrs["scale"] = 0; + op->inputs_.push_back("X"); + op->inputs_.push_back("Y"); + op->outputs_.push_back("O"); + op->attrs_["scale"] = 0; printf("%s\n", op->DebugString().c_str()); - } -} // namespace framework -} // namespace paddle \ No newline at end of file +} // namespace framework +} // namespace paddle \ No newline at end of file From 4f23ec333314ac9b1dc952b2a5d189ff3df9f455 Mon Sep 17 00:00:00 2001 From: qiaolongfei Date: Fri, 7 Jul 2017 22:08:53 +0800 Subject: [PATCH 09/26] net interface of op --- paddle/framework/CMakeLists.txt | 2 +- paddle/framework/op_registry.h | 57 +++++++++------------------- paddle/framework/op_registry_test.cc | 44 +++++++++------------ paddle/framework/operator.cc | 31 +++++++++++++++ paddle/framework/operator.h | 46 +++++++++++++++------- paddle/framework/operator_test.cc | 15 ++++---- 6 files changed, 107 insertions(+), 88 deletions(-) diff --git a/paddle/framework/CMakeLists.txt b/paddle/framework/CMakeLists.txt index d7854dd52d613..a4ad172092fd3 100644 --- a/paddle/framework/CMakeLists.txt +++ b/paddle/framework/CMakeLists.txt @@ -13,7 +13,7 @@ proto_library(op_desc SRCS op_desc.proto DEPS attr_type) cc_test(op_desc_test SRCS op_desc_test.cc DEPS op_desc protobuf) cc_library(operator SRCS operator.cc DEPS op_desc protobuf) cc_test(operator_test SRCS operator_test.cc DEPS operator op_desc protobuf) -cc_test(op_registry_test SRCS op_registry_test.cc DEPS op_proto op_desc) +cc_test(op_registry_test SRCS op_registry_test.cc DEPS op_proto op_desc operator) py_proto_compile(framework_py_proto SRCS attr_type.proto op_proto.proto op_desc.proto) # Generate an empty __init__.py to make framework_py_proto as a valid python module. add_custom_target(framework_py_proto_init ALL COMMAND ${CMAKE_COMMAND} -E touch __init__.py) diff --git a/paddle/framework/op_registry.h b/paddle/framework/op_registry.h index 81241b5342d89..956c3c91f2d49 100644 --- a/paddle/framework/op_registry.h +++ b/paddle/framework/op_registry.h @@ -2,25 +2,13 @@ #include "paddle/framework/attr_checker.h" -//#include "paddle/framework/op_base.h" +#include "paddle/framework/operator.h" #include "paddle/framework/op_desc.pb.h" #include "paddle/framework/op_proto.pb.h" namespace paddle { namespace framework { -//==================For test================// -class OpBase { - public: - std::vector inputs_; - std::vector outputs_; - AttributeMap attr_map_; - - virtual std::string Run() const = 0; - virtual ~OpBase() {} -}; -//=========================================// - // helper class to set attribute type struct AttrTypeHelper { template @@ -134,7 +122,7 @@ class OpProtoAndCheckerMaker { }; class OpRegistry { - typedef std::function OpCreator; + typedef std::function OpCreator; public: template @@ -143,28 +131,22 @@ class OpRegistry { OpProto& op_proto = protos_[op_type]; OpAttrChecker& op_checker = op_checkers_[op_type]; ProtoMakerType(&op_proto, &op_checker); - PADDLE_ENFORCE(op_proto.IsInitialized() == true, + PADDLE_ENFORCE(op_proto.IsInitialized(), "Fail to initialize %s's OpProto !", op_type); } - static OpBase* CreateOp(const OpDesc& op_desc) { + static OperatorBase* CreateOp(const OpDesc& op_desc) { std::string op_type = op_desc.type(); - OpBase* op = (creators_.at(op_type))(); - (op->inputs_).resize(op_desc.inputs_size()); - for (int i = 0; i < op_desc.inputs_size(); ++i) { - (op->inputs_)[i] = op_desc.inputs(i); - } - (op->outputs_).resize(op_desc.outputs_size()); - for (int i = 0; i < op_desc.outputs_size(); ++i) { - (op->outputs_)[i] = op_desc.outputs(i); - } + OperatorBase* op = (creators_.at(op_type))(); + AttributeMap attrs; for (int i = 0; i < op_desc.attrs_size(); ++i) { const AttrDesc& ith_attr = op_desc.attrs(i); std::string name = ith_attr.name(); - (op->attr_map_)[name] = AttrTypeHelper::GetAttrValue(ith_attr); + (attrs)[name] = AttrTypeHelper::GetAttrValue(ith_attr); } - const OpAttrChecker& op_checker = op_checkers_.at(op_type); - op_checker.Check(op->attr_map_); + const OpAttrChecker& op_checker = OpRegistry::op_checkers_.at(op_type); + op_checker.Check(attrs); + op->Init(op_desc, attrs); return op; } @@ -174,7 +156,7 @@ class OpRegistry { static std::unordered_map op_checkers_; }; -std::unordered_map> OpRegistry::creators_; +std::unordered_map> OpRegistry::creators_; std::unordered_map OpRegistry::protos_; std::unordered_map OpRegistry::op_checkers_; @@ -196,12 +178,10 @@ class OpRegisterHelper { // Demos -class CosineOp : public OpBase { +class CosineOp : public OperatorBase { public: - virtual std::string Run() const { - std::string msg = "CosineOp runs! scale = " + - std::to_string(boost::get(attr_map_.at("scale"))); - return msg; + void Run(OpRunContext* context) const override { + printf("%s\n", DebugString().c_str()); } }; @@ -221,13 +201,10 @@ class CosineOpProtoAndCheckerMaker : public OpProtoAndCheckerMaker { REGISTER_OP(CosineOp, CosineOpProtoAndCheckerMaker, cos_sim) -class MyTestOp : public OpBase { +class MyTestOp : public OperatorBase { public: - virtual std::string Run() const { - std::string msg = - "MyTestOp runs! test_attr = " + - std::to_string(boost::get(attr_map_.at("test_attr"))); - return msg; + void Run(OpRunContext* context) const override { + printf("%s\n", DebugString().c_str()); } }; diff --git a/paddle/framework/op_registry_test.cc b/paddle/framework/op_registry_test.cc index 17849ca0191db..d67cd9c6810f8 100644 --- a/paddle/framework/op_registry_test.cc +++ b/paddle/framework/op_registry_test.cc @@ -7,19 +7,17 @@ TEST(OpRegistry, CreateOp) { op_desc.add_inputs("aa"); op_desc.add_outputs("bb"); + float scale = 3.3; auto attr = op_desc.mutable_attrs()->Add(); attr->set_name("scale"); attr->set_type(paddle::framework::AttrType::FLOAT); - attr->set_f(3.3); + attr->set_f(scale); - paddle::framework::OpBase* op = + paddle::framework::OperatorBase* op = paddle::framework::OpRegistry::CreateOp(op_desc); - std::string debug_str = op->Run(); - std::string str = "CosineOp runs! scale = " + std::to_string(3.3); - ASSERT_EQ(str.size(), debug_str.size()); - for (size_t i = 0; i < debug_str.length(); ++i) { - ASSERT_EQ(debug_str[i], str[i]); - } + op->Run(nullptr); + float scale_get = boost::get(op->GetAttr("scale")); + ASSERT_EQ(scale_get, scale); } TEST(OpRegistry, IllegalAttr) { @@ -35,7 +33,7 @@ TEST(OpRegistry, IllegalAttr) { bool caught = false; try { - paddle::framework::OpBase* op __attribute__((unused)) = + paddle::framework::OperatorBase* op __attribute__((unused)) = paddle::framework::OpRegistry::CreateOp(op_desc); } catch (paddle::framework::EnforceNotMet err) { caught = true; @@ -54,15 +52,12 @@ TEST(OpRegistry, DefaultValue) { op_desc.add_inputs("aa"); op_desc.add_outputs("bb"); - paddle::framework::OpBase* op = + ASSERT_TRUE(op_desc.IsInitialized()); + + paddle::framework::OperatorBase* op = paddle::framework::OpRegistry::CreateOp(op_desc); - std::string debug_str = op->Run(); - float default_value = 1.0; - std::string str = "CosineOp runs! scale = " + std::to_string(default_value); - ASSERT_EQ(str.size(), debug_str.size()); - for (size_t i = 0; i < debug_str.length(); ++i) { - ASSERT_EQ(debug_str[i], str[i]); - } + op->Run(nullptr); + ASSERT_EQ(boost::get(op->GetAttr("scale")), 1.0); } TEST(OpRegistry, CustomChecker) { @@ -74,7 +69,7 @@ TEST(OpRegistry, CustomChecker) { // attr 'test_attr' is not set bool caught = false; try { - paddle::framework::OpBase* op __attribute__((unused)) = + paddle::framework::OperatorBase* op __attribute__((unused)) = paddle::framework::OpRegistry::CreateOp(op_desc); } catch (paddle::framework::EnforceNotMet err) { caught = true; @@ -93,7 +88,7 @@ TEST(OpRegistry, CustomChecker) { attr->set_i(3); caught = false; try { - paddle::framework::OpBase* op __attribute__((unused)) = + paddle::framework::OperatorBase* op __attribute__((unused)) = paddle::framework::OpRegistry::CreateOp(op_desc); } catch (paddle::framework::EnforceNotMet err) { caught = true; @@ -111,12 +106,9 @@ TEST(OpRegistry, CustomChecker) { attr->set_name("test_attr"); attr->set_type(paddle::framework::AttrType::INT); attr->set_i(4); - paddle::framework::OpBase* op = + paddle::framework::OperatorBase* op = paddle::framework::OpRegistry::CreateOp(op_desc); - std::string debug_str = op->Run(); - std::string str = "MyTestOp runs! test_attr = " + std::to_string(4); - ASSERT_EQ(str.size(), debug_str.size()); - for (size_t i = 0; i < debug_str.length(); ++i) { - ASSERT_EQ(debug_str[i], str[i]); - } + op->Run(nullptr); + int test_attr = boost::get(op->GetAttr("test_attr")); + ASSERT_EQ(test_attr, 4); } \ No newline at end of file diff --git a/paddle/framework/operator.cc b/paddle/framework/operator.cc index a09787c1c6405..9af8d809e3237 100644 --- a/paddle/framework/operator.cc +++ b/paddle/framework/operator.cc @@ -18,8 +18,39 @@ limitations under the License. */ namespace paddle { namespace framework { +void OperatorBase::Init(const OpDesc &op_desc, AttributeMap& attrs) { + desc_ = op_desc; + inputs_.reserve(desc_.inputs_size()); + for (auto& input : desc_.inputs()) { + inputs_.push_back(input); + } + outputs_.reserve(desc_.outputs_size()); + for (auto& output : desc_.outputs()) { + outputs_.push_back(output); + } + for(auto it = attrs.begin(); it != attrs.end(); ++it) { + attrs_[it->first] = it->second; + } +} + +Variable* OperatorBase::Input(Scope* scope, int index) const { + return scope->GetVariable(inputs_[index]); +} + +Variable* OperatorBase::Output(Scope* scope, int index) const { + return scope->GetVariable(outputs_[index]); +} + +Attribute OperatorBase::GetAttr(std::string name) { + return attrs_[name]; +} + +void OperatorBase::InferShape(Scope *scope) const {} + std::string OperatorBase::DebugString() const { std::stringstream ss; + ss << "=================\n"; + ss << "type = " << type() << "\n"; ss << "inputs = ["; for (auto& ipt : inputs_) { ss << ipt << ", "; diff --git a/paddle/framework/operator.h b/paddle/framework/operator.h index 5c70c12be77c6..f5fd684b5be60 100644 --- a/paddle/framework/operator.h +++ b/paddle/framework/operator.h @@ -21,13 +21,14 @@ limitations under the License. */ #include "paddle/framework/op_desc.pb.h" #include "paddle/framework/scope.h" #include "paddle/utils/Error.h" +#include "paddle/framework/attr_checker.h" namespace paddle { namespace framework { class DeviceContext {}; -class CpuContext : public DeviceContext {}; -class GpuContext : public DeviceContext {}; +class CPUContext : public DeviceContext {}; +class GPUContext : public DeviceContext {}; /** * OpRunContext is the only parameter of Operator's Run function. @@ -41,11 +42,6 @@ class OpRunContext { DeviceContext* device_context; }; -using Attribute = - boost::variant, - std::vector, std::vector>; -using AttributeMap = std::unordered_map; - /** * OperatorBase has the basic element that Net will call to do compute. * It have no construct function because CreateOperator(const& op_desc) @@ -54,18 +50,42 @@ using AttributeMap = std::unordered_map; class OperatorBase { public: virtual ~OperatorBase() {} - /// when implement an Op, your should implement this function. - virtual void Run(OpRunContext* context) const = 0; + + void Init(const OpDesc& op_desc, AttributeMap& attrs); + + std::string type() const { + return desc_.type(); + } + + Variable* Input(Scope* scope, int index) const; + Variable* Output(Scope* scope, int index) const; + + Attribute GetAttr(std::string name); + + inline const AttributeMap attrs() const { + return attrs_; + } + + inline const std::vector inputs() const { + return inputs_; + } + + inline const std::vector outputs() const { + return outputs_; + } + + std::string DebugString() const; /// InferShape infer the size of Variables used by this Operator with /// information /// inside scope - virtual void InferShape(const Scope* scope) const = 0; + void InferShape(Scope* scope) const; - std::string DebugString() const; + /// when implement an Op, your should implement this function. + virtual void Run(OpRunContext* context) const = 0; - public: - std::string type_; + protected: + OpDesc desc_; std::vector inputs_; std::vector outputs_; AttributeMap attrs_; diff --git a/paddle/framework/operator_test.cc b/paddle/framework/operator_test.cc index cc0dca2dceae2..74d6dc13b500b 100644 --- a/paddle/framework/operator_test.cc +++ b/paddle/framework/operator_test.cc @@ -20,10 +20,14 @@ namespace paddle { namespace framework { class OperatorTest : public OperatorBase { - void Run(OpRunContext* context) const override {} - void InferShape(const Scope* scope) const override {} + public: + void Run(OpRunContext* context) const override { + printf("%s\n", DebugString().c_str()); + } }; + + TEST(OperatorBase, DebugString) { Scope* scope = new Scope(); DeviceContext* device_context = new DeviceContext(); @@ -32,12 +36,7 @@ TEST(OperatorBase, DebugString) { op_context->device_context = device_context; auto op = new OperatorTest(); - op->inputs_.push_back("X"); - op->inputs_.push_back("Y"); - op->outputs_.push_back("O"); - op->attrs_["scale"] = 0; - - printf("%s\n", op->DebugString().c_str()); + op->Run(op_context); } } // namespace framework From 88f66219856f98a9c03463a2d3938d1a4254806c Mon Sep 17 00:00:00 2001 From: qiaolongfei Date: Sat, 8 Jul 2017 09:23:28 +0800 Subject: [PATCH 10/26] refine operator interface --- paddle/CMakeLists.txt | 1 + paddle/framework/CMakeLists.txt | 2 +- paddle/framework/op_registry.h | 55 ++------------------------ paddle/framework/op_registry_test.cc | 7 ++-- paddle/framework/operator.cc | 23 ++--------- paddle/framework/operator.h | 48 ++++++++++++---------- paddle/framework/operator_test.cc | 35 +++++++++++++++-- paddle/operators/.clang-format | 5 +++ paddle/operators/CMakeLists.txt | 0 paddle/operators/demo_op.h | 59 ++++++++++++++++++++++++++++ 10 files changed, 135 insertions(+), 100 deletions(-) create mode 100644 paddle/operators/.clang-format create mode 100644 paddle/operators/CMakeLists.txt create mode 100644 paddle/operators/demo_op.h diff --git a/paddle/CMakeLists.txt b/paddle/CMakeLists.txt index 307e99bbe3a83..60977cbab933f 100644 --- a/paddle/CMakeLists.txt +++ b/paddle/CMakeLists.txt @@ -15,6 +15,7 @@ if(Boost_FOUND) add_subdirectory(memory) add_subdirectory(platform) add_subdirectory(framework) + add_subdirectory(operators) endif() if(WITH_C_API) diff --git a/paddle/framework/CMakeLists.txt b/paddle/framework/CMakeLists.txt index a4ad172092fd3..524fbbabcba04 100644 --- a/paddle/framework/CMakeLists.txt +++ b/paddle/framework/CMakeLists.txt @@ -12,7 +12,7 @@ cc_test(op_proto_test SRCS op_proto_test.cc DEPS op_proto protobuf) proto_library(op_desc SRCS op_desc.proto DEPS attr_type) cc_test(op_desc_test SRCS op_desc_test.cc DEPS op_desc protobuf) cc_library(operator SRCS operator.cc DEPS op_desc protobuf) -cc_test(operator_test SRCS operator_test.cc DEPS operator op_desc protobuf) +cc_test(operator_test SRCS operator_test.cc DEPS operator op_proto op_desc attr_type protobuf) cc_test(op_registry_test SRCS op_registry_test.cc DEPS op_proto op_desc operator) py_proto_compile(framework_py_proto SRCS attr_type.proto op_proto.proto op_desc.proto) # Generate an empty __init__.py to make framework_py_proto as a valid python module. diff --git a/paddle/framework/op_registry.h b/paddle/framework/op_registry.h index 956c3c91f2d49..1c1e4ccf4940e 100644 --- a/paddle/framework/op_registry.h +++ b/paddle/framework/op_registry.h @@ -2,9 +2,9 @@ #include "paddle/framework/attr_checker.h" -#include "paddle/framework/operator.h" #include "paddle/framework/op_desc.pb.h" #include "paddle/framework/op_proto.pb.h" +#include "paddle/framework/operator.h" namespace paddle { namespace framework { @@ -156,7 +156,8 @@ class OpRegistry { static std::unordered_map op_checkers_; }; -std::unordered_map> OpRegistry::creators_; +std::unordered_map> + OpRegistry::creators_; std::unordered_map OpRegistry::protos_; std::unordered_map OpRegistry::op_checkers_; @@ -176,55 +177,5 @@ class OpRegisterHelper { const OpRegisterHelper<__op_class, __op_maker_class> \ __op_class##Register::reg(#__op_type); -// Demos - -class CosineOp : public OperatorBase { - public: - void Run(OpRunContext* context) const override { - printf("%s\n", DebugString().c_str()); - } -}; - -class CosineOpProtoAndCheckerMaker : public OpProtoAndCheckerMaker { - public: - CosineOpProtoAndCheckerMaker(OpProto* proto, OpAttrChecker* op_checker) - : OpProtoAndCheckerMaker(proto, op_checker) { - AddInput("input", "input of cosine op"); - AddOutput("output", "output of cosine op"); - AddAttr("scale", "scale of cosine op") - .SetDefault(1.0) - .LargerThan(0.0); - AddType("cos"); - AddComment("This is cos op"); - } -}; - -REGISTER_OP(CosineOp, CosineOpProtoAndCheckerMaker, cos_sim) - -class MyTestOp : public OperatorBase { - public: - void Run(OpRunContext* context) const override { - printf("%s\n", DebugString().c_str()); - } -}; - -class MyTestOpProtoAndCheckerMaker : public OpProtoAndCheckerMaker { - public: - MyTestOpProtoAndCheckerMaker(OpProto* proto, OpAttrChecker* op_checker) - : OpProtoAndCheckerMaker(proto, op_checker) { - AddInput("input", "input of cosine op"); - AddOutput("output", "output of cosine op"); - auto my_checker = [](int i) { - PADDLE_ENFORCE(i % 2 == 0, "'test_attr' must be even!"); - }; - AddAttr("test_attr", "a simple test attribute") - .AddCustomChecker(my_checker); - AddType("my_test_op"); - AddComment("This is my_test op"); - } -}; - -REGISTER_OP(MyTestOp, MyTestOpProtoAndCheckerMaker, my_test_op) - } // namespace framework } // namespace paddle diff --git a/paddle/framework/op_registry_test.cc b/paddle/framework/op_registry_test.cc index d67cd9c6810f8..d97f621d29c71 100644 --- a/paddle/framework/op_registry_test.cc +++ b/paddle/framework/op_registry_test.cc @@ -1,5 +1,6 @@ #include "paddle/framework/op_registry.h" #include +#include "paddle/operators/demo_op.h" TEST(OpRegistry, CreateOp) { paddle::framework::OpDesc op_desc; @@ -16,7 +17,7 @@ TEST(OpRegistry, CreateOp) { paddle::framework::OperatorBase* op = paddle::framework::OpRegistry::CreateOp(op_desc); op->Run(nullptr); - float scale_get = boost::get(op->GetAttr("scale")); + float scale_get = op->GetAttr("scale"); ASSERT_EQ(scale_get, scale); } @@ -57,7 +58,7 @@ TEST(OpRegistry, DefaultValue) { paddle::framework::OperatorBase* op = paddle::framework::OpRegistry::CreateOp(op_desc); op->Run(nullptr); - ASSERT_EQ(boost::get(op->GetAttr("scale")), 1.0); + ASSERT_EQ(op->GetAttr("scale"), 1.0); } TEST(OpRegistry, CustomChecker) { @@ -109,6 +110,6 @@ TEST(OpRegistry, CustomChecker) { paddle::framework::OperatorBase* op = paddle::framework::OpRegistry::CreateOp(op_desc); op->Run(nullptr); - int test_attr = boost::get(op->GetAttr("test_attr")); + int test_attr = op->GetAttr("test_attr"); ASSERT_EQ(test_attr, 4); } \ No newline at end of file diff --git a/paddle/framework/operator.cc b/paddle/framework/operator.cc index 9af8d809e3237..57f8aaa48e397 100644 --- a/paddle/framework/operator.cc +++ b/paddle/framework/operator.cc @@ -13,12 +13,11 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "paddle/framework/operator.h" -#include namespace paddle { namespace framework { -void OperatorBase::Init(const OpDesc &op_desc, AttributeMap& attrs) { +void OperatorBase::Init(const OpDesc& op_desc, AttributeMap& attrs) { desc_ = op_desc; inputs_.reserve(desc_.inputs_size()); for (auto& input : desc_.inputs()) { @@ -28,29 +27,15 @@ void OperatorBase::Init(const OpDesc &op_desc, AttributeMap& attrs) { for (auto& output : desc_.outputs()) { outputs_.push_back(output); } - for(auto it = attrs.begin(); it != attrs.end(); ++it) { - attrs_[it->first] = it->second; - } -} - -Variable* OperatorBase::Input(Scope* scope, int index) const { - return scope->GetVariable(inputs_[index]); -} - -Variable* OperatorBase::Output(Scope* scope, int index) const { - return scope->GetVariable(outputs_[index]); -} - -Attribute OperatorBase::GetAttr(std::string name) { - return attrs_[name]; + attrs_.insert(attrs.begin(), attrs.end()); } -void OperatorBase::InferShape(Scope *scope) const {} +void OperatorBase::InferShape(Scope* scope) const {} std::string OperatorBase::DebugString() const { std::stringstream ss; ss << "=================\n"; - ss << "type = " << type() << "\n"; + ss << "type = " << desc_.type() << "\n"; ss << "inputs = ["; for (auto& ipt : inputs_) { ss << ipt << ", "; diff --git a/paddle/framework/operator.h b/paddle/framework/operator.h index f5fd684b5be60..40e1e66fca67c 100644 --- a/paddle/framework/operator.h +++ b/paddle/framework/operator.h @@ -18,10 +18,10 @@ limitations under the License. */ #include #include #include +#include "paddle/framework/attr_checker.h" #include "paddle/framework/op_desc.pb.h" #include "paddle/framework/scope.h" #include "paddle/utils/Error.h" -#include "paddle/framework/attr_checker.h" namespace paddle { namespace framework { @@ -36,44 +36,50 @@ class GPUContext : public DeviceContext {}; * device resource such as CUDA stream, cublas handle, etc. from * OpRunContext. User should construct it before run the Operator. */ -class OpRunContext { +class OpContext { public: Scope* scope; DeviceContext* device_context; }; /** - * OperatorBase has the basic element that Net will call to do compute. - * It have no construct function because CreateOperator(const& op_desc) - * will parse op_desc and set the input/output/attr properly. + * OperatorBase has the basic element that Net will call to do computation. + * Only CreateOperator from OpRegistry will new Operator directly. User + * should always construct a proto message OpDesc and call + * OpRegistry::CreateOp(op_desc) to get an Operator instance. */ class OperatorBase { public: virtual ~OperatorBase() {} + /// We do not use ctor but an init function to construct an Operator. + /// There is no need for all sub operators to have a constructor and + /// write this init parameters. void Init(const OpDesc& op_desc, AttributeMap& attrs); - std::string type() const { - return desc_.type(); - } - - Variable* Input(Scope* scope, int index) const; - Variable* Output(Scope* scope, int index) const; + inline const OpDesc op_desc() const { return desc_; } - Attribute GetAttr(std::string name); - - inline const AttributeMap attrs() const { - return attrs_; + inline const Variable* Input(Scope* scope, int index) const { + PADDLE_ENFORCE(scope != nullptr, "scope should not be nullptr"); + return scope->GetVariable(inputs_[index]); } - inline const std::vector inputs() const { - return inputs_; + inline Variable* Output(Scope* scope, int index) const { + PADDLE_ENFORCE(scope != nullptr, "scope should not be nullptr"); + return scope->GetVariable(outputs_[index]); } - inline const std::vector outputs() const { - return outputs_; + template + inline const T GetAttr(const std::string& name) const { + PADDLE_ENFORCE(attrs_.count(name) != 0, "%s should be in AttributeMap", + name); + return boost::get(attrs_.at(name)); } + inline const std::vector inputs() const { return inputs_; } + + inline const std::vector outputs() const { return outputs_; } + std::string DebugString() const; /// InferShape infer the size of Variables used by this Operator with @@ -82,9 +88,9 @@ class OperatorBase { void InferShape(Scope* scope) const; /// when implement an Op, your should implement this function. - virtual void Run(OpRunContext* context) const = 0; + virtual void Run(OpContext* context) const = 0; - protected: + private: OpDesc desc_; std::vector inputs_; std::vector outputs_; diff --git a/paddle/framework/operator_test.cc b/paddle/framework/operator_test.cc index 74d6dc13b500b..f3c0dbceca35d 100644 --- a/paddle/framework/operator_test.cc +++ b/paddle/framework/operator_test.cc @@ -14,28 +14,55 @@ limitations under the License. */ #include "paddle/framework/operator.h" #include "gtest/gtest.h" -#include "paddle/framework/scope.h" +#include "paddle/framework/op_registry.h" namespace paddle { namespace framework { class OperatorTest : public OperatorBase { public: - void Run(OpRunContext* context) const override { + void Run(OpContext* context) const override { + float scale = GetAttr("scale"); + printf("get attr %s = %f\n", "scale", scale); printf("%s\n", DebugString().c_str()); } }; +class OperatorTestProtoAndCheckerMaker : public OpProtoAndCheckerMaker { + public: + OperatorTestProtoAndCheckerMaker(OpProto* proto, OpAttrChecker* op_checker) + : OpProtoAndCheckerMaker(proto, op_checker) { + AddInput("input", "input of test op"); + AddOutput("output", "output of test op"); + AddAttr("scale", "scale of cosine op") + .SetDefault(1.0) + .LargerThan(0.0); + AddType("cos"); + AddComment("This is test op"); + } +}; +REGISTER_OP(OperatorTest, OperatorTestProtoAndCheckerMaker, test_operator) TEST(OperatorBase, DebugString) { + OpDesc op_desc; + op_desc.set_type("test_operator"); + op_desc.add_inputs("IN1"); + op_desc.add_inputs("IN2"); + op_desc.add_outputs("OUT1"); + op_desc.add_outputs("OUT2"); + auto attr = op_desc.mutable_attrs()->Add(); + attr->set_name("scale"); + attr->set_type(paddle::framework::AttrType::FLOAT); + attr->set_f(3.14); + Scope* scope = new Scope(); DeviceContext* device_context = new DeviceContext(); - OpRunContext* op_context = new OpRunContext(); + OpContext* op_context = new OpContext(); op_context->scope = scope; op_context->device_context = device_context; - auto op = new OperatorTest(); + OperatorBase* op = paddle::framework::OpRegistry::CreateOp(op_desc); op->Run(op_context); } diff --git a/paddle/operators/.clang-format b/paddle/operators/.clang-format new file mode 100644 index 0000000000000..29282dc87e2c4 --- /dev/null +++ b/paddle/operators/.clang-format @@ -0,0 +1,5 @@ +--- +Language: Cpp +BasedOnStyle: Google +Standard: Cpp11 +... diff --git a/paddle/operators/CMakeLists.txt b/paddle/operators/CMakeLists.txt new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/paddle/operators/demo_op.h b/paddle/operators/demo_op.h new file mode 100644 index 0000000000000..1c1704c3e582d --- /dev/null +++ b/paddle/operators/demo_op.h @@ -0,0 +1,59 @@ +#pragma once + +#include "paddle/framework/op_registry.h" + +using namespace paddle::framework; + +namespace paddle { +namespace operators { + +class CosineOp : public OperatorBase { + public: + void Run(OpContext *context) const override { + printf("%s\n", DebugString().c_str()); + } +}; + +class CosineOpProtoAndCheckerMaker : public OpProtoAndCheckerMaker { + public: + CosineOpProtoAndCheckerMaker(OpProto *proto, OpAttrChecker *op_checker) + : OpProtoAndCheckerMaker(proto, op_checker) { + AddInput("input", "input of cosine op"); + AddOutput("output", "output of cosine op"); + AddAttr("scale", "scale of cosine op") + .SetDefault(1.0) + .LargerThan(0.0); + AddType("cos"); + AddComment("This is cos op"); + } +}; + +REGISTER_OP(CosineOp, CosineOpProtoAndCheckerMaker, cos_sim) + +class MyTestOp : public OperatorBase { + public: + void Run(OpContext *context) const override { + printf("%s\n", DebugString().c_str()); + } +}; + +class MyTestOpProtoAndCheckerMaker : public OpProtoAndCheckerMaker { + public: + MyTestOpProtoAndCheckerMaker(OpProto *proto, OpAttrChecker *op_checker) + : OpProtoAndCheckerMaker(proto, op_checker) { + AddInput("input", "input of cosine op"); + AddOutput("output", "output of cosine op"); + auto my_checker = [](int i) { + PADDLE_ENFORCE(i % 2 == 0, "'test_attr' must be even!"); + }; + AddAttr("test_attr", "a simple test attribute") + .AddCustomChecker(my_checker); + AddType("my_test_op"); + AddComment("This is my_test op"); + } +}; + +REGISTER_OP(MyTestOp, MyTestOpProtoAndCheckerMaker, my_test_op) + +} // namespace operators +} // namespace operators From 42afab20ef8bdff16775e1e1ffc1c13a4070ac90 Mon Sep 17 00:00:00 2001 From: qiaolongfei Date: Sun, 9 Jul 2017 12:56:30 +0800 Subject: [PATCH 11/26] optimize operator test --- paddle/framework/operator_test.cc | 18 +++++++++++++----- 1 file changed, 13 insertions(+), 5 deletions(-) diff --git a/paddle/framework/operator_test.cc b/paddle/framework/operator_test.cc index f3c0dbceca35d..03646d3c97753 100644 --- a/paddle/framework/operator_test.cc +++ b/paddle/framework/operator_test.cc @@ -47,14 +47,19 @@ REGISTER_OP(OperatorTest, OperatorTestProtoAndCheckerMaker, test_operator) TEST(OperatorBase, DebugString) { OpDesc op_desc; op_desc.set_type("test_operator"); - op_desc.add_inputs("IN1"); - op_desc.add_inputs("IN2"); - op_desc.add_outputs("OUT1"); - op_desc.add_outputs("OUT2"); + std::vector inputs = {"IN1", "IN2"}; + for (auto& input : inputs) { + op_desc.add_inputs(input); + } + std::vector outputs = {"OUT1", "OUT2"}; + for (auto& output : outputs) { + op_desc.add_outputs(output); + } auto attr = op_desc.mutable_attrs()->Add(); attr->set_name("scale"); attr->set_type(paddle::framework::AttrType::FLOAT); - attr->set_f(3.14); + float scale = 3.14; + attr->set_f(scale); Scope* scope = new Scope(); DeviceContext* device_context = new DeviceContext(); @@ -63,6 +68,9 @@ TEST(OperatorBase, DebugString) { op_context->device_context = device_context; OperatorBase* op = paddle::framework::OpRegistry::CreateOp(op_desc); + ASSERT_EQ(op->inputs(), inputs); + ASSERT_EQ(op->outputs(), outputs); + ASSERT_EQ(op->GetAttr("scale"), scale); op->Run(op_context); } From 0a662ff4bb0230f1f69f8f4253703dd09924cfcc Mon Sep 17 00:00:00 2001 From: qiaolongfei Date: Sun, 9 Jul 2017 13:38:37 +0800 Subject: [PATCH 12/26] optimize code --- paddle/framework/operator.cc | 2 +- paddle/framework/operator.h | 10 +++++++--- paddle/framework/operator_test.cc | 2 ++ 3 files changed, 10 insertions(+), 4 deletions(-) diff --git a/paddle/framework/operator.cc b/paddle/framework/operator.cc index 57f8aaa48e397..ad8c51162ce12 100644 --- a/paddle/framework/operator.cc +++ b/paddle/framework/operator.cc @@ -32,7 +32,7 @@ void OperatorBase::Init(const OpDesc& op_desc, AttributeMap& attrs) { void OperatorBase::InferShape(Scope* scope) const {} -std::string OperatorBase::DebugString() const { +const std::string OperatorBase::DebugString() const { std::stringstream ss; ss << "=================\n"; ss << "type = " << desc_.type() << "\n"; diff --git a/paddle/framework/operator.h b/paddle/framework/operator.h index 40e1e66fca67c..917115672e6b2 100644 --- a/paddle/framework/operator.h +++ b/paddle/framework/operator.h @@ -61,18 +61,22 @@ class OperatorBase { inline const Variable* Input(Scope* scope, int index) const { PADDLE_ENFORCE(scope != nullptr, "scope should not be nullptr"); + PADDLE_ENFORCE(index >= 0, "input index should not be negative"); + PADDLE_ENFORCE(index < (int)inputs().size(), "input index should less then %d", inputs().size()); return scope->GetVariable(inputs_[index]); } inline Variable* Output(Scope* scope, int index) const { PADDLE_ENFORCE(scope != nullptr, "scope should not be nullptr"); + PADDLE_ENFORCE(index >= 0, "output index should not be negative"); + PADDLE_ENFORCE(index < (int)outputs().size(), "output index should less then %d", outputs().size()); return scope->GetVariable(outputs_[index]); } template inline const T GetAttr(const std::string& name) const { - PADDLE_ENFORCE(attrs_.count(name) != 0, "%s should be in AttributeMap", - name); + PADDLE_ENFORCE(attrs_.count(name) != 0, + "%s should be in AttributeMap", name); return boost::get(attrs_.at(name)); } @@ -80,7 +84,7 @@ class OperatorBase { inline const std::vector outputs() const { return outputs_; } - std::string DebugString() const; + const std::string DebugString() const; /// InferShape infer the size of Variables used by this Operator with /// information diff --git a/paddle/framework/operator_test.cc b/paddle/framework/operator_test.cc index 03646d3c97753..1c9b045b70ca2 100644 --- a/paddle/framework/operator_test.cc +++ b/paddle/framework/operator_test.cc @@ -23,6 +23,8 @@ class OperatorTest : public OperatorBase { public: void Run(OpContext* context) const override { float scale = GetAttr("scale"); + PADDLE_ENFORCE(Input(context->scope, 0) == nullptr, "Input(0) should not initialized"); + PADDLE_ENFORCE(Input(context->scope, 1) == nullptr, "Input(1) should not initialized"); printf("get attr %s = %f\n", "scale", scale); printf("%s\n", DebugString().c_str()); } From 474debd4a51f3e1ca60a1383aefd0eb152adcc1d Mon Sep 17 00:00:00 2001 From: qiaolongfei Date: Mon, 10 Jul 2017 09:43:01 +0800 Subject: [PATCH 13/26] change test op name to test_operator --- paddle/framework/operator_test.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/paddle/framework/operator_test.cc b/paddle/framework/operator_test.cc index 1c9b045b70ca2..86b547dcdec1c 100644 --- a/paddle/framework/operator_test.cc +++ b/paddle/framework/operator_test.cc @@ -39,7 +39,7 @@ class OperatorTestProtoAndCheckerMaker : public OpProtoAndCheckerMaker { AddAttr("scale", "scale of cosine op") .SetDefault(1.0) .LargerThan(0.0); - AddType("cos"); + AddType("test_operator"); AddComment("This is test op"); } }; From 7b77d76938061b9807c6489ece492fe6d0775a77 Mon Sep 17 00:00:00 2001 From: qiaolongfei Date: Mon, 10 Jul 2017 10:02:26 +0800 Subject: [PATCH 14/26] add optional op name --- paddle/framework/op_desc.proto | 5 ++++- paddle/framework/operator.cc | 1 + paddle/framework/operator.h | 5 ++--- paddle/framework/operator_test.cc | 3 +++ 4 files changed, 10 insertions(+), 4 deletions(-) diff --git a/paddle/framework/op_desc.proto b/paddle/framework/op_desc.proto index 89497f3c16bc2..d3eb9dfa8da04 100644 --- a/paddle/framework/op_desc.proto +++ b/paddle/framework/op_desc.proto @@ -51,6 +51,9 @@ message OpDesc { // type of this Operator, such as "add", "sub", "fc". required string type = 3; + // the operator name, such as "fc1", "fc_1", etc. This is optional. + optional string name = 4; + // Attributes of this Operator. e.g., scale=3.0 in cosine op. - repeated AttrDesc attrs = 4; + repeated AttrDesc attrs = 5; }; \ No newline at end of file diff --git a/paddle/framework/operator.cc b/paddle/framework/operator.cc index ad8c51162ce12..941f79cc9d9a0 100644 --- a/paddle/framework/operator.cc +++ b/paddle/framework/operator.cc @@ -36,6 +36,7 @@ const std::string OperatorBase::DebugString() const { std::stringstream ss; ss << "=================\n"; ss << "type = " << desc_.type() << "\n"; + ss << "name = " << desc_.name() << "\n"; ss << "inputs = ["; for (auto& ipt : inputs_) { ss << ipt << ", "; diff --git a/paddle/framework/operator.h b/paddle/framework/operator.h index 917115672e6b2..754d8f22b794a 100644 --- a/paddle/framework/operator.h +++ b/paddle/framework/operator.h @@ -57,7 +57,7 @@ class OperatorBase { /// write this init parameters. void Init(const OpDesc& op_desc, AttributeMap& attrs); - inline const OpDesc op_desc() const { return desc_; } + inline const OpDesc desc() const { return desc_; } inline const Variable* Input(Scope* scope, int index) const { PADDLE_ENFORCE(scope != nullptr, "scope should not be nullptr"); @@ -87,8 +87,7 @@ class OperatorBase { const std::string DebugString() const; /// InferShape infer the size of Variables used by this Operator with - /// information - /// inside scope + /// information inside scope void InferShape(Scope* scope) const; /// when implement an Op, your should implement this function. diff --git a/paddle/framework/operator_test.cc b/paddle/framework/operator_test.cc index 86b547dcdec1c..24764e639c367 100644 --- a/paddle/framework/operator_test.cc +++ b/paddle/framework/operator_test.cc @@ -48,7 +48,9 @@ REGISTER_OP(OperatorTest, OperatorTestProtoAndCheckerMaker, test_operator) TEST(OperatorBase, DebugString) { OpDesc op_desc; + std::string op_name = "op1"; op_desc.set_type("test_operator"); + op_desc.set_name(op_name); std::vector inputs = {"IN1", "IN2"}; for (auto& input : inputs) { op_desc.add_inputs(input); @@ -70,6 +72,7 @@ TEST(OperatorBase, DebugString) { op_context->device_context = device_context; OperatorBase* op = paddle::framework::OpRegistry::CreateOp(op_desc); + ASSERT_EQ(op->desc().name(), op_name); ASSERT_EQ(op->inputs(), inputs); ASSERT_EQ(op->outputs(), outputs); ASSERT_EQ(op->GetAttr("scale"), scale); From 2d1d021526db510177637107b69c100f58f730ed Mon Sep 17 00:00:00 2001 From: qiaolongfei Date: Mon, 10 Jul 2017 14:16:53 +0800 Subject: [PATCH 15/26] rm name, use DeviceContext --- paddle/framework/op_desc.proto | 5 +---- paddle/framework/operator.cc | 1 - paddle/framework/operator.h | 12 +++++------- paddle/framework/operator_test.cc | 6 +----- 4 files changed, 7 insertions(+), 17 deletions(-) diff --git a/paddle/framework/op_desc.proto b/paddle/framework/op_desc.proto index d3eb9dfa8da04..89497f3c16bc2 100644 --- a/paddle/framework/op_desc.proto +++ b/paddle/framework/op_desc.proto @@ -51,9 +51,6 @@ message OpDesc { // type of this Operator, such as "add", "sub", "fc". required string type = 3; - // the operator name, such as "fc1", "fc_1", etc. This is optional. - optional string name = 4; - // Attributes of this Operator. e.g., scale=3.0 in cosine op. - repeated AttrDesc attrs = 5; + repeated AttrDesc attrs = 4; }; \ No newline at end of file diff --git a/paddle/framework/operator.cc b/paddle/framework/operator.cc index 941f79cc9d9a0..ad8c51162ce12 100644 --- a/paddle/framework/operator.cc +++ b/paddle/framework/operator.cc @@ -36,7 +36,6 @@ const std::string OperatorBase::DebugString() const { std::stringstream ss; ss << "=================\n"; ss << "type = " << desc_.type() << "\n"; - ss << "name = " << desc_.name() << "\n"; ss << "inputs = ["; for (auto& ipt : inputs_) { ss << ipt << ", "; diff --git a/paddle/framework/operator.h b/paddle/framework/operator.h index 754d8f22b794a..07354185dc306 100644 --- a/paddle/framework/operator.h +++ b/paddle/framework/operator.h @@ -22,13 +22,11 @@ limitations under the License. */ #include "paddle/framework/op_desc.pb.h" #include "paddle/framework/scope.h" #include "paddle/utils/Error.h" +#include "paddle/platform/device_context.h" namespace paddle { namespace framework { - -class DeviceContext {}; -class CPUContext : public DeviceContext {}; -class GPUContext : public DeviceContext {}; +using paddle::platform::DeviceContext; /** * OpRunContext is the only parameter of Operator's Run function. @@ -38,7 +36,7 @@ class GPUContext : public DeviceContext {}; */ class OpContext { public: - Scope* scope; + std::shared_ptr scope; DeviceContext* device_context; }; @@ -59,14 +57,14 @@ class OperatorBase { inline const OpDesc desc() const { return desc_; } - inline const Variable* Input(Scope* scope, int index) const { + inline const Variable* Input(std::shared_ptr scope, int index) const { PADDLE_ENFORCE(scope != nullptr, "scope should not be nullptr"); PADDLE_ENFORCE(index >= 0, "input index should not be negative"); PADDLE_ENFORCE(index < (int)inputs().size(), "input index should less then %d", inputs().size()); return scope->GetVariable(inputs_[index]); } - inline Variable* Output(Scope* scope, int index) const { + inline Variable* Output(std::shared_ptr scope, int index) const { PADDLE_ENFORCE(scope != nullptr, "scope should not be nullptr"); PADDLE_ENFORCE(index >= 0, "output index should not be negative"); PADDLE_ENFORCE(index < (int)outputs().size(), "output index should less then %d", outputs().size()); diff --git a/paddle/framework/operator_test.cc b/paddle/framework/operator_test.cc index 24764e639c367..0a06c62c20a91 100644 --- a/paddle/framework/operator_test.cc +++ b/paddle/framework/operator_test.cc @@ -48,9 +48,7 @@ REGISTER_OP(OperatorTest, OperatorTestProtoAndCheckerMaker, test_operator) TEST(OperatorBase, DebugString) { OpDesc op_desc; - std::string op_name = "op1"; op_desc.set_type("test_operator"); - op_desc.set_name(op_name); std::vector inputs = {"IN1", "IN2"}; for (auto& input : inputs) { op_desc.add_inputs(input); @@ -65,14 +63,12 @@ TEST(OperatorBase, DebugString) { float scale = 3.14; attr->set_f(scale); - Scope* scope = new Scope(); DeviceContext* device_context = new DeviceContext(); OpContext* op_context = new OpContext(); - op_context->scope = scope; + op_context->scope = std::make_shared(); op_context->device_context = device_context; OperatorBase* op = paddle::framework::OpRegistry::CreateOp(op_desc); - ASSERT_EQ(op->desc().name(), op_name); ASSERT_EQ(op->inputs(), inputs); ASSERT_EQ(op->outputs(), outputs); ASSERT_EQ(op->GetAttr("scale"), scale); From 1e35e83836b4d9c50c3c171fe826a9029664b903 Mon Sep 17 00:00:00 2001 From: qiaolongfei Date: Mon, 10 Jul 2017 14:59:25 +0800 Subject: [PATCH 16/26] prepare for opkernel --- paddle/framework/operator.cc | 8 +++++++ paddle/framework/operator.h | 37 +++++++++++++++++-------------- paddle/framework/operator_test.cc | 13 +++++------ 3 files changed, 34 insertions(+), 24 deletions(-) diff --git a/paddle/framework/operator.cc b/paddle/framework/operator.cc index ad8c51162ce12..312daddd917c4 100644 --- a/paddle/framework/operator.cc +++ b/paddle/framework/operator.cc @@ -54,5 +54,13 @@ const std::string OperatorBase::DebugString() const { return ss.str(); } +const Variable* OpContext::Input(int index) const { + return scope->GetVariable(op.inputs()[index]); +} + +Variable* OpContext::Output(int index) const { + return scope->GetVariable(op.outputs()[index]); +} + } // namespace framework } // namespace paddle \ No newline at end of file diff --git a/paddle/framework/operator.h b/paddle/framework/operator.h index 07354185dc306..e46cac71b4e32 100644 --- a/paddle/framework/operator.h +++ b/paddle/framework/operator.h @@ -21,12 +21,14 @@ limitations under the License. */ #include "paddle/framework/attr_checker.h" #include "paddle/framework/op_desc.pb.h" #include "paddle/framework/scope.h" -#include "paddle/utils/Error.h" #include "paddle/platform/device_context.h" +#include "paddle/utils/Error.h" namespace paddle { namespace framework { + using paddle::platform::DeviceContext; +class OperatorBase; /** * OpRunContext is the only parameter of Operator's Run function. @@ -36,6 +38,15 @@ using paddle::platform::DeviceContext; */ class OpContext { public: + OpContext(OperatorBase& op, std::shared_ptr scope, + DeviceContext* device_context) + : op(op), scope(scope), device_context(device_context) {} + + const Variable* Input(int index) const; + Variable* Output(int index) const; + + public: + OperatorBase& op; std::shared_ptr scope; DeviceContext* device_context; }; @@ -57,24 +68,10 @@ class OperatorBase { inline const OpDesc desc() const { return desc_; } - inline const Variable* Input(std::shared_ptr scope, int index) const { - PADDLE_ENFORCE(scope != nullptr, "scope should not be nullptr"); - PADDLE_ENFORCE(index >= 0, "input index should not be negative"); - PADDLE_ENFORCE(index < (int)inputs().size(), "input index should less then %d", inputs().size()); - return scope->GetVariable(inputs_[index]); - } - - inline Variable* Output(std::shared_ptr scope, int index) const { - PADDLE_ENFORCE(scope != nullptr, "scope should not be nullptr"); - PADDLE_ENFORCE(index >= 0, "output index should not be negative"); - PADDLE_ENFORCE(index < (int)outputs().size(), "output index should less then %d", outputs().size()); - return scope->GetVariable(outputs_[index]); - } - template inline const T GetAttr(const std::string& name) const { - PADDLE_ENFORCE(attrs_.count(name) != 0, - "%s should be in AttributeMap", name); + PADDLE_ENFORCE(attrs_.count(name) != 0, "%s should be in AttributeMap", + name); return boost::get(attrs_.at(name)); } @@ -88,7 +85,13 @@ class OperatorBase { /// information inside scope void InferShape(Scope* scope) const; + void Run(std::shared_ptr scope, DeviceContext* dev_ctx) { + OpContext* op_ctx = new OpContext(*this, scope, dev_ctx); + Run(op_ctx); + } + /// when implement an Op, your should implement this function. + /// this function should be moved to OpKernel later virtual void Run(OpContext* context) const = 0; private: diff --git a/paddle/framework/operator_test.cc b/paddle/framework/operator_test.cc index 0a06c62c20a91..3e9a35aba6bbd 100644 --- a/paddle/framework/operator_test.cc +++ b/paddle/framework/operator_test.cc @@ -21,10 +21,11 @@ namespace framework { class OperatorTest : public OperatorBase { public: - void Run(OpContext* context) const override { + void Run(OpContext* ctx) const override { float scale = GetAttr("scale"); - PADDLE_ENFORCE(Input(context->scope, 0) == nullptr, "Input(0) should not initialized"); - PADDLE_ENFORCE(Input(context->scope, 1) == nullptr, "Input(1) should not initialized"); + PADDLE_ENFORCE(ctx->Input(0) == nullptr, "Input(0) should not initialized"); + PADDLE_ENFORCE(ctx->Output(0) == nullptr, + "Output(1) should not initialized"); printf("get attr %s = %f\n", "scale", scale); printf("%s\n", DebugString().c_str()); } @@ -64,15 +65,13 @@ TEST(OperatorBase, DebugString) { attr->set_f(scale); DeviceContext* device_context = new DeviceContext(); - OpContext* op_context = new OpContext(); - op_context->scope = std::make_shared(); - op_context->device_context = device_context; + auto scope = std::make_shared(); OperatorBase* op = paddle::framework::OpRegistry::CreateOp(op_desc); ASSERT_EQ(op->inputs(), inputs); ASSERT_EQ(op->outputs(), outputs); ASSERT_EQ(op->GetAttr("scale"), scale); - op->Run(op_context); + op->Run(scope, device_context); } } // namespace framework From 4c68debf8ed4441bc76ce3dd5e02c5319abb3cff Mon Sep 17 00:00:00 2001 From: qiaolongfei Date: Mon, 10 Jul 2017 15:06:46 +0800 Subject: [PATCH 17/26] change op in OpContext from ref to const pointer --- paddle/framework/operator.cc | 4 ++-- paddle/framework/operator.h | 6 +++--- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/paddle/framework/operator.cc b/paddle/framework/operator.cc index 312daddd917c4..9e4ae5dafea3c 100644 --- a/paddle/framework/operator.cc +++ b/paddle/framework/operator.cc @@ -55,11 +55,11 @@ const std::string OperatorBase::DebugString() const { } const Variable* OpContext::Input(int index) const { - return scope->GetVariable(op.inputs()[index]); + return scope->GetVariable(op->inputs()[index]); } Variable* OpContext::Output(int index) const { - return scope->GetVariable(op.outputs()[index]); + return scope->GetVariable(op->outputs()[index]); } } // namespace framework diff --git a/paddle/framework/operator.h b/paddle/framework/operator.h index e46cac71b4e32..a9714829a4130 100644 --- a/paddle/framework/operator.h +++ b/paddle/framework/operator.h @@ -38,7 +38,7 @@ class OperatorBase; */ class OpContext { public: - OpContext(OperatorBase& op, std::shared_ptr scope, + OpContext(const OperatorBase* op, std::shared_ptr scope, DeviceContext* device_context) : op(op), scope(scope), device_context(device_context) {} @@ -46,7 +46,7 @@ class OpContext { Variable* Output(int index) const; public: - OperatorBase& op; + const OperatorBase* op; std::shared_ptr scope; DeviceContext* device_context; }; @@ -86,7 +86,7 @@ class OperatorBase { void InferShape(Scope* scope) const; void Run(std::shared_ptr scope, DeviceContext* dev_ctx) { - OpContext* op_ctx = new OpContext(*this, scope, dev_ctx); + OpContext* op_ctx = new OpContext(this, scope, dev_ctx); Run(op_ctx); } From d0fae91004ee30856372d6038c6175f1a136d53d Mon Sep 17 00:00:00 2001 From: qiaolongfei Date: Mon, 10 Jul 2017 17:25:55 +0800 Subject: [PATCH 18/26] change op member to public --- paddle/framework/op_registry.h | 15 +++++++++++---- paddle/framework/operator.cc | 21 ++++----------------- paddle/framework/operator.h | 18 ++++-------------- paddle/framework/operator_test.cc | 4 ++-- 4 files changed, 21 insertions(+), 37 deletions(-) diff --git a/paddle/framework/op_registry.h b/paddle/framework/op_registry.h index 1c1e4ccf4940e..5424bdbb026cd 100644 --- a/paddle/framework/op_registry.h +++ b/paddle/framework/op_registry.h @@ -138,15 +138,22 @@ class OpRegistry { static OperatorBase* CreateOp(const OpDesc& op_desc) { std::string op_type = op_desc.type(); OperatorBase* op = (creators_.at(op_type))(); - AttributeMap attrs; + // init attrs for (int i = 0; i < op_desc.attrs_size(); ++i) { const AttrDesc& ith_attr = op_desc.attrs(i); std::string name = ith_attr.name(); - (attrs)[name] = AttrTypeHelper::GetAttrValue(ith_attr); + (op->attrs_)[name] = AttrTypeHelper::GetAttrValue(ith_attr); } const OpAttrChecker& op_checker = OpRegistry::op_checkers_.at(op_type); - op_checker.Check(attrs); - op->Init(op_desc, attrs); + // check attrs + op_checker.Check(op->attrs_); + op->desc_ = op_desc; + for (auto& input : op_desc.inputs()) { + op->inputs_.push_back(input); + } + for (auto& output : op_desc.outputs()) { + op->outputs_.push_back(output); + } return op; } diff --git a/paddle/framework/operator.cc b/paddle/framework/operator.cc index 9e4ae5dafea3c..d2afc03b8ee31 100644 --- a/paddle/framework/operator.cc +++ b/paddle/framework/operator.cc @@ -17,22 +17,9 @@ limitations under the License. */ namespace paddle { namespace framework { -void OperatorBase::Init(const OpDesc& op_desc, AttributeMap& attrs) { - desc_ = op_desc; - inputs_.reserve(desc_.inputs_size()); - for (auto& input : desc_.inputs()) { - inputs_.push_back(input); - } - outputs_.reserve(desc_.outputs_size()); - for (auto& output : desc_.outputs()) { - outputs_.push_back(output); - } - attrs_.insert(attrs.begin(), attrs.end()); -} - -void OperatorBase::InferShape(Scope* scope) const {} +void OperatorBase::InferShape(std::shared_ptr scope) const {} -const std::string OperatorBase::DebugString() const { +std::string OperatorBase::DebugString() const { std::stringstream ss; ss << "=================\n"; ss << "type = " << desc_.type() << "\n"; @@ -55,11 +42,11 @@ const std::string OperatorBase::DebugString() const { } const Variable* OpContext::Input(int index) const { - return scope->GetVariable(op->inputs()[index]); + return scope->GetVariable(op->inputs_[index]); } Variable* OpContext::Output(int index) const { - return scope->GetVariable(op->outputs()[index]); + return scope->GetVariable(op->outputs_[index]); } } // namespace framework diff --git a/paddle/framework/operator.h b/paddle/framework/operator.h index a9714829a4130..d6398bfa06a30 100644 --- a/paddle/framework/operator.h +++ b/paddle/framework/operator.h @@ -61,13 +61,6 @@ class OperatorBase { public: virtual ~OperatorBase() {} - /// We do not use ctor but an init function to construct an Operator. - /// There is no need for all sub operators to have a constructor and - /// write this init parameters. - void Init(const OpDesc& op_desc, AttributeMap& attrs); - - inline const OpDesc desc() const { return desc_; } - template inline const T GetAttr(const std::string& name) const { PADDLE_ENFORCE(attrs_.count(name) != 0, "%s should be in AttributeMap", @@ -75,16 +68,13 @@ class OperatorBase { return boost::get(attrs_.at(name)); } - inline const std::vector inputs() const { return inputs_; } - - inline const std::vector outputs() const { return outputs_; } - - const std::string DebugString() const; + std::string DebugString() const; /// InferShape infer the size of Variables used by this Operator with /// information inside scope - void InferShape(Scope* scope) const; + virtual void InferShape(std::shared_ptr scope) const; + /// Net will call this function to Run an op. void Run(std::shared_ptr scope, DeviceContext* dev_ctx) { OpContext* op_ctx = new OpContext(this, scope, dev_ctx); Run(op_ctx); @@ -94,7 +84,7 @@ class OperatorBase { /// this function should be moved to OpKernel later virtual void Run(OpContext* context) const = 0; - private: + public: OpDesc desc_; std::vector inputs_; std::vector outputs_; diff --git a/paddle/framework/operator_test.cc b/paddle/framework/operator_test.cc index 3e9a35aba6bbd..0c23c42473b8a 100644 --- a/paddle/framework/operator_test.cc +++ b/paddle/framework/operator_test.cc @@ -68,8 +68,8 @@ TEST(OperatorBase, DebugString) { auto scope = std::make_shared(); OperatorBase* op = paddle::framework::OpRegistry::CreateOp(op_desc); - ASSERT_EQ(op->inputs(), inputs); - ASSERT_EQ(op->outputs(), outputs); + ASSERT_EQ(op->inputs_, inputs); + ASSERT_EQ(op->outputs_, outputs); ASSERT_EQ(op->GetAttr("scale"), scale); op->Run(scope, device_context); } From 40022f79d575f9b9e6578515ba4444e84b3b8c97 Mon Sep 17 00:00:00 2001 From: qiaolongfei Date: Mon, 10 Jul 2017 18:13:01 +0800 Subject: [PATCH 19/26] remove New OpContext --- paddle/framework/operator.h | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/paddle/framework/operator.h b/paddle/framework/operator.h index d6398bfa06a30..6274708be607c 100644 --- a/paddle/framework/operator.h +++ b/paddle/framework/operator.h @@ -18,6 +18,7 @@ limitations under the License. */ #include #include #include + #include "paddle/framework/attr_checker.h" #include "paddle/framework/op_desc.pb.h" #include "paddle/framework/scope.h" @@ -76,8 +77,8 @@ class OperatorBase { /// Net will call this function to Run an op. void Run(std::shared_ptr scope, DeviceContext* dev_ctx) { - OpContext* op_ctx = new OpContext(this, scope, dev_ctx); - Run(op_ctx); + OpContext op_ctx(this, scope, dev_ctx); + Run(&op_ctx); } /// when implement an Op, your should implement this function. From e34c5790c1995e5a1a21141934a2a43f64db7309 Mon Sep 17 00:00:00 2001 From: qiaolongfei Date: Tue, 11 Jul 2017 06:38:25 +0800 Subject: [PATCH 20/26] fix style problem --- python/paddle/trainer_config_helpers/networks.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/python/paddle/trainer_config_helpers/networks.py b/python/paddle/trainer_config_helpers/networks.py index b77932ce5f094..f0b6625dc3736 100755 --- a/python/paddle/trainer_config_helpers/networks.py +++ b/python/paddle/trainer_config_helpers/networks.py @@ -1395,7 +1395,7 @@ def inputs(layers, *args): if len(args) != 0: layers.extend(args) - Inputs(* [l.name for l in layers]) + Inputs(*[l.name for l in layers]) def outputs(layers, *args): @@ -1438,7 +1438,7 @@ def __dfs_travel__(layer, assert len(layers) > 0 if HasInputsSet(): # input already set - Outputs(* [l.name for l in layers]) + Outputs(*[l.name for l in layers]) return # just return outputs. if len(layers) != 1: From cd2338f350d8a6fc9a3da43e2af5df6b99768c04 Mon Sep 17 00:00:00 2001 From: qiaolongfei Date: Tue, 11 Jul 2017 14:26:43 +0800 Subject: [PATCH 21/26] use fake DeviceContext --- paddle/framework/CMakeLists.txt | 2 +- paddle/framework/operator.h | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/paddle/framework/CMakeLists.txt b/paddle/framework/CMakeLists.txt index 9f2eb6cbfae31..e844efd4240b2 100644 --- a/paddle/framework/CMakeLists.txt +++ b/paddle/framework/CMakeLists.txt @@ -12,7 +12,7 @@ cc_test(op_proto_test SRCS op_proto_test.cc DEPS op_proto protobuf) proto_library(op_desc SRCS op_desc.proto DEPS attr_type) cc_test(op_desc_test SRCS op_desc_test.cc DEPS op_desc protobuf) cc_library(operator SRCS operator.cc DEPS op_desc protobuf) -cc_test(operator_test SRCS operator_test.cc DEPS operator op_proto op_desc attr_type protobuf) +cc_test(operator_test SRCS operator_test.cc DEPS operator op_proto) cc_test(op_registry_test SRCS op_registry_test.cc DEPS op_proto op_desc operator) py_proto_compile(framework_py_proto SRCS attr_type.proto op_proto.proto op_desc.proto) # Generate an empty __init__.py to make framework_py_proto as a valid python module. diff --git a/paddle/framework/operator.h b/paddle/framework/operator.h index 6274708be607c..521a0f1a5e6d8 100644 --- a/paddle/framework/operator.h +++ b/paddle/framework/operator.h @@ -22,15 +22,15 @@ limitations under the License. */ #include "paddle/framework/attr_checker.h" #include "paddle/framework/op_desc.pb.h" #include "paddle/framework/scope.h" -#include "paddle/platform/device_context.h" #include "paddle/utils/Error.h" namespace paddle { namespace framework { -using paddle::platform::DeviceContext; class OperatorBase; +class DeviceContext {}; + /** * OpRunContext is the only parameter of Operator's Run function. * Run will get input/output variables, state such as momentum and From dfbf13d4a8d78104ea173a633246940f4dc1651c Mon Sep 17 00:00:00 2001 From: qiaolongfei Date: Tue, 11 Jul 2017 16:43:29 +0800 Subject: [PATCH 22/26] use const func call --- paddle/framework/operator.cc | 10 +++++----- paddle/framework/operator.h | 22 +++++++++++----------- paddle/framework/operator_test.cc | 2 +- paddle/operators/demo_op.h | 4 ++-- 4 files changed, 19 insertions(+), 19 deletions(-) diff --git a/paddle/framework/operator.cc b/paddle/framework/operator.cc index d2afc03b8ee31..a036fb38f24c5 100644 --- a/paddle/framework/operator.cc +++ b/paddle/framework/operator.cc @@ -17,7 +17,7 @@ limitations under the License. */ namespace paddle { namespace framework { -void OperatorBase::InferShape(std::shared_ptr scope) const {} +void OperatorBase::InferShape(const std::shared_ptr& scope) const {} std::string OperatorBase::DebugString() const { std::stringstream ss; @@ -41,12 +41,12 @@ std::string OperatorBase::DebugString() const { return ss.str(); } -const Variable* OpContext::Input(int index) const { - return scope->GetVariable(op->inputs_[index]); +const Variable* OpRunContext::Input(int index) const { + return scope_->GetVariable(op_->inputs_[index]); } -Variable* OpContext::Output(int index) const { - return scope->GetVariable(op->outputs_[index]); +Variable* OpRunContext::Output(int index) const { + return scope_->GetVariable(op_->outputs_[index]); } } // namespace framework diff --git a/paddle/framework/operator.h b/paddle/framework/operator.h index 521a0f1a5e6d8..03210a0c708f1 100644 --- a/paddle/framework/operator.h +++ b/paddle/framework/operator.h @@ -37,19 +37,19 @@ class DeviceContext {}; * device resource such as CUDA stream, cublas handle, etc. from * OpRunContext. User should construct it before run the Operator. */ -class OpContext { +class OpRunContext { public: - OpContext(const OperatorBase* op, std::shared_ptr scope, - DeviceContext* device_context) - : op(op), scope(scope), device_context(device_context) {} + OpRunContext(const OperatorBase* op, std::shared_ptr scope, + DeviceContext* device_context) + : op_(op), scope_(scope), device_context_(device_context) {} const Variable* Input(int index) const; Variable* Output(int index) const; public: - const OperatorBase* op; - std::shared_ptr scope; - DeviceContext* device_context; + const OperatorBase* op_; + std::shared_ptr scope_; + DeviceContext* device_context_; }; /** @@ -73,17 +73,17 @@ class OperatorBase { /// InferShape infer the size of Variables used by this Operator with /// information inside scope - virtual void InferShape(std::shared_ptr scope) const; + virtual void InferShape(const std::shared_ptr& scope) const; /// Net will call this function to Run an op. - void Run(std::shared_ptr scope, DeviceContext* dev_ctx) { - OpContext op_ctx(this, scope, dev_ctx); + void Run(const std::shared_ptr& scope, DeviceContext* dev_ctx) { + OpRunContext op_ctx(this, scope, dev_ctx); Run(&op_ctx); } /// when implement an Op, your should implement this function. /// this function should be moved to OpKernel later - virtual void Run(OpContext* context) const = 0; + virtual void Run(OpRunContext* context) const = 0; public: OpDesc desc_; diff --git a/paddle/framework/operator_test.cc b/paddle/framework/operator_test.cc index 0c23c42473b8a..fce84e3128d46 100644 --- a/paddle/framework/operator_test.cc +++ b/paddle/framework/operator_test.cc @@ -21,7 +21,7 @@ namespace framework { class OperatorTest : public OperatorBase { public: - void Run(OpContext* ctx) const override { + void Run(OpRunContext* ctx) const override { float scale = GetAttr("scale"); PADDLE_ENFORCE(ctx->Input(0) == nullptr, "Input(0) should not initialized"); PADDLE_ENFORCE(ctx->Output(0) == nullptr, diff --git a/paddle/operators/demo_op.h b/paddle/operators/demo_op.h index 1c1704c3e582d..48f79c804094c 100644 --- a/paddle/operators/demo_op.h +++ b/paddle/operators/demo_op.h @@ -9,7 +9,7 @@ namespace operators { class CosineOp : public OperatorBase { public: - void Run(OpContext *context) const override { + void Run(OpRunContext *context) const override { printf("%s\n", DebugString().c_str()); } }; @@ -32,7 +32,7 @@ REGISTER_OP(CosineOp, CosineOpProtoAndCheckerMaker, cos_sim) class MyTestOp : public OperatorBase { public: - void Run(OpContext *context) const override { + void Run(OpRunContext *context) const override { printf("%s\n", DebugString().c_str()); } }; From 6a893c15c786f82a222b0f65ad7cacb8b3833053 Mon Sep 17 00:00:00 2001 From: qiaolongfei Date: Tue, 11 Jul 2017 17:56:56 +0800 Subject: [PATCH 23/26] add operator with kernel --- paddle/framework/op_registry_test.cc | 15 ++++++++++++--- paddle/framework/operator.cc | 2 -- paddle/framework/operator.h | 24 +++++++++++++++++------- paddle/framework/operator_test.cc | 6 +++--- paddle/operators/demo_op.h | 4 ++-- 5 files changed, 34 insertions(+), 17 deletions(-) diff --git a/paddle/framework/op_registry_test.cc b/paddle/framework/op_registry_test.cc index 18063aaf5db99..b0c40aa30ccbf 100644 --- a/paddle/framework/op_registry_test.cc +++ b/paddle/framework/op_registry_test.cc @@ -1,7 +1,10 @@ #include "paddle/framework/op_registry.h" #include +#include "paddle/framework/operator.h" #include "paddle/operators/demo_op.h" +using namespace paddle::framework; + TEST(OpRegistry, CreateOp) { paddle::framework::OpDesc op_desc; op_desc.set_type("cos_sim"); @@ -16,7 +19,9 @@ TEST(OpRegistry, CreateOp) { paddle::framework::OperatorBase* op = paddle::framework::OpRegistry::CreateOp(op_desc); - op->Run(nullptr); + auto scope = std::make_shared(); + auto dev_ctx = DeviceContext(); + op->Run(scope, &dev_ctx); float scale_get = op->GetAttr("scale"); ASSERT_EQ(scale_get, scale); } @@ -57,7 +62,9 @@ TEST(OpRegistry, DefaultValue) { paddle::framework::OperatorBase* op = paddle::framework::OpRegistry::CreateOp(op_desc); - op->Run(nullptr); + auto scope = std::make_shared(); + auto dev_ctx = DeviceContext(); + op->Run(scope, &dev_ctx); ASSERT_EQ(op->GetAttr("scale"), 1.0); } @@ -109,7 +116,9 @@ TEST(OpRegistry, CustomChecker) { attr->set_i(4); paddle::framework::OperatorBase* op = paddle::framework::OpRegistry::CreateOp(op_desc); - op->Run(nullptr); + auto dev_ctx = DeviceContext(); + auto scope = std::make_shared(); + op->Run(scope, &dev_ctx); int test_attr = op->GetAttr("test_attr"); ASSERT_EQ(test_attr, 4); } diff --git a/paddle/framework/operator.cc b/paddle/framework/operator.cc index a036fb38f24c5..3db3706e47dfa 100644 --- a/paddle/framework/operator.cc +++ b/paddle/framework/operator.cc @@ -17,8 +17,6 @@ limitations under the License. */ namespace paddle { namespace framework { -void OperatorBase::InferShape(const std::shared_ptr& scope) const {} - std::string OperatorBase::DebugString() const { std::stringstream ss; ss << "=================\n"; diff --git a/paddle/framework/operator.h b/paddle/framework/operator.h index 03210a0c708f1..5ff4e9f618c7b 100644 --- a/paddle/framework/operator.h +++ b/paddle/framework/operator.h @@ -73,9 +73,25 @@ class OperatorBase { /// InferShape infer the size of Variables used by this Operator with /// information inside scope - virtual void InferShape(const std::shared_ptr& scope) const; + virtual void InferShape(const std::shared_ptr& scope) const = 0; /// Net will call this function to Run an op. + virtual void Run(const std::shared_ptr& scope, + DeviceContext* dev_ctx) = 0; + + public: + OpDesc desc_; + std::vector inputs_; + std::vector outputs_; + AttributeMap attrs_; +}; + +class OperatorWithKernel : public OperatorBase { + public: + virtual ~OperatorWithKernel() {} + + virtual void InferShape(const std::shared_ptr& scope) const {} + void Run(const std::shared_ptr& scope, DeviceContext* dev_ctx) { OpRunContext op_ctx(this, scope, dev_ctx); Run(&op_ctx); @@ -84,12 +100,6 @@ class OperatorBase { /// when implement an Op, your should implement this function. /// this function should be moved to OpKernel later virtual void Run(OpRunContext* context) const = 0; - - public: - OpDesc desc_; - std::vector inputs_; - std::vector outputs_; - AttributeMap attrs_; }; } // namespace framework diff --git a/paddle/framework/operator_test.cc b/paddle/framework/operator_test.cc index fce84e3128d46..1f370043a24e9 100644 --- a/paddle/framework/operator_test.cc +++ b/paddle/framework/operator_test.cc @@ -19,7 +19,7 @@ limitations under the License. */ namespace paddle { namespace framework { -class OperatorTest : public OperatorBase { +class OperatorTest : public OperatorWithKernel { public: void Run(OpRunContext* ctx) const override { float scale = GetAttr("scale"); @@ -64,14 +64,14 @@ TEST(OperatorBase, DebugString) { float scale = 3.14; attr->set_f(scale); - DeviceContext* device_context = new DeviceContext(); + DeviceContext device_context; auto scope = std::make_shared(); OperatorBase* op = paddle::framework::OpRegistry::CreateOp(op_desc); ASSERT_EQ(op->inputs_, inputs); ASSERT_EQ(op->outputs_, outputs); ASSERT_EQ(op->GetAttr("scale"), scale); - op->Run(scope, device_context); + op->Run(scope, &device_context); } } // namespace framework diff --git a/paddle/operators/demo_op.h b/paddle/operators/demo_op.h index 48f79c804094c..919066bfb510c 100644 --- a/paddle/operators/demo_op.h +++ b/paddle/operators/demo_op.h @@ -7,7 +7,7 @@ using namespace paddle::framework; namespace paddle { namespace operators { -class CosineOp : public OperatorBase { +class CosineOp : public OperatorWithKernel { public: void Run(OpRunContext *context) const override { printf("%s\n", DebugString().c_str()); @@ -30,7 +30,7 @@ class CosineOpProtoAndCheckerMaker : public OpProtoAndCheckerMaker { REGISTER_OP(CosineOp, CosineOpProtoAndCheckerMaker, cos_sim) -class MyTestOp : public OperatorBase { +class MyTestOp : public OperatorWithKernel { public: void Run(OpRunContext *context) const override { printf("%s\n", DebugString().c_str()); From 93249157d8f64c1ba67b33eaf6c40daaa058d555 Mon Sep 17 00:00:00 2001 From: qiaolongfei Date: Tue, 11 Jul 2017 19:26:39 +0800 Subject: [PATCH 24/26] merge new attr --- paddle/framework/CMakeLists.txt | 4 ++-- paddle/framework/op_registry.cc | 4 ++-- paddle/framework/op_registry.h | 2 +- paddle/framework/op_registry_test.cc | 18 +++++++----------- paddle/framework/operator_test.cc | 2 +- 5 files changed, 13 insertions(+), 17 deletions(-) diff --git a/paddle/framework/CMakeLists.txt b/paddle/framework/CMakeLists.txt index f34b377b242d4..aac49fdb7a04a 100644 --- a/paddle/framework/CMakeLists.txt +++ b/paddle/framework/CMakeLists.txt @@ -12,9 +12,9 @@ cc_test(op_proto_test SRCS op_proto_test.cc DEPS op_proto protobuf) proto_library(op_desc SRCS op_desc.proto DEPS attr_type) cc_test(op_desc_test SRCS op_desc_test.cc DEPS op_desc protobuf) cc_library(operator SRCS operator.cc DEPS op_desc protobuf) -cc_test(operator_test SRCS operator_test.cc DEPS operator op_proto) +cc_test(operator_test SRCS operator_test.cc DEPS operator op_registry) cc_library(op_registry SRCS op_registry.cc DEPS op_proto op_desc) -cc_test(op_registry_test SRCS op_registry_test.cc DEPS op_registry) +cc_test(op_registry_test SRCS op_registry_test.cc DEPS op_registry operator) py_proto_compile(framework_py_proto SRCS attr_type.proto op_proto.proto op_desc.proto) # Generate an empty __init__.py to make framework_py_proto as a valid python module. add_custom_target(framework_py_proto_init ALL COMMAND ${CMAKE_COMMAND} -E touch __init__.py) diff --git a/paddle/framework/op_registry.cc b/paddle/framework/op_registry.cc index bc6a0dda57d4d..4b35e04e681b4 100644 --- a/paddle/framework/op_registry.cc +++ b/paddle/framework/op_registry.cc @@ -32,5 +32,5 @@ template <> void AttrTypeHelper::SetAttrType>(AttrProto* attr) { attr->set_type(paddle::framework::AttrType::STRINGS); } -} -} \ No newline at end of file +} // namespace framework +} // namespace paddle \ No newline at end of file diff --git a/paddle/framework/op_registry.h b/paddle/framework/op_registry.h index 0664e5b569564..02c99d50bb50c 100644 --- a/paddle/framework/op_registry.h +++ b/paddle/framework/op_registry.h @@ -1,7 +1,7 @@ #pragma once -#include "paddle/framework/attr_checker.h" #include +#include "paddle/framework/attr_checker.h" #include "paddle/framework/op_desc.pb.h" #include "paddle/framework/op_proto.pb.h" #include "paddle/framework/operator.h" diff --git a/paddle/framework/op_registry_test.cc b/paddle/framework/op_registry_test.cc index 302695cbddefb..f177b76c3b04b 100644 --- a/paddle/framework/op_registry_test.cc +++ b/paddle/framework/op_registry_test.cc @@ -7,12 +7,10 @@ using namespace paddle::framework; namespace paddle { namespace framework { -class CosineOp : public OpBase { +class CosineOp : public OperatorWithKernel { public: - virtual std::string Run() const { - std::string msg = "CosineOp runs! scale = " + - std::to_string(boost::get(attr_map_.at("scale"))); - return msg; + void Run(OpRunContext* context) const override { + printf("%s\n", DebugString().c_str()); } }; @@ -32,13 +30,11 @@ class CosineOpProtoAndCheckerMaker : public OpProtoAndCheckerMaker { REGISTER_OP(CosineOp, CosineOpProtoAndCheckerMaker, cos_sim) -class MyTestOp : public OpBase { +class MyTestOp : public OperatorWithKernel { public: - virtual std::string Run() const { - std::string msg = - "MyTestOp runs! test_attr = " + - std::to_string(boost::get(attr_map_.at("test_attr"))); - return msg; + void Run(OpRunContext* ctx) const override { + printf("%s\n", DebugString().c_str()); + printf("test_attr = %d\n", ctx->op_->GetAttr("test_attr")); } }; diff --git a/paddle/framework/operator_test.cc b/paddle/framework/operator_test.cc index 1f370043a24e9..3683ef4e2cd9f 100644 --- a/paddle/framework/operator_test.cc +++ b/paddle/framework/operator_test.cc @@ -22,7 +22,7 @@ namespace framework { class OperatorTest : public OperatorWithKernel { public: void Run(OpRunContext* ctx) const override { - float scale = GetAttr("scale"); + float scale = ctx->op_->GetAttr("scale"); PADDLE_ENFORCE(ctx->Input(0) == nullptr, "Input(0) should not initialized"); PADDLE_ENFORCE(ctx->Output(0) == nullptr, "Output(1) should not initialized"); From f0332fb6a2f52606529bdf9297ab047152a7f233 Mon Sep 17 00:00:00 2001 From: qiaolongfei Date: Wed, 12 Jul 2017 00:04:33 +0800 Subject: [PATCH 25/26] optimize code --- paddle/framework/op_registry_test.cc | 4 ++-- paddle/framework/operator.h | 17 +++++++++-------- paddle/framework/operator_test.cc | 4 +++- paddle/operators/demo_op.h | 4 ++-- 4 files changed, 16 insertions(+), 13 deletions(-) diff --git a/paddle/framework/op_registry_test.cc b/paddle/framework/op_registry_test.cc index f177b76c3b04b..c4baafc2aebc8 100644 --- a/paddle/framework/op_registry_test.cc +++ b/paddle/framework/op_registry_test.cc @@ -9,7 +9,7 @@ namespace paddle { namespace framework { class CosineOp : public OperatorWithKernel { public: - void Run(OpRunContext* context) const override { + void Run(const OpRunContext* context) const override { printf("%s\n", DebugString().c_str()); } }; @@ -32,7 +32,7 @@ REGISTER_OP(CosineOp, CosineOpProtoAndCheckerMaker, cos_sim) class MyTestOp : public OperatorWithKernel { public: - void Run(OpRunContext* ctx) const override { + void Run(const OpRunContext* ctx) const override { printf("%s\n", DebugString().c_str()); printf("test_attr = %d\n", ctx->op_->GetAttr("test_attr")); } diff --git a/paddle/framework/operator.h b/paddle/framework/operator.h index 5ff4e9f618c7b..c38dde20d49ae 100644 --- a/paddle/framework/operator.h +++ b/paddle/framework/operator.h @@ -39,8 +39,9 @@ class DeviceContext {}; */ class OpRunContext { public: - OpRunContext(const OperatorBase* op, std::shared_ptr scope, - DeviceContext* device_context) + OpRunContext(const OperatorBase* op, + const std::shared_ptr scope, + const DeviceContext* device_context) : op_(op), scope_(scope), device_context_(device_context) {} const Variable* Input(int index) const; @@ -48,8 +49,8 @@ class OpRunContext { public: const OperatorBase* op_; - std::shared_ptr scope_; - DeviceContext* device_context_; + const std::shared_ptr scope_; + const DeviceContext* device_context_; }; /** @@ -63,7 +64,7 @@ class OperatorBase { virtual ~OperatorBase() {} template - inline const T GetAttr(const std::string& name) const { + inline const T& GetAttr(const std::string& name) const { PADDLE_ENFORCE(attrs_.count(name) != 0, "%s should be in AttributeMap", name); return boost::get(attrs_.at(name)); @@ -77,7 +78,7 @@ class OperatorBase { /// Net will call this function to Run an op. virtual void Run(const std::shared_ptr& scope, - DeviceContext* dev_ctx) = 0; + const DeviceContext* dev_ctx) const = 0; public: OpDesc desc_; @@ -92,14 +93,14 @@ class OperatorWithKernel : public OperatorBase { virtual void InferShape(const std::shared_ptr& scope) const {} - void Run(const std::shared_ptr& scope, DeviceContext* dev_ctx) { + void Run(const std::shared_ptr& scope, const DeviceContext* dev_ctx) const { OpRunContext op_ctx(this, scope, dev_ctx); Run(&op_ctx); } /// when implement an Op, your should implement this function. /// this function should be moved to OpKernel later - virtual void Run(OpRunContext* context) const = 0; + virtual void Run(const OpRunContext* context) const = 0; }; } // namespace framework diff --git a/paddle/framework/operator_test.cc b/paddle/framework/operator_test.cc index 3683ef4e2cd9f..48808dabb2711 100644 --- a/paddle/framework/operator_test.cc +++ b/paddle/framework/operator_test.cc @@ -21,11 +21,13 @@ namespace framework { class OperatorTest : public OperatorWithKernel { public: - void Run(OpRunContext* ctx) const override { + void Run(const OpRunContext* ctx) const override { float scale = ctx->op_->GetAttr("scale"); PADDLE_ENFORCE(ctx->Input(0) == nullptr, "Input(0) should not initialized"); PADDLE_ENFORCE(ctx->Output(0) == nullptr, "Output(1) should not initialized"); + auto output1 = ctx->scope_->CreateVariable("output1"); + PADDLE_ENFORCE(output1 != nullptr, "should create output1 from scope"); printf("get attr %s = %f\n", "scale", scale); printf("%s\n", DebugString().c_str()); } diff --git a/paddle/operators/demo_op.h b/paddle/operators/demo_op.h index 919066bfb510c..d0b7420b4e25d 100644 --- a/paddle/operators/demo_op.h +++ b/paddle/operators/demo_op.h @@ -9,7 +9,7 @@ namespace operators { class CosineOp : public OperatorWithKernel { public: - void Run(OpRunContext *context) const override { + void Run(const OpRunContext *context) const override { printf("%s\n", DebugString().c_str()); } }; @@ -32,7 +32,7 @@ REGISTER_OP(CosineOp, CosineOpProtoAndCheckerMaker, cos_sim) class MyTestOp : public OperatorWithKernel { public: - void Run(OpRunContext *context) const override { + void Run(const OpRunContext *context) const override { printf("%s\n", DebugString().c_str()); } }; From c1d8cbb3f9cbf4567928e9685ff7465d77fd09e3 Mon Sep 17 00:00:00 2001 From: qiaolongfei Date: Wed, 12 Jul 2017 00:41:59 +0800 Subject: [PATCH 26/26] fix cpp style --- paddle/framework/operator.h | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/paddle/framework/operator.h b/paddle/framework/operator.h index c38dde20d49ae..6570d5869814a 100644 --- a/paddle/framework/operator.h +++ b/paddle/framework/operator.h @@ -39,8 +39,7 @@ class DeviceContext {}; */ class OpRunContext { public: - OpRunContext(const OperatorBase* op, - const std::shared_ptr scope, + OpRunContext(const OperatorBase* op, const std::shared_ptr scope, const DeviceContext* device_context) : op_(op), scope_(scope), device_context_(device_context) {} @@ -93,7 +92,8 @@ class OperatorWithKernel : public OperatorBase { virtual void InferShape(const std::shared_ptr& scope) const {} - void Run(const std::shared_ptr& scope, const DeviceContext* dev_ctx) const { + void Run(const std::shared_ptr& scope, + const DeviceContext* dev_ctx) const { OpRunContext op_ctx(this, scope, dev_ctx); Run(&op_ctx); }