diff --git a/paddle/phi/kernels/cpu/gather_tree_kernel.cc b/paddle/phi/kernels/cpu/gather_tree_kernel.cc index 6f3cac6c4aa10..250ee1b1e8a2e 100644 --- a/paddle/phi/kernels/cpu/gather_tree_kernel.cc +++ b/paddle/phi/kernels/cpu/gather_tree_kernel.cc @@ -14,6 +14,7 @@ #include "paddle/phi/kernels/gather_tree_kernel.h" +#include "paddle/phi/core/enforce.h" #include "paddle/phi/core/kernel_registry.h" namespace phi { @@ -49,6 +50,15 @@ void GatherTreeKernel(const Context &dev_ctx, out_data[idx] = ids_data[idx]; auto parent = parents_data[idx]; for (int step = max_length - 2; step >= 0; step--) { + PADDLE_ENFORCE_LT( + parent, + beam_size, + phi::errors::InvalidArgument( + "The parents must be less than beam size, but recieved" + "parents %d is greater than or equal to beam size %d. ", + parent, + beam_size)); + idx = step * batch_size * beam_size + batch * beam_size; out_data[idx + beam] = ids_data[idx + parent]; parent = parents_data[idx + parent]; diff --git a/paddle/phi/kernels/gpu/gather_tree_kernel.cu b/paddle/phi/kernels/gpu/gather_tree_kernel.cu index 22b174b5f0bc2..585376ee6688d 100644 --- a/paddle/phi/kernels/gpu/gather_tree_kernel.cu +++ b/paddle/phi/kernels/gpu/gather_tree_kernel.cu @@ -16,6 +16,7 @@ #include +#include "paddle/phi/core/enforce.h" #include "paddle/phi/core/kernel_registry.h" namespace phi { @@ -35,6 +36,12 @@ __global__ void GatherTree(const T *ids_data, out_data[idx] = ids_data[idx]; auto parent = parents_data[idx]; for (int step = max_length - 2; step >= 0; step--) { + PADDLE_ENFORCE((parent < beam_size), + "The parents must be less than beam size, but recieved" + "parents %ld is greater than or equal to beam size %ld. ", + parent, + beam_size); + idx = step * batch_size * beam_size + batch * beam_size; out_data[idx + beam] = ids_data[idx + parent]; parent = parents_data[idx + parent]; diff --git a/python/paddle/fluid/tests/unittests/test_gather_tree_op.py b/python/paddle/fluid/tests/unittests/test_gather_tree_op.py index f3a5acc048404..242ed7e4a745d 100644 --- a/python/paddle/fluid/tests/unittests/test_gather_tree_op.py +++ b/python/paddle/fluid/tests/unittests/test_gather_tree_op.py @@ -125,6 +125,25 @@ def test_type_parents(): fluid.layers.gather_tree(ids, bad_parents) self.assertRaises(TypeError, test_type_parents) + + def test_ids_ndim(): + bad_ids = fluid.layers.data(name='bad_test_ids', + shape=[5, 2], + dtype='int64', + append_batch_size=False) + paddle.nn.functional.gather_tree(bad_ids, parents) + + self.assertRaises(ValueError, test_ids_ndim) + + def test_parents_ndim(): + bad_parents = fluid.layers.data(name='bad_test_parents', + shape=[5, 2], + dtype='int64', + append_batch_size=False) + paddle.nn.functional.gather_tree(ids, bad_parents) + + self.assertRaises(ValueError, test_parents_ndim) + paddle.disable_static() diff --git a/python/paddle/nn/functional/extension.py b/python/paddle/nn/functional/extension.py index 7ae35666c8612..602175c98f515 100644 --- a/python/paddle/nn/functional/extension.py +++ b/python/paddle/nn/functional/extension.py @@ -305,6 +305,13 @@ def gather_tree(ids, parents): # [[[2, 2], [1, 6]], [[3, 3], [6, 1]], [[0, 1], [9, 0]]] """ + if ids.ndim != 3: + raise ValueError( + "The input ids must be a 3D tensor with shape [length, batch_size, beam_size]" + ) + if ids.ndim != parents.ndim: + raise ValueError("The ids's shape must be the same as parents' shape. ") + if in_dygraph_mode(): return _C_ops.gather_tree(ids, parents) else: