From 8c680e697e0d4ec93511edeb226ec1b8f08efe23 Mon Sep 17 00:00:00 2001 From: tiancaishaonvjituizi <452565578@qq.com> Date: Mon, 9 May 2022 11:16:06 +0800 Subject: [PATCH] address reviews --- paddle/fluid/operators/cum_op.cc | 4 +- paddle/phi/kernels/cpu/cum_kernel.cc | 2 - paddle/phi/kernels/gpu/cum_kernel.cu | 31 ++++++++++ paddle/utils/variant.h | 9 --- .../tests/unittests/test_logcumsumexp_op.py | 57 +++++++++++++++++-- python/paddle/tensor/math.py | 41 +++++++++++++ 6 files changed, 127 insertions(+), 17 deletions(-) diff --git a/paddle/fluid/operators/cum_op.cc b/paddle/fluid/operators/cum_op.cc index 7043d47a26b1e..c4e906c25d837 100644 --- a/paddle/fluid/operators/cum_op.cc +++ b/paddle/fluid/operators/cum_op.cc @@ -49,7 +49,7 @@ class CumsumOpMaker : public framework::OpProtoAndCheckerMaker { AddComment(R"DOC( The cumulative sum of the elements along a given axis. By default, the first element of the result is the same of the first element of -the input. If exlusive is true, the first element of the result is 0. +the input. If exclusive is true, the first element of the result is 0. )DOC"); } }; @@ -97,7 +97,7 @@ class LogcumsumexpOpMaker : public framework::OpProtoAndCheckerMaker { AddComment(R"DOC( Returns the logarithm of the cumulative summation of the exponentiation of elements of input along the given axis. By default, the first element of the result is the same of the first element of -the input. If exlusive is true, the first element of the result is the minimum value of dtype. +the input. If exclusive is true, the first element of the result is the the lowest finite value of the dtype of output tensor. )DOC"); } }; diff --git a/paddle/phi/kernels/cpu/cum_kernel.cc b/paddle/phi/kernels/cpu/cum_kernel.cc index 3cb406f8b2907..85a6ea5d8be1b 100644 --- a/paddle/phi/kernels/cpu/cum_kernel.cc +++ b/paddle/phi/kernels/cpu/cum_kernel.cc @@ -146,8 +146,6 @@ void CumsumKernel(const Context& dev_ctx, dev_ctx, x, axis, flatten, exclusive, reverse, reducer, out); } -// Copied from -// https://github.com/tensorflow/tensorflow/blob/master/tensorflow/core/kernels/scan_ops.h template struct LogSumExp { EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE T operator()(const T& a, diff --git a/paddle/phi/kernels/gpu/cum_kernel.cu b/paddle/phi/kernels/gpu/cum_kernel.cu index 7ec60fa11cfb8..59cd4eb7abc59 100644 --- a/paddle/phi/kernels/gpu/cum_kernel.cu +++ b/paddle/phi/kernels/gpu/cum_kernel.cu @@ -241,6 +241,37 @@ void ScanKernel(const Context& dev_ctx, T* out_data = dev_ctx.template Alloc(out); const T* in_data = x.data(); + // Use thrust for parallel acceleration when the input size is equal to the + // length of the ‘axis’ dimension. + if (std::is_same::value && size == out_dims[axis]) { +#ifdef __HIPCC__ + const auto& policy = thrust::hip::par.on(dev_ctx.stream()); +#else + const auto& policy = thrust::cuda::par.on(dev_ctx.stream()); +#endif + if (reverse) { + thrust::reverse_iterator> reversed_in( + thrust::device_pointer_cast(in_data) + size); + thrust::reverse_iterator> reversed_out( + thrust::device_pointer_cast(out_data) + size); + if (exclusive) { + thrust::exclusive_scan( + policy, reversed_in, reversed_in + size, reversed_out); + } else { + thrust::inclusive_scan( + policy, reversed_in, reversed_in + size, reversed_out); + } + } else { + if (exclusive) { + thrust::exclusive_scan(policy, in_data, in_data + size, out_data); + } else { + thrust::inclusive_scan(policy, in_data, in_data + size, out_data); + } + } + return; + } + + size_t height = 1; size_t width = 1; for (size_t i = 0; i <= axis; i++) { diff --git a/paddle/utils/variant.h b/paddle/utils/variant.h index 7b11ae1bee88c..a7546d094c2ff 100644 --- a/paddle/utils/variant.h +++ b/paddle/utils/variant.h @@ -13,11 +13,6 @@ #pragma once -#if defined(__GNUC__) && !defined(__clang__) && __GNUC__ >= 9 -#pragma GCC diagnostic push -#pragma GCC diagnostic ignored "-Wdeprecated-copy" -#endif - /* variant synopsis @@ -2833,7 +2828,3 @@ struct hash { }; } // namespace std - -#if defined(__GNUC__) && !defined(__clang__) && __GNUC__ >= 9 -#pragma GCC diagnostic pop -#endif diff --git a/python/paddle/fluid/tests/unittests/test_logcumsumexp_op.py b/python/paddle/fluid/tests/unittests/test_logcumsumexp_op.py index 21038b3e52730..28313674007dd 100644 --- a/python/paddle/fluid/tests/unittests/test_logcumsumexp_op.py +++ b/python/paddle/fluid/tests/unittests/test_logcumsumexp_op.py @@ -23,6 +23,11 @@ import paddle.fluid as fluid from paddle.fluid import compiler, Program, program_guard from paddle.fluid.framework import _test_eager_guard +from op_test import OpTest + + +def np_naive_logcumsumexp(x: np.ndarray, axis: Optional[int]=None): + return np.log(np.cumsum(np.exp(x), axis=axis)) def np_logcumsumexp(x: np.ndarray, @@ -56,8 +61,8 @@ def np_logcumsumexp(x: np.ndarray, return x -class TestLogcumsumexpOp(unittest.TestCase): - def run_cases(self): +class TestLogcumsumexp(unittest.TestCase): + def run_imperative(self): data_np = np.arange(12, dtype=np.float32).reshape(3, 4) data = paddle.to_tensor(data_np) @@ -86,6 +91,17 @@ def run_cases(self): with self.assertRaises(IndexError): y = paddle.logcumsumexp(data, axis=2) + data_np = np.arange(10000, 10024, dtype=np.float32) + data = paddle.to_tensor(data_np) + y = paddle.logcumsumexp(data) + z = np_naive_logcumsumexp(data_np) + # check that naive algorithm overflows + self.assertTrue(all(z == np.inf)) + z = np_logcumsumexp(data_np) + # check that our algorithm doesn't overflow + self.assertTrue(all(z != np.inf)) + self.assertTrue(np.allclose(z, y.numpy())) + def run_static(self, use_gpu=False): with fluid.program_guard(fluid.Program()): data_np = np.random.random((100, 100)).astype(np.float32) @@ -120,7 +136,7 @@ def run_static(self, use_gpu=False): def test_cpu(self): paddle.disable_static(paddle.fluid.CPUPlace()) - self.run_cases() + self.run_imperative() paddle.enable_static() self.run_static() @@ -129,7 +145,7 @@ def test_gpu(self): if not fluid.core.is_compiled_with_cuda(): return paddle.disable_static(paddle.fluid.CUDAPlace(0)) - self.run_cases() + self.run_imperative() paddle.enable_static() self.run_static(use_gpu=True) @@ -154,5 +170,38 @@ def test_type_error(self): out = exe.run(feed={'X': data_np}, fetch_list=[y.name]) +class BaseOpTest(OpTest): + def setUp(self): + self.op_type = "logcumsumexp" + input, attrs = self.input_and_attrs() + self.inputs = {'X': input} + self.attrs = attrs + self.outputs = {'Out': np_logcumsumexp(input)} + + def test_check_output(self): + self.check_output() + + def test_check_grad(self): + self.check_grad(['X'], 'Out') + + def input_and_attrs(self): + raise NotImplementedError() + + +def TestLogcumsumexpOp1(BaseOpTest): + def input_and_attrs(self): + return np.random.randn(20, 6), {'axis': 0, 'flatten': True, 'reverse': True} + + +def TestLogcumsumexpOp2(BaseOpTest): + def input_and_attrs(self): + return np.random.randn(20, 6), {'axis': 1, 'flatten': False, 'reverse': True} + + +def TestLogcumsumexpOp3(BaseOpTest): + def input_and_attrs(self): + return np.random.randn(20, 6), {'axis': 1, 'flatten': False, 'reverse': False} + + if __name__ == '__main__': unittest.main() diff --git a/python/paddle/tensor/math.py b/python/paddle/tensor/math.py index 2419494c11df4..7caf91556ab93 100644 --- a/python/paddle/tensor/math.py +++ b/python/paddle/tensor/math.py @@ -2971,6 +2971,47 @@ def cumsum(x, axis=None, dtype=None, name=None): def logcumsumexp(x, axis=None, dtype=None, name=None): + """ + The the logarithm of the cumulative summation of the exponentiation of the elements along a given axis. + + **Note**: + The first element of the result is the same of the first element of the input. + + Args: + x (Tensor): The input tensor + axis (int, optional): The dimension to do the operation along. -1 means the last dimension. The default (None) is to compute the cumsum over the flattened array. + dtype (str, optional): The data type of the output tensor, can be float32, float64. If specified, the input tensor is casted to dtype before the operation is performed. This is useful for preventing data type overflows. The default value is None. + name (str, optional): Name for the operation (optional, default is None). For more information, please refer to :ref:`api_guide_Name`. + + Returns: + Tensor, the result of logcumsumexp operator. + + Examples: + .. code-block:: python + + import paddle + + data = paddle.arange(12) + data = paddle.reshape(data, (3, 4)) + + y = paddle.logcumsumexp(data) + # [ 0. 1.3132617 2.4076061 3.4401898 4.4519143 5.4561934 + # 6.4577627 7.4583397 8.458551 9.45863 10.458658 11.458669 ] + + y = paddle.logcumsumexp(data, axis=0) + # [[ 0. 1. 2. 3. ] + # [ 4.01815 5.01815 6.01815 7.01815 ] + # [ 8.018479 9.018479 10.018479 11.018479]] + + y = paddle.logcumsumexp(data, axis=-1) + # [[ 0. 1.3132617 2.4076061 3.4401898] + # [ 4. 5.3132615 6.407606 7.44019 ] + # [ 8. 9.313262 10.407606 11.440189 ]] + + y = paddle.logcumsumexp(data, dtype='float64') + print(y.dtype) + # paddle.float64 + """ if axis is None: flatten = True else: