Skip to content

Commit

Permalink
Sync the pull request PaddlePaddle#51903.
Browse files Browse the repository at this point in the history
  • Loading branch information
Xreki committed Apr 6, 2023
1 parent 0373a2c commit 0bd1aa1
Show file tree
Hide file tree
Showing 3 changed files with 16 additions and 3 deletions.
4 changes: 3 additions & 1 deletion paddle/phi/kernels/gpu/gather_nd_grad_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
#include "paddle/phi/kernels/gather_nd_grad_kernel.h"

#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/common/bfloat16.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/funcs/eigen/common.h"
#include "paddle/phi/kernels/funcs/scatter.cu.h"
Expand Down Expand Up @@ -63,4 +64,5 @@ PD_REGISTER_KERNEL(gather_nd_grad,
double,
int64_t,
int,
phi::dtype::float16) {}
phi::dtype::float16,
phi::dtype::bfloat16) {}
2 changes: 2 additions & 0 deletions paddle/phi/kernels/gpu/gather_nd_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
#include "paddle/phi/kernels/gather_nd_kernel.h"

#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/common/bfloat16.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/funcs/gather.cu.h"
#include "paddle/phi/kernels/funcs/scatter.cu.h"
Expand Down Expand Up @@ -58,4 +59,5 @@ PD_REGISTER_KERNEL(gather_nd,
int,
int16_t,
bool,
phi::dtype::float16,
phi::dtype::float16) {}
13 changes: 11 additions & 2 deletions python/paddle/tensor/manipulation.py
Original file line number Diff line number Diff line change
Expand Up @@ -3791,8 +3791,17 @@ def gather_nd(x, index, name=None):
check_variable_and_dtype(
x,
'x',
['bool', 'float32', 'float64', 'int16', 'int32', 'int64'],
'gather_np',
[
'bool',
'float16',
'uint16',
'float32',
'float64',
'int16',
'int32',
'int64',
],
'gather_nd',
)
check_variable_and_dtype(index, 'index', ['int32', 'int64'], 'gather_np')
helper = LayerHelper('gather_nd', **locals())
Expand Down

0 comments on commit 0bd1aa1

Please sign in to comment.