From aede869805d67b9869912eacaad0c2b090f9508f Mon Sep 17 00:00:00 2001 From: Luo Tao Date: Fri, 14 Jul 2017 17:59:02 +0800 Subject: [PATCH] refine LinkMemories --- paddle/framework/recurrent_network_op.cc | 60 +++++++++--------------- paddle/framework/recurrent_network_op.h | 5 +- 2 files changed, 24 insertions(+), 41 deletions(-) diff --git a/paddle/framework/recurrent_network_op.cc b/paddle/framework/recurrent_network_op.cc index 52fb869663308..316d5deeea503 100644 --- a/paddle/framework/recurrent_network_op.cc +++ b/paddle/framework/recurrent_network_op.cc @@ -40,11 +40,10 @@ void RecurrentOp::Run(OpContext* contex) const { auto step_scopes = GetStepScopes(scope); for (size_t step_id = 0; step_id < max_seq_len; step_id++) { LOG(INFO) << "run step " << step_id; - ScopePtr step_scope = step_scopes[step_id]; // TODO replace memorys' copy with reference - LinkMemories(scope, step_scopes, step_id); + LinkMemories(step_scopes, step_id); - net->GetMutable()->Run(step_scope); + net->GetMutable()->Run(step_scopes[step_id]); } LOG(INFO) << "concat outputs"; @@ -177,51 +176,38 @@ void RecurrentOp::ConcatOutputs(ScopePtr scope) const { } } -void RecurrentOp::LinkMemories(ScopePtr scope, - std::vector& step_scopes, - size_t step) const { - PADDLE_ENFORCE(step < step_scopes.size(), - "step [%d] out of range of step scopes' size [%d]", step, +void RecurrentOp::LinkMemories(std::vector& step_scopes, + size_t step_id) const { + PADDLE_ENFORCE(step_id < step_scopes.size(), + "step [%d] out of range of step scopes' size [%d]", step_id, step_scopes.size()); - auto step_scope = step_scopes[step]; - // copy boot memory + ScopePtr step_scope = step_scopes[step_id]; for (auto& attr : memory_attrs_) { - Tensor* boot_tensor{nullptr}; - if (step == 0) { - PADDLE_ENFORCE(scope->HasVariable(attr.boot_var), + Tensor* pre_memory_tensor = + step_scope->CreateVariable(attr.pre_var)->GetMutable(); + + if (step_id == 0) { + PADDLE_ENFORCE(step_scope->HasVariable(attr.boot_var), "memory [%s]'s boot variable [%s] not exists", attr.var, attr.boot_var); - // update memory's ddim - boot_tensor = scope->CreateVariable(attr.boot_var)->GetMutable(); - attr.dims = boot_tensor->dims(); - } - Variable* memory_var = step_scope->CreateVariable(attr.pre_var); - - // TODO the memory of current step should be allocaled in step net ? - Tensor* cur_memory = - step_scopes[step]->CreateVariable(attr.var)->GetMutable(); - cur_memory->mutable_data(attr.dims, platform::CPUPlace()); - - // copy from boot memory - // TODO support more device - // TODO mutable_data is currently invalid - float* memory_tensor_val = - memory_var->GetMutable()->mutable_data( - attr.dims, platform::CPUPlace()); - if (step == 0) { + Tensor* boot_tensor = + step_scope->CreateVariable(attr.boot_var)->GetMutable(); PADDLE_ENFORCE(boot_tensor, "boot_tensor should be retrieved before"); // copy from boot memory - std::memcpy(memory_tensor_val, boot_tensor->data(), - product(attr.dims)); + pre_memory_tensor->ShareDataFrom(*boot_tensor); } else { // copy from previous step scope's memory to this scope's // `pre - memory` Tensor* pre_step_memory = - step_scopes[step - 1]->GetVariable(attr.var)->GetMutable(); - - std::memcpy(memory_tensor_val, pre_step_memory->data(), - product(attr.dims)); + step_scopes[step_id - 1]->GetVariable(attr.var)->GetMutable(); + pre_memory_tensor->ShareDataFrom(*pre_step_memory); } + + // TODO the memory of current step should be allocated in step net ? + Tensor* cur_memory_tensor = + step_scopes[step_id]->CreateVariable(attr.var)->GetMutable(); + cur_memory_tensor->mutable_data(pre_memory_tensor->dims(), + platform::CPUPlace()); } } diff --git a/paddle/framework/recurrent_network_op.h b/paddle/framework/recurrent_network_op.h index 2f62f365e42b8..857bb3164d638 100644 --- a/paddle/framework/recurrent_network_op.h +++ b/paddle/framework/recurrent_network_op.h @@ -169,8 +169,7 @@ class RecurrentOp : public OperatorBase { /* * Link memory in previous step scope to current scope. */ - void LinkMemories(ScopePtr scope, std::vector& step_scopes, - size_t step) const; + void LinkMemories(std::vector& step_scopes, size_t step_id) const; private: /* @@ -188,8 +187,6 @@ class RecurrentOp : public OperatorBase { // name of the variables to init this memory (same role of `boot_layer` in // PaddlePaddle), which is store in father's scope. std::string boot_var; - // this dim will infered from boot memories's tensor in the first step. - DDim dims; }; /*