Skip to content

Commit

Permalink
mean: not support int32, int64; add check for axis (#26401)
Browse files Browse the repository at this point in the history
  • Loading branch information
zhupengyang committed Aug 21, 2020
1 parent 6e6567f commit 6e5670b
Show file tree
Hide file tree
Showing 5 changed files with 21 additions and 16 deletions.
10 changes: 2 additions & 8 deletions paddle/fluid/operators/reduce_ops/reduce_mean_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -103,18 +103,12 @@ REGISTER_OP_CPU_KERNEL(reduce_mean,
ops::ReduceKernel<paddle::platform::CPUDeviceContext,
float, ops::MeanFunctor>,
ops::ReduceKernel<paddle::platform::CPUDeviceContext,
double, ops::MeanFunctor>,
ops::ReduceKernel<paddle::platform::CPUDeviceContext,
int, ops::MeanFunctor>,
ops::ReduceKernel<paddle::platform::CPUDeviceContext,
int64_t, ops::MeanFunctor>);
double, ops::MeanFunctor>);

template <typename T>
using CPUReduceMeanGradKernel =
ops::ReduceGradKernel<paddle::platform::CPUDeviceContext, T,
ops::MeanGradFunctor, true>;

REGISTER_OP_CPU_KERNEL(reduce_mean_grad, CPUReduceMeanGradKernel<float>,
CPUReduceMeanGradKernel<double>,
CPUReduceMeanGradKernel<int>,
CPUReduceMeanGradKernel<int64_t>);
CPUReduceMeanGradKernel<double>);
4 changes: 1 addition & 3 deletions paddle/fluid/operators/reduce_ops/reduce_mean_op.cu
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,4 @@ class ReduceMeanKernel : public framework::OpKernel<T> {
} // namespace paddle

REGISTER_OP_CUDA_KERNEL(reduce_mean, ops::ReduceMeanKernel<float>,
ops::ReduceMeanKernel<double>,
ops::ReduceMeanKernel<int>,
ops::ReduceMeanKernel<int64_t>);
ops::ReduceMeanKernel<double>);
6 changes: 6 additions & 0 deletions paddle/fluid/operators/reduce_ops/reduce_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -334,6 +334,12 @@ class ReduceOp : public framework::OperatorWithKernel {
"range [-dimension(X), dimension(X)] "
"which dimesion = %d. But received dim index = %d.",
i, x_rank, dims[i]));
PADDLE_ENFORCE_GE(dims[i], -x_rank,
platform::errors::InvalidArgument(
"The reduce dim index %d should be in the "
"range [-dimension(X), dimension(X)] "
"which dimesion = %d. But received dim index = %d.",
i, x_rank, dims[i]));
if (dims[i] < 0) dims[i] = x_rank + dims[i];
}
sort(dims.begin(), dims.end());
Expand Down
7 changes: 6 additions & 1 deletion python/paddle/fluid/tests/unittests/test_mean_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,9 +129,14 @@ def test_case(x, axis=None, keepdim=False):
paddle.enable_static()

def test_errors(self):
paddle.disable_static()
x = np.random.uniform(-1, 1, [10, 12]).astype('float32')
x = paddle.to_tensor(x)
self.assertRaises(Exception, paddle.mean, x, -3)
self.assertRaises(Exception, paddle.mean, x, 2)
paddle.enable_static()
with paddle.static.program_guard(paddle.static.Program()):
x = paddle.data('X', [10, 12], 'int8')
x = paddle.data('X', [10, 12], 'int32')
self.assertRaises(TypeError, paddle.mean, x)


Expand Down
10 changes: 6 additions & 4 deletions python/paddle/tensor/stat.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,7 @@ def mean(x, axis=None, keepdim=False, name=None):
Computes the mean of the input tensor's elements along ``axis``.
Args:
x (Tensor): The input Tensor with data type float32, float64, int32,
int64.
x (Tensor): The input Tensor with data type float32, float64.
axis (int|list|tuple, optional): The axis along which to perform mean
calculations. ``axis`` should be int, list(int) or tuple(int). If
``axis`` is a list/tuple of dimension(s), mean is calculated along
Expand Down Expand Up @@ -97,9 +96,12 @@ def mean(x, axis=None, keepdim=False, name=None):
return core.ops.reduce_mean(x, 'dim', axis, 'keep_dim', keepdim,
'reduce_all', reduce_all)

check_variable_and_dtype(x, 'x/input',
['float32', 'float64', 'int32', 'int64'],
check_variable_and_dtype(x, 'x/input', ['float32', 'float64'],
'mean/reduce_mean')
check_type(axis, 'axis/dim', (int, list, tuple), 'mean/reduce_mean')
if isinstance(axis, (list, tuple)):
for item in axis:
check_type(item, 'elements of axis/dim', (int), 'mean/reduce_mean')

helper = LayerHelper('mean', **locals())
attrs = {'dim': axis, 'keep_dim': keepdim, 'reduce_all': reduce_all}
Expand Down

0 comments on commit 6e5670b

Please sign in to comment.