diff --git a/paddle/phi/kernels/sparse/impl/unary_grad_kernel_impl.h b/paddle/phi/kernels/sparse/impl/unary_grad_kernel_impl.h index dbb2730d94f51..763293de836e5 100644 --- a/paddle/phi/kernels/sparse/impl/unary_grad_kernel_impl.h +++ b/paddle/phi/kernels/sparse/impl/unary_grad_kernel_impl.h @@ -91,46 +91,13 @@ DEFINE_SPARSE_UNARY_GRAD_KERNEL(Atanh) DEFINE_SPARSE_UNARY_GRAD_KERNEL(Sqrt) DEFINE_SPARSE_UNARY_GRAD_KERNEL(Square) DEFINE_SPARSE_UNARY_GRAD_KERNEL(Log1p) +DEFINE_SPARSE_UNARY_GRAD_KERNEL(Abs) DEFINE_SPARSE_UNARY_GRAD_KERNEL(Relu) DEFINE_SPARSE_UNARY_GRAD_KERNEL(Expm1) DEFINE_SPARSE_UNARY_GRAD_KERNEL(Relu6) DEFINE_SPARSE_UNARY_GRAD_KERNEL_WITH_ONE_ATTR(Pow, factor) DEFINE_SPARSE_UNARY_GRAD_KERNEL_WITH_ONE_ATTR(LeakyRelu, alpha) -template -void AbsCooGradKernel(const Context& dev_ctx, - const SparseCooTensor& x_or_out, - const SparseCooTensor& dout, - SparseCooTensor* dx) { - EmptyLikeCooKernel(dev_ctx, x_or_out, dx); - phi::AbsGradKernel(dev_ctx, - x_or_out.non_zero_elements(), - dout.non_zero_elements(), - dx->mutable_non_zero_elements()); - if (dx->dtype() == DataType::COMPLEX64 || - dx->dtype() == DataType::COMPLEX128) { - DenseTensor* out_values = dx->mutable_non_zero_elements(); - dx->set_type(out_values->dtype()); - } -} - -template -void AbsCsrGradKernel(const Context& dev_ctx, - const SparseCsrTensor& x_or_out, - const SparseCsrTensor& dout, - SparseCsrTensor* dx) { - EmptyLikeCsrKernel(dev_ctx, x_or_out, dx); - phi::AbsGradKernel(dev_ctx, - x_or_out.non_zero_elements(), - dout.non_zero_elements(), - dx->mutable_non_zero_elements()); - if (dx->dtype() == DataType::COMPLEX64 || - dx->dtype() == DataType::COMPLEX128) { - DenseTensor* out_values = dx->mutable_non_zero_elements(); - dx->set_type(out_values->dtype()); - } -} - template void CastCooGradKernel(const Context& dev_ctx, const SparseCooTensor& x, diff --git a/test/legacy_test/test_sparse_unary_op.py b/test/legacy_test/test_sparse_unary_op.py index 26e867c966e51..70a29ce0c62e3 100644 --- a/test/legacy_test/test_sparse_unary_op.py +++ b/test/legacy_test/test_sparse_unary_op.py @@ -30,7 +30,13 @@ def to_sparse(self, x, format): return x.detach().to_sparse_csr() def check_result( - self, dense_func, sparse_func, format, dtype='float32', *args + self, + dense_func, + sparse_func, + format, + device='cpu', + dtype='float32', + *args ): if dtype == 'complex64': origin_x_real = paddle.rand([8, 16, 32], 'float32') @@ -54,6 +60,7 @@ def check_result( # --- check sparse coo with dense --- # dense_x = origin_x * mask + dense_x.to(device) sp_x = self.to_sparse(dense_x, format) sp_x.stop_gradient = False if len(args) == 0: @@ -103,21 +110,19 @@ def compare_with_dense(self, dense_func, sparse_func, dtype='float32'): if (device == 'cpu' and dtype != 'float16') or ( device == 'gpu' and paddle.is_compiled_with_cuda() ): - paddle.set_device(device) - self.check_result(dense_func, sparse_func, 'coo', dtype) - self.check_result(dense_func, sparse_func, 'csr', dtype) + self.check_result(dense_func, sparse_func, 'coo', device, dtype) + self.check_result(dense_func, sparse_func, 'csr', device, dtype) def compare_with_dense_one_attr(self, dense_func, sparse_func, attr1): for device in devices: if device == 'cpu' or ( device == 'gpu' and paddle.is_compiled_with_cuda() ): - paddle.set_device(device) self.check_result( - dense_func, sparse_func, 'coo', 'float32', attr1 + dense_func, sparse_func, 'coo', device, 'float32', attr1 ) self.check_result( - dense_func, sparse_func, 'csr', 'float32', attr1 + dense_func, sparse_func, 'csr', device, 'float32', attr1 ) def compare_with_dense_two_attr( @@ -127,12 +132,23 @@ def compare_with_dense_two_attr( if device == 'cpu' or ( device == 'gpu' and paddle.is_compiled_with_cuda() ): - paddle.set_device(device) self.check_result( - dense_func, sparse_func, 'coo', 'float32', attr1, attr2 + dense_func, + sparse_func, + 'coo', + device, + 'float32', + attr1, + attr2, ) self.check_result( - dense_func, sparse_func, 'csr', 'float32', attr1, attr2 + dense_func, + sparse_func, + 'csr', + device, + 'float32', + attr1, + attr2, ) def test_sparse_abs(self):