From 71c95d0b784e8e2035c4264544b1bb5c799122d3 Mon Sep 17 00:00:00 2001 From: zyt1024 <1522063645@qq.com> Date: Thu, 9 Nov 2023 02:18:48 +0000 Subject: [PATCH] support_complex_unzip_op --- paddle/fluid/operators/unzip_op.cc | 18 +++++-- paddle/fluid/operators/unzip_op.cu | 8 ++- python/paddle/incubate/operators/unzip.py | 11 +++- test/legacy_test/test_unzip_op.py | 66 +++++++++++++++++++++++ 4 files changed, 96 insertions(+), 7 deletions(-) diff --git a/paddle/fluid/operators/unzip_op.cc b/paddle/fluid/operators/unzip_op.cc index 556b345c17b0a..d871030646dbe 100644 --- a/paddle/fluid/operators/unzip_op.cc +++ b/paddle/fluid/operators/unzip_op.cc @@ -162,7 +162,17 @@ REGISTER_OPERATOR(unzip, REGISTER_OPERATOR(unzip_grad, ops::unzipGradientOp); -PD_REGISTER_STRUCT_KERNEL(unzip, CPU, ALL_LAYOUT, ops::unzipOpKernel, int64_t) { -} -PD_REGISTER_STRUCT_KERNEL( - unzip_grad, CPU, ALL_LAYOUT, ops::unzipGradOpKernel, int64_t) {} +PD_REGISTER_STRUCT_KERNEL(unzip, + CPU, + ALL_LAYOUT, + ops::unzipOpKernel, + int64_t, + phi::dtype::complex, + phi::dtype::complex) {} +PD_REGISTER_STRUCT_KERNEL(unzip_grad, + CPU, + ALL_LAYOUT, + ops::unzipGradOpKernel, + int64_t, + phi::dtype::complex, + phi::dtype::complex) {} diff --git a/paddle/fluid/operators/unzip_op.cu b/paddle/fluid/operators/unzip_op.cu index d60af556cd279..c67c4bd0af334 100644 --- a/paddle/fluid/operators/unzip_op.cu +++ b/paddle/fluid/operators/unzip_op.cu @@ -91,7 +91,9 @@ PD_REGISTER_STRUCT_KERNEL(unzip, plat::float16, bool, int, - int64_t) {} + int64_t, + phi::dtype::complex, + phi::dtype::complex) {} PD_REGISTER_STRUCT_KERNEL(unzip_grad, GPU, ALL_LAYOUT, @@ -101,4 +103,6 @@ PD_REGISTER_STRUCT_KERNEL(unzip_grad, plat::float16, bool, int, - int64_t) {} + int64_t, + phi::dtype::complex, + phi::dtype::complex) {} diff --git a/python/paddle/incubate/operators/unzip.py b/python/paddle/incubate/operators/unzip.py index 2edf3a392025d..6134c0fe3908c 100644 --- a/python/paddle/incubate/operators/unzip.py +++ b/python/paddle/incubate/operators/unzip.py @@ -66,7 +66,16 @@ def unzip(input, lod): check_variable_and_dtype( input, 'input', - ['float16', 'float32', 'float64', 'int', 'bool', 'int64'], + [ + 'float16', + 'float32', + 'float64', + 'int', + 'bool', + 'int64', + 'complex64', + 'complex128', + ], 'unzip', ) check_variable_and_dtype(lod, 'lod', ['int', 'int64'], 'unzip') diff --git a/test/legacy_test/test_unzip_op.py b/test/legacy_test/test_unzip_op.py index 8125edc15125a..7b6f1576e5024 100644 --- a/test/legacy_test/test_unzip_op.py +++ b/test/legacy_test/test_unzip_op.py @@ -64,5 +64,71 @@ def test_result(self): assert (res == out_np).all(), "output is not right" +class TestUnzipOp_Complex(unittest.TestCase): + def test_result(self): + """ + For unzip op + """ + self.dtype = self.get_dtype() + paddle.enable_static() + prog = paddle.static.Program() + startup_prog = paddle.static.Program() + with paddle.static.program_guard(prog, startup_prog): + if core.is_compiled_with_cuda(): + place = base.CUDAPlace(0) + x = paddle.static.data( + name='Complex64_X', shape=[3, 4], dtype=self.dtype + ) + lod = paddle.static.data(name='lodx', shape=[11], dtype='int64') + output = paddle.incubate.operators.unzip(x, lod) + input = [ + [1.0 + 1.0j, 2.0 + 2.0j, 3.0 + 3.0j, 4.0 + 4.0j], + [10.0 + 10.0j, 20.0 + 20.0j, 30.0 + 30.0j, 40.0 + 40.0j], + [ + 100.0 + 100.0j, + 200.0 + 200.0j, + 300.0 + 300.0j, + 400.0 + 400.0j, + ], + ] + lod = [0, 4, 4, 8, 8, 8, 8, 12, 12, 12, 12] + + feed = { + 'Complex64_X': np.array(input).astype(self.dtype), + 'lodx': np.array(lod).astype("int64"), + } + + exe = base.Executor(place=place) + exe.run(base.default_startup_program()) + res = exe.run(prog, feed=feed, fetch_list=[output]) + out = [ + [1.0 + 1.0j, 2.0 + 2.0j, 3.0 + 3.0j, 4.0 + 4.0j], + [0.0j, 0.0j, 0.0j, 0.0j], + [10.0 + 10.0j, 20.0 + 20.0j, 30.0 + 30.0j, 40.0 + 40.0j], + [0.0j, 0.0j, 0.0j, 0.0j], + [0.0j, 0.0j, 0.0j, 0.0j], + [0.0j, 0.0j, 0.0j, 0.0j], + [ + 100.0 + 100.0j, + 200.0 + 200.0j, + 300.0 + 300.0j, + 400.0 + 400.0j, + ], + [0.0j, 0.0j, 0.0j, 0.0j], + [0.0j, 0.0j, 0.0j, 0.0j], + [0.0j, 0.0j, 0.0j, 0.0j], + ] + out_np = np.array(out, dtype=self.dtype) + assert (res == out_np).all(), "output is not right" + + def get_dtype(self): + return np.complex64 + + +class TestUnzipOp_Complex128(TestUnzipOp_Complex): + def get_dtype(self): + return np.complex128 + + if __name__ == '__main__': unittest.main()