Skip to content

Commit

Permalink
Abstract GetStepScopes and GetMaxSeqLen function
Browse files Browse the repository at this point in the history
  • Loading branch information
luotao1 committed Jul 14, 2017
1 parent f525390 commit 3a27b02
Show file tree
Hide file tree
Showing 2 changed files with 41 additions and 25 deletions.
53 changes: 28 additions & 25 deletions paddle/framework/recurrent_network_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -34,18 +34,15 @@ void RecurrentOp::Run(OpContext* contex) const {
LOG(INFO) << "segment input";
SegmentInputs(scope);

Variable* step_scopes = scope->GetVariable(step_scopes_name_);
PADDLE_ENFORCE(step_scopes, "failed to get step scopes");
// forward
auto dims = Input(scope, inlinks_[0])->GetMutable<Tensor>()->dims();
size_t seq_len = dims[0];
LOG(INFO) << "sequence length " << seq_len;
auto& scopes = *step_scopes->GetMutable<std::vector<ScopePtr>>();
for (size_t step_id = 0; step_id < seq_len; step_id++) {
size_t max_seq_len = GetMaxSeqLen(scope);
LOG(INFO) << "sequence length " << max_seq_len;
auto step_scopes = GetStepScopes(scope);
for (size_t step_id = 0; step_id < max_seq_len; step_id++) {
LOG(INFO) << "run step " << step_id;
ScopePtr step_scope = scopes[step_id];
ScopePtr step_scope = step_scopes[step_id];
// TODO replace memorys' copy with reference
LinkMemories(scope, scopes, step_id);
LinkMemories(scope, step_scopes, step_id);

net->GetMutable<PlainNet>()->Run(step_scope);
}
Expand Down Expand Up @@ -109,15 +106,20 @@ void RecurrentOp::Init(const OpDesc& op_desc, AttributeMap& attrs) {
}
}

size_t RecurrentOp::GetMaxSeqLen(ScopePtr scope) const {
// TODO update this function when using variable-length of sequence.
return Input(scope, inlinks_[0])->GetMutable<Tensor>()->dims()[0];
}

void RecurrentOp::CreateScopes(ScopePtr scope) const {
auto dims = Input(scope, inlinks_[0])->GetMutable<Tensor>()->dims();
size_t seq_len = dims[0];
Variable* scopes_var = scope->GetVariable(step_scopes_name_);
auto step_scopes = scopes_var->GetMutable<std::vector<ScopePtr>>();
size_t max_seq_len = GetMaxSeqLen(scope);
std::vector<ScopePtr>* step_scopes =
scope->GetVariable(step_scopes_name_)
->GetMutable<std::vector<ScopePtr>>();
// TODO Only two scopes are needed for inference, this case will be
// supported later.
if (seq_len > step_scopes->size()) {
for (size_t i = step_scopes->size(); i < seq_len; ++i) {
if (max_seq_len > step_scopes->size()) {
for (size_t i = step_scopes->size(); i < max_seq_len; ++i) {
step_scopes->push_back(std::make_shared<Scope>(scope));
}
}
Expand All @@ -129,17 +131,17 @@ void RecurrentOp::SegmentInputs(ScopePtr scope) const {
PADDLE_ENFORCE(inlinks_.size() == inlink_alias.size(),
"in_links/in_link_alias mismatch.");

Variable* scopes_var = scope->GetVariable(step_scopes_name_);
auto& step_scopes = *scopes_var->GetMutable<std::vector<ScopePtr>>();
auto dims = Input(scope, inlinks_[0])->GetMutable<Tensor>()->dims();
int seq_len = dims[0];
auto step_scopes = GetStepScopes(scope);
size_t max_seq_len = GetMaxSeqLen(scope);
for (size_t i = 0; i < inlinks_.size(); ++i) {
Tensor* scope_input_tensor =
Input(scope, inlinks_[i])->GetMutable<Tensor>();
for (int j = 0; j < seq_len; j++) {
for (size_t j = 0; j < max_seq_len; j++) {
Variable* input_var = step_scopes[j]->CreateVariable(inlink_alias[i]);
Tensor* step_input_tensor = input_var->GetMutable<Tensor>();
*step_input_tensor = scope_input_tensor->Slice(j, j + 1);
// TODO (luotao1): use reshape function to decrease the dims of
// step_input_tensor.
}
}
}
Expand All @@ -149,10 +151,10 @@ void RecurrentOp::ConcatOutputs(ScopePtr scope) const {
PADDLE_ENFORCE(outlinks_.size() == outlink_alias.size(),
"out_links/out_link_alias mismatch.");

Variable* scopes_var = scope->GetVariable(step_scopes_name_);
auto& step_scopes = *scopes_var->GetMutable<std::vector<ScopePtr>>();
auto step_scopes = GetStepScopes(scope);
size_t max_seq_len = GetMaxSeqLen(scope);
// TODO (luotao1): update using CopyFrom function in tensor.
auto dims = Input(scope, inlinks_[0])->GetMutable<Tensor>()->dims();
int seq_len = dims[0];
int batch_size = dims[1];
for (size_t i = 0; i < outlinks_.size(); i++) {
auto output_dims = step_scopes[0]
Expand All @@ -164,8 +166,9 @@ void RecurrentOp::ConcatOutputs(ScopePtr scope) const {
Tensor* output_tensor =
scope->CreateVariable(outlinks_[i])->GetMutable<Tensor>();
float* output = output_tensor->mutable_data<float>(
make_ddim({seq_len, batch_size, output_dim}), platform::CPUPlace());
for (int j = 0; j < seq_len; j++) {
make_ddim({(int)max_seq_len, batch_size, output_dim}),
platform::CPUPlace());
for (size_t j = 0; j < max_seq_len; j++) {
Variable* output_var = step_scopes[j]->GetVariable(outlink_alias[i]);
const float* step_output =
output_var->GetMutable<Tensor>()->data<float>();
Expand Down
13 changes: 13 additions & 0 deletions paddle/framework/recurrent_network_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,11 @@ class RecurrentOp : public OperatorBase {
virtual ~RecurrentOp() {}

protected:
/*
* Get the max sequence length of the scope.
*/
size_t GetMaxSeqLen(ScopePtr scope) const;

/*
* Prepare inputs for each stepnet.
*/
Expand All @@ -153,6 +158,14 @@ class RecurrentOp : public OperatorBase {
*/
void CreateScopes(ScopePtr scope) const;

/*
* Get the step scopes.
*/
inline const std::vector<ScopePtr>& GetStepScopes(ScopePtr scope) const {
return *(scope->GetVariable(step_scopes_name_))
->GetMutable<std::vector<ScopePtr>>();
}

/*
* Link memory in previous step scope to current scope.
*/
Expand Down

0 comments on commit 3a27b02

Please sign in to comment.