Skip to content

Commit

Permalink
address reviews
Browse files Browse the repository at this point in the history
  • Loading branch information
tiancaishaonvjituizi committed May 9, 2022
1 parent e94f42c commit 8c680e6
Show file tree
Hide file tree
Showing 6 changed files with 127 additions and 17 deletions.
4 changes: 2 additions & 2 deletions paddle/fluid/operators/cum_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ class CumsumOpMaker : public framework::OpProtoAndCheckerMaker {
AddComment(R"DOC(
The cumulative sum of the elements along a given axis.
By default, the first element of the result is the same of the first element of
the input. If exlusive is true, the first element of the result is 0.
the input. If exclusive is true, the first element of the result is 0.
)DOC");
}
};
Expand Down Expand Up @@ -97,7 +97,7 @@ class LogcumsumexpOpMaker : public framework::OpProtoAndCheckerMaker {
AddComment(R"DOC(
Returns the logarithm of the cumulative summation of the exponentiation of elements of input along the given axis.
By default, the first element of the result is the same of the first element of
the input. If exlusive is true, the first element of the result is the minimum value of dtype.
the input. If exclusive is true, the first element of the result is the the lowest finite value of the dtype of output tensor.
)DOC");
}
};
Expand Down
2 changes: 0 additions & 2 deletions paddle/phi/kernels/cpu/cum_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -146,8 +146,6 @@ void CumsumKernel(const Context& dev_ctx,
dev_ctx, x, axis, flatten, exclusive, reverse, reducer, out);
}

