From 3a27b0200ff7a88a30aef1f86e7211d2a4f34640 Mon Sep 17 00:00:00 2001 From: Luo Tao Date: Fri, 14 Jul 2017 16:30:44 +0800 Subject: [PATCH] Abstract GetStepScopes and GetMaxSeqLen function --- paddle/framework/recurrent_network_op.cc | 53 +++++++++++++----------- paddle/framework/recurrent_network_op.h | 13 ++++++ 2 files changed, 41 insertions(+), 25 deletions(-) diff --git a/paddle/framework/recurrent_network_op.cc b/paddle/framework/recurrent_network_op.cc index ae4f8000d2f71..52fb869663308 100644 --- a/paddle/framework/recurrent_network_op.cc +++ b/paddle/framework/recurrent_network_op.cc @@ -34,18 +34,15 @@ void RecurrentOp::Run(OpContext* contex) const { LOG(INFO) << "segment input"; SegmentInputs(scope); - Variable* step_scopes = scope->GetVariable(step_scopes_name_); - PADDLE_ENFORCE(step_scopes, "failed to get step scopes"); // forward - auto dims = Input(scope, inlinks_[0])->GetMutable()->dims(); - size_t seq_len = dims[0]; - LOG(INFO) << "sequence length " << seq_len; - auto& scopes = *step_scopes->GetMutable>(); - for (size_t step_id = 0; step_id < seq_len; step_id++) { + size_t max_seq_len = GetMaxSeqLen(scope); + LOG(INFO) << "sequence length " << max_seq_len; + auto step_scopes = GetStepScopes(scope); + for (size_t step_id = 0; step_id < max_seq_len; step_id++) { LOG(INFO) << "run step " << step_id; - ScopePtr step_scope = scopes[step_id]; + ScopePtr step_scope = step_scopes[step_id]; // TODO replace memorys' copy with reference - LinkMemories(scope, scopes, step_id); + LinkMemories(scope, step_scopes, step_id); net->GetMutable()->Run(step_scope); } @@ -109,15 +106,20 @@ void RecurrentOp::Init(const OpDesc& op_desc, AttributeMap& attrs) { } } +size_t RecurrentOp::GetMaxSeqLen(ScopePtr scope) const { + // TODO update this function when using variable-length of sequence. + return Input(scope, inlinks_[0])->GetMutable()->dims()[0]; +} + void RecurrentOp::CreateScopes(ScopePtr scope) const { - auto dims = Input(scope, inlinks_[0])->GetMutable()->dims(); - size_t seq_len = dims[0]; - Variable* scopes_var = scope->GetVariable(step_scopes_name_); - auto step_scopes = scopes_var->GetMutable>(); + size_t max_seq_len = GetMaxSeqLen(scope); + std::vector* step_scopes = + scope->GetVariable(step_scopes_name_) + ->GetMutable>(); // TODO Only two scopes are needed for inference, this case will be // supported later. - if (seq_len > step_scopes->size()) { - for (size_t i = step_scopes->size(); i < seq_len; ++i) { + if (max_seq_len > step_scopes->size()) { + for (size_t i = step_scopes->size(); i < max_seq_len; ++i) { step_scopes->push_back(std::make_shared(scope)); } } @@ -129,17 +131,17 @@ void RecurrentOp::SegmentInputs(ScopePtr scope) const { PADDLE_ENFORCE(inlinks_.size() == inlink_alias.size(), "in_links/in_link_alias mismatch."); - Variable* scopes_var = scope->GetVariable(step_scopes_name_); - auto& step_scopes = *scopes_var->GetMutable>(); - auto dims = Input(scope, inlinks_[0])->GetMutable()->dims(); - int seq_len = dims[0]; + auto step_scopes = GetStepScopes(scope); + size_t max_seq_len = GetMaxSeqLen(scope); for (size_t i = 0; i < inlinks_.size(); ++i) { Tensor* scope_input_tensor = Input(scope, inlinks_[i])->GetMutable(); - for (int j = 0; j < seq_len; j++) { + for (size_t j = 0; j < max_seq_len; j++) { Variable* input_var = step_scopes[j]->CreateVariable(inlink_alias[i]); Tensor* step_input_tensor = input_var->GetMutable(); *step_input_tensor = scope_input_tensor->Slice(j, j + 1); + // TODO (luotao1): use reshape function to decrease the dims of + // step_input_tensor. } } } @@ -149,10 +151,10 @@ void RecurrentOp::ConcatOutputs(ScopePtr scope) const { PADDLE_ENFORCE(outlinks_.size() == outlink_alias.size(), "out_links/out_link_alias mismatch."); - Variable* scopes_var = scope->GetVariable(step_scopes_name_); - auto& step_scopes = *scopes_var->GetMutable>(); + auto step_scopes = GetStepScopes(scope); + size_t max_seq_len = GetMaxSeqLen(scope); + // TODO (luotao1): update using CopyFrom function in tensor. auto dims = Input(scope, inlinks_[0])->GetMutable()->dims(); - int seq_len = dims[0]; int batch_size = dims[1]; for (size_t i = 0; i < outlinks_.size(); i++) { auto output_dims = step_scopes[0] @@ -164,8 +166,9 @@ void RecurrentOp::ConcatOutputs(ScopePtr scope) const { Tensor* output_tensor = scope->CreateVariable(outlinks_[i])->GetMutable(); float* output = output_tensor->mutable_data( - make_ddim({seq_len, batch_size, output_dim}), platform::CPUPlace()); - for (int j = 0; j < seq_len; j++) { + make_ddim({(int)max_seq_len, batch_size, output_dim}), + platform::CPUPlace()); + for (size_t j = 0; j < max_seq_len; j++) { Variable* output_var = step_scopes[j]->GetVariable(outlink_alias[i]); const float* step_output = output_var->GetMutable()->data(); diff --git a/paddle/framework/recurrent_network_op.h b/paddle/framework/recurrent_network_op.h index b81ed49e7f3e7..2f62f365e42b8 100644 --- a/paddle/framework/recurrent_network_op.h +++ b/paddle/framework/recurrent_network_op.h @@ -133,6 +133,11 @@ class RecurrentOp : public OperatorBase { virtual ~RecurrentOp() {} protected: + /* + * Get the max sequence length of the scope. + */ + size_t GetMaxSeqLen(ScopePtr scope) const; + /* * Prepare inputs for each stepnet. */ @@ -153,6 +158,14 @@ class RecurrentOp : public OperatorBase { */ void CreateScopes(ScopePtr scope) const; + /* + * Get the step scopes. + */ + inline const std::vector& GetStepScopes(ScopePtr scope) const { + return *(scope->GetVariable(step_scopes_name_)) + ->GetMutable>(); + } + /* * Link memory in previous step scope to current scope. */