Skip to content

Commit

Permalink
refine LinkMemories
Browse files Browse the repository at this point in the history
  • Loading branch information
luotao1 committed Jul 14, 2017
1 parent 3a27b02 commit aede869
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 41 deletions.
60 changes: 23 additions & 37 deletions paddle/framework/recurrent_network_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -40,11 +40,10 @@ void RecurrentOp::Run(OpContext* contex) const {
auto step_scopes = GetStepScopes(scope);
for (size_t step_id = 0; step_id < max_seq_len; step_id++) {
LOG(INFO) << "run step " << step_id;
ScopePtr step_scope = step_scopes[step_id];
// TODO replace memorys' copy with reference
LinkMemories(scope, step_scopes, step_id);
LinkMemories(step_scopes, step_id);

net->GetMutable<PlainNet>()->Run(step_scope);
net->GetMutable<PlainNet>()->Run(step_scopes[step_id]);
}

LOG(INFO) << "concat outputs";
Expand Down Expand Up @@ -177,51 +176,38 @@ void RecurrentOp::ConcatOutputs(ScopePtr scope) const {
}
}

void RecurrentOp::LinkMemories(ScopePtr scope,
std::vector<ScopePtr>& step_scopes,
size_t step) const {
PADDLE_ENFORCE(step < step_scopes.size(),
"step [%d] out of range of step scopes' size [%d]", step,
void RecurrentOp::LinkMemories(std::vector<ScopePtr>& step_scopes,
size_t step_id) const {
PADDLE_ENFORCE(step_id < step_scopes.size(),
"step [%d] out of range of step scopes' size [%d]", step_id,
step_scopes.size());
auto step_scope = step_scopes[step];
// copy boot memory
ScopePtr step_scope = step_scopes[step_id];
for (auto& attr : memory_attrs_) {
Tensor* boot_tensor{nullptr};
if (step == 0) {
PADDLE_ENFORCE(scope->HasVariable(attr.boot_var),
Tensor* pre_memory_tensor =
step_scope->CreateVariable(attr.pre_var)->GetMutable<Tensor>();

if (step_id == 0) {
PADDLE_ENFORCE(step_scope->HasVariable(attr.boot_var),
"memory [%s]'s boot variable [%s] not exists", attr.var,
attr.boot_var);
// update memory's ddim
boot_tensor = scope->CreateVariable(attr.boot_var)->GetMutable<Tensor>();
attr.dims = boot_tensor->dims();
}
Variable* memory_var = step_scope->CreateVariable(attr.pre_var);

// TODO the memory of current step should be allocaled in step net ?
Tensor* cur_memory =
step_scopes[step]->CreateVariable(attr.var)->GetMutable<Tensor>();
cur_memory->mutable_data<float>(attr.dims, platform::CPUPlace());

// copy from boot memory
// TODO support more device
// TODO mutable_data is currently invalid
float* memory_tensor_val =
memory_var->GetMutable<Tensor>()->mutable_data<float>(
attr.dims, platform::CPUPlace());
if (step == 0) {
Tensor* boot_tensor =
step_scope->CreateVariable(attr.boot_var)->GetMutable<Tensor>();
PADDLE_ENFORCE(boot_tensor, "boot_tensor should be retrieved before");
// copy from boot memory
std::memcpy(memory_tensor_val, boot_tensor->data<float>(),
product(attr.dims));
pre_memory_tensor->ShareDataFrom(*boot_tensor);
} else {
// copy from previous step scope's memory to this scope's
// `pre - memory`
Tensor* pre_step_memory =
step_scopes[step - 1]->GetVariable(attr.var)->GetMutable<Tensor>();

std::memcpy(memory_tensor_val, pre_step_memory->data<float>(),
product(attr.dims));
step_scopes[step_id - 1]->GetVariable(attr.var)->GetMutable<Tensor>();
pre_memory_tensor->ShareDataFrom(*pre_step_memory);
}

// TODO the memory of current step should be allocated in step net ?
Tensor* cur_memory_tensor =
step_scopes[step_id]->CreateVariable(attr.var)->GetMutable<Tensor>();
cur_memory_tensor->mutable_data<float>(pre_memory_tensor->dims(),
platform::CPUPlace());
}
}

Expand Down
5 changes: 1 addition & 4 deletions paddle/framework/recurrent_network_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -169,8 +169,7 @@ class RecurrentOp : public OperatorBase {
/*
* Link memory in previous step scope to current scope.
*/
void LinkMemories(ScopePtr scope, std::vector<ScopePtr>& step_scopes,
size_t step) const;
void LinkMemories(std::vector<ScopePtr>& step_scopes, size_t step_id) const;

private:
/*
Expand All @@ -188,8 +187,6 @@ class RecurrentOp : public OperatorBase {
// name of the variables to init this memory (same role of `boot_layer` in
// PaddlePaddle), which is store in father's scope.
std::string boot_var;
// this dim will infered from boot memories's tensor in the first step.
DDim dims;
};

/*
Expand Down

0 comments on commit aede869

Please sign in to comment.