Skip to content

Commit

Permalink
add implementation
Browse files Browse the repository at this point in the history
  • Loading branch information
Superjomn committed Sep 30, 2017
1 parent 661f30b commit d77fb89
Show file tree
Hide file tree
Showing 6 changed files with 296 additions and 4 deletions.
4 changes: 2 additions & 2 deletions paddle/framework/tensor_array.h
Original file line number Diff line number Diff line change
Expand Up @@ -90,12 +90,12 @@ class TensorArray {
LoDTensor Stack() const;

/*
* Unpacks the given division of a rank-`R` tensor into rank-`(R-1)` tensors.
* Unstacks the given division of a rank-`R` tensor into rank-`(R-1)` tensors.
*/
void Unstack(const LoDTensor &source) const;

/*
* Unpacks the given division of a rank-`R` tensor into rank-`(R-1)` tensors,
* Unstacks the given division of a rank-`R` tensor into rank-`(R-1)` tensors,
* with memory of tensors shared.
*/
void UnstackShared(const LoDTensor &source) const;
Expand Down
164 changes: 164 additions & 0 deletions paddle/operators/dynamic_recurrent_op.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,164 @@
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#include "paddle/operators/dynamic_recurrent_op.h"

namespace paddle {
namespace operators {
void DynamicRecurrentOp::Run(const Scope& scope,
const platform::DeviceContext& dev_ctx) const {
arg_cache_.Init(kArgName, *this, scope, &arg_);
SplitInputs(scope);
CreateScopes(scope);
WriteStepInputs(scope);

// call stepnet
}

void DynamicRecurrentOp::SplitInputs(const Scope& scope) const {
// TODO(superjom) make level a config
// TODO(superjom) check all the inputs has the same LoD
int level = 0;
const auto& inlinks = arg_cache_.inlinks;
for (auto& item : inlinks) {
const auto& var = item.second;
const auto& tensor = var->Get<LoDTensor>();
TensorArray& ta = step_inputs_[item.first];
dy_seq_metas_[item.first] =
ta.Unpack(tensor, level, true /*length_descend*/);
}
}

void DynamicRecurrentOp::WriteStepInputs(const Scope& scope) const {
const auto& inlinks = arg_cache_.inlinks;
for (auto& item : inlinks) {
TensorArray& ta = step_inputs_[item.first];
for (size_t step = 0; step < ta.size(); step++) {
auto tensor = ta.Read(step);
auto& step_scope = arg_cache_.GetScope(step);
step_scope.FindVar(item.first)
->GetMutable<LoDTensor>()
->ShareDataWith<value_type>(tensor);
}
}
}

void DynamicRecurrentOp::WriteStepOutputs(const Scope& scope) const {
for (size_t step = 0; step < arg_cache_.scopes->size(); step++) {
auto& scope = arg_cache_.GetScope(step);
for (auto& item : step_outputs_) {
const auto& step_output = scope.FindVar(item.first)->Get<LoDTensor>();
item.second.WriteShared(step, step_output);
}
}
}

void DynamicRecurrentOp::CreateScopes(const Scope& scope) const {
for (size_t step = arg_cache_.scopes->size(); step < step_inputs_.size();
step++) {
CreateTempInputsInScope(arg_cache_.GetScope(step));
CreateTempOutputsInScope(arg_cache_.GetScope(step));
}
}

void DynamicRecurrentOp::ConcatOutputs(const Scope& scope) const {
// TODO(superjom) transform this to a config
int level = 0;
// TODO(superjom) pass in some lod
// just a placeholder
framework::LoD lod;
for (auto& item : step_outputs_) {
auto tensor = item.second.Pack(level, dy_seq_metas_[item.first], lod);
auto& output = arg_cache_.outlinks[item.first]->Get<LoDTensor>();
const_cast<LoDTensor*>(&output)->ShareDataWith<value_type>(tensor);
}
}

void DynamicRecurrentOp::InitStates(Scope* step_scopes) const {}

void DynamicRecurrentOp::ArgCache::Init(
const rnn::ArgumentName& name, const paddle::framework::OperatorBase& op,
const paddle::framework::Scope& scope, const rnn::Argument* arg) {
InitArgument(name, op, arg);
CacheScopes(scope, *arg);
CacheInlinks(scope, arg->inlinks);
CacheOutlinks(scope, arg->outlinks);
}

// NOTE(superjom) should be called after SplitInputs
void DynamicRecurrentOp::CreateTempInputsInScope(Scope& scope) const {
for (auto& input : stepnet_->Inputs()) {
for (const std::string& var : input.second) {
if (!scope.FindVar(var)) {
scope.NewVar(var)->GetMutable<LoDTensor>();
}
}
}
}

void DynamicRecurrentOp::CreateTempOutputsInScope(Scope& scope) const {
for (auto& input : stepnet_->Outputs()) {
for (const std::string& var : input.second) {
if (!scope.FindVar(var)) {
scope.NewVar(var)->GetMutable<LoDTensor>();
}
}
}
}

void DynamicRecurrentOp::ArgCache::InitArgument(const rnn::ArgumentName& name,
const OperatorBase& op,
const rnn::Argument* arg) {
rnn::InitArgument(name, arg, op, false /*is_grad*/);
}

void DynamicRecurrentOp::ArgCache::CacheScopes(const Scope& scope,
const rnn::Argument& arg) {
auto scopes_var = scope.FindVar(arg.step_scopes);
PADDLE_ENFORCE(scopes_var != nullptr,
"the step_scopes output argument [%s] should be created first "
"by framework.",
arg.step_scopes);
scopes = scopes_var->GetMutable<std::vector<Scope*>>();
}

void DynamicRecurrentOp::ArgCache::CacheInlinks(
const Scope& scope, const std::vector<std::string>& names) {
for (auto name : names) {
auto* var = GetVariable(scope, name);
inlinks[name] = var;
}
}

void DynamicRecurrentOp::ArgCache::CacheOutlinks(
const Scope& scope, const std::vector<std::string>& names) {
for (auto name : names) {
auto* var = GetVariable(scope, name);
inlinks[name] = var;
}
}

Variable* DynamicRecurrentOp::ArgCache::GetVariable(const Scope& scope,
const std::string& name) {
auto* var = scope.FindVar(name);
PADDLE_ENFORCE_NOT_NULL(var, "variable [%s] not exist in scope", name);
return var;
}

const rnn::ArgumentName DynamicRecurrentOp::kArgName{
"step_net", "step_scopes", "inlinks", "outlinks",
"memories", "pre_memories", "boot_memories"};

} // namespace operators
} // namespace paddle
128 changes: 128 additions & 0 deletions paddle/operators/dynamic_recurrent_op.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,128 @@
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#pragma once

#include "paddle/framework/lod_tensor.h"
#include "paddle/framework/operator.h"
#include "paddle/framework/tensor_array.h"
#include "paddle/framework/variable.h"
#include "paddle/operators/rnn/recurrent_op_utils.h"

namespace paddle {
namespace operators {

using framework::Scope;
using framework::TensorArray;
using framework::LoDTensor;
using framework::Variable;

class DynamicRecurrentOp : public framework::OperatorBase {
public:
static const rnn::ArgumentName kArgName;
using value_type = float;

void Run(const Scope& scope,
const platform::DeviceContext& dev_ctx) const override;

/*
* Split the inputs(LoDTensors) to segments for each time step.
*/
void SplitInputs(const Scope& scope) const;

/*
* Create step-scopes to store temporary outputs in each time steps.
*/
void CreateScopes(const Scope& scope) const;

/*
* Link TensorArray steps to the corresponding variables located in
* step-scopes.
*/
void WriteStepInputs(const Scope& scope) const;

/*
* Write output of each step to the corresponding TensorArray.
*/
void WriteStepOutputs(const Scope& scope) const;

/*
* Initialize the states, each state will have a corresponding pre-state,
* which share the memory with the state in the previous time state. The
* pre-state in the first time step will be initialized with an zero tensor or
* a tensor in parent scope if is provided.
*/
void InitStates(Scope* step_scopes) const;

/*
* Concatenate outputs in each time step and generate a LoDTensor.
*/
void ConcatOutputs(const Scope& scope) const;

/*
* set a stepnet that is created according to a RecurrentOp's stepnet.
*/
void SetStepNet(std::unique_ptr<OperatorBase> net) {
stepnet_ = std::move(net);
}
const OperatorBase& GetStepNet() const { return *stepnet_; }

/*
* Create the temporary inputs of a step-net in a step-scope.
*/
void CreateTempInputsInScope(Scope& scope) const;

/*
* Create the temporary outputs of a step-net in a step-scope.
*/
void CreateTempOutputsInScope(Scope& scope) const;

protected:
struct ArgCache {
std::vector<Scope*>* scopes;
std::map<std::string, Variable*> inlinks;
std::map<std::string, Variable*> outlinks;

void Init(const rnn::ArgumentName& name, const OperatorBase& op,
const Scope& scope, const rnn::Argument* arg);

Scope& GetScope(size_t index) {
PADDLE_ENFORCE_LT(index, scopes->size());
return *scopes->at(index);
}

protected:
void InitArgument(const rnn::ArgumentName& name, const OperatorBase& op,
const rnn::Argument* arg);
void CacheScopes(const Scope& scope, const rnn::Argument& arg);
void CacheInlinks(const Scope& scope,
const std::vector<std::string>& names);
void CacheOutlinks(const Scope& scope,
const std::vector<std::string>& names);
Variable* GetVariable(const Scope& scope, const std::string& name);
};

private:
std::unique_ptr<OperatorBase> stepnet_;
mutable TensorArray states_;
mutable std::map<std::string, TensorArray> step_inputs_;
mutable std::map<std::string, TensorArray> step_outputs_;
mutable std::map<std::string, std::vector<framework::DySeqMeta>>
dy_seq_metas_;
rnn::Argument arg_;
mutable ArgCache arg_cache_;
};

} // namespace operators
} // namespace paddle
Empty file.
2 changes: 1 addition & 1 deletion paddle/operators/rnn/recurrent_op_utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ void LinkMemories(const std::vector<Scope*>& scopes,
}
}

void InitArgument(const ArgumentName& name, Argument* arg,
void InitArgument(const ArgumentName& name, const Argument* arg,
const framework::OperatorBase& op, bool is_grad) {
arg->step_scopes =
is_grad ? op.Input(name.step_scopes) : op.Output(name.step_scopes);
Expand Down
2 changes: 1 addition & 1 deletion paddle/operators/rnn/recurrent_op_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ void LinkMemories(const std::vector<Scope*>& step_scopes,
const std::vector<MemoryAttr>& memories, const size_t step_id,
const int offset, bool infer_shape_mode);

void InitArgument(const ArgumentName& name, Argument* arg,
void InitArgument(const ArgumentName& name, const Argument* arg,
const framework::OperatorBase& op, bool is_grad = false);

} // namespace rnn
Expand Down

0 comments on commit d77fb89

Please sign in to comment.