diff --git a/paddle/framework/CMakeLists.txt b/paddle/framework/CMakeLists.txt index 4409c6feae218..e6e3b79d7bd11 100644 --- a/paddle/framework/CMakeLists.txt +++ b/paddle/framework/CMakeLists.txt @@ -16,3 +16,6 @@ py_proto_compile(framework_py_proto SRCS attr_type.proto op_proto.proto op_desc. # Generate an empty __init__.py to make framework_py_proto as a valid python module. add_custom_target(framework_py_proto_init ALL COMMAND ${CMAKE_COMMAND} -E touch __init__.py) add_dependencies(framework_py_proto framework_py_proto_init) + +proto_library(net_proto SRCS net_proto.proto DEPS op_proto) +cc_library(net SRCS net.cc DEPS net_proto attr_type op_proto) diff --git a/paddle/framework/net.cc b/paddle/framework/net.cc new file mode 100644 index 0000000000000..73b3051235ee9 --- /dev/null +++ b/paddle/framework/net.cc @@ -0,0 +1,20 @@ +#include "paddle/framework/net.h" + +namespace paddle { +namespace framework { + +PlainNet::PlainNet(const NetDesc& def) {} + +void PlainNet::InferShape(Scope* scope) { + for (auto& op : ops_) { + op.InferShape(); + } +} + +void PlainNet::Run(std::shared_ptr scope, DeviceContext* ctx) { + for (auto& op : ops_) { + op.Run(ctx); + } +} +} // namespace framework +} // namespace paddle diff --git a/paddle/framework/net.h b/paddle/framework/net.h new file mode 100644 index 0000000000000..76992e0728290 --- /dev/null +++ b/paddle/framework/net.h @@ -0,0 +1,171 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ + +#pragma once + +#include "paddle/framework/net_proto.pb.h" +#include "paddle/framework/op_proto.pb.h" +#include "paddle/framework/scope.h" +#include "paddle/platform/device_context.h" + +namespace paddle { +namespace framework { +using namespace paddle::platform; + +// operator's index stored in a network. +typedef int OpIndex; +/** + * NOTE following codes are some definitions of unimplemented concepts. + * We write some basic implementation to make Net compilable. These APIs will + * keep updating if the concepts related are implemented. + */ + +struct OpDesc; +struct OpAttrs {}; + +class Operator { + public: + Operator(const OpDesc &def) {} + void InferShape() {} + void Run(DeviceContext *ctx) {} +}; + +/** + * @brief Network that manage the operators it has. + * + * Network is the container and controller of a set of operators, user can build + * a real network from a NetDesc which is a protobuf message and use + * Network.Run() * to run all the operators in the network. + + * A network object knows all Operators belonging to this network. Variables, + * which are inputs and outputs of these operators, are created and managed by a + * hierarchy of Scope objects. + * + * This is the base class of network, all the networks should implement the apis + * it defines. + */ +class Net { + public: + /** + * @brief Infer shapes of all inputs and outputs of operators. + */ + virtual void InferShape(Scope *scope) = 0; + /** + * @brief Run the network. + * + * Run all the operators and return success(true) or not, with all the + * variables are located in `scope`. `context` describes the detail execution + * environment for ops. `begin` and `end` specify the scope of `ops_` to run, + * If no positive indexes are provided, all operators in `ops_` will run. + */ + virtual void Run(std::shared_ptr scope, DeviceContext *ctx) = 0; + + /** + * @brief Add an Operator according to `def`. + */ + virtual OpIndex AddOp(const OpProto &def) = 0; + + /** + * @brief Add optimizer operators acctording to `attrs`. + */ + virtual void AddOptimizerOps(const OpAttrs &attrs) = 0; + + /** + * @brief Add backward operators. + */ + virtual void AddBackwardOps() = 0; + + /** + * @brief Create a network. + */ + static std::unique_ptr Create(const NetDesc &def = NetDesc()); + + virtual ~Net() {} +}; + +/** + * @brief a basic implementation of Net. + * + * PlainNet is a very simple Net, it create a list of operators, and run them + * sequentially following the order they added. + */ +class PlainNet : public Net { + public: + /** + * @brief Initialize a PlainNet. + * + * Initialize from a network describe by `def`. NetDesc is the definition of + * a network. + */ + PlainNet(const NetDesc &def); + + /** + * Infer all the operators' input and output varialbes' shapes, will be called + * before every mini-batch + */ + virtual void InferShape(Scope *scope) override; + + /** + * @brief Run the network. + * + * Run all the operators with the `scope`, if no scope is provided, default + * scope will be used instead. If no OpContext is provicded, default context + * will be used. + */ + virtual void Run(std::shared_ptr scope, DeviceContext *ctx) override; + + /** + * @brief Add an operator to this network. + */ + virtual OpIndex AddOp(const OpProto &def) override; + + /** + * @brief Add all optimizer operators related into the network. + */ + virtual void AddOptimizerOps(const OpAttrs &attrs) override; + + /** + * @brief Add all backward operators related into the network. + */ + virtual void AddBackwardOps() override; + + virtual ~PlainNet() override {} + + protected: + /** + * @brief Build the network. + * + * Create operators accordding to `def`, will be called by the constructor. + */ + void BuildNet(const NetDesc &def); + + /** + * @brief Add an operator into this network. + * + * Add a operator which is identified as `type` and has attributes described + * in `attrs`, the `inputs` are the keys of readonly input variables, + * `outputs` are keys of mutable output variables. An `OpIndex` will be + * returned to indicate the offset of the new operator in `ops_`. + */ + OpIndex AddOp(const std::string &type, const std::vector &inputs, + const std::vector &outputs, + const OpAttrs &attrs = OpAttrs()); + + private: + // the operators owned by `Network`. + std::vector ops_; +}; + +} // namespace framework +} // namespace paddle diff --git a/paddle/framework/net_proto.proto b/paddle/framework/net_proto.proto new file mode 100644 index 0000000000000..0779f49fe2a9a --- /dev/null +++ b/paddle/framework/net_proto.proto @@ -0,0 +1,15 @@ +syntax="proto2"; +package paddle.framework; + +import "op_proto.proto"; + +message NetDesc { + // network identification + optional string name = 1; + // operator contains in network + repeated OpProto operators = 2; + // network type to run with. e.g "plainNet", "DAG" + optional string net_type = 3; + // num worker always + optional int32 num_workers = 4; +} diff --git a/paddle/framework/net_test.cc b/paddle/framework/net_test.cc new file mode 100644 index 0000000000000..a8e31c1497519 --- /dev/null +++ b/paddle/framework/net_test.cc @@ -0,0 +1,24 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ + +#include "paddle/framework/net.h" +#include "paddle/framework/op_registry.h" + +#include + +namespace paddle { +namespace framework { +class FakeFC : public Operator {} +} // namespace framework +} // namespace paddle diff --git a/paddle/framework/op_registry_test.cc b/paddle/framework/op_registry_test.cc index 17849ca0191db..ae6b7387129c6 100644 --- a/paddle/framework/op_registry_test.cc +++ b/paddle/framework/op_registry_test.cc @@ -119,4 +119,4 @@ TEST(OpRegistry, CustomChecker) { for (size_t i = 0; i < debug_str.length(); ++i) { ASSERT_EQ(debug_str[i], str[i]); } -} \ No newline at end of file +} diff --git a/paddle/platform/device_context.h b/paddle/platform/device_context.h index fcef0a5e3058f..160eb4e12060b 100644 --- a/paddle/platform/device_context.h +++ b/paddle/platform/device_context.h @@ -36,6 +36,7 @@ class DeviceContext { class CPUDeviceContext : public DeviceContext {}; #ifndef PADDLE_ONLY_CPU + class GPUPlaceGuard { public: explicit GPUPlaceGuard(GPUPlace new_place) : previous_(GetCurrentDeviceId()) {