// Copied from
// https://github.com/tensorflow/tensorflow/blob/master/tensorflow/core/kernels/scan_ops.h
template <typename T>
struct LogSumExp {
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE T operator()(const T& a,
Expand Down
31 changes: 31 additions & 0 deletions paddle/phi/kernels/gpu/cum_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -241,6 +241,37 @@ void ScanKernel(const Context& dev_ctx,
T* out_data = dev_ctx.template Alloc<T>(out);
const T* in_data = x.data<T>();

// Use thrust for parallel acceleration when the input size is equal to the
// length of the ‘axis’ dimension.
if (std::is_same<Op, cub::Sum>::value && size == out_dims[axis]) {
#ifdef __HIPCC__
const auto& policy = thrust::hip::par.on(dev_ctx.stream());
#else
const auto& policy = thrust::cuda::par.on(dev_ctx.stream());
#endif
if (reverse) {
thrust::reverse_iterator<thrust::device_ptr<const T>> reversed_in(
thrust::device_pointer_cast(in_data) + size);
thrust::reverse_iterator<thrust::device_ptr<T>> reversed_out(
thrust::device_pointer_cast(out_data) + size);
if (exclusive) {
thrust::exclusive_scan(
policy, reversed_in, reversed_in + size, reversed_out);
} else {
thrust::inclusive_scan(
policy, reversed_in, reversed_in + size, reversed_out);
}
} else {
if (exclusive) {
thrust::exclusive_scan(policy, in_data, in_data + size, out_data);
} else {
thrust::inclusive_scan(policy, in_data, in_data + size, out_data);
}
}
return;
}


size_t height = 1;
size_t width = 1;
for (size_t i = 0; i <= axis; i++) {
Expand Down
9 changes: 0 additions & 9 deletions paddle/utils/variant.h
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,6 @@

#pragma once

#if defined(__GNUC__) && !defined(__clang__) && __GNUC__ >= 9
#pragma GCC diagnostic push
#pragma GCC diagnostic ignored "-Wdeprecated-copy"
#endif

/*
variant synopsis
Expand Down Expand Up @@ -2833,7 +2828,3 @@ struct hash<paddle::monostate> {
};

} // namespace std

#if defined(__GNUC__) && !defined(__clang__) && __GNUC__ >= 9
#pragma GCC diagnostic pop
#endif
57 changes: 53 additions & 4 deletions python/paddle/fluid/tests/unittests/test_logcumsumexp_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,11 @@
import paddle.fluid as fluid
from paddle.fluid import compiler, Program, program_guard
from paddle.fluid.framework import _test_eager_guard
from op_test import OpTest


def np_naive_logcumsumexp(x: np.ndarray, axis: Optional[int]=None):
return np.log(np.cumsum(np.exp(x), axis=axis))


def np_logcumsumexp(x: np.ndarray,
Expand Down Expand Up @@ -56,8 +61,8 @@ def np_logcumsumexp(x: np.ndarray,
return x


class TestLogcumsumexpOp(unittest.TestCase):
def run_cases(self):
class TestLogcumsumexp(unittest.TestCase):
def run_imperative(self):
data_np = np.arange(12, dtype=np.float32).reshape(3, 4)
data = paddle.to_tensor(data_np)

Expand Down Expand Up @@ -86,6 +91,17 @@ def run_cases(self):
with self.assertRaises(IndexError):
y = paddle.logcumsumexp(data, axis=2)

data_np = np.arange(10000, 10024, dtype=np.float32)
data = paddle.to_tensor(data_np)
y = paddle.logcumsumexp(data)
z = np_naive_logcumsumexp(data_np)
# check that naive algorithm overflows
self.assertTrue(all(z == np.inf))
z = np_logcumsumexp(data_np)
# check that our algorithm doesn't overflow
self.assertTrue(all(z != np.inf))
self.assertTrue(np.allclose(z, y.numpy()))

def run_static(self, use_gpu=False):
with fluid.program_guard(fluid.Program()):
data_np = np.random.random((100, 100)).astype(np.float32)
Expand Down Expand Up @@ -120,7 +136,7 @@ def run_static(self, use_gpu=False):

def test_cpu(self):
paddle.disable_static(paddle.fluid.CPUPlace())
self.run_cases()
self.run_imperative()
paddle.enable_static()

self.run_static()
Expand All @@ -129,7 +145,7 @@ def test_gpu(self):
if not fluid.core.is_compiled_with_cuda():
return
paddle.disable_static(paddle.fluid.CUDAPlace(0))
self.run_cases()
self.run_imperative()
paddle.enable_static()

self.run_static(use_gpu=True)
Expand All @@ -154,5 +170,38 @@ def test_type_error(self):
out = exe.run(feed={'X': data_np}, fetch_list=[y.name])


class BaseOpTest(OpTest):
def setUp(self):
self.op_type = "logcumsumexp"
input, attrs = self.input_and_attrs()
self.inputs = {'X': input}
self.attrs = attrs
self.outputs = {'Out': np_logcumsumexp(input)}

def test_check_output(self):
self.check_output()

def test_check_grad(self):
self.check_grad(['X'], 'Out')

def input_and_attrs(self):
raise NotImplementedError()


def TestLogcumsumexpOp1(BaseOpTest):
def input_and_attrs(self):
return np.random.randn(20, 6), {'axis': 0, 'flatten': True, 'reverse': True}


def TestLogcumsumexpOp2(BaseOpTest):
def input_and_attrs(self):
return np.random.randn(20, 6), {'axis': 1, 'flatten': False, 'reverse': True}


def TestLogcumsumexpOp3(BaseOpTest):
def input_and_attrs(self):
return np.random.randn(20, 6), {'axis': 1, 'flatten': False, 'reverse': False}


if __name__ == '__main__':
unittest.main()
41 changes: 41 additions & 0 deletions python/paddle/tensor/math.py
Original file line number Diff line number Diff line change
Expand Up @@ -2971,6 +2971,47 @@ def cumsum(x, axis=None, dtype=None, name=None):


def logcumsumexp(x, axis=None, dtype=None, name=None):
"""
The the logarithm of the cumulative summation of the exponentiation of the elements along a given axis.
**Note**:
The first element of the result is the same of the first element of the input.
Args:
x (Tensor): The input tensor
axis (int, optional): The dimension to do the operation along. -1 means the last dimension. The default (None) is to compute the cumsum over the flattened array.
dtype (str, optional): The data type of the output tensor, can be float32, float64. If specified, the input tensor is casted to dtype before the operation is performed. This is useful for preventing data type overflows. The default value is None.
name (str, optional): Name for the operation (optional, default is None). For more information, please refer to :ref:`api_guide_Name`.
Returns:
Tensor, the result of logcumsumexp operator.
Examples:
.. code-block:: python
import paddle
data = paddle.arange(12)
data = paddle.reshape(data, (3, 4))
y = paddle.logcumsumexp(data)
# [ 0. 1.3132617 2.4076061 3.4401898 4.4519143 5.4561934
# 6.4577627 7.4583397 8.458551 9.45863 10.458658 11.458669 ]
y = paddle.logcumsumexp(data, axis=0)
# [[ 0. 1. 2. 3. ]
# [ 4.01815 5.01815 6.01815 7.01815 ]
# [ 8.018479 9.018479 10.018479 11.018479]]
y = paddle.logcumsumexp(data, axis=-1)
# [[ 0. 1.3132617 2.4076061 3.4401898]
# [ 4. 5.3132615 6.407606 7.44019 ]
# [ 8. 9.313262 10.407606 11.440189 ]]
y = paddle.logcumsumexp(data, dtype='float64')
print(y.dtype)
# paddle.float64
"""
if axis is None:
flatten = True
else:
Expand Down

0 comments on commit 8c680e6

Please sign in to comment.