diff --git a/paddle/framework/CMakeLists.txt b/paddle/framework/CMakeLists.txt index 8415ce67e9039..95638ebcdf8ae 100644 --- a/paddle/framework/CMakeLists.txt +++ b/paddle/framework/CMakeLists.txt @@ -19,6 +19,8 @@ py_proto_compile(framework_py_proto SRCS attr_type.proto op_proto.proto op_desc. # Generate an empty __init__.py to make framework_py_proto as a valid python module. add_custom_target(framework_py_proto_init ALL COMMAND ${CMAKE_COMMAND} -E touch __init__.py) add_dependencies(framework_py_proto framework_py_proto_init) +cc_library(recurrent_network_op SRCS recurrent_network_op.cc DEPS op_desc place) +cc_test(recurrent_network_op_test SRCS recurrent_network_op_test.cc DEPS recurrent_network_op glog gtest gflags ddim op_desc) proto_library(net_proto SRCS net_proto.proto DEPS op_proto) cc_library(net SRCS net.cc DEPS net_proto) diff --git a/paddle/framework/net.h b/paddle/framework/net.h index 0481d8f47ccbc..e6b94e062fad6 100644 --- a/paddle/framework/net.h +++ b/paddle/framework/net.h @@ -15,6 +15,7 @@ #pragma once #include "paddle/framework/net_proto.pb.h" +#include "paddle/framework/op_desc.pb.h" #include "paddle/framework/op_proto.pb.h" #include "paddle/framework/scope.h" #include "paddle/platform/device_context.h" @@ -31,7 +32,6 @@ typedef int OpIndex; * keep updating if the concepts related are implemented. */ -struct OpDesc; struct OpAttrs {}; class Operator { @@ -74,7 +74,7 @@ class Net { /** * @brief Add an Operator according to `def`. */ - virtual OpIndex AddOp(const OpProto &def) = 0; + virtual OpIndex AddOp(const OpDesc &def) = 0; /** * @brief Add optimizer operators acctording to `attrs`. @@ -129,7 +129,7 @@ class PlainNet : public Net { /** * @brief Add an operator to this network. */ - virtual OpIndex AddOp(const OpProto &def) override; + virtual OpIndex AddOp(const OpProto &def); /** * @brief Add all optimizer operators related into the network. diff --git a/paddle/framework/recurrent_network_op.cc b/paddle/framework/recurrent_network_op.cc new file mode 100644 index 0000000000000..92e7c019a2500 --- /dev/null +++ b/paddle/framework/recurrent_network_op.cc @@ -0,0 +1,248 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ + +#include "paddle/framework/recurrent_network_op.h" +#include "paddle/framework/tensor.h" + +#include +#include + +namespace paddle { +namespace framework { + +void RecurrentOp::Run(OpContext* contex) const { + auto scope = contex->scope; + + PADDLE_ENFORCE(scope->HasVariable(net_name_), "step net is not in scope."); + Variable* net = scope->GetVariable(net_name_); + PADDLE_ENFORCE(net, "failed to get step net"); + + LOG(INFO) << "create scopes"; + CreateScopes(scope); + LOG(INFO) << "segment input"; + SegmentInputs(scope); + + // forward + size_t max_seq_len = GetMaxSeqLen(scope); + LOG(INFO) << "sequence length " << max_seq_len; + auto step_scopes = GetStepScopes(scope); + for (size_t step_id = 0; step_id < max_seq_len; step_id++) { + LOG(INFO) << "run step " << step_id; + LinkMemories(step_scopes, step_id); + + net->GetMutable()->Run(step_scopes[step_id]); + } + + LOG(INFO) << "concat outputs"; + // prepare outputs + ConcatOutputs(scope); +} + +void RecurrentOp::Init(const OpDesc& op_desc, AttributeMap& attrs) { + OperatorBase::Init(op_desc, attrs); + + // set original inputs + for (const std::string& input : op_desc.inputs()) { + LOG(INFO) << "set input " << input; + inputs_.push_back(input); + } + // set original outputs + for (const std::string& output : op_desc.outputs()) { + LOG(INFO) << "set output " << output; + outputs_.push_back(output); + } + + net_name_ = inputs_.at(GetAttr("step_net")); + step_scopes_name_ = outputs_.back(); + + // prepare inlinks + PADDLE_ENFORCE(inlinks_.empty(), "RecurrentOp duplicate inited"); + LOG(INFO) << "set inlinks"; + for (auto id : GetAttr>("in_links")) { + inlinks_.push_back(inputs_[id]); + } + auto inlink_alias = GetAttr>("in_link_alias"); + in_link_alias_ = + std::vector{inlink_alias.begin(), inlink_alias.end()}; + PADDLE_ENFORCE(inlinks_.size() == in_link_alias_.size(), + "in_links/in_link_alias mismatch."); + + PADDLE_ENFORCE( + outputs_.size() > 1, + "more than 1 output should be provided and the last is `step_scopes`"); + outlinks_ = std::vector{outputs_.begin(), outputs_.end() - 1}; + + auto outlink_alias = GetAttr>("out_link_alias"); + out_link_alias_ = + std::vector{outlink_alias.begin(), outlink_alias.end()}; + PADDLE_ENFORCE(outlinks_.size() == outlink_alias.size(), + "out_links/out_link_alias mismatch."); + + // set memories + auto memories = GetAttr>("memories"); + auto pre_memories = GetAttr>("pre_memories"); + PADDLE_ENFORCE(memories.size() == pre_memories.size(), + "The size of memories and pre_memories doesn't match: %d,%d.", + memories.size(), pre_memories.size()); + + std::vector boot_memories; + LOG(INFO) << "set boot_memories"; + for (auto id : GetAttr>("boot_memories")) { + boot_memories.push_back(inputs_[id]); + } + PADDLE_ENFORCE(memories.size() == boot_memories.size(), + "the size of memories and boot_memories doesn't match: %d,%d", + memories.size(), boot_memories.size()); + for (size_t i = 0; i < memories.size(); ++i) { + details::MemoryAttr mem_attr; + mem_attr.var = memories[i]; + mem_attr.pre_var = pre_memories[i]; + mem_attr.boot_var = boot_memories[i]; + memory_attrs_.push_back(mem_attr); + LOG(INFO) << "set memorys:\t" + << "memory:" << mem_attr.var << "\tboot:" << mem_attr.boot_var; + } +} + +size_t RecurrentOp::GetMaxSeqLen(ScopePtr scope) const { + // TODO update this function when using variable-length of sequence. + return Input(scope, inlinks_[0])->GetMutable()->dims()[0]; +} + +void RecurrentOp::CreateScopes(ScopePtr scope) const { + size_t max_seq_len = GetMaxSeqLen(scope); + std::vector* step_scopes = + scope->GetVariable(step_scopes_name_) + ->GetMutable>(); + // TODO Only two scopes are needed for inference, this case will be + // supported later. + if (max_seq_len > step_scopes->size()) { + for (size_t i = step_scopes->size(); i < max_seq_len; ++i) { + step_scopes->push_back(std::make_shared(scope)); + } + } +} + +void RecurrentOp::SegmentInputs(ScopePtr scope) const { + PADDLE_ENFORCE(!inlinks_.empty(), "no in links are provided."); + auto step_scopes = GetStepScopes(scope); + size_t max_seq_len = GetMaxSeqLen(scope); + for (size_t i = 0; i < inlinks_.size(); ++i) { + Tensor* input_tensor = Input(scope, inlinks_[i])->GetMutable(); + for (size_t j = 0; j < max_seq_len; j++) { + Variable* input_var = step_scopes[j]->CreateVariable(in_link_alias_[i]); + Tensor* step_input_tensor = input_var->GetMutable(); + *step_input_tensor = input_tensor->Slice(j, j + 1); + // TODO (luotao1): use reshape function to decrease the dims of + // step_input_tensor. + } + } +} + +void RecurrentOp::ConcatOutputs(ScopePtr scope) const { + auto step_scopes = GetStepScopes(scope); + size_t max_seq_len = GetMaxSeqLen(scope); + // TODO (luotao1): update using CopyFrom function in tensor. + auto dims = Input(scope, inlinks_[0])->GetMutable()->dims(); + int batch_size = dims[1]; + for (size_t i = 0; i < outlinks_.size(); i++) { + auto output_dims = step_scopes[0] + ->GetVariable(out_link_alias_[0]) + ->GetMutable() + ->dims(); + int output_dim = output_dims[1]; + int length = batch_size * output_dim; + Tensor* output_tensor = + scope->CreateVariable(outlinks_[i])->GetMutable(); + float* output = output_tensor->mutable_data( + make_ddim({(int)max_seq_len, batch_size, output_dim}), + platform::CPUPlace()); + for (size_t j = 0; j < max_seq_len; j++) { + Variable* output_var = step_scopes[j]->GetVariable(out_link_alias_[i]); + const float* step_output = + output_var->GetMutable()->data(); + std::memcpy(output + j * length, step_output, length); + } + } +} + +void RecurrentOp::LinkMemories(std::vector& step_scopes, + size_t step_id) const { + PADDLE_ENFORCE(step_id < step_scopes.size(), + "step [%d] out of range of step scopes' size [%d]", step_id, + step_scopes.size()); + ScopePtr step_scope = step_scopes[step_id]; + for (auto& attr : memory_attrs_) { + Tensor* pre_memory_tensor = + step_scope->CreateVariable(attr.pre_var)->GetMutable(); + + if (step_id == 0) { + PADDLE_ENFORCE(step_scope->HasVariable(attr.boot_var), + "memory [%s]'s boot variable [%s] not exists", attr.var, + attr.boot_var); + Tensor* boot_tensor = + step_scope->CreateVariable(attr.boot_var)->GetMutable(); + PADDLE_ENFORCE(boot_tensor, "boot_tensor should be retrieved before"); + // copy from boot memory + pre_memory_tensor->ShareDataFrom(*boot_tensor); + } else { + // copy from previous step scope's memory to this scope's + // `pre - memory` + Tensor* pre_step_memory = + step_scopes[step_id - 1]->GetVariable(attr.var)->GetMutable(); + pre_memory_tensor->ShareDataFrom(*pre_step_memory); + } + + // TODO the memory of current step should be allocated in step net + Tensor* cur_memory_tensor = + step_scopes[step_id]->CreateVariable(attr.var)->GetMutable(); + cur_memory_tensor->mutable_data(pre_memory_tensor->dims(), + platform::CPUPlace()); + } +} + +// TODO testing when including operator.h + +// class RecurrentOpProtoAndCheckerMaker : public OpProtoAndCheckerMaker { +// public: +// RecurrentOpProtoAndCheckerMaker(OpProto* proto, OpAttrChecker* op_checker) +// : OpProtoAndCheckerMaker(proto, op_checker) { +// // AddInput("input", "input of test op"); // need to support dynamic +// number +// // AddOutput("output", "output of test op"); // need to support dynamic +// number +// AddAttr>("in_links", "The input link positions in +// the all inputs.") +// .SetDefault({0}); +// AddAttr>("boot_memories", "The initial memory +// positions in the all inputs."); +// AddAttr("step_net", "The step net position in the all inputs."); +// +// AddAttr>("in_link_alias", "The input link +// alias in the step network."); +// AddAttr>("out_link_alias", "The output link +// alias in the step network."); +// AddAttr>("memories", "The memory names."); +// AddAttr>("pre_memories", "The +// history/previous memory names."); +// +// AddType("recurrent_op"); +// AddComment("This is a recurrent group operator."); +// } +// }; +// +// REGISTER_OP(recurrent_op, RecurrentOp, RecurrentOpProtoAndCheckerMaker); + +} // namespace framework +} // namespace paddle diff --git a/paddle/framework/recurrent_network_op.h b/paddle/framework/recurrent_network_op.h new file mode 100644 index 0000000000000..161408482fc92 --- /dev/null +++ b/paddle/framework/recurrent_network_op.h @@ -0,0 +1,275 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ + +#pragma once + +#include +#include "paddle/framework/attr_checker.h" +#include "paddle/framework/ddim.h" +#include "paddle/framework/enforce.h" +#include "paddle/framework/scope.h" +#include "paddle/framework/variable.h" + +#include +#include "paddle/framework/op_desc.pb.h" + +namespace paddle { +namespace framework { + +// -------------------------------------------------------------------- +// fake interfaces that has not be implemented by other modules. +// TODO keep updating according to other modules' designs. +typedef std::shared_ptr ScopePtr; +struct OpContext { + ScopePtr scope; +}; + +class OperatorBase { + public: + virtual ~OperatorBase() {} + void Init(const OpDesc& op_desc, AttributeMap& attrs) { attrs_ = attrs; } + virtual void Run(OpContext* context) const = 0; + virtual void InferShape(ScopePtr scope) const = 0; + inline Variable* Input(ScopePtr scope, std::string name) const { + return scope->GetVariable(name); + }; + + template + inline const T& GetAttr(const std::string& name) const { + PADDLE_ENFORCE(attrs_.count(name) != 0, "%s should be in AttributeMap", + name); + return boost::get(attrs_.at(name)); + } + + protected: + std::vector inputs_; + std::vector outputs_; + AttributeMap attrs_; +}; + +struct NetDesc { + std::string name_; + std::vector op_descs; +}; + +class PlainNet { + public: + PlainNet() {} + PlainNet(const NetDesc& desc) { + for (const OpDesc& proto : desc.op_descs) { + AddOp(proto); + } + } + // PlainNet(const std::string desc) {} + void AddOp(const OpDesc& desc); + void Run(ScopePtr scope) { + OpContext ctx; + ctx.scope = scope; + for (auto& op : ops_) { + op->Run(&ctx); + } + } + + private: + std::vector> ops_; +}; + +namespace details { + +/* + * Memory of a RNN (same as the role of `Momory` in PaddlePaddle). + * + * Memory attributes cached by this op, dims will be infered from + * boot memories in father scope. Other attributes are copied from Op's proto + * attributes. + */ +struct MemoryAttr { + // name of current state variable + std::string var; + // name of previous step's state variable + std::string pre_var; + // name of the variables to init this memory (same role of `boot_layer` in + // PaddlePaddle), which is store in father's scope. + std::string boot_var; +}; + +}; // namespace details + +// fake interfaces end +// -------------------------------------------------------------------- +// The sequence format in RecurrentOp is Tensor now. +// TODO: +// 1. No-padding computing for sequences with indifinite length in one batch. +// 2. Hierarchical RNN for sequence with sub-sequence. +// 3. External Memory. +// 4. More Complex RNN architecture, such as Gated Feedback RNN. +// Refer to: https://arxiv.org/pdf/1502.02367.pdf + +/* + * RecurrentOp inputs stored in proto: + * - in_links : real inputs that need to be segmented to steps. + * - boot memories + * - all weights in step net + * - step net + * + * outputs: + * - out_links : real outputs + * - step scopes + * + * Attributes stored in AttributeMap: + * - in_links: vector + * - boot_memories: vector + * - step_net: int + * - in_link_alias: vector the alias of in_links in step net. + * - out_link_alias: vector the alias of out_links in step net + * - memories: vector the memory names + * - pre_memories: vector the previous memory names + * + * see RecurrentOpProtoAndCheckerMaker + */ + +class RecurrentOp : public OperatorBase { + public: + /* + * Initialize the recurrent operator from the operator protobuf + * and attributes. + */ + void Init(const OpDesc& op_desc, AttributeMap& attrs); + + virtual void InferShape(ScopePtr scope) const override {} + + /* + * Forward run the RNN. + * + * NOTE the context's scope is not given until `Run` called, so step scopes' + * father should be set/updated in this method. + */ + virtual void Run(OpContext* contex) const override; + + virtual ~RecurrentOp() {} + + protected: + /* + * Get the max sequence length of the scope. + */ + size_t GetMaxSeqLen(ScopePtr scope) const; + + /* + * Prepare inputs for each stepnet. + */ + void SegmentInputs(ScopePtr scope) const; + + /* + * Process outputs of stepnets and merge to variables. + */ + void ConcatOutputs(ScopePtr scope) const; + + /* + * the step scopes as the father scope. The step scopes will be stored in + * the father scope as a variable whose name is specified by + * `step_scopes_name_`. + * + * NOTE the scopes are reused by both the `Forward` and `Backward`, so just + * create once and expand its size if more steps need. + */ + void CreateScopes(ScopePtr scope) const; + + /* + * Get the step scopes. + */ + inline const std::vector& GetStepScopes(ScopePtr scope) const { + return *(scope->GetVariable(step_scopes_name_)) + ->GetMutable>(); + } + + /* + * Link memory in previous step scope to current scope. + */ + void LinkMemories(std::vector& step_scopes, size_t step_id) const; + + private: + /* + * The attributes in protobuf about the memory description and the initial + * memory description are as follows. The number of initial memories should + * equal to the memories number. + * + * arg { + * name: "memories" + * strings: "hidden" + * strings: "state" + * } + * arg { + * name: “pre_memories" + * strings: "pre_hidden" + * strings: "pre_state" + * } + * arg { + * name: “boot_memories" + * strings: "boot_hidden" + * strings: "boot_state" + * } + */ + mutable std::vector memory_attrs_; + + // name of rnn op's step net, the step net will be shared by both `Forward` + // and `Backward`, so we store it as a variable in father's scope, with a + // unique key specified by `net_name_`. + std::string net_name_; + // name of steps' scopes which is stored in father scope with a unique key + // specified by `step_scopes_name_`. + std::string step_scopes_name_; + // real inputs that need to be segmented. + std::vector inlinks_; + std::vector outlinks_; + + std::vector in_link_alias_; + std::vector out_link_alias_; +}; + +/* + * RNN's backward alogorithm. + * + * To accelerate the development of RecurrentBackwardOp, we decouple RNN's + * algorithm and `RecurrentBackwardAlgorithm`, the former contains the core + * implementation of a RNN, and will keep stable even if the framework changes a + * lot, and the latter is a wrapper acts like an dapter for it to make RNN an + * operator. + */ +class RecurrentBackwardAlgorithm { + public: + private: + // stepnet for backward + // NOTE this stepnet is created by others and should insert AddOp for its + // weights gradient updating, RNN backward just run it. + std::string stepnet_name_; + // step scopes that shared by both the forward and backward operators. + std::string step_scopes_name_; + + // inputs(gradients of forward operator's outputs) that need to be segmented + // for each step. + std::vector inlinks_; + // outputs(gradients of forward operator's inputs) of each step that need to + // be concated. + std::vector outlinks_; + + // alias to avoid duplicate keys in scopes. + std::vector inlink_alias_; + std::vector outlink_alias_; + + // NOTE the first step's boot memories' gradients should be outputed. + std::vector memories_; +}; + +} // namespace framework +} // namespace paddle diff --git a/paddle/framework/recurrent_network_op_test.cc b/paddle/framework/recurrent_network_op_test.cc new file mode 100644 index 0000000000000..ce65235c1e40b --- /dev/null +++ b/paddle/framework/recurrent_network_op_test.cc @@ -0,0 +1,245 @@ +/* + Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +*/ + +#include +#include + +#include "paddle/framework/recurrent_network_op.h" +#include "paddle/framework/tensor.h" + +namespace paddle { +namespace framework { + +// fake op implementations +namespace fake { +class FcOp : public OperatorBase { + public: + FcOp(const OpDesc& desc) {} + + virtual void InferShape(ScopePtr scope) const override { + for (const auto& output : outputs_) { + LOG(INFO) << "fc [" << name_ << "]" + << " create output variable [" << output << "]"; + scope->CreateVariable(output); + } + } + + virtual void Run(OpContext* contex) const override { + LOG(INFO) << "run fc op"; + for (const auto& input : inputs_) { + PADDLE_ENFORCE(contex->scope->HasVariable(input), + "no input variable [%s] exists"); + LOG(INFO) << "fc [" << name_ << "] read input [" << input << "]"; + } + for (const auto& output : outputs_) { + PADDLE_ENFORCE(contex->scope->HasVariable(output), + "no output variable [%s] exists"); + LOG(INFO) << "fc [" << name_ << "] write output [" << output << "]"; + } + } + + private: + std::string name_; +}; + +class AddOp : public OperatorBase { + public: + AddOp(const OpDesc& desc) {} + + virtual void InferShape(ScopePtr scope) const override { + for (const auto& output : outputs_) { + LOG(INFO) << "add [" << name_ << "]" + << " create output variable [" << output << "]"; + scope->CreateVariable(output); + } + } + + virtual void Run(OpContext* contex) const override { + LOG(INFO) << "run add op"; + for (const auto& input : inputs_) { + PADDLE_ENFORCE(contex->scope->HasVariable(input), + "no input variable [%s] exists"); + LOG(INFO) << "add [" << name_ << "] read input [" << input << "]"; + } + for (const auto& output : outputs_) { + PADDLE_ENFORCE(contex->scope->HasVariable(output), + "no output variable [%s] exists"); + LOG(INFO) << "add [" << name_ << "] write output [" << output << "]"; + } + } + + private: + std::string name_; +}; +} // namespace fake + +void PlainNet::AddOp(const OpDesc& desc) { + if (desc.type() == "fc") { + ops_.emplace_back(new fake::FcOp(desc)); + } else if (desc.type() == "add") { + ops_.emplace_back(new fake::AddOp(desc)); + } +} + +class RecurrentOpTest : public ::testing::Test { + protected: + virtual void SetUp() override { + CreateGlobalVariables(); + CreateStepNet(); + CreateRNNOp(); + } + + virtual void TearDown() override {} + + void CreateGlobalVariables() { + scope_ = std::make_shared(); + LOG(INFO) << "create global variable h_boot"; + // create boot memory + scope_->CreateVariable("h_boot"); + // create input, and init content + LOG(INFO) << "create global variable x"; + Variable* x = scope_->CreateVariable("x"); + DDim dims = make_ddim(std::vector{10 /*sent size*/, 20 /*batch size*/, + 30 /*input dim*/}); + x->GetMutable()->mutable_data(dims, platform::CPUPlace()); + + LOG(INFO) << "create global variable w"; + Variable* w = scope_->CreateVariable("rnn/w"); + w->GetMutable()->mutable_data( + make_ddim(std::vector{30, 30}), platform::CPUPlace()); + + LOG(INFO) << "create global variable h_boot"; + Variable* h_boot = scope_->CreateVariable("h_boot"); + h_boot->GetMutable()->mutable_data( + make_ddim(std::vector{20 /*batch size*/, 30 /*input dim*/}), + platform::CPUPlace()); + + LOG(INFO) << "create variable step_scopes"; + scope_->CreateVariable("step_scopes"); + + LOG(INFO) << "create variable h"; + scope_->CreateVariable("h"); + } + + void CreateRNNOp() { + OpDesc op_desc; + + op_desc.set_type("rnn_op"); + op_desc.add_inputs("x"); + op_desc.add_inputs("h_boot"); // initial memory + op_desc.add_inputs("step_net"); // step net + // output hidden vectors + op_desc.add_outputs("h"); + op_desc.add_outputs("step_scopes"); // step scopes + + // add real input + auto input_attr = op_desc.mutable_attrs()->Add(); + input_attr->set_type(paddle::framework::AttrType::INTS); + *input_attr->mutable_ints()->Add() = 0; + input_attr->set_name("in_links"); + + // add input alias, this alias is used in step net. + auto input_alias_attr = op_desc.mutable_attrs()->Add(); + input_alias_attr->set_type(paddle::framework::AttrType::STRINGS); + *input_alias_attr->mutable_strings()->Add() = "rnn/x"; + input_alias_attr->set_name("in_link_alias"); + + // add output alias, this alias is used in step net. + auto output_alias_attr = op_desc.mutable_attrs()->Add(); + output_alias_attr->set_type(paddle::framework::AttrType::STRINGS); + *output_alias_attr->mutable_strings()->Add() = "rnn/h"; + output_alias_attr->set_name("out_link_alias"); + + // add memories + auto memories_attr = op_desc.mutable_attrs()->Add(); + memories_attr->set_type(paddle::framework::AttrType::STRINGS); + *memories_attr->mutable_strings()->Add() = "rnn/h"; + memories_attr->set_name("memories"); + + // add history/previous memories + auto pre_memories_attr = op_desc.mutable_attrs()->Add(); + pre_memories_attr->set_type(paddle::framework::AttrType::STRINGS); + *pre_memories_attr->mutable_strings()->Add() = "rnn/h_pre"; + pre_memories_attr->set_name("pre_memories"); + + // add initial memories + auto boot_memories_attr = op_desc.mutable_attrs()->Add(); + boot_memories_attr->set_type(paddle::framework::AttrType::INTS); + *boot_memories_attr->mutable_ints()->Add() = 1; + boot_memories_attr->set_name("boot_memories"); + + // add step net desc + auto step_net_attr = op_desc.mutable_attrs()->Add(); + step_net_attr->set_type(paddle::framework::AttrType::INT); + step_net_attr->set_i(2); + step_net_attr->set_name("step_net"); + + AttributeMap attrs; + attrs["in_links"] = std::vector{0}; + attrs["in_link_alias"] = std::vector{"rnn/x"}; + attrs["out_link_alias"] = std::vector{"rnn/h"}; + attrs["memories"] = std::vector{"rnn/h"}; + attrs["pre_memories"] = std::vector{"h_pre"}; + attrs["boot_memories"] = std::vector{1}; + attrs["step_net"] = 2; + + LOG(INFO) << "rnn_op to init"; + rnn_op_.Init(op_desc, attrs); + LOG(INFO) << "rnn_op finish init"; + } + + OpDesc CreateFcOpDesc() { + OpDesc op_desc; + op_desc.set_type("fc"); + op_desc.add_inputs("rnn/h_pre"); + op_desc.add_inputs("rnn/w"); + op_desc.add_outputs("rnn/s"); + // rnn/s = rnn/h_pre * rnn/w + return op_desc; + } + + OpDesc CreateAddOpDesc() { + OpDesc op_desc; + op_desc.set_type("add"); + op_desc.add_inputs("rnn/x"); + op_desc.add_inputs("rnn/s"); + op_desc.add_outputs("rnn/h"); + // rnn/h = rnn/x + rnn/s + return op_desc; + } + + void CreateStepNet() { + LOG(INFO) << "create variable step_net"; + Variable* net_var = scope_->CreateVariable("step_net"); + NetDesc net_desc; + net_desc.name_ = "rnn"; + net_desc.op_descs.push_back(CreateFcOpDesc()); + net_desc.op_descs.push_back(CreateAddOpDesc()); + net_var->Reset(new PlainNet(net_desc)); + } + + // father scope + std::shared_ptr scope_; + RecurrentOp rnn_op_; +}; + +// TEST_F(RecurrentOpTest, create_op) {} + +TEST_F(RecurrentOpTest, Run) { + OpContext ctx; + ctx.scope = scope_; + rnn_op_.Run(&ctx); +} + +} // namespace framework +} // namespace paddle diff --git a/paddle/framework/tensor.h b/paddle/framework/tensor.h index 62e0710a8244c..8756c5d33c68d 100644 --- a/paddle/framework/tensor.h +++ b/paddle/framework/tensor.h @@ -29,8 +29,6 @@ class Tensor { public: Tensor() : numel_(0), offset_(0) {} - Tensor& operator=(const Tensor& src) = delete; - template const T* data() const { CheckDims(); @@ -130,7 +128,8 @@ class Tensor { public: Deleter(platform::Place place) : place_(place) {} void operator()(T* ptr) { - paddle::memory::Free(place_, static_cast(ptr)); + // paddle::memory::Free(place_, static_cast(ptr)); + free(static_cast(ptr)); } private: @@ -138,9 +137,12 @@ class Tensor { }; public: + // PlaceholderImpl(paddle::platform::Place place, size_t size) + // : ptr_(static_cast(paddle::memory::Alloc(place, size)), + // Deleter(place)), + PlaceholderImpl(paddle::platform::Place place, size_t size) - : ptr_(static_cast(paddle::memory::Alloc(place, size)), - Deleter(place)), + : ptr_(static_cast(malloc(size * sizeof(T))), Deleter(place)), place_(place), size_(size) {} diff --git a/paddle/framework/variable.h b/paddle/framework/variable.h index 72c4a7a2a1d1c..adc00f5492fd4 100644 --- a/paddle/framework/variable.h +++ b/paddle/framework/variable.h @@ -29,6 +29,11 @@ class Variable { return *static_cast(holder_->Ptr()); } + template + void Reset(T* p) { + holder_.reset(new PlaceholderImpl(p)); + } + template T* GetMutable() { if (!IsType()) {