Skip to content

Commit

Permalink
change the device setting in test op
Browse files Browse the repository at this point in the history
  • Loading branch information
bapijun committed Mar 29, 2024
1 parent d11cb5b commit e589c98
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 44 deletions.
35 changes: 1 addition & 34 deletions paddle/phi/kernels/sparse/impl/unary_grad_kernel_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -91,46 +91,13 @@ DEFINE_SPARSE_UNARY_GRAD_KERNEL(Atanh)
DEFINE_SPARSE_UNARY_GRAD_KERNEL(Sqrt)
DEFINE_SPARSE_UNARY_GRAD_KERNEL(Square)
DEFINE_SPARSE_UNARY_GRAD_KERNEL(Log1p)
DEFINE_SPARSE_UNARY_GRAD_KERNEL(Abs)
DEFINE_SPARSE_UNARY_GRAD_KERNEL(Relu)
DEFINE_SPARSE_UNARY_GRAD_KERNEL(Expm1)
DEFINE_SPARSE_UNARY_GRAD_KERNEL(Relu6)
DEFINE_SPARSE_UNARY_GRAD_KERNEL_WITH_ONE_ATTR(Pow, factor)
DEFINE_SPARSE_UNARY_GRAD_KERNEL_WITH_ONE_ATTR(LeakyRelu, alpha)

template <typename T, typename Context>
void AbsCooGradKernel(const Context& dev_ctx,
const SparseCooTensor& x_or_out,
const SparseCooTensor& dout,
SparseCooTensor* dx) {
EmptyLikeCooKernel<T, Context>(dev_ctx, x_or_out, dx);
phi::AbsGradKernel<T, Context>(dev_ctx,
x_or_out.non_zero_elements(),
dout.non_zero_elements(),
dx->mutable_non_zero_elements());
if (dx->dtype() == DataType::COMPLEX64 ||
dx->dtype() == DataType::COMPLEX128) {
DenseTensor* out_values = dx->mutable_non_zero_elements();
dx->set_type(out_values->dtype());
}
}

template <typename T, typename Context>
void AbsCsrGradKernel(const Context& dev_ctx,
const SparseCsrTensor& x_or_out,
const SparseCsrTensor& dout,
SparseCsrTensor* dx) {
EmptyLikeCsrKernel<T, Context>(dev_ctx, x_or_out, dx);
phi::AbsGradKernel<T, Context>(dev_ctx,
x_or_out.non_zero_elements(),
dout.non_zero_elements(),
dx->mutable_non_zero_elements());
if (dx->dtype() == DataType::COMPLEX64 ||
dx->dtype() == DataType::COMPLEX128) {
DenseTensor* out_values = dx->mutable_non_zero_elements();
dx->set_type(out_values->dtype());
}
}

template <typename T, typename Context>
void CastCooGradKernel(const Context& dev_ctx,
const SparseCooTensor& x,
Expand Down
36 changes: 26 additions & 10 deletions test/legacy_test/test_sparse_unary_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,13 @@ def to_sparse(self, x, format):
return x.detach().to_sparse_csr()

def check_result(
self, dense_func, sparse_func, format, dtype='float32', *args
self,
dense_func,
sparse_func,
format,
device='cpu',
dtype='float32',
*args
):
if dtype == 'complex64':
origin_x_real = paddle.rand([8, 16, 32], 'float32')
Expand All @@ -54,6 +60,7 @@ def check_result(

# --- check sparse coo with dense --- #
dense_x = origin_x * mask
dense_x.to(device)
sp_x = self.to_sparse(dense_x, format)
sp_x.stop_gradient = False
if len(args) == 0:
Expand Down Expand Up @@ -103,21 +110,19 @@ def compare_with_dense(self, dense_func, sparse_func, dtype='float32'):
if (device == 'cpu' and dtype != 'float16') or (
device == 'gpu' and paddle.is_compiled_with_cuda()
):
paddle.set_device(device)
self.check_result(dense_func, sparse_func, 'coo', dtype)
self.check_result(dense_func, sparse_func, 'csr', dtype)
self.check_result(dense_func, sparse_func, 'coo', device, dtype)
self.check_result(dense_func, sparse_func, 'csr', device, dtype)

def compare_with_dense_one_attr(self, dense_func, sparse_func, attr1):
for device in devices:
if device == 'cpu' or (
device == 'gpu' and paddle.is_compiled_with_cuda()
):
paddle.set_device(device)
self.check_result(
dense_func, sparse_func, 'coo', 'float32', attr1
dense_func, sparse_func, 'coo', device, 'float32', attr1
)
self.check_result(
dense_func, sparse_func, 'csr', 'float32', attr1
dense_func, sparse_func, 'csr', device, 'float32', attr1
)

def compare_with_dense_two_attr(
Expand All @@ -127,12 +132,23 @@ def compare_with_dense_two_attr(
if device == 'cpu' or (
device == 'gpu' and paddle.is_compiled_with_cuda()
):
paddle.set_device(device)
self.check_result(
dense_func, sparse_func, 'coo', 'float32', attr1, attr2
dense_func,
sparse_func,
'coo',
device,
'float32',
attr1,
attr2,
)
self.check_result(
dense_func, sparse_func, 'csr', 'float32', attr1, attr2
dense_func,
sparse_func,
'csr',
device,
'float32',
attr1,
attr2,
)

def test_sparse_abs(self):
Expand Down

0 comments on commit e589c98

Please sign in to comment.