From 45c02b217f7b42f294c14196be086746e8d10585 Mon Sep 17 00:00:00 2001 From: Guo Sheng Date: Sat, 12 Oct 2019 10:53:16 +0800 Subject: [PATCH 1/4] Add seq2seq api related code (#19820) --- paddle/fluid/API.spec | 38 +- paddle/fluid/operators/assign_op.cc | 6 +- .../fill_constant_batch_size_like_op.cc | 9 +- .../fill_constant_batch_size_like_op.cu.cc | 4 +- .../fill_constant_batch_size_like_op.h | 19 +- paddle/fluid/operators/fill_constant_op.cu.cc | 1 + paddle/fluid/operators/gather_nd_op.cc | 11 +- paddle/fluid/operators/gather_nd_op.cu | 1 + paddle/fluid/operators/gather_tree_op.cc | 78 ++ paddle/fluid/operators/gather_tree_op.cu | 80 ++ paddle/fluid/operators/gather_tree_op.h | 58 + .../operators/reduce_ops/reduce_all_op.cc | 4 +- .../operators/reduce_ops/reduce_any_op.cc | 4 +- paddle/fluid/operators/reduce_ops/reduce_op.h | 27 +- .../operators/tensor_array_to_tensor_op.cc | 75 +- python/paddle/fluid/layers/__init__.py | 4 + python/paddle/fluid/layers/nn.py | 76 ++ python/paddle/fluid/layers/rnn.py | 1165 +++++++++++++++++ python/paddle/fluid/layers/tensor.py | 111 +- python/paddle/fluid/layers/utils.py | 172 +++ .../tests/unittests/test_gather_tree_op.py | 65 + .../tests/unittests/test_rnn_cell_api.py | 249 ++++ .../tests/unittests/test_rnn_decode_api.py | 214 +++ .../unittests/test_tensor_array_to_tensor.py | 81 +- 24 files changed, 2480 insertions(+), 72 deletions(-) create mode 100644 paddle/fluid/operators/gather_tree_op.cc create mode 100644 paddle/fluid/operators/gather_tree_op.cu create mode 100644 paddle/fluid/operators/gather_tree_op.h create mode 100644 python/paddle/fluid/layers/rnn.py create mode 100644 python/paddle/fluid/tests/unittests/test_gather_tree_op.py create mode 100644 python/paddle/fluid/tests/unittests/test_rnn_cell_api.py create mode 100644 python/paddle/fluid/tests/unittests/test_rnn_decode_api.py diff --git a/paddle/fluid/API.spec b/paddle/fluid/API.spec index 7394f4c2090b2..a6272ae5ca766 100644 --- a/paddle/fluid/API.spec +++ b/paddle/fluid/API.spec @@ -306,6 +306,7 @@ paddle.fluid.layers.deformable_roi_pooling (ArgSpec(args=['input', 'rois', 'tran paddle.fluid.layers.filter_by_instag (ArgSpec(args=['ins', 'ins_tag', 'filter_tag', 'is_lod'], varargs=None, keywords=None, defaults=None), ('document', '7703a2088af8de4128b143ff1164ca4a')) paddle.fluid.layers.shard_index (ArgSpec(args=['input', 'index_num', 'nshards', 'shard_id', 'ignore_value'], varargs=None, keywords=None, defaults=(-1,)), ('document', '3c6b30e9cd57b38d4a5fa1ade887f779')) paddle.fluid.layers.hard_swish (ArgSpec(args=['x', 'threshold', 'scale', 'offset', 'name'], varargs=None, keywords=None, defaults=(6.0, 6.0, 3.0, None)), ('document', 'bd763b9ca99239d624c3cb4626e3627a')) +paddle.fluid.layers.gather_tree (ArgSpec(args=['ids', 'parents'], varargs=None, keywords=None, defaults=None), ('document', '201b54fa7512305078c70a6610beaead')) paddle.fluid.layers.mse_loss (ArgSpec(args=['input', 'label'], varargs=None, keywords=None, defaults=None), ('document', '88b967ef5132567396062d5d654b3064')) paddle.fluid.layers.uniform_random (ArgSpec(args=['shape', 'dtype', 'min', 'max', 'seed'], varargs=None, keywords=None, defaults=('float32', -1.0, 1.0, 0)), ('document', '126ede8ce0e751244b1b54cd359c89d7')) paddle.fluid.layers.data (ArgSpec(args=['name', 'shape', 'append_batch_size', 'dtype', 'lod_level', 'type', 'stop_gradient'], varargs=None, keywords=None, defaults=(True, 'float32', 0, VarType.LOD_TENSOR, True)), ('document', '9d7806e31bdf727c1a23b8782a09b545')) @@ -318,11 +319,11 @@ paddle.fluid.layers.create_tensor (ArgSpec(args=['dtype', 'name', 'persistable'] paddle.fluid.layers.create_parameter (ArgSpec(args=['shape', 'dtype', 'name', 'attr', 'is_bias', 'default_initializer'], varargs=None, keywords=None, defaults=(None, None, False, None)), ('document', '021272f30e0cdf7503586815378abfb8')) paddle.fluid.layers.create_global_var (ArgSpec(args=['shape', 'value', 'dtype', 'persistable', 'force_cpu', 'name'], varargs=None, keywords=None, defaults=(False, False, None)), ('document', '47ea8b8c91879e50c9036e418b00ef4a')) paddle.fluid.layers.cast (ArgSpec(args=['x', 'dtype'], varargs=None, keywords=None, defaults=None), ('document', '45df178cbd8c302f92c30ebdaaa6fa8a')) -paddle.fluid.layers.tensor_array_to_tensor (ArgSpec(args=['input', 'axis', 'name'], varargs=None, keywords=None, defaults=(1, None)), ('document', 'dd7d2f1e12a8a4225d017209866e5621')) +paddle.fluid.layers.tensor_array_to_tensor (ArgSpec(args=['input', 'axis', 'name', 'use_stack'], varargs=None, keywords=None, defaults=(1, None, False)), ('document', '4aa82374218ccf593bb8011df79c71e3')) paddle.fluid.layers.concat (ArgSpec(args=['input', 'axis', 'name'], varargs=None, keywords=None, defaults=(0, None)), ('document', 'ec7d6e716fb29ef1e73e1e3efa5ca46b')) paddle.fluid.layers.sums (ArgSpec(args=['input', 'out'], varargs=None, keywords=None, defaults=(None,)), ('document', '5df743d578638cd2bbb9369499b44af4')) paddle.fluid.layers.assign (ArgSpec(args=['input', 'output'], varargs=None, keywords=None, defaults=(None,)), ('document', '8bd94aef4e123986d9a8c29f67b5532b')) -paddle.fluid.layers.fill_constant_batch_size_like (ArgSpec(args=['input', 'shape', 'dtype', 'value', 'input_dim_idx', 'output_dim_idx'], varargs=None, keywords=None, defaults=(0, 0)), ('document', '37a288e4400f6d5510e982827461c11b')) +paddle.fluid.layers.fill_constant_batch_size_like (ArgSpec(args=['input', 'shape', 'dtype', 'value', 'input_dim_idx', 'output_dim_idx', 'force_cpu'], varargs=None, keywords=None, defaults=(0, 0, False)), ('document', '2bb57637664173fee5f654e55896aec6')) paddle.fluid.layers.fill_constant (ArgSpec(args=['shape', 'dtype', 'value', 'force_cpu', 'out'], varargs=None, keywords=None, defaults=(False, None)), ('document', '66e1e468666dd47e5b2715226cebeac0')) paddle.fluid.layers.argmin (ArgSpec(args=['x', 'axis'], varargs=None, keywords=None, defaults=(0,)), ('document', '53629e27597e5dfb7020aac5bc639ebb')) paddle.fluid.layers.argmax (ArgSpec(args=['x', 'axis'], varargs=None, keywords=None, defaults=(0,)), ('document', 'd9a89fbedbaebd5f65897ac75ee636f3')) @@ -467,6 +468,39 @@ paddle.fluid.layers.MultivariateNormalDiag.entropy (ArgSpec(args=['self'], varar paddle.fluid.layers.MultivariateNormalDiag.kl_divergence (ArgSpec(args=['self', 'other'], varargs=None, keywords=None, defaults=None), ('document', 'd9190d29dbd54c81f747a6436c35f062')) paddle.fluid.layers.MultivariateNormalDiag.log_prob (ArgSpec(args=['self', 'value'], varargs=None, keywords=None, defaults=None), ('document', 'c0edd2e2fc76711477b32dc4da9de768')) paddle.fluid.layers.MultivariateNormalDiag.sample (ArgSpec(args=['self'], varargs=None, keywords=None, defaults=None), ('document', '08a2bbcaa20ee176ee7ec3d05737a0f6')) +paddle.fluid.layers.RNNCell ('paddle.fluid.layers.rnn.RNNCell', ('document', '2c3a2d3ecb4a3cec130395e7df0bd5c9')) +paddle.fluid.layers.RNNCell.__init__ +paddle.fluid.layers.RNNCell.call (ArgSpec(args=['self', 'inputs', 'states'], varargs=None, keywords='kwargs', defaults=None), ('document', '3ac714b638258c520d66f682be67b658')) +paddle.fluid.layers.RNNCell.get_initial_states (ArgSpec(args=['self', 'batch_ref', 'shape', 'dtype', 'init_value'], varargs=None, keywords=None, defaults=(None, None, 0)), ('document', '003d1b4c99128f798ac0b0eecc81c489')) +paddle.fluid.layers.GRUCell ('paddle.fluid.layers.rnn.GRUCell', ('document', '7b2902a91258c4688a879805290adc00')) +paddle.fluid.layers.GRUCell.__init__ (ArgSpec(args=['self', 'hidden_size', 'param_attr', 'bias_attr', 'gate_activation', 'activation', 'dtype', 'name'], varargs=None, keywords=None, defaults=(None, None, None, None, 'float32', 'GRUCell')), ('document', '3624a6c93b4a999d0d809eb1a66d272e')) +paddle.fluid.layers.GRUCell.call (ArgSpec(args=['self', 'inputs', 'states'], varargs=None, keywords=None, defaults=None), ('document', '6094ab09a56c732c76abb5105327ea54')) +paddle.fluid.layers.GRUCell.get_initial_states (ArgSpec(args=['self', 'batch_ref', 'shape', 'dtype', 'init_value'], varargs=None, keywords=None, defaults=(None, None, 0)), ('document', '003d1b4c99128f798ac0b0eecc81c489')) +paddle.fluid.layers.LSTMCell ('paddle.fluid.layers.rnn.LSTMCell', ('document', '5cbd87bce446ba0f50398ce2772d43e9')) +paddle.fluid.layers.LSTMCell.__init__ (ArgSpec(args=['self', 'hidden_size', 'param_attr', 'bias_attr', 'gate_activation', 'activation', 'forget_bias', 'dtype', 'name'], varargs=None, keywords=None, defaults=(None, None, None, None, 1.0, 'float32', 'LSTMCell')), ('document', '9015961869b436d2739a0347618028e3')) +paddle.fluid.layers.LSTMCell.call (ArgSpec(args=['self', 'inputs', 'states'], varargs=None, keywords=None, defaults=None), ('document', '9c84a477021e4a7d0a497c1e6a31be2d')) +paddle.fluid.layers.LSTMCell.get_initial_states (ArgSpec(args=['self', 'batch_ref', 'shape', 'dtype', 'init_value'], varargs=None, keywords=None, defaults=(None, None, 0)), ('document', '003d1b4c99128f798ac0b0eecc81c489')) +paddle.fluid.layers.Decoder ('paddle.fluid.layers.rnn.Decoder', ('document', '23838bd065fddca1557a6a3368d9e365')) +paddle.fluid.layers.Decoder.__init__ +paddle.fluid.layers.Decoder.finalize (ArgSpec(args=['self', 'outputs', 'final_states', 'sequence_lengths'], varargs=None, keywords=None, defaults=None), ('document', 'cab7fc752a05db18e99258473f50359d')) +paddle.fluid.layers.Decoder.initialize (ArgSpec(args=['self', 'inits'], varargs=None, keywords=None, defaults=None), ('document', '68cf1846fb58056dbe5a524f1ca9dff5')) +paddle.fluid.layers.Decoder.step (ArgSpec(args=['self', 'time', 'inputs', 'states'], varargs=None, keywords=None, defaults=None), ('document', '151d0229930b9654689f86c85f7c4c3f')) +paddle.fluid.layers.BeamSearchDecoder ('paddle.fluid.layers.rnn.BeamSearchDecoder', ('document', 'd7ef0c9229bfe73e0daefcfda24a2635')) +paddle.fluid.layers.BeamSearchDecoder.OutputWrapper ('paddle.fluid.layers.rnn.OutputWrapper', ('document', 'a7141ebf1fb097fa71006cdd35bdc219')) +paddle.fluid.layers.BeamSearchDecoder.OutputWrapper.__init__ +paddle.fluid.layers.BeamSearchDecoder.OutputWrapper.count T.count(value) -> integer -- return number of occurrences of value +paddle.fluid.layers.BeamSearchDecoder.OutputWrapper.index T.index(value, [start, [stop]]) -> integer -- return first index of value. +paddle.fluid.layers.BeamSearchDecoder.StateWrapper ('paddle.fluid.layers.rnn.StateWrapper', ('document', '157731f37c88ea01bc746653125a41c8')) +paddle.fluid.layers.BeamSearchDecoder.StateWrapper.__init__ +paddle.fluid.layers.BeamSearchDecoder.StateWrapper.count T.count(value) -> integer -- return number of occurrences of value +paddle.fluid.layers.BeamSearchDecoder.StateWrapper.index T.index(value, [start, [stop]]) -> integer -- return first index of value. +paddle.fluid.layers.BeamSearchDecoder.__init__ (ArgSpec(args=['self', 'cell', 'start_token', 'end_token', 'beam_size', 'embedding_fn', 'output_fn'], varargs=None, keywords=None, defaults=(None, None)), ('document', '68951eaed573ec47c17a43155514b2f1')) +paddle.fluid.layers.BeamSearchDecoder.finalize (ArgSpec(args=['self', 'outputs', 'final_states', 'sequence_lengths'], varargs=None, keywords=None, defaults=None), ('document', '9a7f0a8fc5802bf860f2ac960466fb45')) +paddle.fluid.layers.BeamSearchDecoder.initialize (ArgSpec(args=['self', 'initial_cell_states'], varargs=None, keywords=None, defaults=None), ('document', '01ee508a9615e2483fe6ddcf14d5fa25')) +paddle.fluid.layers.BeamSearchDecoder.step (ArgSpec(args=['self', 'time', 'inputs', 'states'], varargs=None, keywords='kwargs', defaults=None), ('document', '35ee583c3c0fe7cceeafa289ed3374bd')) +paddle.fluid.layers.BeamSearchDecoder.tile_beam_merge_with_batch (ArgSpec(args=['x', 'beam_size'], varargs=None, keywords=None, defaults=None), ('document', 'ce7ffacba6f56f57acbf5d4dd82fe04d')) +paddle.fluid.layers.rnn (ArgSpec(args=['cell', 'inputs', 'initial_states', 'sequence_length', 'time_major', 'is_reverse'], varargs=None, keywords='kwargs', defaults=(None, None, False, False)), ('document', 'c36ade777ff43d2ba5542079b66a012b')) +paddle.fluid.layers.dynamic_decode (ArgSpec(args=['decoder', 'inits', 'max_step_num', 'output_time_major'], varargs=None, keywords='kwargs', defaults=(None, None, False)), ('document', '55b44de9d290c0c2ad8fdd635e6ab575')) paddle.fluid.contrib.InitState ('paddle.fluid.contrib.decoder.beam_search_decoder.InitState', ('document', '3afd1f84232718e628e9e566941c5f05')) paddle.fluid.contrib.InitState.__init__ (ArgSpec(args=['self', 'init', 'shape', 'value', 'init_boot', 'need_reorder', 'dtype'], varargs=None, keywords=None, defaults=(None, None, 0.0, None, False, 'float32')), ('document', '6adf97f83acf6453d4a6a4b1070f3754')) paddle.fluid.contrib.StateCell ('paddle.fluid.contrib.decoder.beam_search_decoder.StateCell', ('document', 'ecd0066c02867d445d7b461e28220c50')) diff --git a/paddle/fluid/operators/assign_op.cc b/paddle/fluid/operators/assign_op.cc index ff423778c5982..221204878659e 100644 --- a/paddle/fluid/operators/assign_op.cc +++ b/paddle/fluid/operators/assign_op.cc @@ -154,10 +154,12 @@ REGISTER_OPERATOR(assign, ops::AssignOp, ops::AssignGradMaker, ops::AssignOpProtoMaker, ops::AssignOpInplaceInferer); REGISTER_OP_CPU_KERNEL_FUNCTOR(assign, float, ops::AssignKernel, double, ops::AssignKernel, int, ops::AssignKernel, - int64_t, ops::AssignKernel); + int64_t, ops::AssignKernel, bool, + ops::AssignKernel); #ifdef PADDLE_WITH_CUDA REGISTER_OP_CUDA_KERNEL_FUNCTOR(assign, float, ops::AssignKernel, double, ops::AssignKernel, int, ops::AssignKernel, - int64_t, ops::AssignKernel); + int64_t, ops::AssignKernel, bool, + ops::AssignKernel); #endif diff --git a/paddle/fluid/operators/fill_constant_batch_size_like_op.cc b/paddle/fluid/operators/fill_constant_batch_size_like_op.cc index b8921b171cf37..404c2a92a47f9 100644 --- a/paddle/fluid/operators/fill_constant_batch_size_like_op.cc +++ b/paddle/fluid/operators/fill_constant_batch_size_like_op.cc @@ -38,6 +38,11 @@ class FillConstantBatchSizeLikeOpMaker : public BatchSizeLikeOpMaker { .SetDefault(framework::proto::VarType::FP32); AddAttr("value", "default 0. The value to be filled") .SetDefault(0.0f); + AddAttr("force_cpu", + "(bool, default false) Force fill output variable to cpu " + "memory. Otherwise, fill output variable to the running " + "device") + .SetDefault(false); AddComment(R"DOC( This function creates a tensor of specified *shape*, *dtype* and batch size, and initializes this with a constant supplied in *value*. The batch size is @@ -65,4 +70,6 @@ REGISTER_OP_CPU_KERNEL( ops::FillConstantBatchSizeLikeOpKernel, ops::FillConstantBatchSizeLikeOpKernel); + int64_t>, + ops::FillConstantBatchSizeLikeOpKernel); diff --git a/paddle/fluid/operators/fill_constant_batch_size_like_op.cu.cc b/paddle/fluid/operators/fill_constant_batch_size_like_op.cu.cc index 2cbbd05bfbb5d..353f73cdd6d05 100644 --- a/paddle/fluid/operators/fill_constant_batch_size_like_op.cu.cc +++ b/paddle/fluid/operators/fill_constant_batch_size_like_op.cu.cc @@ -25,4 +25,6 @@ REGISTER_OP_CUDA_KERNEL( ops::FillConstantBatchSizeLikeOpKernel, ops::FillConstantBatchSizeLikeOpKernel); + int64_t>, + ops::FillConstantBatchSizeLikeOpKernel); diff --git a/paddle/fluid/operators/fill_constant_batch_size_like_op.h b/paddle/fluid/operators/fill_constant_batch_size_like_op.h index 63ea60678f807..f915f37feab5d 100644 --- a/paddle/fluid/operators/fill_constant_batch_size_like_op.h +++ b/paddle/fluid/operators/fill_constant_batch_size_like_op.h @@ -23,6 +23,11 @@ template class FillConstantBatchSizeLikeOpKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& ctx) const override { + auto data_type = + static_cast(ctx.Attr("dtype")); + auto value = ctx.Attr("value"); + auto force_cpu = ctx.Attr("force_cpu"); + auto* out = ctx.Output("Out"); auto* in = ctx.Input("Input"); if (in->lod().size() && ctx.Attr("input_dim_idx") == 0) { @@ -32,12 +37,16 @@ class FillConstantBatchSizeLikeOpKernel : public framework::OpKernel { odims[output_dim_idx] = static_cast(in->lod().back().size()) - 1; out->mutable_data(odims, ctx.GetPlace()); } - out->mutable_data(ctx.GetPlace()); - auto value = ctx.Attr("value"); - math::SetConstant setter; - setter(ctx.template device_context(), out, - static_cast(value)); + if (force_cpu) { + out->mutable_data(platform::CPUPlace(), data_type); + } else { + out->mutable_data(ctx.GetPlace(), data_type); + } + + platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance(); + auto& dev_ctx = *pool.Get(ctx.GetPlace()); + math::set_constant(dev_ctx, out, value); } }; diff --git a/paddle/fluid/operators/fill_constant_op.cu.cc b/paddle/fluid/operators/fill_constant_op.cu.cc index 77027b5a87d4a..4a7b0110a1d96 100644 --- a/paddle/fluid/operators/fill_constant_op.cu.cc +++ b/paddle/fluid/operators/fill_constant_op.cu.cc @@ -19,4 +19,5 @@ REGISTER_OP_CUDA_KERNEL(fill_constant, ops::FillConstantKernel, ops::FillConstantKernel, ops::FillConstantKernel, ops::FillConstantKernel, + ops::FillConstantKernel, ops::FillConstantKernel); diff --git a/paddle/fluid/operators/gather_nd_op.cc b/paddle/fluid/operators/gather_nd_op.cc index aed0f824e6966..cbeefa0a7f651 100644 --- a/paddle/fluid/operators/gather_nd_op.cc +++ b/paddle/fluid/operators/gather_nd_op.cc @@ -60,8 +60,13 @@ class GatherNdOp : public framework::OperatorWithKernel { protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType(ctx.Input("X")->type(), - ctx.device_context()); + auto* x = ctx.Input("X"); + const auto& x_type = x->type(); + return framework::OpKernelType( + x_type, + x_type == framework::proto::VarType::BOOL + ? x->place() // to be consistent with compare and logical ops + : ctx.device_context().GetPlace()); } }; @@ -173,7 +178,7 @@ REGISTER_OPERATOR(gather_nd_grad, ops::GatherNdGradOp, REGISTER_OP_CPU_KERNEL(gather_nd, ops::GatherNdOpKernel, ops::GatherNdOpKernel, ops::GatherNdOpKernel, - ops::GatherNdOpKernel, + ops::GatherNdOpKernel, ops::GatherNdOpKernel, ops::GatherNdOpKernel); REGISTER_OP_CPU_KERNEL(gather_nd_grad, ops::GatherNdGradOpKernel, diff --git a/paddle/fluid/operators/gather_nd_op.cu b/paddle/fluid/operators/gather_nd_op.cu index 1ad335039a9cd..68f54a511597e 100644 --- a/paddle/fluid/operators/gather_nd_op.cu +++ b/paddle/fluid/operators/gather_nd_op.cu @@ -95,6 +95,7 @@ REGISTER_OP_CUDA_KERNEL(gather_nd, ops::GatherNdOpCUDAKernel, ops::GatherNdOpCUDAKernel, ops::GatherNdOpCUDAKernel, ops::GatherNdOpCUDAKernel, + ops::GatherNdOpCUDAKernel, ops::GatherNdOpCUDAKernel); REGISTER_OP_CUDA_KERNEL(gather_nd_grad, diff --git a/paddle/fluid/operators/gather_tree_op.cc b/paddle/fluid/operators/gather_tree_op.cc new file mode 100644 index 0000000000000..94fa3b6aa1e7e --- /dev/null +++ b/paddle/fluid/operators/gather_tree_op.cc @@ -0,0 +1,78 @@ +/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/operators/gather_tree_op.h" + +namespace paddle { +namespace operators { + +class GatherTreeOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + void InferShape(framework::InferShapeContext* ctx) const override { + PADDLE_ENFORCE(ctx->HasInput("Ids"), + "Input(Ids) of GatherTreeOp should not be null."); + PADDLE_ENFORCE(ctx->HasInput("Parents"), + "Input(Parents) of GatherTreeOp should not be null."); + PADDLE_ENFORCE(ctx->HasOutput("Out"), + "Output(Out) of GatherTreeOp should not be null."); + + auto ids_dims = ctx->GetInputDim("Ids"); + auto parents_dims = ctx->GetInputDim("Parents"); + PADDLE_ENFORCE(ids_dims == parents_dims, + "The shape of Input(Parents) must be same with the shape of " + "Input(Ids)."); + ctx->SetOutputDim("Out", ids_dims); + } + + protected: + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext& ctx) const override { + return framework::OpKernelType(ctx.Input("Ids")->type(), + ctx.device_context()); + } +}; + +class GatherTreeOpMaker : public framework::OpProtoAndCheckerMaker { + public: + void Make() override { + AddInput("Ids", + "The Tensor with shape [length, batch_size, beam_size] containing " + "the selected ids of all time steps."); + AddInput("Parents", + "The Tensor has the same shape as Ids and contains the parents " + "corresponding to selected ids when searching among beams."); + AddOutput( + "Out", + "A Tensor with shape [length, batch_size, beam_size] containing the " + "full sequences. The sequences is collected by backtracing from the " + "last time step of Ids."); + AddComment(R"DOC( +GatherTree Operator. + +Backtrace from the last time step and generate the full sequences by collecting beam search +selected ids. + +)DOC"); + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +REGISTER_OPERATOR(gather_tree, ops::GatherTreeOp, ops::GatherTreeOpMaker); +REGISTER_OP_CPU_KERNEL(gather_tree, ops::GatherTreeOpKernel, + ops::GatherTreeOpKernel); diff --git a/paddle/fluid/operators/gather_tree_op.cu b/paddle/fluid/operators/gather_tree_op.cu new file mode 100644 index 0000000000000..7ea3641b99f1a --- /dev/null +++ b/paddle/fluid/operators/gather_tree_op.cu @@ -0,0 +1,80 @@ +/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/operators/gather_tree_op.h" + +namespace paddle { +namespace operators { + +#define CUDA_1D_KERNEL_LOOP(i, n) \ + for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < (n); \ + i += blockDim.x * gridDim.x) + +template +__global__ void GatherTree(const T *ids_data, const T *parents_data, + T *out_data, const int64_t max_length, + const int64_t batch_size, const int64_t beam_size) { + CUDA_1D_KERNEL_LOOP(i, batch_size * beam_size) { + int batch = i / beam_size; + int beam = i % beam_size; + auto idx = + (max_length - 1) * batch_size * beam_size + batch * beam_size + beam; + out_data[idx] = ids_data[idx]; + auto parent = parents_data[idx]; + for (int step = max_length - 2; step >= 0; step--) { + idx = step * batch_size * beam_size + batch * beam_size; + out_data[idx + beam] = ids_data[idx + parent]; + parent = parents_data[idx + parent]; + } + } +} + +template +class GatherTreeOpCUDAKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext &ctx) const override { + auto *ids = ctx.Input("Ids"); + auto *parents = ctx.Input("Parents"); + auto *out = ctx.Output("Out"); + + const auto *ids_data = ids->data(); + const auto *parents_data = parents->data(); + auto *out_data = out->mutable_data(ctx.GetPlace()); + + auto &ids_dims = ids->dims(); + int64_t max_length = ids_dims[0]; + int64_t batch_size = ids_dims[1]; + int64_t beam_size = ids_dims[2]; + + auto &dev_ctx = ctx.cuda_device_context(); + + const int block = 512; + int max_threads = + std::min(static_cast(dev_ctx.GetMaxPhysicalThreadCount()), + batch_size * beam_size); + const int grid = std::max(max_threads / block, 1); + GatherTree<<>>(ids_data, parents_data, out_data, max_length, + batch_size, beam_size); + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; + +REGISTER_OP_CUDA_KERNEL(gather_tree, ops::GatherTreeOpCUDAKernel, + ops::GatherTreeOpCUDAKernel); diff --git a/paddle/fluid/operators/gather_tree_op.h b/paddle/fluid/operators/gather_tree_op.h new file mode 100644 index 0000000000000..742a7ffcaae4c --- /dev/null +++ b/paddle/fluid/operators/gather_tree_op.h @@ -0,0 +1,58 @@ +/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once +#include "paddle/fluid/framework/eigen.h" +#include "paddle/fluid/framework/op_registry.h" + +namespace paddle { +namespace operators { + +using Tensor = framework::Tensor; + +template +class GatherTreeOpKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext &ctx) const override { + auto *ids = ctx.Input("Ids"); + auto *parents = ctx.Input("Parents"); + auto *out = ctx.Output("Out"); + + const auto *ids_data = ids->data(); + const auto *parents_data = parents->data(); + auto *out_data = out->mutable_data(ctx.GetPlace()); + + auto &ids_dims = ids->dims(); + auto max_length = ids_dims[0]; + auto batch_size = ids_dims[1]; + auto beam_size = ids_dims[2]; + + for (int batch = 0; batch < batch_size; batch++) { + for (int beam = 0; beam < beam_size; beam++) { + auto idx = (max_length - 1) * batch_size * beam_size + + batch * beam_size + beam; + out_data[idx] = ids_data[idx]; + auto parent = parents_data[idx]; + for (int step = max_length - 2; step >= 0; step--) { + idx = step * batch_size * beam_size + batch * beam_size; + out_data[idx + beam] = ids_data[idx + parent]; + parent = parents_data[idx + parent]; + } + } + } + } +}; + +} // namespace operators +} // namespace paddle diff --git a/paddle/fluid/operators/reduce_ops/reduce_all_op.cc b/paddle/fluid/operators/reduce_ops/reduce_all_op.cc index a3ca9ae067547..49d6e72988ee0 100644 --- a/paddle/fluid/operators/reduce_ops/reduce_all_op.cc +++ b/paddle/fluid/operators/reduce_ops/reduce_all_op.cc @@ -14,7 +14,9 @@ #include "paddle/fluid/operators/reduce_ops/reduce_all_op.h" -REGISTER_REDUCE_OP_WITHOUT_GRAD(reduce_all); +// kernel's device type is decided by input tensor place, to be consistent with +// compare and logical ops +REGISTER_REDUCE_OP_WITHOUT_GRAD(reduce_all, UseInputPlace); REGISTER_OP_CPU_KERNEL(reduce_all, ops::ReduceKernel); diff --git a/paddle/fluid/operators/reduce_ops/reduce_any_op.cc b/paddle/fluid/operators/reduce_ops/reduce_any_op.cc index 34f0fffc9adef..516d3183fd614 100644 --- a/paddle/fluid/operators/reduce_ops/reduce_any_op.cc +++ b/paddle/fluid/operators/reduce_ops/reduce_any_op.cc @@ -14,7 +14,9 @@ #include "paddle/fluid/operators/reduce_ops/reduce_any_op.h" -REGISTER_REDUCE_OP_WITHOUT_GRAD(reduce_any); +// kernel's device type is decided by input tensor place, to be consistent with +// compare and logical ops +REGISTER_REDUCE_OP_WITHOUT_GRAD(reduce_any, UseInputPlace); REGISTER_OP_CPU_KERNEL(reduce_any, ops::ReduceKernel); diff --git a/paddle/fluid/operators/reduce_ops/reduce_op.h b/paddle/fluid/operators/reduce_ops/reduce_op.h index cbc4adf95881a..abf274c08c1db 100644 --- a/paddle/fluid/operators/reduce_ops/reduce_op.h +++ b/paddle/fluid/operators/reduce_ops/reduce_op.h @@ -217,6 +217,19 @@ class ReduceOp : public framework::OperatorWithKernel { } }; +class ReduceOpUseInputPlace : public ReduceOp { + public: + using ReduceOp::ReduceOp; + + protected: + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext& ctx) const override { + framework::OpKernelType kt = OperatorWithKernel::GetExpectedKernelType(ctx); + kt.place_ = ctx.Input("X")->place(); + return kt; + } +}; + class ReduceGradOp : public framework::OperatorWithKernel { public: using framework::OperatorWithKernel::OperatorWithKernel; @@ -307,11 +320,11 @@ namespace ops = paddle::operators; paddle::framework::DefaultGradOpDescMaker); \ REGISTER_OPERATOR(op_name##_grad, ops::ReduceGradOp) -#define REGISTER_REDUCE_OP_WITHOUT_GRAD(op_name) \ - class __##op_name##Maker__ : public ops::ReduceOpMaker { \ - protected: \ - virtual std::string GetName() const { return #op_name; } \ - virtual std::string GetOpType() const { return "Reduce " #op_name; } \ - }; \ - REGISTER_OPERATOR(op_name, ops::ReduceOp, __##op_name##Maker__, \ +#define REGISTER_REDUCE_OP_WITHOUT_GRAD(op_name, ...) \ + class __##op_name##Maker__ : public ops::ReduceOpMaker { \ + protected: \ + virtual std::string GetName() const { return #op_name; } \ + virtual std::string GetOpType() const { return "Reduce " #op_name; } \ + }; \ + REGISTER_OPERATOR(op_name, ops::ReduceOp##__VA_ARGS__, __##op_name##Maker__, \ paddle::framework::EmptyGradOpMaker); diff --git a/paddle/fluid/operators/tensor_array_to_tensor_op.cc b/paddle/fluid/operators/tensor_array_to_tensor_op.cc index 8cba4961153cf..82741077d9cf0 100644 --- a/paddle/fluid/operators/tensor_array_to_tensor_op.cc +++ b/paddle/fluid/operators/tensor_array_to_tensor_op.cc @@ -120,11 +120,18 @@ class LoDTensorArray2TensorOp : public framework::OperatorBase { out.Resize(out_dims); LodTensorArray2LodTensorVector(scope, base_name, Input("X"), &names); - // Invoke concat Op - auto concat_op = framework::OpRegistry::CreateOp( - "concat", {{"X", names}}, {{"Out", {Output("Out")}}}, attrs); - concat_op->Run(scope, place); + auto use_stack = Attr("use_stack"); + + // Invoke concat Op or stack Op + auto op = + use_stack + ? framework::OpRegistry::CreateOp("stack", {{"X", names}}, + {{"Y", {Output("Out")}}}, attrs) + : framework::OpRegistry::CreateOp( + "concat", {{"X", names}}, {{"Out", {Output("Out")}}}, attrs); + + op->Run(scope, place); } }; @@ -139,17 +146,32 @@ class LoDTensorArray2TensorOpMaker : public framework::OpProtoAndCheckerMaker { AddAttr("axis", "The axis along which the input tensors will be concatenated.") .SetDefault(0); + AddAttr("use_stack", + "Act as concat_op or stack_op. For stack mode, all tensors " + "in the tensor array must have the same shape.") + .SetDefault(false); AddComment(R"DOC( tensor_array_to_tensor Operator. -Concatenate the input LoDTensorArray along dimension axis to the output Tensor. +If use concat mode, concatenate all tensors in the input LoDTensorArray along +axis into the output Tensor. + +Examples: + Input = {[1,2], [3,4], [5,6]} + axis = 0 + Output = [1,2,3,4,5,6] + OutputIndex = [2,2,2] + +If use stack mode, stack all tensors in the input LoDTensorArray along axis into +the output Tensor. + Examples: Input = {[1,2], [3,4], [5,6]} axis = 0 Output = [[1,2], [3,4], [5,6]] - OutputIndex = [1,1,1] + OutputIndex = [2,2,2] )DOC"); } @@ -157,12 +179,34 @@ Concatenate the input LoDTensorArray along dimension axis to the output Tensor. class LoDTensorArray2TensorOpInferShape : public framework::InferShapeBase { public: - void operator()(framework::InferShapeContext *ctx) const override {} + void operator()(framework::InferShapeContext *ctx) const override { + // in runtime, shape is determined by RunImpl + if (ctx->IsRuntime()) return; + auto dims = ctx->GetInputDim("X"); + // if the shape is empty + if (dims == framework::make_ddim({0UL})) return; + // otherwise, suppose the shape of array is the shape of tensor in the + // array, which is consistent with what tensor_array_read_write dose + auto axis = ctx->Attrs().Get("axis"); + auto use_stack = ctx->Attrs().Get("use_stack"); + if (use_stack) { + auto dim_vec = framework::vectorize(dims); + // use -1 for the stack dim size + dim_vec.insert(dim_vec.begin() + axis, -1); + dims = framework::make_ddim(dim_vec); + } else { + // use -1 for the concat dim size + dims[axis] = -1; + } + ctx->SetOutputDim("Out", dims); + } }; class LoDTensorArray2TensorGradInferShape : public framework::InferShapeBase { public: - void operator()(framework::InferShapeContext *context) const override {} + void operator()(framework::InferShapeContext *ctx) const override { + ctx->SetOutputDim(framework::GradVarName("X"), ctx->GetInputDim("X")); + } }; class LoDTensorArray2TensorGradInferVarType @@ -204,11 +248,18 @@ class LoDTensorArray2TensorGradOp : public framework::OperatorBase { LodTensorVectorResizeFromLodTensorArray(scope, "grad_name", Input("X"), &grad_names); - auto concat_grad_op = framework::OpRegistry::CreateOp( - "concat_grad", {{"X", names}, {"Out@GRAD", {dout_name}}}, - {{"X@GRAD", grad_names}}, attrs); + auto use_stack = Attr("use_stack"); + + auto grad_op = + use_stack + ? framework::OpRegistry::CreateOp( + "stack_grad", {{"X", names}, {"Y@GRAD", {dout_name}}}, + {{"X@GRAD", grad_names}}, attrs) + : framework::OpRegistry::CreateOp( + "concat_grad", {{"X", names}, {"Out@GRAD", {dout_name}}}, + {{"X@GRAD", grad_names}}, attrs); - concat_grad_op->Run(scope, place); + grad_op->Run(scope, place); LodTensorArrayCreateFromLodTensorArray(scope, Input("X"), dx_name); auto &grad_inx = diff --git a/python/paddle/fluid/layers/__init__.py b/python/paddle/fluid/layers/__init__.py index d17636d6d54f9..a1e560168f9de 100644 --- a/python/paddle/fluid/layers/__init__.py +++ b/python/paddle/fluid/layers/__init__.py @@ -35,6 +35,7 @@ from .learning_rate_scheduler import * from .collective import * from .distributions import * +from . import rnn __all__ = [] __all__ += nn.__all__ @@ -47,3 +48,6 @@ __all__ += metric_op.__all__ __all__ += learning_rate_scheduler.__all__ __all__ += distributions.__all__ +__all__ += rnn.__all__ + +from .rnn import * diff --git a/python/paddle/fluid/layers/nn.py b/python/paddle/fluid/layers/nn.py index 54d9934ee7cc6..7c879a19e9c3f 100755 --- a/python/paddle/fluid/layers/nn.py +++ b/python/paddle/fluid/layers/nn.py @@ -221,6 +221,7 @@ 'filter_by_instag', 'shard_index', 'hard_swish', + 'gather_tree', 'mse_loss', 'uniform_random', ] @@ -16862,6 +16863,81 @@ def hard_swish(x, threshold=6.0, scale=6.0, offset=3.0, name=None): return out +def gather_tree(ids, parents): + """ + To be used after beam search. After beam search, we get selected ids at + each time step and the corresponding parents in the search tree. Both ids + and parents have the layout :attr:`[max_time, batch_size, beam_size]`. Then + :attr:`gather_tree` is used to backtrace from the last time step and + generate the full sequences by collecting selected ids. + + Here is an example: + + .. code-block:: text + + Given: + ids = [[[2 2] + [6 1]] + [[3 9] + [6 1]] + [[0 1] + [9 0]]] + parents = [[[0 0] + [1 1]] + [[1 0] + [1 0]] + [[0 0] + [0 1]]] + + Then: + gather_tree(ids, parents) + = [[[2 2] + [1 6]] + [[3 3] + [6 1]] + [[0 1] + [9 0]]] + + Args: + ids(Variable): A Tensor with shape :attr:`[length, batch_size, beam_size]` + and data type :attr:`int32` or :attr:`int64`. It contains the selected + ids of all time steps. + parents(Variable): A Tensor with the same shape and data type as :attr:`ids`, + It contains the parents corresponding to selected ids when searching + among beams. + + Returns: + Variable: A Tensor with the same shape and data type as :attr:`ids`. \ + It contains the full sequences. The sequences are collected from \ + :attr:`ids` by backtracing according to :attr:`parents`. + + Examples: + .. code-block:: python + + import paddle.fluid as fluid + + ids = fluid.layers.data(name='ids', + shape=[5, 2, 2], + dtype='int64', + append_batch_size=False) + parents = fluid.layers.data(name='parents', + shape=[5, 2, 2], + dtype='int64', + append_batch_size=False) + final_sequences = fluid.layers.gather_tree(ids, parents) + """ + helper = LayerHelper('gather_tree', **locals()) + out = helper.create_variable_for_type_inference(dtype=ids.dtype) + + helper.append_op( + type="gather_tree", + inputs={"Ids": ids, + "Parents": parents}, + outputs={"Out": out}) + + return out + + def mse_loss(input, label): """ This op accepts input predications and target label and returns the mean square error. diff --git a/python/paddle/fluid/layers/rnn.py b/python/paddle/fluid/layers/rnn.py new file mode 100644 index 0000000000000..40f5df60e4a02 --- /dev/null +++ b/python/paddle/fluid/layers/rnn.py @@ -0,0 +1,1165 @@ +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +from functools import partial, reduce + +from . import nn +from . import tensor +from . import control_flow +from . import utils +from .utils import * + +__all__ = [ + 'RNNCell', + 'GRUCell', + 'LSTMCell', + 'Decoder', + 'BeamSearchDecoder', + 'rnn', + 'dynamic_decode', +] + + +class RNNCell(object): + """ + RNNCell is the base class for abstraction representing the calculations + mapping the input and state to the output and new state. It is suitable to + and mostly used in RNN. + """ + + def call(self, inputs, states, **kwargs): + """ + Every cell must implement this method to do the calculations mapping the + inputs and states to the output and new states. + + To be more flexible, both inputs and states can be a tensor variable or + a nested structure (list|tuple|namedtuple|dict) of tensor variable, that + is, a (possibly nested structure of) tensor variable[s]. + + Parameters: + inputs: A (possibly nested structure of) tensor variable[s]. + states: A (possibly nested structure of) tensor variable[s]. + **kwargs: Additional keyword arguments, provided by the caller. + + Returns: + tuple: outputs and new_states pair. outputs and new_states both \ + can be nested structure of tensor variables. new_states must \ + have the same structure with states. + + """ + raise NotImplementedError("RNNCell must implent the call function.") + + def __call__(self, inputs, states, **kwargs): + return self.call(inputs, states, **kwargs) + + def get_initial_states(self, + batch_ref, + shape=None, + dtype=None, + init_value=0): + """ + Generate initialized states according to provided shape, data type and + value. + + Parameters: + batch_ref: A (possibly nested structure of) tensor variable[s]. + The first dimension of the tensor will be used as batch size to + initialize states. + shape: A (possiblely nested structure of) shape[s], where a shape is + represented as a list/tuple of integer). -1(for batch size) will + beautomatically inserted if shape is not started with it. If None, + property `state_shape` will be used. The default value is None. + dtype: A (possiblely nested structure of) data type[s]. The structure + must be same as that of `shape`, except when all tensors' in states + has the same data type, a single data type can be used. If None and + property `cell.state_shape` is not available, float32 will be used + as the data type. The default value is None. + init_value: A float value used to initialize states. + + Returns: + Variable: tensor variable[s] packed in the same structure provided \ + by shape, representing the initialized states. + """ + # TODO: use inputs and batch_size + batch_ref = flatten(batch_ref)[0] + + def _is_shape_sequence(seq): + """For shape, list/tuple of integer is the finest-grained objection""" + if (isinstance(seq, list) or isinstance(seq, tuple)): + if reduce(lambda flag, x: isinstance(x, int) and flag, seq, + True): + return False + # TODO: Add check for the illegal + if isinstance(seq, dict): + return True + return (isinstance(seq, collections.Sequence) and + not isinstance(seq, six.string_types)) + + class Shape(object): + def __init__(self, shape): + self.shape = shape if shape[0] == -1 else ([-1] + list(shape)) + + # nested structure of shapes + states_shapes = self.state_shape if shape is None else shape + is_sequence_ori = utils.is_sequence + utils.is_sequence = _is_shape_sequence + states_shapes = map_structure(lambda shape: Shape(shape), states_shapes) + utils.is_sequence = is_sequence_ori + + # nested structure of dtypes + try: + states_dtypes = self.state_dtype if dtype is None else dtype + except NotImplementedError: # use fp32 as default + states_dtypes = "float32" + if len(flatten(states_dtypes)) == 1: + dtype = flatten(states_dtypes)[0] + states_dtypes = map_structure(lambda shape: dtype, states_shapes) + + init_states = map_structure( + lambda shape, dtype: tensor.fill_constant_batch_size_like( + input=batch_ref, + shape=shape.shape, + dtype=dtype, + value=init_value), states_shapes, states_dtypes) + return init_states + + @property + def state_shape(self): + """ + Used to initialize states. + A (possiblely nested structure of) shape[s], where a shape is represented + as a list/tuple of integers (-1 for batch size would be automatically + inserted into a shape if shape is not started with it). + Not necessary to be implemented if states are not initialized by + `get_initial_states` or the `shape` argument is provided when using + `get_initial_states`. + """ + raise NotImplementedError + + @property + def state_dtype(self): + """ + Used to initialize states. + A (possiblely nested structure of) data types[s]. The structure must be + same as that of `shape`, except when all tensors' in states has the same + data type, a signle data type can be used. + Not necessary to be implemented if states are not initialized + by `get_initial_states` or the `dtype` argument is provided when using + `get_initial_states`. + """ + raise NotImplementedError + + +class GRUCell(RNNCell): + """ + Gated Recurrent Unit cell. It is a wrapper for + `fluid.contrib.layers.rnn_impl.BasicGRUUnit` to make it adapt to RNNCell. + + The formula used is as follow: + + .. math:: + + u_t & = act_g(W_{ux}x_{t} + W_{uh}h_{t-1} + b_u) + + r_t & = act_g(W_{rx}x_{t} + W_{rh}h_{t-1} + b_r) + + \\tilde{h_t} & = act_c(W_{cx}x_{t} + W_{ch}(r_t \odot h_{t-1}) + b_c) + + h_t & = u_t \odot h_{t-1} + (1-u_t) \odot \\tilde{h_t} + + For more details, please refer to `Learning Phrase Representations using + RNN Encoder Decoder for Statistical Machine Translation `_ + + Examples: + + .. code-block:: python + + import paddle.fluid.layers as layers + cell = layers.GRUCell(hidden_size=256) + """ + + def __init__(self, + hidden_size, + param_attr=None, + bias_attr=None, + gate_activation=None, + activation=None, + dtype="float32", + name="GRUCell"): + """ + Constructor of GRUCell. + + Parameters: + hidden_size (int): The hidden size in the GRU cell. + param_attr(ParamAttr, optional): The parameter attribute for the learnable + weight matrix. Default: None. + bias_attr (ParamAttr, optional): The parameter attribute for the bias + of GRU. Default: None. + gate_activation (function, optional): The activation function for :math:`act_g`. + Default: `fluid.layers.sigmoid`. + activation (function, optional): The activation function for :math:`act_c`. + Default: `fluid.layers.tanh`. + dtype(string, optional): The data type used in this cell. Default float32. + name(string, optional) : The name scope used to identify parameters and biases. + """ + self.hidden_size = hidden_size + from .. import contrib # TODO: resolve recurrent import + self.gru_unit = contrib.layers.rnn_impl.BasicGRUUnit( + name, hidden_size, param_attr, bias_attr, gate_activation, + activation, dtype) + + def call(self, inputs, states): + """ + Perform calculations of GRU. + + Parameters: + inputs(Variable): A tensor with shape `[batch_size, input_size]`, + corresponding to :math:`x_t` in the formula. The data type + should be float32. + states(Variable): A tensor with shape `[batch_size, hidden_size]`. + corresponding to :math:`h_{t-1}` in the formula. The data type + should be float32. + + Returns: + tuple: A tuple( :code:`(outputs, new_states)` ), where `outputs` and \ + `new_states` is the same tensor shaped `[batch_size, hidden_size]`, \ + corresponding to :math:`h_t` in the formula. The data type of the \ + tensor is same as that of `states`. + """ + new_hidden = self.gru_unit(inputs, states) + return new_hidden, new_hidden + + @property + def state_shape(self): + """ + The `state_shape` of GRUCell is a shape `[hidden_size]` (-1 for batch + size would be automatically inserted into shape). The shape corresponds + to :math:`h_{t-1}`. + """ + return [self.hidden_size] + + +class LSTMCell(RNNCell): + """ + Long-Short Term Memory cell. It is a wrapper for + `fluid.contrib.layers.rnn_impl.BasicLSTMUnit` to make it adapt to RNNCell. + + The formula used is as follow: + + .. math:: + + i_{t} & = act_g(W_{x_{i}}x_{t} + W_{h_{i}}h_{t-1} + b_{i}) + + f_{t} & = act_g(W_{x_{f}}x_{t} + W_{h_{f}}h_{t-1} + b_{f} + forget\\_bias) + + c_{t} & = f_{t}c_{t-1} + i_{t} act_c (W_{x_{c}}x_{t} + W_{h_{c}}h_{t-1} + b_{c}) + + o_{t} & = act_g(W_{x_{o}}x_{t} + W_{h_{o}}h_{t-1} + b_{o}) + + h_{t} & = o_{t} act_c (c_{t}) + + For more details, please refer to `RECURRENT NEURAL NETWORK REGULARIZATION `_ + + Examples: + + .. code-block:: python + + import paddle.fluid.layers as layers + cell = layers.LSTMCell(hidden_size=256) + """ + + def __init__(self, + hidden_size, + param_attr=None, + bias_attr=None, + gate_activation=None, + activation=None, + forget_bias=1.0, + dtype="float32", + name="LSTMCell"): + """ + Constructor of LSTMCell. + + Parameters: + hidden_size (int): The hidden size in the LSTM cell. + param_attr(ParamAttr, optional): The parameter attribute for the learnable + weight matrix. Default: None. + bias_attr (ParamAttr, optional): The parameter attribute for the bias + of LSTM. Default: None. + gate_activation (function, optional): The activation function for :math:`act_g`. + Default: 'fluid.layers.sigmoid'. + activation (function, optional): The activation function for :math:`act_h`. + Default: 'fluid.layers.tanh'. + forget_bias(float, optional): forget bias used when computing forget gate. + Default 1.0 + dtype(string, optional): The data type used in this cell. Default float32. + name(string, optional) : The name scope used to identify parameters and biases. + """ + self.hidden_size = hidden_size + from .. import contrib # TODO: resolve recurrent import + self.lstm_unit = contrib.layers.rnn_impl.BasicLSTMUnit( + name, hidden_size, param_attr, bias_attr, gate_activation, + activation, forget_bias, dtype) + + def call(self, inputs, states): + """ + Perform calculations of LSTM. + + Parameters: + inputs(Variable): A tensor with shape `[batch_size, input_size]`, + corresponding to :math:`x_t` in the formula. The data type + should be float32. + states(Variable): A list of containing two tensers, each shaped + `[batch_size, hidden_size]`, corresponding to :math:`h_{t-1}, c_{t-1}` + in the formula. The data type should be float32. + + Returns: + tuple: A tuple( :code:`(outputs, new_states)` ), where `outputs` is \ + a tensor with shape `[batch_size, hidden_size]`, corresponding \ + to :math:`h_{t}` in the formula; `new_states` is a list containing \ + two tenser variables shaped `[batch_size, hidden_size]`, corresponding \ + to :math:`h_{t}, c_{t}` in the formula. The data type of these \ + tensors all is same as that of `states`. + """ + pre_hidden, pre_cell = states + new_hidden, new_cell = self.lstm_unit(inputs, pre_hidden, pre_cell) + return new_hidden, [new_hidden, new_cell] + + @property + def state_shape(self): + """ + The `state_shape` of LSTMCell is a list with two shapes: `[[hidden_size], [hidden_size]]` + (-1 for batch size would be automatically inserted into shape). These two + shapes correspond to :math:`h_{t-1}` and :math:`c_{t-1}` separately. + """ + return [[self.hidden_size], [self.hidden_size]] + + +def rnn(cell, + inputs, + initial_states=None, + sequence_length=None, + time_major=False, + is_reverse=False, + **kwargs): + """ + rnn creates a recurrent neural network specified by RNNCell `cell`, + which performs :code:`cell.call()` repeatedly until reachs to the maximum + length of `inputs`. + + Parameters: + cell(RNNCell): An instance of `RNNCell`. + inputs(Variable): A (possibly nested structure of) tensor variable[s]. + The shape of tensor should be `[batch_size, sequence_length, ...]` + for `time_major == False` or `[sequence_length, batch_size, ...]` + for `time_major == True`. It represents the inputs to be unrolled + in RNN. + initial_states(Variable, optional): A (possibly nested structure of) + tensor variable[s], representing the initial state for RNN. + If not provided, `cell.get_initial_states` would be used to produce + the initial state. Default None. + sequence_length(Variable, optional): A tensor with shape `[batch_size]`. + It stores real length of each instance, thus enables users to extract + the last valid state when past a batch element's sequence length for + correctness. If not provided, the padddings would be treated same as + non-padding inputs. Default None. + time_major(bool, optional): Indicate the data layout of Tensor included + in `input` and `output` tensors. If `False`, the data layout would + be batch major with shape `[batch_size, sequence_length, ...]`. If + `True`, the data layout would be time major with shape + `[sequence_length, batch_size, ...]`. Default: `False`. + is_reverse(bool, optional): Indicate whether to calculate in the reverse + order of input sequences. Default: `False`. + **kwargs: Additional keyword arguments. Arguments passed to `cell.call`. + + Returns: + tuple: A tuple( :code:`(final_outputs, final_states)` ) including the final \ + outputs and states, both are Tensor or nested structure of Tensor. \ + `final_outputs` has the same structure and data types as \ + the returned `outputs` of :code:`cell.call` , and each Tenser in `final_outputs` \ + stacks all time steps' counterpart in `outputs` thus has shape `[batch_size, sequence_length, ...]` \ + for `time_major == False` or `[sequence_length, batch_size, ...]` for `time_major == True`. \ + `final_states` is the counterpart at last time step of initial states, \ + thus has the same structure with it and has tensors with same shapes \ + and data types. + + + Examples: + + .. code-block:: python + + import paddle.fluid as fluid + + inputs = fluid.data(name="inputs", + shape=[-1, 32, 128], + dtype="float32") + cell = fluid.layers.GRUCell(hidden_size=128) + outputs = fluid.layers.rnn(cell=cell, inputs=inputs) + """ + + def _maybe_copy(state, new_state, step_mask): + # TODO: use where_op + new_state = nn.elementwise_mul( + new_state, step_mask, axis=0) - nn.elementwise_mul( + state, (step_mask - 1), axis=0) + return new_state + + def _transpose_batch_time(x): + return nn.transpose(x, [1, 0] + list(range(2, len(x.shape)))) + + def _switch_grad(x, stop=False): + x.stop_gradient = stop + return x + + if initial_states is None: + initial_states = cell.get_initial_states(batch_ref=inputs) + initial_states = map_structure(_switch_grad, initial_states) + + if not time_major: + inputs = map_structure(_transpose_batch_time, inputs) + + if sequence_length: + max_seq_len = nn.shape(flatten(inputs)[0])[0] + mask = nn.sequence_mask( + sequence_length, + maxlen=max_seq_len, + dtype=flatten(initial_states)[0].dtype) + mask = nn.transpose(mask, [1, 0]) + if is_reverse: + inputs = map_structure(lambda x: tensor.reverse(x, axis=[0]), inputs) + mask = tensor.reverse(mask, axis=[0]) if sequence_length else None + + # StaticRNN + rnn = control_flow.StaticRNN() + with rnn.step(): + inputs = map_structure(rnn.step_input, inputs) + states = map_structure(rnn.memory, initial_states) + copy_states = map_structure(lambda x: x, states) + outputs, new_states = cell.call(inputs, copy_states, **kwargs) + assert_same_structure(states, new_states) + if sequence_length: + step_mask = rnn.step_input(mask) + new_states = map_structure( + partial( + _maybe_copy, step_mask=step_mask), states, new_states) + + map_structure(rnn.update_memory, states, new_states) + flat_outputs = flatten(outputs) + map_structure(rnn.step_output, outputs) + map_structure(rnn.step_output, new_states) + + rnn_out = rnn() + final_outputs = rnn_out[:len(flat_outputs)] + final_outputs = pack_sequence_as(outputs, final_outputs) + final_states = map_structure(lambda x: x[-1], rnn_out[len(flat_outputs):]) + final_states = pack_sequence_as(new_states, final_states) + + if is_reverse: + final_outputs = map_structure(lambda x: tensor.reverse(x, axis=[0]), + final_outputs) + + if not time_major: + final_outputs = map_structure(_transpose_batch_time, final_outputs) + + return (final_outputs, final_states) + + +class Decoder(object): + """ + Decoder is the base class for any decoder instance used in `dynamic_decode`. + It provides interface for output generation for one time step, which can be + used to generate sequences. + + The key abstraction provided by Decoder is: + + 1. :code:`(initial_input, initial_state, finished) = initialize(inits)` , + which generates the input and state for the first decoding step, and gives the + inintial status telling whether each sequence in the batch is finished. + It would be called once before the decoding iterations. + + 2. :code:`(output, next_state, next_input, finished) = step(time, input, state)` , + which transforms the input and state to the output and new state, generates + input for the next decoding step, and emits the flag indicating finished status. + It is the main part for each decoding iteration. + + 3. :code:`(final_outputs, final_state) = finalize(outputs, final_state, sequence_lengths)` , + which revises the outputs(stack of all time steps' output) and final state(state from the + last decoding step) to get the counterpart for special usage. + Not necessary to be implemented if no need to revise the stacked outputs and + state from the last decoding step. If implemented, it would be called after + the decoding iterations. + + Decoder is more general compared to RNNCell, since the returned `next_input` + and `finished` make it can determine the input and when to finish by itself + when used in dynamic decoding. Decoder always wraps a RNNCell instance though + not necessary. + """ + + def initialize(self, inits): + """ + Called once before the decoding iterations. + + Parameters: + inits: Argument provided by the caller. + + Returns: + tuple: A tuple( :code:(initial_inputs, initial_states, finished)` ). \ + `initial_inputs` and `initial_states` both are a (possibly nested \ + structure of) tensor variable[s], and `finished` is a tensor with \ + bool data type. + """ + raise NotImplementedError + + def step(self, time, inputs, states): + """ + Called per step of decoding. + + Parameters: + time(Variable): A Tensor with shape :math:`[1]` provided by the caller. + The data type is int64. + inputs(Variable): A (possibly nested structure of) tensor variable[s]. + states(Variable): A (possibly nested structure of) tensor variable[s]. + + Returns: + tuple: A tuple( :code:(outputs, next_states, next_inputs, finished)` ). \ + `next_inputs` and `next_states` both are a (possibly nested \ + structure of) tensor variable[s], and the structure, shape and \ + data type must be same as the counterpart from input arguments. \ + `outputs` is a (possibly nested structure of) tensor variable[s]. \ + `finished` is a Tensor with bool data type. + """ + raise NotImplementedError + + @property + def output_dtype(self): + """ + A (possiblely nested structure of) data type[s]. The structure must be + same as `outputs` returned by `decoder.step`. + """ + raise NotImplementedError + + def finalize(self, outputs, final_states, sequence_lengths): + """ + Called once after the decoding iterations if implemented. + + Parameters: + outputs(Variable): A (possibly nested structure of) tensor variable[s]. + The structure and data type is same as `output_dtype`. + The tensor stacks all time steps' output thus has shape + :math:`[time\_step, batch\_size, ...]` , which is done by the caller. + final_states(Variable): A (possibly nested structure of) tensor variable[s]. + It is the `next_states` returned by `decoder.step` at last decoding step, + thus has the same structrue, shape and data type with states at any time + step. + + Returns: + tuple: A tuple( :code:`(final_outputs, final_states)` ). \ + `final_outputs` and `final_states` both are a (possibly nested \ + structure of) tensor variable[s]. + """ + raise NotImplementedError + + +class BeamSearchDecoder(Decoder): + """ + Decoder with beam search decoding strategy. It wraps a cell to get probabilities, + and follows a beam search step to calculate scores and select candidate + token ids for each decoding step. + + Please refer to `Beam search `_ + for more details. + + **NOTE** When decoding with beam search, the `inputs` and `states` of cell + would be tiled to `beam_size` (unsqueeze and tile), resulting to shapes like + `[batch_size * beam_size, ...]` , which is built into `BeamSearchDecoder` and + done automatically. Thus any other tensor with shape `[batch_size, ...]` used + in `cell.call` needs to be tiled manually first, which can be completed by using + :code:`BeamSearchDecoder.tile_beam_merge_with_batch` . The most common case + for this is the encoder output in attention mechanism. + + + Examples: + + .. code-block:: python + + import paddle.fluid as fluid + from paddle.fluid.layers import GRUCell, BeamSearchDecoder + + trg_embeder = lambda x: fluid.embedding( + x, size=[10000, 128], param_attr=fluid.ParamAttr(name="trg_embedding")) + output_layer = lambda x: layers.fc(x, + size=10000, + num_flatten_dims=len(x.shape) - 1, + param_attr=fluid.ParamAttr(name= + "output_w"), + bias_attr=False) + decoder_cell = GRUCell(hidden_size=128) + decoder = BeamSearchDecoder(decoder_cell, + start_token=0, + end_token=1, + beam_size=4, + embedding_fn=trg_embeder, + output_fn=output_layer) + """ + + def __init__(self, + cell, + start_token, + end_token, + beam_size, + embedding_fn=None, + output_fn=None): + """ + Constructor of BeamSearchDecoder. + + Parameters: + cell(RNNCell): An instance of `RNNCell` or object with the same interface. + start_token(int): The start token id. + end_token(int): The end token id. + beam_size(int): The beam width used in beam search. + embedding_fn(optional): A callable to apply to selected candidate ids. + Mostly it is an embedding layer to transform ids to embeddings, + and the returned value acts as the `input` argument for `cell.call`. + **Note that fluid.embedding should be used here rather than + fluid.layers.embedding, since shape of ids is [batch_size, beam_size]. + when using fluid.layers.embedding, must unsqueeze in embedding_fn.** + If not provided, the id to embedding transfomation must be built into + `cell.call`. Default None. + output_fn(optional): A callable to apply to the cell's output prior to + calculate scores and select candidate token ids. Default None. + """ + self.cell = cell + self.embedding_fn = embedding_fn + self.output_fn = output_fn + self.start_token = start_token + self.end_token = end_token + self.beam_size = beam_size + + @staticmethod + def tile_beam_merge_with_batch(x, beam_size): + """ + Tile the batch dimension of a tensor. Specifically, this function takes + a tensor t shaped `[batch_size, s0, s1, ...]` composed of minibatch + entries `t[0], ..., t[batch_size - 1]` and tiles it to have a shape + `[batch_size * beam_size, s0, s1, ...]` composed of minibatch entries + `t[0], t[0], ..., t[1], t[1], ...` where each minibatch entry is repeated + `beam_size` times. + + Parameters: + x(Variable): A tenosr with shape `[batch_size, ...]`. The data type + should be float32, float64, int32, int64 or bool. + beam_size(int): The beam width used in beam search. + + Returns: + Variable: A tensor with shape `[batch_size * beam_size, ...]`, whose \ + data type is same as `x`. + """ + x = nn.unsqueeze(x, [1]) # [batch_size, 1, ...] + expand_times = [1] * len(x.shape) + expand_times[1] = beam_size + x = nn.expand(x, expand_times) # [batch_size, beam_size, ...] + x = nn.transpose(x, list(range(2, len(x.shape))) + + [0, 1]) # [..., batch_size, beam_size] + # use 0 to copy to avoid wrong shape + x = nn.reshape( + x, shape=[0] * + (len(x.shape) - 2) + [-1]) # [..., batch_size * beam_size] + x = nn.transpose( + x, [len(x.shape) - 1] + + list(range(0, len(x.shape) - 1))) # [batch_size * beam_size, ...] + return x + + def _split_batch_beams(self, x): + """ + Reshape a tensor with shape `[batch_size * beam_size, ...]` to a new + tensor with shape `[batch_size, beam_size, ...]`. + + Parameters: + x(Variable): A tenosr with shape `[batch_size * beam_size, ...]`. The + data type should be float32, float64, int32, int64 or bool. + + Returns: + Variable: A tensor with shape `[batch_size, beam_size, ...]`, whose \ + data type is same as `x`. + """ + # TODO: avoid fake shape in compile-time like tile_beam_merge_with_batch + return nn.reshape(x, shape=(-1, self.beam_size) + x.shape[1:]) + + def _merge_batch_beams(self, x): + """ + Reshape a tensor with shape `[batch_size, beam_size, ...]` to a new + tensor with shape `[batch_size * beam_size, ...]`. + + Parameters: + x(Variable): A tenosr with shape `[batch_size, beam_size, ...]`. The + data type should be float32, float64, int32, int64 or bool. + + Returns: + Variable: A tensor with shape `[batch_size * beam_size, ...]`, whose \ + data type is same as `x`. + """ + # TODO: avoid fake shape in compile-time like tile_beam_merge_with_batch + return nn.reshape(x, shape=(-1, ) + x.shape[2:]) + + def _expand_to_beam_size(self, x): + """ + This function takes a tensor t shaped `[batch_size, s0, s1, ...]` composed + of minibatch entries `t[0], ..., t[batch_size - 1]` and tiles it to have a + shape `[batch_size, beam_size, s0, s1, ...]` composed of minibatch entries + `t[0], t[0], ..., t[1], t[1], ...` where each minibatch entry is repeated + `beam_size` times. + + Parameters: + probs(Variable): A tensor with shape `[batch_size, ...]`, representing + the log probabilities. Its data type should be float32. + finished(Variable): A tensor with shape `[batch_size, beam_size]`, + representing the finished status for all beams. Its data type + should be bool. + + Returns: + Variable: A tensor with shape `[batch_size, beam_size, ...]`, whose \ + data type is same as `x`. + """ + x = nn.unsqueeze(x, [1]) + expand_times = [1] * len(x.shape) + expand_times[1] = self.beam_size + x = nn.expand(x, expand_times) + return x + + def _mask_probs(self, probs, finished): + """ + Mask log probabilities. It forces finished beams to allocate all probability + mass to eos and unfinished beams to remain unchanged. + + Parameters: + probs(Variable): A tensor with shape `[batch_size, beam_size, vocab_size]`, + representing the log probabilities. Its data type should be float32. + finished(Variable): A tensor with shape `[batch_size, beam_size]`, + representing the finished status for all beams. Its data type + should be bool. + + Returns: + Variable: A tensor with the same shape and data type as `x`, \ + where unfinished beams stay unchanged and finished beams are \ + replaced with a tensor with all probability on the EOS token. + """ + # TODO: use where_op + finished = tensor.cast(finished, dtype=probs.dtype) + probs = nn.elementwise_mul( + nn.expand(nn.unsqueeze(finished, [2]), [1, 1, self.vocab_size]), + self.noend_mask_tensor, + axis=-1) - nn.elementwise_mul( + probs, (finished - 1), axis=0) + return probs + + def _gather(self, x, indices, batch_size): + """ + Gather from the tensor `x` using `indices`. + + Parameters: + x(Variable): A tensor with shape `[batch_size, beam_size, ...]`. + indices(Variable): A `int64` tensor with shape `[batch_size, beam_size]`, + representing the indices that we use to gather. + batch_size(Variable): A tensor with shape `[1]`. Its data type should + be int32 or int64. + + Returns: + Variable: A tensor with the same shape and data type as `x`, \ + representing the gathered tensor. + """ + # TODO: compatibility of int32 and int64 + batch_size = tensor.cast( + batch_size, + indices.dtype) if batch_size.dtype != indices.dtype else batch_size + batch_pos = nn.expand( + nn.unsqueeze( + tensor.range( + 0, batch_size, 1, dtype=indices.dtype), [1]), + [1, self.beam_size]) + topk_coordinates = nn.stack([batch_pos, indices], axis=2) + return nn.gather_nd(x, topk_coordinates) + + class OutputWrapper( + collections.namedtuple("OutputWrapper", + ("scores", "predicted_ids", "parent_ids"))): + """ + The structure for the returned value `outputs` of `decoder.step`. + A namedtuple includes scores, predicted_ids, parent_ids as fields. + """ + pass + + class StateWrapper( + collections.namedtuple( + "StateWrapper", + ("cell_states", "log_probs", "finished", "lengths"))): + """ + The structure for the argument `states` of `decoder.step`. + A namedtuple includes cell_states, log_probs, finished, lengths as fields. + """ + pass + + def initialize(self, initial_cell_states): + """ + Initialize the BeamSearchDecoder. + + Parameters: + initial_cell_states(Variable): A (possibly nested structure of) + tensor variable[s]. An argument provided by the caller. + + Returns: + tuple: A tuple( :code:`(initial_inputs, initial_states, finished)` ). \ + `initial_inputs` is a tensor t filled by `start_token` with shape \ + `[batch_size, beam_size, 1]` when `embedding_fn` is None, or the \ + returned value of `embedding_fn(t)` when `embedding_fn` is provided. \ + `initial_states` is a nested structure(namedtuple including cell_states, \ + log_probs, finished, lengths as fields) of tensor variables, where \ + `log_probs, finished, lengths` all has a tensor value shaped \ + `[batch_size, beam_size]` with data type `float32, bool, int64`. \ + cell_states has a value with the same structure as the input \ + argument `initial_cell_states` but with tiled shape `[batch_size, beam_size, ...]`. \ + `finished` is a `bool` tensor filled by False with shape `[batch_size, beam_size]`. + """ + self.kinf = 1e9 + state = flatten(initial_cell_states)[0] + self.batch_size = nn.shape(state)[0] + + self.start_token_tensor = tensor.fill_constant( + shape=[1], dtype="int64", value=self.start_token) + self.end_token_tensor = tensor.fill_constant( + shape=[1], dtype="int64", value=self.end_token) + + init_cell_states = map_structure(self._expand_to_beam_size, + initial_cell_states) + # TODO: use fill_constant when support variable shape + init_inputs = nn.expand( + nn.unsqueeze( + nn.expand(self.start_token_tensor, [self.batch_size]), [1]), + [1, self.beam_size]) + log_probs = nn.expand( + tensor.assign( + np.array( + [[0.] + [-self.kinf] * (self.beam_size - 1)], + dtype="float32")), [self.batch_size, 1]) + # TODO: remove the restriction of force_cpu + init_finished = tensor.fill_constant_batch_size_like( + input=state, + shape=[-1, self.beam_size], + dtype="bool", + value=False, + force_cpu=True) + init_lengths = tensor.zeros_like(init_inputs) + init_inputs = self.embedding_fn( + init_inputs) if self.embedding_fn else init_inputs + return init_inputs, self.StateWrapper(init_cell_states, log_probs, + init_finished, + init_lengths), init_finished + + def _beam_search_step(self, time, logits, next_cell_states, beam_state): + """ + Calculate scores and select candidate token ids. + + Parameters: + time(Variable): An `int64` tensor with shape `[1]` provided by the caller, + representing the current time step number of decoding. + logits(Variable): A tensor with shape `[batch_size, beam_size, vocab_size]`, + representing the logits at the current time step. Its data type is float32. + next_cell_states(Variable): A (possibly nested structure of) tensor variable[s]. + It has the same structure, shape and data type as the `cell_states` of + `initial_states` returned by `initialize()`. It represents the next state + from the cell. + beam_state(Variable): A structure of tensor variables. + It is same as the `initial_states` returned by `initialize()` for + the first decoding step and `beam_search_state` returned by + `initialize()` for the others. + + Returns: + tuple: A tuple( :code:`(beam_search_output, beam_search_state)` ). \ + `beam_search_output` is a namedtuple(including scores, predicted_ids, \ + parent_ids as fields) of tensor variables, where \ + `scores, predicted_ids, parent_ids` all has a tensor value shaped \ + `[batch_size, beam_size]` with data type `float32, int64, int64`. + `beam_search_state` has the same structure, shape and data type \ + as the input argument `beam_state`. + + """ + self.vocab_size = logits.shape[-1] + self.vocab_size_tensor = tensor.fill_constant( + shape=[1], dtype="int64", value=self.vocab_size) + noend_array = [-self.kinf] * self.vocab_size + noend_array[self.end_token] = 0 + self.noend_mask_tensor = tensor.assign(np.array(noend_array, "float32")) + + step_log_probs = nn.log(nn.softmax(logits)) + step_log_probs = self._mask_probs(step_log_probs, beam_state.finished) + log_probs = nn.elementwise_add( + x=step_log_probs, y=beam_state.log_probs, axis=0) + # TODO: length penalty + scores = log_probs + scores = nn.reshape(scores, [-1, self.beam_size * self.vocab_size]) + topk_scores, topk_indices = nn.topk(input=scores, k=self.beam_size) + beam_indices = nn.elementwise_floordiv(topk_indices, + self.vocab_size_tensor) + token_indices = nn.elementwise_mod(topk_indices, self.vocab_size_tensor) + next_log_probs = self._gather( + nn.reshape(log_probs, [-1, self.beam_size * self.vocab_size]), + topk_indices, self.batch_size) + next_cell_states = map_structure( + lambda x: self._gather(x, beam_indices, self.batch_size), + next_cell_states) + next_finished = self._gather(beam_state.finished, beam_indices, + self.batch_size) + next_lengths = self._gather(beam_state.lengths, beam_indices, + self.batch_size) + next_lengths = next_lengths + tensor.cast( + nn.logical_not(next_finished), beam_state.lengths.dtype) + next_finished = control_flow.logical_or( + next_finished, + control_flow.equal(token_indices, self.end_token_tensor)) + + beam_search_output = self.OutputWrapper(topk_scores, token_indices, + beam_indices) + beam_search_state = self.StateWrapper(next_cell_states, next_log_probs, + next_finished, next_lengths) + return beam_search_output, beam_search_state + + def step(self, time, inputs, states, **kwargs): + """ + Perform a beam search decoding step, which uses `cell` to get probabilities, + and follows a beam search step to calculate scores and select candidate + token ids. + + Parameters: + time(Variable): An `int64` tensor with shape `[1]` provided by the caller, + representing the current time step number of decoding. + inputs(Variable): A tensor variable. It is same as `initial_inputs` + returned by `initialize()` for the first decoding step and + `next_inputs` returned by `step()` for the others. + states(Variable): A structure of tensor variables. + It is same as the `initial_states` returned by `initialize()` for + the first decoding step and `beam_search_state` returned by + `step()` for the others. + **kwargs: Additional keyword arguments, provided by the caller. + + Returns: + tuple: A tuple( :code:`(beam_search_output, beam_search_state, next_inputs, finished)` ). \ + `beam_search_state` and `next_inputs` have the same structure, \ + shape and data type as the input arguments `states` and `inputs` separately. \ + `beam_search_output` is a namedtuple(including scores, predicted_ids, \ + parent_ids as fields) of tensor variables, where \ + `scores, predicted_ids, parent_ids` all has a tensor value shaped \ + `[batch_size, beam_size]` with data type `float32, int64, int64`. \ + `finished` is a `bool` tensor with shape `[batch_size, beam_size]`. + """ + inputs = map_structure(self._merge_batch_beams, inputs) + cell_states = map_structure(self._merge_batch_beams, states.cell_states) + cell_outputs, next_cell_states = self.cell(inputs, cell_states, + **kwargs) + cell_outputs = map_structure(self._split_batch_beams, cell_outputs) + next_cell_states = map_structure(self._split_batch_beams, + next_cell_states) + + if self.output_fn is not None: + cell_outputs = self.output_fn(cell_outputs) + + beam_search_output, beam_search_state = self._beam_search_step( + time=time, + logits=cell_outputs, + next_cell_states=next_cell_states, + beam_state=states) + finished = beam_search_state.finished + sample_ids = beam_search_output.predicted_ids + next_inputs = self.embedding_fn( + sample_ids) if self.embedding_fn else sample_ids + + return (beam_search_output, beam_search_state, next_inputs, finished) + + def finalize(self, outputs, final_states, sequence_lengths): + """ + Use `gather_tree` to backtrace along the beam search tree and construct + the full predicted sequences. + + Parameters: + outputs(Variable): A structure(namedtuple) of tensor variables, + The structure and data type is same as `output_dtype`. + The tensor stacks all time steps' output thus has shape + `[time_step, batch_size, ...]`, which is done by the caller. + final_states(Variable): A structure(namedtuple) of tensor variables. + It is the `next_states` returned by `decoder.step` at last + decoding step, thus has the same structrue, shape and data type + with states at any time step. + sequence_lengths(Variable): An `int64` tensor shaped `[batch_size, beam_size]`. + It contains sequence lengths for each beam determined during + decoding. + + Returns: + tuple: A tuple( :code:`(predicted_ids, final_states)` ). \ + `predicted_ids` is an `int64` tensor shaped \ + `[time_step, batch_size, beam_size]`. `final_states` is the same \ + as the input argument `final_states`. + """ + predicted_ids = nn.gather_tree(outputs.predicted_ids, + outputs.parent_ids) + # TODO: use FinalBeamSearchDecoderOutput as output + return predicted_ids, final_states + + @property + def output_dtype(self): + """ + The nested structure of data types for beam search output. It is a namedtuple + including scores, predicted_ids, parent_ids as fields. + """ + return self.OutputWrapper( + scores="float32", predicted_ids="int64", parent_ids="int64") + + +def dynamic_decode(decoder, + inits=None, + max_step_num=None, + output_time_major=False, + **kwargs): + """ + Dynamic decoding performs :code:`decoder.step()` repeatedly until the returned + Tensor indicating finished status contains all True values or the number of + decoding step reachs to :attr:`max_step_num`. + + :code:`decoder.initialize()` would be called once before the decoding loop. + If the `decoder` has implemented `finalize` method, :code:`decoder.finalize()` + would be called once after the decoding loop. + + Parameters: + decoder(Decoder): An instance of `Decoder`. + inits(object, optional): Argument passed to `decoder.initialize`. + Default `None`. + max_step_num(int, optional): The maximum number of steps. If not provided, + decode until the decoder is fully done, or in other words, the returned + Tensor by :code:`decoder.step()` indicating finished status contains + all True). Default `None`. + output_time_major(bool, optional): Indicate the data layout of Tensor included + in the final outpus(the first returned value of this method). If + attr:`False`, the data layout would be batch major with shape + `[batch_size, seq_len, ...]`. If attr:`True`, the data layout would + be time major with shape `[seq_len, batch_size, ...]`. Default: `False`. + **kwargs: Additional keyword arguments. Arguments passed to `decoder.step`. + + Returns: + tuple: A tuple( :code:`(final_outputs, final_states)` ) including the final \ + outputs and states, both are Tensor or nested structure of Tensor. \ + `final_outputs` has the same structure and data types as \ + :code:`decoder.output_dtype` , and each Tenser in `final_outputs` \ + is the stacked of all decoding steps' outputs, which might be revised \ + by :code:`decoder.finalize` . `final_states` is the counterpart \ + at last time step of initial states returned by :code:`decoder.initialize` , \ + thus has the same structure with it and has tensors with same shapes \ + and data types. + + + Examples: + + .. code-block:: python + + import paddle.fluid as fluid + import paddle.fluid.layers as layers + from paddle.fluid.layers import GRUCell, BeamSearchDecoder, dynamic_decode + + encoder_output = fluid.data(name="encoder_output", + shape=[-1, 32, 128], + dtype="float32") + trg_embeder = lambda x: fluid.embedding( + x, size=[10000, 128], param_attr=fluid.ParamAttr(name="trg_embedding")) + output_layer = lambda x: layers.fc(x, + size=10000, + num_flatten_dims=len(x.shape) - 1, + param_attr=fluid.ParamAttr(name= + "output_w"), + bias_attr=False) + decoder_cell = GRUCell(hidden_size=128) + decoder = BeamSearchDecoder(decoder_cell, + start_token=0, + end_token=1, + beam_size=4, + embedding_fn=trg_embeder, + output_fn=output_layer) + + outputs = dynamic_decode( + decoder=decoder, inits=decoder_cell.get_initial_states(encoder_output)) + """ + initial_inputs, initial_states, initial_finished = decoder.initialize(inits) + global_inputs, global_states, global_finished = ( + initial_inputs, initial_states, initial_finished) + + step_idx = tensor.fill_constant(shape=[1], dtype="int64", value=0) + cond = control_flow.logical_not((nn.reduce_all(initial_finished))) + if max_step_num is not None: + max_step_num = tensor.fill_constant( + shape=[1], dtype="int64", value=max_step_num) + while_op = control_flow.While(cond) + + inputs = map_structure(lambda x: x, initial_inputs) + states = map_structure(lambda x: x, initial_states) + outputs_arrays = map_structure( + lambda dtype: control_flow.create_array(dtype), decoder.output_dtype) + sequence_lengths = tensor.cast(tensor.zeros_like(initial_finished), "int64") + + def _maybe_copy(state, new_state, step_mask): + # TODO: use where_op + new_state = nn.elementwise_mul( + new_state, step_mask, axis=0) - nn.elementwise_mul( + state, (step_mask - 1), axis=0) + return new_state + + def _transpose_batch_time(x): + return nn.transpose(x, [1, 0] + list(range(2, len(x.shape)))) + + # While + with while_op.block(): + (outputs, next_states, next_inputs, + next_finished) = decoder.step(step_idx, inputs, states, **kwargs) + next_sequence_lengths = nn.elementwise_add( + sequence_lengths, + tensor.cast( + control_flow.logical_not(global_finished), + sequence_lengths.dtype)) + + map_structure( + lambda x, x_array: control_flow.array_write( + x, i=step_idx, array=x_array), outputs, outputs_arrays) + control_flow.increment(x=step_idx, value=1.0, in_place=True) + map_structure(tensor.assign, next_inputs, global_inputs) + map_structure(tensor.assign, next_states, global_states) + tensor.assign(next_finished, global_finished) + tensor.assign(next_sequence_lengths, sequence_lengths) + if max_step_num is not None: + control_flow.logical_and( + control_flow.logical_not(nn.reduce_all(next_finished)), + control_flow.less_equal(step_idx, max_step_num), cond) + else: + control_flow.logical_not(nn.reduce_all(next_finished), cond) + + final_outputs = map_structure( + lambda array: tensor.tensor_array_to_tensor( + array, axis=0, use_stack=True)[0], outputs_arrays) + final_states = global_states + + try: + final_outputs, final_states = decoder.finalize( + final_outputs, global_states, sequence_lengths) + except NotImplementedError: + pass + + if not output_time_major: + final_outputs = map_structure(_transpose_batch_time, final_outputs) + + return final_outputs, final_states diff --git a/python/paddle/fluid/layers/tensor.py b/python/paddle/fluid/layers/tensor.py index 4523dfd9589e8..3de8abab47bab 100644 --- a/python/paddle/fluid/layers/tensor.py +++ b/python/paddle/fluid/layers/tensor.py @@ -256,50 +256,85 @@ def concat(input, axis=0, name=None): return out -def tensor_array_to_tensor(input, axis=1, name=None): +def tensor_array_to_tensor(input, axis=1, name=None, use_stack=False): """ - This OP concatenates the input LodTensorArray along the axis. + This function concatenates or stacks all tensors in the input LoDTensorArray + along the axis mentioned and returns that as the output. + + For Example: + + .. code-block:: text + + Case 1: + + Given: + + input.data = {[[0.6, 0.1, 0.3], + [0.5, 0.3, 0.2]], + [[1.3], + [1.8]], + [[2.3, 2.1], + [2.5, 2.4]]} + + axis = 1, use_stack = False + + Then: + + output.data = [[0.6, 0.1, 0.3, 1.3, 2.3, 2.1], + [0.5, 0.3, 0.2, 1.8, 2.5, 2.4]] + + output_index.data = [3, 1, 2] + + Case 2: + + Given: + + input.data = {[[0.6, 0.1], + [0.5, 0.3]], + [[0.3, 1.3], + [0.2, 1.8]], + [[2.3, 2.1], + [2.5, 2.4]]} + + axis = 1, use_stack = True + + Then: + + output.data = [[[0.6, 0.1] + [0.3, 1.3] + [2.3, 2.1], + [[0.5, 0.3] + [0.2, 1.8] + [2.5, 2.4]]] + + output_index.data = [2, 2, 2] Args: - input(Variable): A LodTensorArray with data type float32, float64, int32, - int64. - axis(int, optional): Axis to compute indices along. The effective range - is [-R, R), where R is Rank(x). when axis<0, it works the same way - as axis+R. Default is 1. - name (str, optional): The default value is None. Normally there is no - need for user to set this property. For more information, please - refer to :ref:`api_guide_Name`. + input(Variable): A LodTensorArray variable. + axis(int): The axis along which the tensors in attr::`input` will be + concatenated or stacked. + name(str|None): A name for this layer(optional). If set None, the layer + will be named automatically. + use_stack(bool): Act as concat_op or stack_op. For stack mode, all + tensors in the tensor array must have the same shape. Returns: - Variable: A LoDTensor with the same data type as input's - Variable: The input LodTensorArray items' dims along the axis. + Variable: The concatenated or stacked tensor variable. + Variable: A 1-D tensor variable with int32 data type. The data in this \ + tensor contains all input including tensors' sizes along the axis. Examples: .. code-block:: python import paddle.fluid as fluid import numpy as np - - place = fluid.CPUPlace() - - x1 = fluid.data(name="x", shape=[2,2], lod_level=0) - tmp = fluid.layers.fill_constant(shape=[2,3], dtype="float32", value=1) - x_arr = fluid.layers.create_array(dtype="float32") - c0 = fluid.layers.fill_constant(shape=[1], dtype='int64', value=0) - fluid.layers.array_write(x=tmp, i=c0, array=x_arr) - c1 = fluid.layers.fill_constant(shape=[1], dtype='int64', value=1) - fluid.layers.array_write(x=x1, i=c1, array=x_arr) - output, output_index = fluid.layers.tensor_array_to_tensor(input=x_arr, axis=1) - - exe = fluid.Executor(place) - exe.run(fluid.default_startup_program()) - - feedx = fluid.LoDTensor() - feedx.set(np.array([[1.3,-2.4],[0,4]]).astype("float32"), place) - res = exe.run(fluid.default_main_program(), feed={'x':feedx}, fetch_list=[output], return_numpy=False) - print(np.array(res[0])) - # [[ 1. 1. 1. 1.3 -2.4] - # [ 1. 1. 1. 0. 4. ]] + x0 = fluid.layers.assign(np.random.rand(2, 2).astype("float32")) + x1 = fluid.layers.assign(np.random.rand(2, 2).astype("float32")) + i = fluid.layers.fill_constant(shape=[1], dtype="int64", value=0) + array = fluid.layers.create_array(dtype='float32') + fluid.layers.array_write(x0, i, array) + fluid.layers.array_write(x1, i + 1, array) + output, output_index = fluid.layers.tensor_array_to_tensor(input=array) """ helper = LayerHelper('tensor_array_to_tensor', **locals()) out = helper.create_variable_for_type_inference(dtype=helper.input_dtype()) @@ -309,7 +344,8 @@ def tensor_array_to_tensor(input, axis=1, name=None): inputs={'X': input}, outputs={'Out': [out], 'OutIndex': [out_index]}, - attrs={'axis': axis}) + attrs={'axis': axis, + 'use_stack': use_stack}) return out, out_index @@ -486,7 +522,8 @@ def fill_constant_batch_size_like(input, dtype, value, input_dim_idx=0, - output_dim_idx=0): + output_dim_idx=0, + force_cpu=False): """ This OP creates a Tesnor accroding the shape and dtype, and initializes the Tensor with the constants provided in ``value``. When the input is LoDTensor @@ -506,6 +543,7 @@ def fill_constant_batch_size_like(input, The default value is 0. output_dim_idx(int): Used to specify which dimension of Tensor is created to be set the value of batch_size of input Tensor. The default value is 0. + force_cpu(bool): data should be on CPU if it's true, defalut value is False. Returns: Variable: Tensor which will be created according to dtype. @@ -531,7 +569,8 @@ def fill_constant_batch_size_like(input, 'dtype': out.dtype, 'value': float(value), 'input_dim_idx': input_dim_idx, - 'output_dim_idx': output_dim_idx + 'output_dim_idx': output_dim_idx, + 'force_cpu': force_cpu or force_init_on_cpu() }) out.stop_gradient = True return out diff --git a/python/paddle/fluid/layers/utils.py b/python/paddle/fluid/layers/utils.py index 5688f04ab2382..c8270f88e4012 100644 --- a/python/paddle/fluid/layers/utils.py +++ b/python/paddle/fluid/layers/utils.py @@ -13,6 +13,8 @@ # limitations under the License. from __future__ import print_function +import collections +import six import numpy as np @@ -59,3 +61,173 @@ def convert_to_list(value, n, name, dtype=np.int): "including element " + str(single_value) + " of type" + " " + str(type(single_value))) return value_list + + +def is_sequence(seq): + """ + Whether `seq` is an entry or nested structure + """ + if isinstance(seq, dict): + return True + return (isinstance(seq, collections.Sequence) and + not isinstance(seq, six.string_types)) + + +def _sorted(dict_): + """ + Returns a sorted list of the dict keys, with error if keys not sortable. + """ + try: + return sorted(six.iterkeys(dict_)) + except TypeError: + raise TypeError("nest only supports dicts with sortable keys.") + + +def _yield_value(iterable): + if isinstance(iterable, dict): + # Iterate through dictionaries in a deterministic order by sorting the + # keys. Notice this means that we ignore the original order of `OrderedDict` + # instances. This is intentional, to avoid potential bugs caused by mixing + # ordered and plain dicts (e.g., flattening a dict but using a + # corresponding `OrderedDict` to pack it back). + for key in _sorted(iterable): + yield iterable[key] + else: + for value in iterable: + yield value + + +def _yield_flat_nest(nest): + for n in _yield_value(nest): + if is_sequence(n): + for ni in _yield_flat_nest(n): + yield ni + else: + yield n + + +def flatten(nest): + """ + Traverse all entries in the nested structure and put them into an list. + """ + if is_sequence(nest): + return list(_yield_flat_nest(nest)) + else: + return [nest] + + +def _sequence_like(instance, args): + """ + Convert the sequence `args` to the same type as `instance`. + """ + if isinstance(instance, dict): + # Pack dictionaries in a deterministic order by sorting the keys. + # Notice this means that we ignore the original order of `OrderedDict` + # instances. This is intentional, to avoid potential bugs caused by mixing + # ordered and plain dicts (e.g., flattening a dict but using a + # corresponding `OrderedDict` to pack it back). + result = dict(zip(_sorted(instance), args)) + return type(instance)((key, result[key]) + for key in six.iterkeys(instance)) + elif (isinstance(instance, tuple) and hasattr(instance, "_fields") and + isinstance(instance._fields, collections.Sequence) and + all(isinstance(f, six.string_types) for f in instance._fields)): + # This is a namedtuple + return type(instance)(*args) + else: + # Not a namedtuple + return type(instance)(args) + + +def _packed_nest_with_indices(structure, flat, index): + """ + Helper function for pack_sequence_as. + """ + packed = [] + for s in _yield_value(structure): + if is_sequence(s): + new_index, child = _packed_nest_with_indices(s, flat, index) + packed.append(_sequence_like(s, child)) + index = new_index + else: + packed.append(flat[index]) + index += 1 + return index, packed + + +def pack_sequence_as(structure, flat_sequence): + """ + Pack a given flattened sequence into a given structure. + """ + if not is_sequence(flat_sequence): + raise TypeError("flat_sequence must be a sequence") + if not is_sequence(structure): + if len(flat_sequence) != 1: + raise ValueError( + "Structure is a scalar but len(flat_sequence) == %d > 1" % + len(flat_sequence)) + return flat_sequence[0] + flat_structure = flatten(structure) + if len(flat_structure) != len(flat_sequence): + raise ValueError( + "Could not pack sequence. Structure had %d elements, but flat_sequence " + "had %d elements. Structure: %s, flat_sequence: %s." % + (len(flat_structure), len(flat_sequence), structure, flat_sequence)) + _, packed = _packed_nest_with_indices(structure, flat_sequence, 0) + return _sequence_like(structure, packed) + + +def map_structure(func, *structure): + """ + Apply `func` to each entry in `structure` and return a new structure. + """ + flat_structure = [flatten(s) for s in structure] + entries = zip(*flat_structure) + return pack_sequence_as(structure[0], [func(*x) for x in entries]) + + +def _recursive_assert_same_structure(nest1, nest2, check_types): + """ + Helper function for `assert_same_structure`. + """ + is_sequence_nest1 = is_sequence(nest1) + if is_sequence_nest1 != is_sequence(nest2): + raise ValueError( + "The two structures don't have the same nested structure.\n\n" + "First structure: %s\n\nSecond structure: %s." % (nest1, nest2)) + if not is_sequence_nest1: + return # finished checking + if check_types: + type_nest1 = type(nest1) + type_nest2 = type(nest2) + if type_nest1 != type_nest2: + raise TypeError( + "The two structures don't have the same sequence type. First " + "structure has type %s, while second structure has type %s." % + (type_nest1, type_nest2)) + if isinstance(nest1, dict): + keys1 = set(six.iterkeys(nest1)) + keys2 = set(six.iterkeys(nest2)) + if keys1 != keys2: + raise ValueError( + "The two dictionaries don't have the same set of keys. First " + "structure has keys {}, while second structure has keys {}." + .format(keys1, keys2)) + nest1_as_sequence = [n for n in _yield_value(nest1)] + nest2_as_sequence = [n for n in _yield_value(nest2)] + for n1, n2 in zip(nest1_as_sequence, nest2_as_sequence): + _recursive_assert_same_structure(n1, n2, check_types) + + +def assert_same_structure(nest1, nest2, check_types=True): + """ + Confirm two nested structures with the same structure. + """ + len_nest1 = len(flatten(nest1)) if is_sequence(nest1) else 1 + len_nest2 = len(flatten(nest2)) if is_sequence(nest2) else 1 + if len_nest1 != len_nest2: + raise ValueError("The two structures don't have the same number of " + "elements.\n\nFirst structure (%i elements): %s\n\n" + "Second structure (%i elements): %s" % + (len_nest1, nest1, len_nest2, nest2)) + _recursive_assert_same_structure(nest1, nest2, check_types) diff --git a/python/paddle/fluid/tests/unittests/test_gather_tree_op.py b/python/paddle/fluid/tests/unittests/test_gather_tree_op.py new file mode 100644 index 0000000000000..63c5fc395023e --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_gather_tree_op.py @@ -0,0 +1,65 @@ +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import unittest +import numpy as np +from op_test import OpTest +import paddle.fluid as fluid + + +class TestGatherTreeOp(OpTest): + def setUp(self): + self.op_type = "gather_tree" + max_length, batch_size, beam_size = 5, 2, 2 + ids = np.random.randint( + 0, high=10, size=(max_length, batch_size, beam_size)) + parents = np.random.randint( + 0, high=beam_size, size=(max_length, batch_size, beam_size)) + self.inputs = {"Ids": ids, "Parents": parents} + self.outputs = {'Out': self.backtrace(ids, parents)} + + def test_check_output(self): + self.check_output() + + @staticmethod + def backtrace(ids, parents): + out = np.zeros_like(ids) + (max_length, batch_size, beam_size) = ids.shape + for batch in range(batch_size): + for beam in range(beam_size): + out[max_length - 1, batch, beam] = ids[max_length - 1, batch, + beam] + parent = parents[max_length - 1, batch, beam] + for step in range(max_length - 2, -1, -1): + out[step, batch, beam] = ids[step, batch, parent] + parent = parents[step, batch, parent] + return out + + +class TestGatherTreeOpAPI(OpTest): + def test_case(self): + ids = fluid.layers.data( + name='ids', shape=[5, 2, 2], dtype='int64', append_batch_size=False) + parents = fluid.layers.data( + name='parents', + shape=[5, 2, 2], + dtype='int64', + append_batch_size=False) + final_sequences = fluid.layers.gather_tree(ids, parents) + + +if __name__ == "__main__": + unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_rnn_cell_api.py b/python/paddle/fluid/tests/unittests/test_rnn_cell_api.py new file mode 100644 index 0000000000000..f553b55ddde5f --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_rnn_cell_api.py @@ -0,0 +1,249 @@ +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import unittest +import numpy + +import paddle.fluid as fluid +import paddle.fluid.layers as layers +import paddle.fluid.core as core + +from paddle.fluid.executor import Executor +from paddle.fluid import framework + +from paddle.fluid.layers.rnn import LSTMCell, GRUCell, RNNCell +from paddle.fluid.layers import rnn as dynamic_rnn +from paddle.fluid import contrib +from paddle.fluid.contrib.layers import basic_lstm +import paddle.fluid.layers.utils as utils + +import numpy as np + + +class TestLSTMCell(unittest.TestCase): + def setUp(self): + self.batch_size = 4 + self.input_size = 16 + self.hidden_size = 16 + + def test_run(self): + inputs = fluid.data( + name='inputs', shape=[None, self.input_size], dtype='float32') + pre_hidden = fluid.data( + name='pre_hidden', shape=[None, self.hidden_size], dtype='float32') + pre_cell = fluid.data( + name='pre_cell', shape=[None, self.hidden_size], dtype='float32') + + cell = LSTMCell(self.hidden_size) + lstm_hidden_new, lstm_states_new = cell(inputs, [pre_hidden, pre_cell]) + + lstm_unit = contrib.layers.rnn_impl.BasicLSTMUnit( + "basicLSTM", self.hidden_size, None, None, None, None, 1.0, + "float32") + lstm_hidden, lstm_cell = lstm_unit(inputs, pre_hidden, pre_cell) + + if core.is_compiled_with_cuda(): + place = core.CUDAPlace(0) + else: + place = core.CPUPlace() + exe = Executor(place) + exe.run(framework.default_startup_program()) + + inputs_np = np.random.uniform( + -0.1, 0.1, (self.batch_size, self.input_size)).astype('float32') + pre_hidden_np = np.random.uniform( + -0.1, 0.1, (self.batch_size, self.hidden_size)).astype('float32') + pre_cell_np = np.random.uniform( + -0.1, 0.1, (self.batch_size, self.hidden_size)).astype('float32') + + param_names = [[ + "LSTMCell/BasicLSTMUnit_0.w_0", "basicLSTM/BasicLSTMUnit_0.w_0" + ], ["LSTMCell/BasicLSTMUnit_0.b_0", "basicLSTM/BasicLSTMUnit_0.b_0"]] + + for names in param_names: + param = np.array(fluid.global_scope().find_var(names[0]).get_tensor( + )) + param = np.random.uniform( + -0.1, 0.1, size=param.shape).astype('float32') + fluid.global_scope().find_var(names[0]).get_tensor().set(param, + place) + fluid.global_scope().find_var(names[1]).get_tensor().set(param, + place) + + out = exe.run(feed={ + 'inputs': inputs_np, + 'pre_hidden': pre_hidden_np, + 'pre_cell': pre_cell_np + }, + fetch_list=[lstm_hidden_new, lstm_hidden]) + + self.assertTrue(np.allclose(out[0], out[1], rtol=1e-4, atol=0)) + + +class TestGRUCell(unittest.TestCase): + def setUp(self): + self.batch_size = 4 + self.input_size = 16 + self.hidden_size = 16 + + def test_run(self): + inputs = fluid.data( + name='inputs', shape=[None, self.input_size], dtype='float32') + pre_hidden = layers.data( + name='pre_hidden', + shape=[None, self.hidden_size], + append_batch_size=False, + dtype='float32') + + cell = GRUCell(self.hidden_size) + gru_hidden_new, _ = cell(inputs, pre_hidden) + + gru_unit = contrib.layers.rnn_impl.BasicGRUUnit( + "basicGRU", self.hidden_size, None, None, None, None, "float32") + gru_hidden = gru_unit(inputs, pre_hidden) + + if core.is_compiled_with_cuda(): + place = core.CUDAPlace(0) + else: + place = core.CPUPlace() + exe = Executor(place) + exe.run(framework.default_startup_program()) + + inputs_np = np.random.uniform( + -0.1, 0.1, (self.batch_size, self.input_size)).astype('float32') + pre_hidden_np = np.random.uniform( + -0.1, 0.1, (self.batch_size, self.hidden_size)).astype('float32') + + param_names = [ + ["GRUCell/BasicGRUUnit_0.w_0", "basicGRU/BasicGRUUnit_0.w_0"], + ["GRUCell/BasicGRUUnit_0.w_1", "basicGRU/BasicGRUUnit_0.w_1"], + ["GRUCell/BasicGRUUnit_0.b_0", "basicGRU/BasicGRUUnit_0.b_0"], + ["GRUCell/BasicGRUUnit_0.b_1", "basicGRU/BasicGRUUnit_0.b_1"] + ] + + for names in param_names: + param = np.array(fluid.global_scope().find_var(names[0]).get_tensor( + )) + param = np.random.uniform( + -0.1, 0.1, size=param.shape).astype('float32') + fluid.global_scope().find_var(names[0]).get_tensor().set(param, + place) + fluid.global_scope().find_var(names[1]).get_tensor().set(param, + place) + + out = exe.run(feed={'inputs': inputs_np, + 'pre_hidden': pre_hidden_np}, + fetch_list=[gru_hidden_new, gru_hidden]) + + self.assertTrue(np.allclose(out[0], out[1], rtol=1e-4, atol=0)) + + +class TestRnn(unittest.TestCase): + def setUp(self): + self.batch_size = 4 + self.input_size = 16 + self.hidden_size = 16 + self.seq_len = 4 + + def test_run(self): + inputs_basic_lstm = fluid.data( + name='inputs_basic_lstm', + shape=[None, None, self.input_size], + dtype='float32') + sequence_length = fluid.data( + name="sequence_length", shape=[None], dtype='int64') + + inputs_dynamic_rnn = layers.transpose(inputs_basic_lstm, perm=[1, 0, 2]) + cell = LSTMCell(self.hidden_size, name="LSTMCell_for_rnn") + output, final_state = dynamic_rnn( + cell=cell, + inputs=inputs_dynamic_rnn, + sequence_length=sequence_length, + is_reverse=False) + output_new = layers.transpose(output, perm=[1, 0, 2]) + + rnn_out, last_hidden, last_cell = basic_lstm(inputs_basic_lstm, None, None, self.hidden_size, num_layers=1, \ + batch_first = False, bidirectional=False, sequence_length=sequence_length, forget_bias = 1.0) + + if core.is_compiled_with_cuda(): + place = core.CUDAPlace(0) + else: + place = core.CPUPlace() + exe = Executor(place) + exe.run(framework.default_startup_program()) + + inputs_basic_lstm_np = np.random.uniform( + -0.1, 0.1, + (self.seq_len, self.batch_size, self.input_size)).astype('float32') + sequence_length_np = np.ones( + self.batch_size, dtype='int64') * self.seq_len + + inputs_np = np.random.uniform( + -0.1, 0.1, (self.batch_size, self.input_size)).astype('float32') + pre_hidden_np = np.random.uniform( + -0.1, 0.1, (self.batch_size, self.hidden_size)).astype('float32') + pre_cell_np = np.random.uniform( + -0.1, 0.1, (self.batch_size, self.hidden_size)).astype('float32') + + param_names = [[ + "LSTMCell_for_rnn/BasicLSTMUnit_0.w_0", + "basic_lstm_layers_0/BasicLSTMUnit_0.w_0" + ], [ + "LSTMCell_for_rnn/BasicLSTMUnit_0.b_0", + "basic_lstm_layers_0/BasicLSTMUnit_0.b_0" + ]] + + for names in param_names: + param = np.array(fluid.global_scope().find_var(names[0]).get_tensor( + )) + param = np.random.uniform( + -0.1, 0.1, size=param.shape).astype('float32') + fluid.global_scope().find_var(names[0]).get_tensor().set(param, + place) + fluid.global_scope().find_var(names[1]).get_tensor().set(param, + place) + + out = exe.run(feed={ + 'inputs_basic_lstm': inputs_basic_lstm_np, + 'sequence_length': sequence_length_np, + 'inputs': inputs_np, + 'pre_hidden': pre_hidden_np, + 'pre_cell': pre_cell_np + }, + fetch_list=[output_new, rnn_out]) + + self.assertTrue(np.allclose(out[0], out[1], rtol=1e-4)) + + +class TestRnnUtil(unittest.TestCase): + """ + Test cases for rnn apis' utility methods for coverage. + """ + + def test_case(self): + inputs = {"key1": 1, "key2": 2} + func = lambda x: x + 1 + outputs = utils.map_structure(func, inputs) + utils.assert_same_structure(inputs, outputs) + try: + inputs["key3"] = 3 + utils.assert_same_structure(inputs, outputs) + except ValueError as identifier: + pass + + +if __name__ == '__main__': + unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_rnn_decode_api.py b/python/paddle/fluid/tests/unittests/test_rnn_decode_api.py new file mode 100644 index 0000000000000..55365abd4931f --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_rnn_decode_api.py @@ -0,0 +1,214 @@ +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import unittest +import numpy + +import paddle.fluid as fluid +import paddle.fluid.layers as layers +import paddle.fluid.core as core + +from paddle.fluid.executor import Executor +from paddle.fluid import framework + +from paddle.fluid.layers.rnn import LSTMCell, GRUCell, RNNCell, BeamSearchDecoder, dynamic_decode +from paddle.fluid.layers import rnn as dynamic_rnn +from paddle.fluid import contrib +from paddle.fluid.contrib.layers import basic_lstm + +import numpy as np + + +class EncoderCell(RNNCell): + def __init__(self, num_layers, hidden_size, dropout_prob=0.): + self.num_layers = num_layers + self.hidden_size = hidden_size + self.dropout_prob = dropout_prob + self.lstm_cells = [] + for i in range(num_layers): + self.lstm_cells.append(LSTMCell(hidden_size)) + + def call(self, step_input, states): + new_states = [] + for i in range(self.num_layers): + out, new_state = self.lstm_cells[i](step_input, states[i]) + step_input = layers.dropout( + out, self.dropout_prob) if self.dropout_prob > 0 else out + new_states.append(new_state) + return step_input, new_states + + @property + def state_shape(self): + return [cell.state_shape for cell in self.lstm_cells] + + +class DecoderCell(RNNCell): + def __init__(self, num_layers, hidden_size, dropout_prob=0.): + self.num_layers = num_layers + self.hidden_size = hidden_size + self.dropout_prob = dropout_prob + self.lstm_cells = [] + for i in range(num_layers): + self.lstm_cells.append(LSTMCell(hidden_size)) + + def attention(self, hidden, encoder_output, encoder_padding_mask): + query = layers.fc(hidden, + size=encoder_output.shape[-1], + bias_attr=False) + attn_scores = layers.matmul( + layers.unsqueeze(query, [1]), encoder_output, transpose_y=True) + if encoder_padding_mask is not None: + attn_scores = layers.elementwise_add(attn_scores, + encoder_padding_mask) + attn_scores = layers.softmax(attn_scores) + attn_out = layers.squeeze( + layers.matmul(attn_scores, encoder_output), [1]) + attn_out = layers.concat([attn_out, hidden], 1) + attn_out = layers.fc(attn_out, size=self.hidden_size, bias_attr=False) + return attn_out + + def call(self, + step_input, + states, + encoder_output, + encoder_padding_mask=None): + lstm_states, input_feed = states + new_lstm_states = [] + step_input = layers.concat([step_input, input_feed], 1) + for i in range(self.num_layers): + out, new_lstm_state = self.lstm_cells[i](step_input, lstm_states[i]) + step_input = layers.dropout( + out, self.dropout_prob) if self.dropout_prob > 0 else out + new_lstm_states.append(new_lstm_state) + out = self.attention(step_input, encoder_output, encoder_padding_mask) + return out, [new_lstm_states, out] + + +class TestDynamicDecode(unittest.TestCase): + def setUp(self): + self.batch_size = 4 + self.input_size = 16 + self.hidden_size = 16 + self.seq_len = 4 + + def test_run(self): + start_token = 0 + end_token = 1 + src_vocab_size = 10 + trg_vocab_size = 10 + num_layers = 1 + hidden_size = self.hidden_size + beam_size = 8 + max_length = self.seq_len + + src = layers.data(name="src", shape=[-1, 1], dtype='int64') + src_len = layers.data(name="src_len", shape=[-1], dtype='int64') + + trg = layers.data(name="trg", shape=[-1, 1], dtype='int64') + trg_len = layers.data(name="trg_len", shape=[-1], dtype='int64') + + src_embeder = lambda x: fluid.embedding( + x, + size=[src_vocab_size, hidden_size], + param_attr=fluid.ParamAttr(name="src_embedding")) + + trg_embeder = lambda x: fluid.embedding( + x, + size=[trg_vocab_size, hidden_size], + param_attr=fluid.ParamAttr(name="trg_embedding")) + + # use basic_lstm + encoder_cell = EncoderCell(num_layers, hidden_size) + encoder_output, encoder_final_state = dynamic_rnn( + cell=encoder_cell, + inputs=src_embeder(src), + sequence_length=src_len, + is_reverse=False) + + src_mask = layers.sequence_mask( + src_len, maxlen=layers.shape(src)[1], dtype='float32') + encoder_padding_mask = (src_mask - 1.0) * 1000000000 + encoder_padding_mask = layers.unsqueeze(encoder_padding_mask, [1]) + + decoder_cell = DecoderCell(num_layers, hidden_size) + decoder_initial_states = [ + encoder_final_state, decoder_cell.get_initial_states( + batch_ref=encoder_output, shape=[hidden_size]) + ] + + decoder_output, _ = dynamic_rnn( + cell=decoder_cell, + inputs=trg_embeder(trg), + initial_states=decoder_initial_states, + sequence_length=None, + encoder_output=encoder_output, + encoder_padding_mask=encoder_padding_mask) + + output_layer = lambda x: layers.fc(x, + size=trg_vocab_size, + num_flatten_dims=len(x.shape) - 1, + param_attr=fluid.ParamAttr( + name="output_w"), + bias_attr=False) + + # inference + encoder_output = BeamSearchDecoder.tile_beam_merge_with_batch( + encoder_output, beam_size) + encoder_padding_mask = BeamSearchDecoder.tile_beam_merge_with_batch( + encoder_padding_mask, beam_size) + beam_search_decoder = BeamSearchDecoder( + decoder_cell, + start_token, + end_token, + beam_size, + embedding_fn=trg_embeder, + output_fn=output_layer) + outputs, _ = dynamic_decode( + beam_search_decoder, + inits=decoder_initial_states, + max_step_num=max_length, + encoder_output=encoder_output, + encoder_padding_mask=encoder_padding_mask) + + if core.is_compiled_with_cuda(): + place = core.CUDAPlace(0) + else: + place = core.CPUPlace() + exe = Executor(place) + exe.run(framework.default_startup_program()) + + src_np = np.random.randint( + 0, src_vocab_size, (self.batch_size, max_length)).astype('int64') + src_len_np = np.ones(self.batch_size, dtype='int64') * max_length + trg_np = np.random.randint( + 0, trg_vocab_size, (self.batch_size, max_length)).astype('int64') + trg_len_np = np.ones(self.batch_size, dtype='int64') * max_length + + out = exe.run(feed={ + 'src': src_np, + 'src_len': src_len_np, + 'trg': trg_np, + 'trg_len': trg_len_np + }, + fetch_list=[outputs]) + + self.assertTrue(out[0].shape[0] == self.batch_size) + self.assertTrue(out[0].shape[1] <= max_length + 1) + self.assertTrue(out[0].shape[2] == beam_size) + + +if __name__ == '__main__': + unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_tensor_array_to_tensor.py b/python/paddle/fluid/tests/unittests/test_tensor_array_to_tensor.py index 78b95de7e07b1..23859a6c7d785 100644 --- a/python/paddle/fluid/tests/unittests/test_tensor_array_to_tensor.py +++ b/python/paddle/fluid/tests/unittests/test_tensor_array_to_tensor.py @@ -23,6 +23,8 @@ class TestLoDTensorArrayConcat(unittest.TestCase): + """Test case for concat mode of tensor_array_to_tensor.""" + def setUp(self): self.op_type = "tensor_array_to_tensor" self.attrs = {"axis": 0} @@ -98,7 +100,7 @@ def test_get_set(self): exe = fluid.Executor(fluid.CPUPlace()) out = exe.run(program, fetch_list=fetch_list, scope=scope) - #print ("index: ", numpy.array(out[1])) + #print ("index: ", numpy.array(out[1])) # test forward tensor_res = numpy.array(out[0]) @@ -138,5 +140,82 @@ def test_get_set(self): numpy.array(random_grad[i + 1])) +class TestLoDTensorArrayStack(unittest.TestCase): + """Test case for stack mode of tensor_array_to_tensor.""" + + def setUp(self): + self.op_type = "tensor_array_to_tensor" + self.attrs = {"axis": 1, "use_stack": True} + self.inputs = [ + numpy.random.rand(2, 3, 4).astype("float32"), + numpy.random.rand(2, 3, 4).astype("float32"), + numpy.random.rand(2, 3, 4).astype("float32") + ] + self.outputs = [ + numpy.stack( + self.inputs, axis=self.attrs["axis"]), numpy.array( + [x.shape[self.attrs["axis"]] for x in self.inputs], + dtype="int32") + ] + self.input_grads = [numpy.ones_like(x) for x in self.inputs] + self.set_program() + for var in self.program.list_vars(): + # to avoid scope clearing after execution + var.persistable = True + + def set_program(self): + self.program = fluid.Program() + with fluid.program_guard(self.program): + self.array = array = fluid.layers.create_array(dtype='float32') + idx = fluid.layers.fill_constant(shape=[1], dtype="int64", value=0) + for i, x in enumerate(self.inputs): + x = fluid.layers.assign(x) + fluid.layers.array_write(x, idx + i, array) + output, output_index = fluid.layers.tensor_array_to_tensor( + input=array, **self.attrs) + loss = fluid.layers.reduce_sum(output) + fluid.backward.append_backward(loss) + self.output_vars = [output, output_index] + + def run_check(self, executor, scope): + executor.run(self.program, scope=scope) + for i, output in enumerate(self.outputs): + numpy.allclose( + numpy.array(scope.var(self.output_vars[i].name).get_tensor()), + output, + atol=0) + tensor_array_grad = scope.var(self.array.name).get_lod_tensor_array() + for i, input_grad in enumerate(self.input_grads): + numpy.allclose( + numpy.array(tensor_array_grad[i]), input_grad, atol=0) + + def test_cpu(self): + scope = core.Scope() + place = core.CPUPlace() + executor = fluid.Executor(place) + self.run_check(executor, scope) + + def test_gpu(self): + if core.is_compiled_with_cuda(): + place = core.CUDAPlace(0) + scope = core.Scope() + executor = fluid.Executor(place) + self.run_check(executor, scope) + + +class TestTensorArrayToTensorAPI(unittest.TestCase): + def test_case(self): + x0 = fluid.layers.assign(numpy.random.rand(2, 3, 4).astype("float32")) + x1 = fluid.layers.assign(numpy.random.rand(2, 3, 4).astype("float32")) + i = fluid.layers.fill_constant(shape=[1], dtype="int64", value=0) + array = fluid.layers.create_array(dtype='float32') + fluid.layers.array_write(x0, i, array) + fluid.layers.array_write(x1, i + 1, array) + output, output_index = fluid.layers.tensor_array_to_tensor( + input=array, axis=1, use_stack=True) + output, output_index = fluid.layers.tensor_array_to_tensor( + input=array, axis=1, use_stack=False) + + if __name__ == '__main__': unittest.main() From b5572cc6ab5dfabd13cc4765b7a95ef34ca51cfb Mon Sep 17 00:00:00 2001 From: wangchaochaohu Date: Thu, 10 Oct 2019 21:25:09 +0800 Subject: [PATCH 2/4] fix expand bug (#20340) * fix expand bug test=develop * fix style test=develop * fix style test=develop * fix style test=develop * fix style test=develop --- paddle/fluid/operators/expand_op.cc | 5 ++- paddle/fluid/operators/expand_op.cu | 5 ++- python/paddle/fluid/layers/nn.py | 13 +++++++- .../fluid/tests/unittests/test_expand_op.py | 31 +++++++++++++++++++ 4 files changed, 51 insertions(+), 3 deletions(-) diff --git a/paddle/fluid/operators/expand_op.cc b/paddle/fluid/operators/expand_op.cc index b95373178d458..09c730db3951d 100644 --- a/paddle/fluid/operators/expand_op.cc +++ b/paddle/fluid/operators/expand_op.cc @@ -226,8 +226,11 @@ REGISTER_OP_CPU_KERNEL( expand, ops::ExpandKernel, ops::ExpandKernel, ops::ExpandKernel, + ops::ExpandKernel, ops::ExpandKernel); REGISTER_OP_CPU_KERNEL( expand_grad, ops::ExpandGradKernel, - ops::ExpandGradKernel); + ops::ExpandGradKernel, + ops::ExpandGradKernel, + ops::ExpandGradKernel); diff --git a/paddle/fluid/operators/expand_op.cu b/paddle/fluid/operators/expand_op.cu index 50a506b294db1..cf913f56dde80 100644 --- a/paddle/fluid/operators/expand_op.cu +++ b/paddle/fluid/operators/expand_op.cu @@ -18,8 +18,11 @@ REGISTER_OP_CUDA_KERNEL( expand, ops::ExpandKernel, ops::ExpandKernel, ops::ExpandKernel, + ops::ExpandKernel, ops::ExpandKernel); REGISTER_OP_CUDA_KERNEL( expand_grad, ops::ExpandGradKernel, - ops::ExpandGradKernel); + ops::ExpandGradKernel, + ops::ExpandGradKernel, + ops::ExpandGradKernel); diff --git a/python/paddle/fluid/layers/nn.py b/python/paddle/fluid/layers/nn.py index 7c879a19e9c3f..6bef72c0c8aea 100755 --- a/python/paddle/fluid/layers/nn.py +++ b/python/paddle/fluid/layers/nn.py @@ -12470,10 +12470,21 @@ def expand(x, expand_times, name=None): expanded_2 = fluid.layers.expand(data_2, expand_times=expand_times) # the shape of expanded_2 is [48, 56]. """ - + if not isinstance(x, Variable): + raise TypeError( + "The type of 'input' in reduce_sum must be Variable, but received %s" + % (type(x))) if not isinstance(expand_times, (list, tuple, Variable)): raise ValueError( "Input expand_times must be an Variable, python list or tuple.") + if convert_dtype( + x.dtype) not in ['bool', 'float32', 'float64', 'int32', 'int64']: + raise TypeError( + "The data type of input in expand must be one of bool float32, float64, int32 or int64, but received %s." + % (convert_dtype(x.dtype))) + if convert_dtype(x.dtype) == 'bool' and x.stop_gradient == True: + raise ValueError( + "expand op bool date type must set the stop_gradient to be False") helper = LayerHelper('expand', input=x, **locals()) inputs = {"X": x} diff --git a/python/paddle/fluid/tests/unittests/test_expand_op.py b/python/paddle/fluid/tests/unittests/test_expand_op.py index 449cda29b45ba..b4efda63e10ca 100644 --- a/python/paddle/fluid/tests/unittests/test_expand_op.py +++ b/python/paddle/fluid/tests/unittests/test_expand_op.py @@ -18,6 +18,7 @@ import numpy as np from op_test import OpTest import paddle.fluid as fluid +from paddle.fluid import compiler, Program, program_guard # Situation 1: expand_times is a list(without tensor) @@ -176,6 +177,36 @@ def test_check_output(self): self.check_output() +# Situation 56: input x is Integer +class TestExpandOpInt64_t(OpTest): + def setUp(self): + self.op_type = "expand" + self.inputs = { + 'X': np.random.randint( + 10, size=(2, 4, 5)).astype("int64") + } + self.attrs = {'expand_times': [2, 1, 4]} + output = np.tile(self.inputs['X'], (2, 1, 4)) + self.outputs = {'Out': output} + + def test_check_output(self): + self.check_output() + + +class TestExpandError(OpTest): + def test_errors(self): + with program_guard(Program(), Program()): + x1 = fluid.create_lod_tensor( + np.array([[-1]]), [[1]], fluid.CPUPlace()) + expand_times = [2, 2] + self.assertRaises(TypeError, fluid.layers.expand, x1, expand_times) + x2 = fluid.layers.data(name='x2', shape=[4], dtype="uint8") + self.assertRaises(TypeError, fluid.layers.expand, x2, expand_times) + x3 = fluid.layers.data(name='x3', shape=[4], dtype="bool") + x3.stop_gradient = True + self.assertRaises(ValueError, fluid.layers.expand, x3, expand_times) + + # Test python API class TestExpandAPI(OpTest): def test_api(self): From 50ff1b11648d91edb7ef1dcc33506ed9454c2ded Mon Sep 17 00:00:00 2001 From: Guo Sheng Date: Sat, 12 Oct 2019 22:00:55 +0800 Subject: [PATCH 3/4] Fix the assign data check (#20564) * Fix the assign data check. test=develop * Fix test_assign_op.py. test=develop --- python/paddle/fluid/layers/tensor.py | 6 +++--- python/paddle/fluid/tests/unittests/test_assign_op.py | 4 +--- 2 files changed, 4 insertions(+), 6 deletions(-) diff --git a/python/paddle/fluid/layers/tensor.py b/python/paddle/fluid/layers/tensor.py index 1d0c266ccdcc0..cbc2e191829bc 100644 --- a/python/paddle/fluid/layers/tensor.py +++ b/python/paddle/fluid/layers/tensor.py @@ -438,12 +438,12 @@ def assign(input, output=None): helper = LayerHelper('assign', **locals()) if isinstance(input, Variable): if convert_dtype(input.dtype) not in [ - 'float32', 'float64', 'int32', 'int64' + 'float32', 'float64', 'int32', 'int64', 'bool' ]: raise TypeError( "When the type of 'input' in assign is Variable, the data " - "type of 'input' must be float32, float64, int32 or int64, " - "but received %s." % convert_dtype(input.dtype)) + "type of 'input' must be float32, float64, int32, int64 or " + "bool, but received %s." % convert_dtype(input.dtype)) if output is None: output = helper.create_variable_for_type_inference( dtype=input.dtype) diff --git a/python/paddle/fluid/tests/unittests/test_assign_op.py b/python/paddle/fluid/tests/unittests/test_assign_op.py index fce7331f509c5..4d43747676dfe 100644 --- a/python/paddle/fluid/tests/unittests/test_assign_op.py +++ b/python/paddle/fluid/tests/unittests/test_assign_op.py @@ -44,9 +44,7 @@ def test_errors(self): x1 = fluid.create_lod_tensor( np.array([[-1]]), [[1]], fluid.CPUPlace()) self.assertRaises(TypeError, fluid.layers.assign, x1) - # When the type of input is Variable, the dtype of input must be float32, float64, int32, int64. - x2 = fluid.layers.data(name='x2', shape=[4], dtype="bool") - self.assertRaises(TypeError, fluid.layers.assign, x2) + # When the type of input is Variable, the dtype of input must be float32, float64, int32, int64, bool. x3 = fluid.layers.data(name='x3', shape=[4], dtype="float16") self.assertRaises(TypeError, fluid.layers.assign, x3) x4 = fluid.layers.data(name='x4', shape=[4], dtype="uint8") From 22cc43b844af52a984c743eed996a8787a13c757 Mon Sep 17 00:00:00 2001 From: guosheng Date: Sun, 13 Oct 2019 02:01:20 +0800 Subject: [PATCH 4/4] Update API.spec --- paddle/fluid/API.spec | 22 +++++++++++----------- 1 file changed, 11 insertions(+), 11 deletions(-) diff --git a/paddle/fluid/API.spec b/paddle/fluid/API.spec index bf4d4f8833b73..118aea0b8799e 100644 --- a/paddle/fluid/API.spec +++ b/paddle/fluid/API.spec @@ -9,7 +9,7 @@ paddle.fluid.Program.parse_from_string (ArgSpec(args=['binary_str'], varargs=Non paddle.fluid.Program.to_string (ArgSpec(args=['self', 'throw_on_error', 'with_details'], varargs=None, keywords=None, defaults=(False,)), ('document', '7dde33f16b63aa50d474870a9cebb539')) paddle.fluid.default_startup_program (ArgSpec(args=[], varargs=None, keywords=None, defaults=None), ('document', 'f53890b2fb8c0642b6047e4fee2d6d58')) paddle.fluid.default_main_program (ArgSpec(args=[], varargs=None, keywords=None, defaults=None), ('document', '082aa471d247bd8d7c87814105439e1a')) -paddle.fluid.program_guard (ArgSpec(args=['main_program', 'startup_program'], varargs=None, keywords=None, defaults=(None,)), ('document', '78fb5c7f70ef76bcf4a1862c3f6b8191')) +paddle.fluid.program_guard (ArgSpec(args=['main_program', 'startup_program'], varargs=None, keywords=None, defaults=(None,)), ('document', 'eb4eabc13405a8c6dc2f14308ddf5ed8')) paddle.fluid.name_scope (ArgSpec(args=['prefix'], varargs=None, keywords=None, defaults=(None,)), ('document', '907a5f877206079d8e67ae69b06bb3ba')) paddle.fluid.cuda_places (ArgSpec(args=['device_ids'], varargs=None, keywords=None, defaults=(None,)), ('document', 'ab9bd2079536114aa7c1488a489ee87f')) paddle.fluid.cpu_places (ArgSpec(args=['device_count'], varargs=None, keywords=None, defaults=(None,)), ('document', 'a7352a3dd39308fde4fbbf6421a4193d')) @@ -189,7 +189,7 @@ paddle.fluid.layers.layer_norm (ArgSpec(args=['input', 'scale', 'shift', 'begin_ paddle.fluid.layers.group_norm (ArgSpec(args=['input', 'groups', 'epsilon', 'param_attr', 'bias_attr', 'act', 'data_layout', 'name'], varargs=None, keywords=None, defaults=(1e-05, None, None, None, 'NCHW', None)), ('document', '87dd4b818f102bc1a780e1804c28bd38')) paddle.fluid.layers.spectral_norm (ArgSpec(args=['weight', 'dim', 'power_iters', 'eps', 'name'], varargs=None, keywords=None, defaults=(0, 1, 1e-12, None)), ('document', '7b3d14d6707d878923847ec617d7d521')) paddle.fluid.layers.softmax_with_cross_entropy (ArgSpec(args=['logits', 'label', 'soft_label', 'ignore_index', 'numeric_stable_mode', 'return_softmax', 'axis'], varargs=None, keywords=None, defaults=(False, -100, True, False, -1)), ('document', '6992e4140d667fdf816d0617648b5c00')) -paddle.fluid.layers.smooth_l1 (ArgSpec(args=['x', 'y', 'inside_weight', 'outside_weight', 'sigma'], varargs=None, keywords=None, defaults=(None, None, None)), ('document', 'ecb75c1b00c4c76c98b482f633b7a10c')) +paddle.fluid.layers.smooth_l1 (ArgSpec(args=['x', 'y', 'inside_weight', 'outside_weight', 'sigma'], varargs=None, keywords=None, defaults=(None, None, None)), ('document', 'cbe8940643ac80ef75e1abdfbdb09e88')) paddle.fluid.layers.one_hot (ArgSpec(args=['input', 'depth', 'allow_out_of_range'], varargs=None, keywords=None, defaults=(False,)), ('document', 'cdf5dc2078f1e20dc61dd0bec7e28a29')) paddle.fluid.layers.autoincreased_step_counter (ArgSpec(args=['counter_name', 'begin', 'step'], varargs=None, keywords=None, defaults=(None, 1, 1)), ('document', 'd016c137beb9a4528b7378b437d00151')) paddle.fluid.layers.reshape (ArgSpec(args=['x', 'shape', 'actual_shape', 'act', 'inplace', 'name'], varargs=None, keywords=None, defaults=(None, None, False, None)), ('document', 'd7a6d59e464a7ef1184eb6caefeb49f1')) @@ -291,7 +291,7 @@ paddle.fluid.layers.shuffle_channel (ArgSpec(args=['x', 'group', 'name'], vararg paddle.fluid.layers.temporal_shift (ArgSpec(args=['x', 'seg_num', 'shift_ratio', 'name'], varargs=None, keywords=None, defaults=(0.25, None)), ('document', 'd5945431cdcae3cda21914db5bbf383e')) paddle.fluid.layers.py_func (ArgSpec(args=['func', 'x', 'out', 'backward_func', 'skip_vars_in_backward_input'], varargs=None, keywords=None, defaults=(None, None)), ('document', '231f91231430f5dae2b757df22317c67')) paddle.fluid.layers.psroi_pool (ArgSpec(args=['input', 'rois', 'output_channels', 'spatial_scale', 'pooled_height', 'pooled_width', 'name'], varargs=None, keywords=None, defaults=(None,)), ('document', '9bf0cc6b0717010b8ceec5dc2541d566')) -paddle.fluid.layers.prroi_pool (ArgSpec(args=['input', 'rois', 'output_channels', 'spatial_scale', 'pooled_height', 'pooled_width', 'name'], varargs=None, keywords=None, defaults=(1.0, 1, 1, None)), ('document', '454c7ea8c73313dd41513929d7526303')) +paddle.fluid.layers.prroi_pool (ArgSpec(args=['input', 'rois', 'spatial_scale', 'pooled_height', 'pooled_width', 'name'], varargs=None, keywords=None, defaults=(1.0, 1, 1, None)), ('document', '466be691ac4c1cd7b88fccb40846afce')) paddle.fluid.layers.teacher_student_sigmoid_loss (ArgSpec(args=['input', 'label', 'soft_max_up_bound', 'soft_max_lower_bound'], varargs=None, keywords=None, defaults=(15.0, -15.0)), ('document', 'b0e07aa41caae04b07a8e8217cc96020')) paddle.fluid.layers.huber_loss (ArgSpec(args=['input', 'label', 'delta'], varargs=None, keywords=None, defaults=None), ('document', '9d93ee81f7a3e526d68bb280bc695d6c')) paddle.fluid.layers.kldiv_loss (ArgSpec(args=['x', 'target', 'reduction', 'name'], varargs=None, keywords=None, defaults=('mean', None)), ('document', '45f3ebbcb766fca84cb2fe6307086573')) @@ -317,14 +317,14 @@ paddle.fluid.layers.py_reader (ArgSpec(args=['capacity', 'shapes', 'dtypes', 'lo paddle.fluid.layers.create_py_reader_by_data (ArgSpec(args=['capacity', 'feed_list', 'name', 'use_double_buffer'], varargs=None, keywords=None, defaults=(None, True)), ('document', '1321d4ce89d82f96fcfd5601f816b0f3')) paddle.fluid.layers.load (ArgSpec(args=['out', 'file_path', 'load_as_fp16'], varargs=None, keywords=None, defaults=(None,)), ('document', '309f9e5249463e1b207a7347b2a91134')) paddle.fluid.layers.create_tensor (ArgSpec(args=['dtype', 'name', 'persistable'], varargs=None, keywords=None, defaults=(None, False)), ('document', 'fdc2d964488e99fb0743887454c34e36')) -paddle.fluid.layers.create_parameter (ArgSpec(args=['shape', 'dtype', 'name', 'attr', 'is_bias', 'default_initializer'], varargs=None, keywords=None, defaults=(None, None, False, None)), ('document', '021272f30e0cdf7503586815378abfb8')) -paddle.fluid.layers.create_global_var (ArgSpec(args=['shape', 'value', 'dtype', 'persistable', 'force_cpu', 'name'], varargs=None, keywords=None, defaults=(False, False, None)), ('document', '47ea8b8c91879e50c9036e418b00ef4a')) +paddle.fluid.layers.create_parameter (ArgSpec(args=['shape', 'dtype', 'name', 'attr', 'is_bias', 'default_initializer'], varargs=None, keywords=None, defaults=(None, None, False, None)), ('document', '727aa63c061919bee38547fb126d9428')) +paddle.fluid.layers.create_global_var (ArgSpec(args=['shape', 'value', 'dtype', 'persistable', 'force_cpu', 'name'], varargs=None, keywords=None, defaults=(False, False, None)), ('document', 'fa7f74cfb940521cc9fdffabc83debbf')) paddle.fluid.layers.cast (ArgSpec(args=['x', 'dtype'], varargs=None, keywords=None, defaults=None), ('document', '45df178cbd8c302f92c30ebdaaa6fa8a')) paddle.fluid.layers.tensor_array_to_tensor (ArgSpec(args=['input', 'axis', 'name', 'use_stack'], varargs=None, keywords=None, defaults=(1, None, False)), ('document', '4aa82374218ccf593bb8011df79c71e3')) paddle.fluid.layers.concat (ArgSpec(args=['input', 'axis', 'name'], varargs=None, keywords=None, defaults=(0, None)), ('document', 'ec7d6e716fb29ef1e73e1e3efa5ca46b')) paddle.fluid.layers.sums (ArgSpec(args=['input', 'out'], varargs=None, keywords=None, defaults=(None,)), ('document', '191164436efbc1b7bccc4190a88e7de2')) paddle.fluid.layers.assign (ArgSpec(args=['input', 'output'], varargs=None, keywords=None, defaults=(None,)), ('document', '98ce6e7c3659b8377c04cecfc72c2000')) -paddle.fluid.layers.fill_constant_batch_size_like (ArgSpec(args=['input', 'shape', 'dtype', 'value', 'input_dim_idx', 'output_dim_idx'], varargs=None, keywords=None, defaults=(0, 0)), ('document', '37a288e4400f6d5510e982827461c11b')) +paddle.fluid.layers.fill_constant_batch_size_like (ArgSpec(args=['input', 'shape', 'dtype', 'value', 'input_dim_idx', 'output_dim_idx', 'force_cpu'], varargs=None, keywords=None, defaults=(0, 0, False)), ('document', '2bb57637664173fee5f654e55896aec6')) paddle.fluid.layers.fill_constant (ArgSpec(args=['shape', 'dtype', 'value', 'force_cpu', 'out'], varargs=None, keywords=None, defaults=(False, None)), ('document', '66e1e468666dd47e5b2715226cebeac0')) paddle.fluid.layers.argmin (ArgSpec(args=['x', 'axis'], varargs=None, keywords=None, defaults=(0,)), ('document', '53629e27597e5dfb7020aac5bc639ebb')) paddle.fluid.layers.argmax (ArgSpec(args=['x', 'axis'], varargs=None, keywords=None, defaults=(0,)), ('document', 'd9a89fbedbaebd5f65897ac75ee636f3')) @@ -959,7 +959,7 @@ paddle.fluid.nets.sequence_conv_pool (ArgSpec(args=['input', 'num_filters', 'fil paddle.fluid.nets.glu (ArgSpec(args=['input', 'dim'], varargs=None, keywords=None, defaults=(-1,)), ('document', '3efe197c8e3e75f84a4c464d8b74e943')) paddle.fluid.nets.scaled_dot_product_attention (ArgSpec(args=['queries', 'keys', 'values', 'num_heads', 'dropout_rate'], varargs=None, keywords=None, defaults=(1, 0.0)), ('document', '375898e47266633635f4c2096e1ac296')) paddle.fluid.nets.img_conv_group (ArgSpec(args=['input', 'conv_num_filter', 'pool_size', 'conv_padding', 'conv_filter_size', 'conv_act', 'param_attr', 'conv_with_batchnorm', 'conv_batchnorm_drop_rate', 'pool_stride', 'pool_type', 'use_cudnn'], varargs=None, keywords=None, defaults=(1, 3, None, None, False, 0.0, 1, 'max', True)), ('document', 'a59c581d5969266427e841abe69f694a')) -paddle.fluid.optimizer.SGDOptimizer ('paddle.fluid.optimizer.SGDOptimizer', ('document', 'c3c8dd3193d991adf8bda505560371d6')) +paddle.fluid.optimizer.SGDOptimizer ('paddle.fluid.optimizer.SGDOptimizer', ('document', 'fc09d6e6c1083cec2dce51f6f9f4ecaf')) paddle.fluid.optimizer.SGDOptimizer.__init__ (ArgSpec(args=['self', 'learning_rate', 'regularization', 'name'], varargs=None, keywords=None, defaults=(None, None)), ('document', '6adf97f83acf6453d4a6a4b1070f3754')) paddle.fluid.optimizer.SGDOptimizer.apply_gradients (ArgSpec(args=['self', 'params_grads'], varargs=None, keywords=None, defaults=None), ('document', '80ea99c9af7ef5fac7e57fb302103610')) paddle.fluid.optimizer.SGDOptimizer.apply_optimize (ArgSpec(args=['self', 'loss', 'startup_program', 'params_grads'], varargs=None, keywords=None, defaults=None), ('document', '5c46d1926a40f1f873ffe9f37ac89dae')) @@ -968,7 +968,7 @@ paddle.fluid.optimizer.SGDOptimizer.get_opti_var_name_list (ArgSpec(args=['self' paddle.fluid.optimizer.SGDOptimizer.minimize (ArgSpec(args=['self', 'loss', 'startup_program', 'parameter_list', 'no_grad_set', 'grad_clip'], varargs=None, keywords=None, defaults=(None, None, None, None)), ('document', '8387af01322a6defc92c1832faccd304')) paddle.fluid.optimizer.SGDOptimizer.set_dict (ArgSpec(args=['self', 'state_dict'], varargs=None, keywords=None, defaults=None), ('document', '36aa497a2d29abaa4147987d71721d17')) paddle.fluid.optimizer.SGDOptimizer.state_dict (ArgSpec(args=['self'], varargs=None, keywords=None, defaults=None), ('document', 'deca1537945d33940b350923fb16ddf8')) -paddle.fluid.optimizer.MomentumOptimizer ('paddle.fluid.optimizer.MomentumOptimizer', ('document', 'a72bd02e5459e64596897d190413d449')) +paddle.fluid.optimizer.MomentumOptimizer ('paddle.fluid.optimizer.MomentumOptimizer', ('document', '2bda0a60340fce6c8e594bb35b4e0fcd')) paddle.fluid.optimizer.MomentumOptimizer.__init__ (ArgSpec(args=['self', 'learning_rate', 'momentum', 'use_nesterov', 'regularization', 'name'], varargs=None, keywords=None, defaults=(False, None, None)), ('document', '6adf97f83acf6453d4a6a4b1070f3754')) paddle.fluid.optimizer.MomentumOptimizer.apply_gradients (ArgSpec(args=['self', 'params_grads'], varargs=None, keywords=None, defaults=None), ('document', '80ea99c9af7ef5fac7e57fb302103610')) paddle.fluid.optimizer.MomentumOptimizer.apply_optimize (ArgSpec(args=['self', 'loss', 'startup_program', 'params_grads'], varargs=None, keywords=None, defaults=None), ('document', '5c46d1926a40f1f873ffe9f37ac89dae')) @@ -1022,7 +1022,7 @@ paddle.fluid.optimizer.DecayedAdagradOptimizer.get_opti_var_name_list (ArgSpec(a paddle.fluid.optimizer.DecayedAdagradOptimizer.minimize (ArgSpec(args=['self', 'loss', 'startup_program', 'parameter_list', 'no_grad_set', 'grad_clip'], varargs=None, keywords=None, defaults=(None, None, None, None)), ('document', '8387af01322a6defc92c1832faccd304')) paddle.fluid.optimizer.DecayedAdagradOptimizer.set_dict (ArgSpec(args=['self', 'state_dict'], varargs=None, keywords=None, defaults=None), ('document', '36aa497a2d29abaa4147987d71721d17')) paddle.fluid.optimizer.DecayedAdagradOptimizer.state_dict (ArgSpec(args=['self'], varargs=None, keywords=None, defaults=None), ('document', 'deca1537945d33940b350923fb16ddf8')) -paddle.fluid.optimizer.FtrlOptimizer ('paddle.fluid.optimizer.FtrlOptimizer', ('document', 'cba8aae0a267b9a4d8833ae79a00fc55')) +paddle.fluid.optimizer.FtrlOptimizer ('paddle.fluid.optimizer.FtrlOptimizer', ('document', 'a2573c97cd45c2be0d33243cd1aa4a9b')) paddle.fluid.optimizer.FtrlOptimizer.__init__ (ArgSpec(args=['self', 'learning_rate', 'l1', 'l2', 'lr_power', 'regularization', 'name'], varargs=None, keywords=None, defaults=(0.0, 0.0, -0.5, None, None)), ('document', '6adf97f83acf6453d4a6a4b1070f3754')) paddle.fluid.optimizer.FtrlOptimizer.apply_gradients (ArgSpec(args=['self', 'params_grads'], varargs=None, keywords=None, defaults=None), ('document', '80ea99c9af7ef5fac7e57fb302103610')) paddle.fluid.optimizer.FtrlOptimizer.apply_optimize (ArgSpec(args=['self', 'loss', 'startup_program', 'params_grads'], varargs=None, keywords=None, defaults=None), ('document', '5c46d1926a40f1f873ffe9f37ac89dae')) @@ -1031,7 +1031,7 @@ paddle.fluid.optimizer.FtrlOptimizer.get_opti_var_name_list (ArgSpec(args=['self paddle.fluid.optimizer.FtrlOptimizer.minimize (ArgSpec(args=['self', 'loss', 'startup_program', 'parameter_list', 'no_grad_set', 'grad_clip'], varargs=None, keywords=None, defaults=(None, None, None, None)), ('document', '8387af01322a6defc92c1832faccd304')) paddle.fluid.optimizer.FtrlOptimizer.set_dict (ArgSpec(args=['self', 'state_dict'], varargs=None, keywords=None, defaults=None), ('document', '36aa497a2d29abaa4147987d71721d17')) paddle.fluid.optimizer.FtrlOptimizer.state_dict (ArgSpec(args=['self'], varargs=None, keywords=None, defaults=None), ('document', 'deca1537945d33940b350923fb16ddf8')) -paddle.fluid.optimizer.RMSPropOptimizer ('paddle.fluid.optimizer.RMSPropOptimizer', ('document', '5217bc4fc399010021d6b70541005780')) +paddle.fluid.optimizer.RMSPropOptimizer ('paddle.fluid.optimizer.RMSPropOptimizer', ('document', '6aeb527f958d1d6962d4e56751f44dbd')) paddle.fluid.optimizer.RMSPropOptimizer.__init__ (ArgSpec(args=['self', 'learning_rate', 'rho', 'epsilon', 'momentum', 'centered', 'regularization', 'name'], varargs=None, keywords=None, defaults=(0.95, 1e-06, 0.0, False, None, None)), ('document', '6adf97f83acf6453d4a6a4b1070f3754')) paddle.fluid.optimizer.RMSPropOptimizer.apply_gradients (ArgSpec(args=['self', 'params_grads'], varargs=None, keywords=None, defaults=None), ('document', '80ea99c9af7ef5fac7e57fb302103610')) paddle.fluid.optimizer.RMSPropOptimizer.apply_optimize (ArgSpec(args=['self', 'loss', 'startup_program', 'params_grads'], varargs=None, keywords=None, defaults=None), ('document', '5c46d1926a40f1f873ffe9f37ac89dae')) @@ -1060,7 +1060,7 @@ paddle.fluid.optimizer.ModelAverage.minimize (ArgSpec(args=['self', 'loss', 'sta paddle.fluid.optimizer.ModelAverage.restore (ArgSpec(args=['self', 'executor'], varargs=None, keywords=None, defaults=None), ('document', '7917cbe4d3ed7954ae73360fbccc39f6')) paddle.fluid.optimizer.ModelAverage.set_dict (ArgSpec(args=['self', 'state_dict'], varargs=None, keywords=None, defaults=None), ('document', '36aa497a2d29abaa4147987d71721d17')) paddle.fluid.optimizer.ModelAverage.state_dict (ArgSpec(args=['self'], varargs=None, keywords=None, defaults=None), ('document', 'deca1537945d33940b350923fb16ddf8')) -paddle.fluid.optimizer.LarsMomentumOptimizer ('paddle.fluid.optimizer.LarsMomentumOptimizer', ('document', '030b9092a96a409b1bf5446bf45d0659')) +paddle.fluid.optimizer.LarsMomentumOptimizer ('paddle.fluid.optimizer.LarsMomentumOptimizer', ('document', '107d591a9b03264bfc0c55f424f90574')) paddle.fluid.optimizer.LarsMomentumOptimizer.__init__ (ArgSpec(args=['self', 'learning_rate', 'momentum', 'lars_coeff', 'lars_weight_decay', 'regularization', 'name'], varargs=None, keywords=None, defaults=(0.001, 0.0005, None, None)), ('document', '6adf97f83acf6453d4a6a4b1070f3754')) paddle.fluid.optimizer.LarsMomentumOptimizer.apply_gradients (ArgSpec(args=['self', 'params_grads'], varargs=None, keywords=None, defaults=None), ('document', '80ea99c9af7ef5fac7e57fb302103610')) paddle.fluid.optimizer.LarsMomentumOptimizer.apply_optimize (ArgSpec(args=['self', 'loss', 'startup_program', 'params_grads'], varargs=None, keywords=None, defaults=None), ('document', '5c46d1926a40f1f873ffe9f37ac89dae'))