Skip to content

Commit

Permalink
init reorder boot
Browse files Browse the repository at this point in the history
  • Loading branch information
Superjomn committed Oct 10, 2017
1 parent b5e8a8a commit 50c033b
Showing 1 changed file with 17 additions and 2 deletions.
19 changes: 17 additions & 2 deletions paddle/operators/dynamic_recurrent_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,17 @@ inline void CreateVariables(Scope& scope,
}
}

/*
* The inputs with sequence should be reordered when they are split, so the
* boot_states should be reordered in the same order.
*
* NOTE This may require that the `pre_state` of the first time step should just
* copy the `boot_state` rather than reference it, for that the content should
* be reordered, but the RNN op should not change the `boot_state` as an input
* variable's content.
*/
inline void ReorderBootState();

} // namespace detail

class DynamicRecurrentOpProtoAndCheckerMaker
Expand Down Expand Up @@ -79,8 +90,7 @@ void DynamicRecurrentOp::SplitInputs() const {
// TODO(superjom) make level a config
// TODO(superjom) check all the inputs has the same LoD
int level = 0;
const auto& inlinks = cache_.inlinks;
for (auto& item : inlinks) {
for (auto& item : cache_.inlinks) {
const auto& var = item.second;
const auto& tensor = var->Get<LoDTensor>();
TensorArray& ta = step_inputs_[item.first];
Expand Down Expand Up @@ -115,6 +125,11 @@ void DynamicRecurrentOp::WriteStepInputs() const {
}

void DynamicRecurrentOp::WriteStepOutputs() const {
// initialize step outputs
PADDLE_ENFORCE_GT(cache_.outlinks.size(), 0UL);
for (const auto& item : cache_.outlinks) {
step_outputs_.emplace(item.first, TensorArray());
}
for (size_t step = 0; step < cache_.scopes->size(); step++) {
auto& scope = cache_.GetScope(step);
for (auto& item : step_outputs_) {
Expand Down

0 comments on commit 50c033b

Please sign in to comment.