From 40573cd56f723ebde6328ccd5dabe4a363c9f3db Mon Sep 17 00:00:00 2001 From: Superjom Date: Mon, 3 Jul 2017 14:41:43 +0800 Subject: [PATCH 01/13] add net headers --- paddle/framework/net.cc | 23 +++++ paddle/framework/net.h | 182 ++++++++++++++++++++++++++++++++++++++++ 2 files changed, 205 insertions(+) create mode 100644 paddle/framework/net.cc create mode 100644 paddle/framework/net.h diff --git a/paddle/framework/net.cc b/paddle/framework/net.cc new file mode 100644 index 0000000000000..0ce92968200dc --- /dev/null +++ b/paddle/framework/net.cc @@ -0,0 +1,23 @@ +#include "paddle/framework/net.h" + +namespace paddle { +namespace framework { + +PlainNet::PlainNet(const NetDesc& def) {} + +virtual Error PlainNet::InferShape() { + for (auto& op : ops_) { + // wrong shape + auto err = op.InferShape(); + if (!err) return err; + } + // ok + return Error(); +} + +virtual Error PlainNet::Run(Scope* scope = nullptr, + OpContext* context = nullptr, OpIndex begin = -1, + OpIndex end = -1) const {} + +} // namespace framework +} // namespace paddle diff --git a/paddle/framework/net.h b/paddle/framework/net.h new file mode 100644 index 0000000000000..88bdf0bb68bff --- /dev/null +++ b/paddle/framework/net.h @@ -0,0 +1,182 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ + +#pragma once + +#include "paddle/framework/scope.h" + +namespace paddle { +namespace framework { + +// operator's index stored in a network. +typedef int OpIndex; +/** + * NOTE following codes are some definitions of unimplemented concepts. + * We write some basic implementation to make Net compilable. These APIs will + * keep updating if the concepts related are implemented. + */ + +// Operator's runtime context. +struct OpContext { + int dev_id; + DevType dev_type{kCPU}; + enum DevType { kCPU, kGPU }; +}; + +// Proto definitions, use `struct`s for simpility. +struct VarDesc { + std::string type; + std::vector dims; +}; +struct OpDesc { + std::string type; + std::vector inputs; + std::vector outputs; +}; +struct struct NetDesc { + std::vector ops; +}; +class Operator { + public: + Operator(const OpDesc &def) {} + Error InferShape() {} + Error Run() {} +}; + +/** + * @brief Network that manage the operators it has. + * + * Network is the container and controller of a set of operators, user can build + * a real network from a NetDesc which is a protobuf message and use + * Network.Run() * to run all the operators in the network. + + * A network object knows all Operators belonging to this network. Variables, + * which are inputs and outputs of these operators, are created and managed by a + * hierarchy of Scope objects. + * + * This is the base class of network, all the networks should implement the apis + * it defines. + */ +class Net { + public: + /** + * @brief Infer shapes of all inputs and outputs of operators. + */ + virtual Error InferShape(Scope *scope) override; + /** + * @brief Run the network. + * + * Run all the operators and return success(true) or not, with all the + * variables are located in `scope`. `context` describes the detail execution + * environment for ops. `begin` and `end` specify the scope of `ops_` to run, + * If no positive indexes are provided, all operators in `ops_` will run. + */ + virtual Error Run(Scope *scope, OpContext *context, OpIndex begin = -1, + OpIndex end = -1) const = 0; + + /** + * @brief Add an Operator according to `def`. + */ + virtual OpIndex AddOp(const proto::OpDef &def) = 0; + + /** + * @brief Add optimizer operators acctording to `attrs`. + */ + virtual Error AddOptimizerOps(const OptAttrs &attrs) = 0; + + /** + * @brief Add backward operators. + */ + virtual Error AddBackwardOps() = 0; + + /** + * @brief Create a network. + */ + static std::unique_ptr Create(const NetDesc &def = NetDesc()); +}; + +/** + * @brief a basic implementation of Net. + * + * PlainNet is a very simple Net, it create a list of operators, and run them + * sequentially following the order they added. + */ +class PlainNet : public Net { + public: + /** + * @brief Initialize a PlainNet. + * + * Initialize from a network describe by `def`. NetDesc is the definition of + * a network. + */ + PlainNet(const NetDesc &def); + + /** + * Infer all the operators' input and output varialbes' shapes, will be called + * before every mini-batch + */ + virtual Error InferShape(Scope *scope) override; + + /** + * @brief Run the network. + * + * Run all the operators with the `scope`, if no scope is provided, default + * scope will be used instead. If no OpContext is provicded, default context + * will be used. + */ + virtual Error Run(Scope *scope = nullptr, OpContext *context = nullptr, + OpIndex begin = -1, OpIndex end = -1) const override; + + /** + * @brief Add an operator to this network. + */ + virtual OpIndex AddOp(const proto::OpDef &def) override; + + /** + * @brief Add all optimizer operators related into the network. + */ + virtual Error AddOptimizerOps(const OptAttrs &attrs) override; + + /** + * @brief Add all backward operators related into the network. + */ + virtual Error AddBackwardOps() override; + + protected: + /** + * @brief Build the network. + * + * Create operators accordding to `def`, will be called by the constructor. + */ + Error BuildNet(const NetDesc &def); + + /** + * @brief Add an operator into this network. + * + * Add a operator which is identified as `type` and has attributes described + * in `attrs`, the `inputs` are the keys of readonly input variables, + * `outputs` are keys of mutable output variables. An `OpIndex` will be + * returned to indicate the offset of the new operator in `ops_`. + */ + OpIndex AddOp(const std::string &type, const std::vector &inputs, + const std::vector &outputs, + const OprAttr &attrs = OprAttr()); + + private: + // the operators owned by `Network`. + std::vector ops_; +}; + +} // namespace framework +} // namespace paddle From 9f365d36364d34f2cf186d5bc0569189145c612d Mon Sep 17 00:00:00 2001 From: dongzhihong Date: Tue, 4 Jul 2017 11:23:49 +0800 Subject: [PATCH 02/13] "add net proto" --- paddle/framework/CMakeLists.txt | 4 +++ paddle/framework/net.h | 48 ++++++++++---------------------- paddle/framework/net_proto.proto | 16 +++++++++++ 3 files changed, 35 insertions(+), 33 deletions(-) create mode 100644 paddle/framework/net_proto.proto diff --git a/paddle/framework/CMakeLists.txt b/paddle/framework/CMakeLists.txt index f7e5753ac2c23..8c34a77c20787 100644 --- a/paddle/framework/CMakeLists.txt +++ b/paddle/framework/CMakeLists.txt @@ -7,4 +7,8 @@ cc_test(scope_test SRCS scope_test.cc) cc_test(enforce_test SRCS enforce_test.cc) proto_library(attr_type SRCS attr_type.proto) proto_library(op_proto SRCS op_proto.proto) + cc_test(op_proto_test SRCS op_proto_test.cc DEPS op_proto attr_type protobuf) + +proto_library(net_proto SRCS net_proto.proto) +cc_library(net SRCS net.cc DEPS net_proto attr_type op_proto) diff --git a/paddle/framework/net.h b/paddle/framework/net.h index 88bdf0bb68bff..b3064e4f90b77 100644 --- a/paddle/framework/net.h +++ b/paddle/framework/net.h @@ -14,6 +14,8 @@ #pragma once +#include "paddle/framework/net_proto.pb.h" +#include "paddle/framework/op_proto.pb.h" #include "paddle/framework/scope.h" namespace paddle { @@ -27,31 +29,11 @@ typedef int OpIndex; * keep updating if the concepts related are implemented. */ -// Operator's runtime context. -struct OpContext { - int dev_id; - DevType dev_type{kCPU}; - enum DevType { kCPU, kGPU }; -}; - -// Proto definitions, use `struct`s for simpility. -struct VarDesc { - std::string type; - std::vector dims; -}; -struct OpDesc { - std::string type; - std::vector inputs; - std::vector outputs; -}; -struct struct NetDesc { - std::vector ops; -}; class Operator { public: Operator(const OpDesc &def) {} - Error InferShape() {} - Error Run() {} + bool InferShape() {} + bool Run() {} }; /** @@ -73,7 +55,7 @@ class Net { /** * @brief Infer shapes of all inputs and outputs of operators. */ - virtual Error InferShape(Scope *scope) override; + virtual bool InferShape(Scope *scope) override; /** * @brief Run the network. * @@ -82,8 +64,8 @@ class Net { * environment for ops. `begin` and `end` specify the scope of `ops_` to run, * If no positive indexes are provided, all operators in `ops_` will run. */ - virtual Error Run(Scope *scope, OpContext *context, OpIndex begin = -1, - OpIndex end = -1) const = 0; + virtual bool Run(Scope *scope, OpContext *context, OpIndex begin = -1, + OpIndex end = -1) const = 0; /** * @brief Add an Operator according to `def`. @@ -93,12 +75,12 @@ class Net { /** * @brief Add optimizer operators acctording to `attrs`. */ - virtual Error AddOptimizerOps(const OptAttrs &attrs) = 0; + virtual bool AddOptimizerOps(const OptAttrs &attrs) = 0; /** * @brief Add backward operators. */ - virtual Error AddBackwardOps() = 0; + virtual bool AddBackwardOps() = 0; /** * @brief Create a network. @@ -126,7 +108,7 @@ class PlainNet : public Net { * Infer all the operators' input and output varialbes' shapes, will be called * before every mini-batch */ - virtual Error InferShape(Scope *scope) override; + virtual bool InferShape(Scope *scope) override; /** * @brief Run the network. @@ -135,8 +117,8 @@ class PlainNet : public Net { * scope will be used instead. If no OpContext is provicded, default context * will be used. */ - virtual Error Run(Scope *scope = nullptr, OpContext *context = nullptr, - OpIndex begin = -1, OpIndex end = -1) const override; + virtual bool Run(Scope *scope = nullptr, OpContext *context = nullptr, + OpIndex begin = -1, OpIndex end = -1) const override; /** * @brief Add an operator to this network. @@ -146,12 +128,12 @@ class PlainNet : public Net { /** * @brief Add all optimizer operators related into the network. */ - virtual Error AddOptimizerOps(const OptAttrs &attrs) override; + virtual bool AddOptimizerOps(const OptAttrs &attrs) override; /** * @brief Add all backward operators related into the network. */ - virtual Error AddBackwardOps() override; + virtual bool AddBackwardOps() override; protected: /** @@ -159,7 +141,7 @@ class PlainNet : public Net { * * Create operators accordding to `def`, will be called by the constructor. */ - Error BuildNet(const NetDesc &def); + bool BuildNet(const NetDesc &def); /** * @brief Add an operator into this network. diff --git a/paddle/framework/net_proto.proto b/paddle/framework/net_proto.proto new file mode 100644 index 0000000000000..e9aed8f349b80 --- /dev/null +++ b/paddle/framework/net_proto.proto @@ -0,0 +1,16 @@ +syntax="proto2"; +package paddle.framework; + +import "op_proto.proto" + +message NetDesc { + // network identification + optional string name = 1; + // operator contains in network + repeated OpProto operators = 2; + // network type to run with. e.g "plainNet", "DAG" + optional string type = 3; + // num worker always + optional int32 num_workers = 4; +} + From c602e046132b7e4e38c34f348b2a7fa290d67361 Mon Sep 17 00:00:00 2001 From: Superjom Date: Tue, 4 Jul 2017 13:35:21 +0800 Subject: [PATCH 03/13] add fake interfaces to make compilable --- paddle/framework/net.cc | 10 +++++--- paddle/framework/net.h | 44 +++++++++++++++++++------------- paddle/framework/net_proto.proto | 3 +-- 3 files changed, 33 insertions(+), 24 deletions(-) diff --git a/paddle/framework/net.cc b/paddle/framework/net.cc index 0ce92968200dc..2d9e099dc0c2b 100644 --- a/paddle/framework/net.cc +++ b/paddle/framework/net.cc @@ -5,7 +5,7 @@ namespace framework { PlainNet::PlainNet(const NetDesc& def) {} -virtual Error PlainNet::InferShape() { +Error PlainNet::InferShape(Scope* scope) { for (auto& op : ops_) { // wrong shape auto err = op.InferShape(); @@ -15,9 +15,11 @@ virtual Error PlainNet::InferShape() { return Error(); } -virtual Error PlainNet::Run(Scope* scope = nullptr, - OpContext* context = nullptr, OpIndex begin = -1, - OpIndex end = -1) const {} +Error PlainNet::Run(Scope* scope, OpContext* context, OpIndex begin, + OpIndex end) const { + // TODO Add implementation here. + return Error(); +} } // namespace framework } // namespace paddle diff --git a/paddle/framework/net.h b/paddle/framework/net.h index b3064e4f90b77..76e0ed9330716 100644 --- a/paddle/framework/net.h +++ b/paddle/framework/net.h @@ -17,6 +17,7 @@ #include "paddle/framework/net_proto.pb.h" #include "paddle/framework/op_proto.pb.h" #include "paddle/framework/scope.h" +#include "paddle/utils/Error.h" namespace paddle { namespace framework { @@ -29,11 +30,16 @@ typedef int OpIndex; * keep updating if the concepts related are implemented. */ +struct OpDesc; +struct OpDef; +struct OpContext; +struct OpAttrs {}; + class Operator { public: Operator(const OpDesc &def) {} - bool InferShape() {} - bool Run() {} + Error InferShape() { return Error(); } + Error Run() { return Error(); } }; /** @@ -55,7 +61,7 @@ class Net { /** * @brief Infer shapes of all inputs and outputs of operators. */ - virtual bool InferShape(Scope *scope) override; + virtual Error InferShape(Scope *scope) = 0; /** * @brief Run the network. * @@ -64,28 +70,30 @@ class Net { * environment for ops. `begin` and `end` specify the scope of `ops_` to run, * If no positive indexes are provided, all operators in `ops_` will run. */ - virtual bool Run(Scope *scope, OpContext *context, OpIndex begin = -1, - OpIndex end = -1) const = 0; + virtual Error Run(Scope *scope, OpContext *context, OpIndex begin = -1, + OpIndex end = -1) const = 0; /** * @brief Add an Operator according to `def`. */ - virtual OpIndex AddOp(const proto::OpDef &def) = 0; + virtual OpIndex AddOp(const OpDef &def) = 0; /** * @brief Add optimizer operators acctording to `attrs`. */ - virtual bool AddOptimizerOps(const OptAttrs &attrs) = 0; + virtual Error AddOptimizerOps(const OpAttrs &attrs) = 0; /** * @brief Add backward operators. */ - virtual bool AddBackwardOps() = 0; + virtual Error AddBackwardOps() = 0; /** * @brief Create a network. */ static std::unique_ptr Create(const NetDesc &def = NetDesc()); + + virtual ~Net() = 0; }; /** @@ -108,7 +116,7 @@ class PlainNet : public Net { * Infer all the operators' input and output varialbes' shapes, will be called * before every mini-batch */ - virtual bool InferShape(Scope *scope) override; + virtual Error InferShape(Scope *scope) override; /** * @brief Run the network. @@ -117,23 +125,23 @@ class PlainNet : public Net { * scope will be used instead. If no OpContext is provicded, default context * will be used. */ - virtual bool Run(Scope *scope = nullptr, OpContext *context = nullptr, - OpIndex begin = -1, OpIndex end = -1) const override; + virtual Error Run(Scope *scope = nullptr, OpContext *context = nullptr, + OpIndex begin = -1, OpIndex end = -1) const override; /** * @brief Add an operator to this network. */ - virtual OpIndex AddOp(const proto::OpDef &def) override; + virtual OpIndex AddOp(const OpDef &def) override; /** * @brief Add all optimizer operators related into the network. */ - virtual bool AddOptimizerOps(const OptAttrs &attrs) override; + virtual Error AddOptimizerOps(const OpAttrs &attrs) override; /** * @brief Add all backward operators related into the network. */ - virtual bool AddBackwardOps() override; + virtual Error AddBackwardOps() override; protected: /** @@ -141,7 +149,7 @@ class PlainNet : public Net { * * Create operators accordding to `def`, will be called by the constructor. */ - bool BuildNet(const NetDesc &def); + Error BuildNet(const NetDesc &def); /** * @brief Add an operator into this network. @@ -151,9 +159,9 @@ class PlainNet : public Net { * `outputs` are keys of mutable output variables. An `OpIndex` will be * returned to indicate the offset of the new operator in `ops_`. */ - OpIndex AddOp(const std::string &type, const std::vector &inputs, - const std::vector &outputs, - const OprAttr &attrs = OprAttr()); + OpIndex AddOp(const std::string &type, const std::vector &inputs, + const std::vector &outputs, + const OpAttrs &attrs = OpAttrs()); private: // the operators owned by `Network`. diff --git a/paddle/framework/net_proto.proto b/paddle/framework/net_proto.proto index e9aed8f349b80..2d042457e3306 100644 --- a/paddle/framework/net_proto.proto +++ b/paddle/framework/net_proto.proto @@ -1,7 +1,7 @@ syntax="proto2"; package paddle.framework; -import "op_proto.proto" +import "op_proto.proto"; message NetDesc { // network identification @@ -13,4 +13,3 @@ message NetDesc { // num worker always optional int32 num_workers = 4; } - From 04e20034dfcbb0ceb1de30ddd5b1f8b8ee811d4f Mon Sep 17 00:00:00 2001 From: Superjom Date: Tue, 4 Jul 2017 13:44:01 +0800 Subject: [PATCH 04/13] replace Error with void --- paddle/framework/net.cc | 11 +++-------- paddle/framework/net.h | 23 +++++++++++------------ 2 files changed, 14 insertions(+), 20 deletions(-) diff --git a/paddle/framework/net.cc b/paddle/framework/net.cc index 2d9e099dc0c2b..d49861c343ef1 100644 --- a/paddle/framework/net.cc +++ b/paddle/framework/net.cc @@ -5,20 +5,15 @@ namespace framework { PlainNet::PlainNet(const NetDesc& def) {} -Error PlainNet::InferShape(Scope* scope) { +void PlainNet::InferShape(Scope* scope) { for (auto& op : ops_) { - // wrong shape - auto err = op.InferShape(); - if (!err) return err; + op.InferShape(); } - // ok - return Error(); } -Error PlainNet::Run(Scope* scope, OpContext* context, OpIndex begin, +void PlainNet::Run(Scope* scope, OpContext* context, OpIndex begin, OpIndex end) const { // TODO Add implementation here. - return Error(); } } // namespace framework diff --git a/paddle/framework/net.h b/paddle/framework/net.h index 76e0ed9330716..55dcf147e1d4e 100644 --- a/paddle/framework/net.h +++ b/paddle/framework/net.h @@ -17,7 +17,6 @@ #include "paddle/framework/net_proto.pb.h" #include "paddle/framework/op_proto.pb.h" #include "paddle/framework/scope.h" -#include "paddle/utils/Error.h" namespace paddle { namespace framework { @@ -38,8 +37,8 @@ struct OpAttrs {}; class Operator { public: Operator(const OpDesc &def) {} - Error InferShape() { return Error(); } - Error Run() { return Error(); } + void InferShape() {} + void Run() {} }; /** @@ -61,7 +60,7 @@ class Net { /** * @brief Infer shapes of all inputs and outputs of operators. */ - virtual Error InferShape(Scope *scope) = 0; + virtual void InferShape(Scope *scope) = 0; /** * @brief Run the network. * @@ -70,7 +69,7 @@ class Net { * environment for ops. `begin` and `end` specify the scope of `ops_` to run, * If no positive indexes are provided, all operators in `ops_` will run. */ - virtual Error Run(Scope *scope, OpContext *context, OpIndex begin = -1, + virtual void Run(Scope *scope, OpContext *context, OpIndex begin = -1, OpIndex end = -1) const = 0; /** @@ -81,12 +80,12 @@ class Net { /** * @brief Add optimizer operators acctording to `attrs`. */ - virtual Error AddOptimizerOps(const OpAttrs &attrs) = 0; + virtual void AddOptimizerOps(const OpAttrs &attrs) = 0; /** * @brief Add backward operators. */ - virtual Error AddBackwardOps() = 0; + virtual void AddBackwardOps() = 0; /** * @brief Create a network. @@ -116,7 +115,7 @@ class PlainNet : public Net { * Infer all the operators' input and output varialbes' shapes, will be called * before every mini-batch */ - virtual Error InferShape(Scope *scope) override; + virtual void InferShape(Scope *scope) override; /** * @brief Run the network. @@ -125,7 +124,7 @@ class PlainNet : public Net { * scope will be used instead. If no OpContext is provicded, default context * will be used. */ - virtual Error Run(Scope *scope = nullptr, OpContext *context = nullptr, + virtual void Run(Scope *scope = nullptr, OpContext *context = nullptr, OpIndex begin = -1, OpIndex end = -1) const override; /** @@ -136,12 +135,12 @@ class PlainNet : public Net { /** * @brief Add all optimizer operators related into the network. */ - virtual Error AddOptimizerOps(const OpAttrs &attrs) override; + virtual void AddOptimizerOps(const OpAttrs &attrs) override; /** * @brief Add all backward operators related into the network. */ - virtual Error AddBackwardOps() override; + virtual void AddBackwardOps() override; protected: /** @@ -149,7 +148,7 @@ class PlainNet : public Net { * * Create operators accordding to `def`, will be called by the constructor. */ - Error BuildNet(const NetDesc &def); + void BuildNet(const NetDesc &def); /** * @brief Add an operator into this network. From 109937b8d512904d04a1773bdf19ddb756ecd087 Mon Sep 17 00:00:00 2001 From: Superjom Date: Tue, 4 Jul 2017 15:18:20 +0800 Subject: [PATCH 05/13] fix ci error --- paddle/framework/CMakeLists.txt | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/paddle/framework/CMakeLists.txt b/paddle/framework/CMakeLists.txt index 42600277f6685..ceff1d3581bc6 100644 --- a/paddle/framework/CMakeLists.txt +++ b/paddle/framework/CMakeLists.txt @@ -7,8 +7,6 @@ cc_test(scope_test SRCS scope_test.cc) cc_test(enforce_test SRCS enforce_test.cc) proto_library(attr_type SRCS attr_type.proto) proto_library(op_proto SRCS op_proto.proto DEPS attr_type) - +proto_library(net_proto SRCS net_proto.proto DEPS op_proto) cc_test(op_proto_test SRCS op_proto_test.cc DEPS op_proto attr_type protobuf) - -proto_library(net_proto SRCS net_proto.proto) -cc_library(net SRCS net.cc DEPS net_proto attr_type op_proto) \ No newline at end of file +cc_library(net SRCS net.cc DEPS net_proto attr_type op_proto) From e95299b58300afda0d61e868998dfceb28e999da Mon Sep 17 00:00:00 2001 From: Superjom Date: Tue, 4 Jul 2017 16:28:21 +0800 Subject: [PATCH 06/13] fix ci error --- paddle/framework/CMakeLists.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/paddle/framework/CMakeLists.txt b/paddle/framework/CMakeLists.txt index ceff1d3581bc6..0abc63a831b71 100644 --- a/paddle/framework/CMakeLists.txt +++ b/paddle/framework/CMakeLists.txt @@ -7,6 +7,6 @@ cc_test(scope_test SRCS scope_test.cc) cc_test(enforce_test SRCS enforce_test.cc) proto_library(attr_type SRCS attr_type.proto) proto_library(op_proto SRCS op_proto.proto DEPS attr_type) +cc_test(op_proto_test SRCS op_proto_test.cc DEPS op_proto protobuf) proto_library(net_proto SRCS net_proto.proto DEPS op_proto) -cc_test(op_proto_test SRCS op_proto_test.cc DEPS op_proto attr_type protobuf) cc_library(net SRCS net.cc DEPS net_proto attr_type op_proto) From 9f2357561d939bdeae2a7bc0bd41be43d9ab0fe5 Mon Sep 17 00:00:00 2001 From: Superjom Date: Wed, 5 Jul 2017 10:08:23 +0800 Subject: [PATCH 07/13] fix ci error --- paddle/framework/net.cc | 2 +- paddle/framework/net.h | 6 ++++-- 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/paddle/framework/net.cc b/paddle/framework/net.cc index d49861c343ef1..8c565c28cb986 100644 --- a/paddle/framework/net.cc +++ b/paddle/framework/net.cc @@ -12,7 +12,7 @@ void PlainNet::InferShape(Scope* scope) { } void PlainNet::Run(Scope* scope, OpContext* context, OpIndex begin, - OpIndex end) const { + OpIndex end) const { // TODO Add implementation here. } diff --git a/paddle/framework/net.h b/paddle/framework/net.h index 55dcf147e1d4e..9564c831eef04 100644 --- a/paddle/framework/net.h +++ b/paddle/framework/net.h @@ -70,7 +70,7 @@ class Net { * If no positive indexes are provided, all operators in `ops_` will run. */ virtual void Run(Scope *scope, OpContext *context, OpIndex begin = -1, - OpIndex end = -1) const = 0; + OpIndex end = -1) const = 0; /** * @brief Add an Operator according to `def`. @@ -125,7 +125,7 @@ class PlainNet : public Net { * will be used. */ virtual void Run(Scope *scope = nullptr, OpContext *context = nullptr, - OpIndex begin = -1, OpIndex end = -1) const override; + OpIndex begin = -1, OpIndex end = -1) const override; /** * @brief Add an operator to this network. @@ -142,6 +142,8 @@ class PlainNet : public Net { */ virtual void AddBackwardOps() override; + virtual ~PlainNet() override {} + protected: /** * @brief Build the network. From 5c10a5ad555d834dac4785d8cd2feac18da9b67b Mon Sep 17 00:00:00 2001 From: Superjom Date: Wed, 5 Jul 2017 10:34:49 +0800 Subject: [PATCH 08/13] remove virtual --- paddle/framework/net.h | 4 ---- 1 file changed, 4 deletions(-) diff --git a/paddle/framework/net.h b/paddle/framework/net.h index 9564c831eef04..e60356dc17284 100644 --- a/paddle/framework/net.h +++ b/paddle/framework/net.h @@ -91,8 +91,6 @@ class Net { * @brief Create a network. */ static std::unique_ptr Create(const NetDesc &def = NetDesc()); - - virtual ~Net() = 0; }; /** @@ -142,8 +140,6 @@ class PlainNet : public Net { */ virtual void AddBackwardOps() override; - virtual ~PlainNet() override {} - protected: /** * @brief Build the network. From 568c03ba1d311ac2af2cb9242cefb00537174e50 Mon Sep 17 00:00:00 2001 From: Superjom Date: Wed, 5 Jul 2017 10:51:47 +0800 Subject: [PATCH 09/13] add virtual implementation --- paddle/framework/net.h | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/paddle/framework/net.h b/paddle/framework/net.h index e60356dc17284..2025bfa4b2366 100644 --- a/paddle/framework/net.h +++ b/paddle/framework/net.h @@ -91,6 +91,8 @@ class Net { * @brief Create a network. */ static std::unique_ptr Create(const NetDesc &def = NetDesc()); + + virtual ~Net() {} }; /** @@ -140,6 +142,8 @@ class PlainNet : public Net { */ virtual void AddBackwardOps() override; + virtual ~PlainNet() override {} + protected: /** * @brief Build the network. From 1264480b048cf68e29f3dffa91e228425df55908 Mon Sep 17 00:00:00 2001 From: Superjom Date: Thu, 6 Jul 2017 10:48:00 +0800 Subject: [PATCH 10/13] fix ci --- paddle/framework/CMakeLists.txt | 1 + 1 file changed, 1 insertion(+) diff --git a/paddle/framework/CMakeLists.txt b/paddle/framework/CMakeLists.txt index b33014210fb1c..fc2fbf88f1285 100644 --- a/paddle/framework/CMakeLists.txt +++ b/paddle/framework/CMakeLists.txt @@ -11,5 +11,6 @@ proto_library(op_proto SRCS op_proto.proto DEPS attr_type) cc_test(op_proto_test SRCS op_proto_test.cc DEPS op_proto protobuf) proto_library(net_proto SRCS net_proto.proto DEPS op_proto) #cc_library(net SRCS net.cc DEPS net_proto attr_type op_proto) + proto_library(op_desc SRCS op_desc.proto DEPS attr_type) cc_test(op_desc_test SRCS op_desc_test.cc DEPS op_desc protobuf) From bc021d775ed333dc9dca217203ee0d2999700813 Mon Sep 17 00:00:00 2001 From: dongzhihong Date: Tue, 11 Jul 2017 09:42:07 +0800 Subject: [PATCH 11/13] "move opContext to DeviceContext" --- paddle/framework/net.cc | 5 +---- paddle/framework/net.h | 6 ++---- paddle/framework/net_proto.proto | 2 +- paddle/framework/net_test.cc | 24 ++++++++++++++++++++++++ paddle/framework/op_registry_test.cc | 2 +- 5 files changed, 29 insertions(+), 10 deletions(-) create mode 100644 paddle/framework/net_test.cc diff --git a/paddle/framework/net.cc b/paddle/framework/net.cc index 8c565c28cb986..20c0aef049ceb 100644 --- a/paddle/framework/net.cc +++ b/paddle/framework/net.cc @@ -11,10 +11,7 @@ void PlainNet::InferShape(Scope* scope) { } } -void PlainNet::Run(Scope* scope, OpContext* context, OpIndex begin, - OpIndex end) const { - // TODO Add implementation here. -} +void PlainNet::Run(Scope* scope) const {} } // namespace framework } // namespace paddle diff --git a/paddle/framework/net.h b/paddle/framework/net.h index 2025bfa4b2366..ef5013349196a 100644 --- a/paddle/framework/net.h +++ b/paddle/framework/net.h @@ -69,8 +69,7 @@ class Net { * environment for ops. `begin` and `end` specify the scope of `ops_` to run, * If no positive indexes are provided, all operators in `ops_` will run. */ - virtual void Run(Scope *scope, OpContext *context, OpIndex begin = -1, - OpIndex end = -1) const = 0; + virtual void Run(Scope *scope) const = 0; /** * @brief Add an Operator according to `def`. @@ -124,8 +123,7 @@ class PlainNet : public Net { * scope will be used instead. If no OpContext is provicded, default context * will be used. */ - virtual void Run(Scope *scope = nullptr, OpContext *context = nullptr, - OpIndex begin = -1, OpIndex end = -1) const override; + virtual void Run(Scope *scope) const override; /** * @brief Add an operator to this network. diff --git a/paddle/framework/net_proto.proto b/paddle/framework/net_proto.proto index 2d042457e3306..0779f49fe2a9a 100644 --- a/paddle/framework/net_proto.proto +++ b/paddle/framework/net_proto.proto @@ -9,7 +9,7 @@ message NetDesc { // operator contains in network repeated OpProto operators = 2; // network type to run with. e.g "plainNet", "DAG" - optional string type = 3; + optional string net_type = 3; // num worker always optional int32 num_workers = 4; } diff --git a/paddle/framework/net_test.cc b/paddle/framework/net_test.cc new file mode 100644 index 0000000000000..04f5efdf79bdc --- /dev/null +++ b/paddle/framework/net_test.cc @@ -0,0 +1,24 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ + +#include "paddle/framework/net.h" +#include "paddle/framework/op_registry.h" + +#include + +namespace paddle { +namespace framework { +class FakeFC : public OpBase {} +} // namespace framework +} // namespace paddle diff --git a/paddle/framework/op_registry_test.cc b/paddle/framework/op_registry_test.cc index 17849ca0191db..ae6b7387129c6 100644 --- a/paddle/framework/op_registry_test.cc +++ b/paddle/framework/op_registry_test.cc @@ -119,4 +119,4 @@ TEST(OpRegistry, CustomChecker) { for (size_t i = 0; i < debug_str.length(); ++i) { ASSERT_EQ(debug_str[i], str[i]); } -} \ No newline at end of file +} From 18e65b0c084ef482492b528985173341a24284cc Mon Sep 17 00:00:00 2001 From: dongzhihong Date: Tue, 11 Jul 2017 10:37:41 +0800 Subject: [PATCH 12/13] "support net_proto header" --- paddle/framework/CMakeLists.txt | 2 +- paddle/framework/net.cc | 7 +++++-- paddle/framework/net.h | 14 +++++++------- paddle/framework/net_test.cc | 2 +- paddle/platform/device_context.h | 1 + 5 files changed, 15 insertions(+), 11 deletions(-) diff --git a/paddle/framework/CMakeLists.txt b/paddle/framework/CMakeLists.txt index 39cfb4623795e..e6e3b79d7bd11 100644 --- a/paddle/framework/CMakeLists.txt +++ b/paddle/framework/CMakeLists.txt @@ -18,4 +18,4 @@ add_custom_target(framework_py_proto_init ALL COMMAND ${CMAKE_COMMAND} -E touch add_dependencies(framework_py_proto framework_py_proto_init) proto_library(net_proto SRCS net_proto.proto DEPS op_proto) -#cc_library(net SRCS net.cc DEPS net_proto attr_type op_proto) +cc_library(net SRCS net.cc DEPS net_proto attr_type op_proto) diff --git a/paddle/framework/net.cc b/paddle/framework/net.cc index 20c0aef049ceb..f0c128d554b29 100644 --- a/paddle/framework/net.cc +++ b/paddle/framework/net.cc @@ -11,7 +11,10 @@ void PlainNet::InferShape(Scope* scope) { } } -void PlainNet::Run(Scope* scope) const {} - +void PlainNet::Run(Scope* scope, DeviceContext* ctx) { + for (auto& op : ops_) { + op.Run(ctx); + } +} } // namespace framework } // namespace paddle diff --git a/paddle/framework/net.h b/paddle/framework/net.h index ef5013349196a..b2894320dafdf 100644 --- a/paddle/framework/net.h +++ b/paddle/framework/net.h @@ -17,9 +17,11 @@ #include "paddle/framework/net_proto.pb.h" #include "paddle/framework/op_proto.pb.h" #include "paddle/framework/scope.h" +#include "paddle/platform/device_context.h" namespace paddle { namespace framework { +using namespace paddle::platform; // operator's index stored in a network. typedef int OpIndex; @@ -30,15 +32,13 @@ typedef int OpIndex; */ struct OpDesc; -struct OpDef; -struct OpContext; struct OpAttrs {}; class Operator { public: Operator(const OpDesc &def) {} void InferShape() {} - void Run() {} + void Run(DeviceContext *ctx) {} }; /** @@ -69,12 +69,12 @@ class Net { * environment for ops. `begin` and `end` specify the scope of `ops_` to run, * If no positive indexes are provided, all operators in `ops_` will run. */ - virtual void Run(Scope *scope) const = 0; + virtual void Run(Scope *scope, DeviceContext *ctx) = 0; /** * @brief Add an Operator according to `def`. */ - virtual OpIndex AddOp(const OpDef &def) = 0; + virtual OpIndex AddOp(const OpProto &def) = 0; /** * @brief Add optimizer operators acctording to `attrs`. @@ -123,12 +123,12 @@ class PlainNet : public Net { * scope will be used instead. If no OpContext is provicded, default context * will be used. */ - virtual void Run(Scope *scope) const override; + virtual void Run(Scope *scope, DeviceContext *ctx) override; /** * @brief Add an operator to this network. */ - virtual OpIndex AddOp(const OpDef &def) override; + virtual OpIndex AddOp(const OpProto &def) override; /** * @brief Add all optimizer operators related into the network. diff --git a/paddle/framework/net_test.cc b/paddle/framework/net_test.cc index 04f5efdf79bdc..a8e31c1497519 100644 --- a/paddle/framework/net_test.cc +++ b/paddle/framework/net_test.cc @@ -19,6 +19,6 @@ namespace paddle { namespace framework { -class FakeFC : public OpBase {} +class FakeFC : public Operator {} } // namespace framework } // namespace paddle diff --git a/paddle/platform/device_context.h b/paddle/platform/device_context.h index fcef0a5e3058f..160eb4e12060b 100644 --- a/paddle/platform/device_context.h +++ b/paddle/platform/device_context.h @@ -36,6 +36,7 @@ class DeviceContext { class CPUDeviceContext : public DeviceContext {}; #ifndef PADDLE_ONLY_CPU + class GPUPlaceGuard { public: explicit GPUPlaceGuard(GPUPlace new_place) : previous_(GetCurrentDeviceId()) { From b871641a5315b10bfb1d0776e288dd25ef2969d2 Mon Sep 17 00:00:00 2001 From: dongzhihong Date: Tue, 11 Jul 2017 10:53:48 +0800 Subject: [PATCH 13/13] "switch to shared_ptr" --- paddle/framework/net.cc | 2 +- paddle/framework/net.h | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/paddle/framework/net.cc b/paddle/framework/net.cc index f0c128d554b29..73b3051235ee9 100644 --- a/paddle/framework/net.cc +++ b/paddle/framework/net.cc @@ -11,7 +11,7 @@ void PlainNet::InferShape(Scope* scope) { } } -void PlainNet::Run(Scope* scope, DeviceContext* ctx) { +void PlainNet::Run(std::shared_ptr scope, DeviceContext* ctx) { for (auto& op : ops_) { op.Run(ctx); } diff --git a/paddle/framework/net.h b/paddle/framework/net.h index b2894320dafdf..76992e0728290 100644 --- a/paddle/framework/net.h +++ b/paddle/framework/net.h @@ -69,7 +69,7 @@ class Net { * environment for ops. `begin` and `end` specify the scope of `ops_` to run, * If no positive indexes are provided, all operators in `ops_` will run. */ - virtual void Run(Scope *scope, DeviceContext *ctx) = 0; + virtual void Run(std::shared_ptr scope, DeviceContext *ctx) = 0; /** * @brief Add an Operator according to `def`. @@ -123,7 +123,7 @@ class PlainNet : public Net { * scope will be used instead. If no OpContext is provicded, default context * will be used. */ - virtual void Run(Scope *scope, DeviceContext *ctx) override; + virtual void Run(std::shared_ptr scope, DeviceContext *ctx) override; /** * @brief Add an operator to this network.