Skip to content

Commit

Permalink
【Complex op】add complex support for index_select and index_sample (#5…
Browse files Browse the repository at this point in the history
…6457)

* support index_select op

* index_sample in cpu

* support index_sample in gpu

* change data_transform

* fix api gen and use skip_transform in yaml
  • Loading branch information
ScottWong98 committed Sep 1, 2023
1 parent 7635af0 commit 0b60839
Show file tree
Hide file tree
Showing 15 changed files with 127 additions and 17 deletions.
1 change: 1 addition & 0 deletions paddle/fluid/eager/grad_node_info.cc
Original file line number Diff line number Diff line change
Expand Up @@ -548,6 +548,7 @@ void GradNodeBase::HandleComplexGradToRealGrad(
for (size_t slot_id = 0; slot_id < out_grads->size(); slot_id++) {
const std::vector<paddle::Tensor>& slot_out_grads = (*out_grads)[slot_id];
for (size_t rank_id = 0; rank_id < slot_out_grads.size(); rank_id++) {
if (bwd_out_meta_[slot_id].size() == 0) continue;
const GradSlotMeta& slot_meta = bwd_out_meta_[slot_id][rank_id];

PADDLE_ENFORCE(
Expand Down
4 changes: 4 additions & 0 deletions paddle/phi/api/yaml/backward.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -1120,6 +1120,8 @@
func : index_sample_grad
data_type : out_grad
no_need_buffer : x
data_transform :
skip_transform : index

- backward_op : index_select_grad
forward : index_select(Tensor x, Tensor index, int axis) -> Tensor(out)
Expand All @@ -1132,6 +1134,8 @@
func : index_select_grad
data_type : out_grad
no_need_buffer : x
data_transform :
skip_transform : index

- backward_op : index_select_strided_grad
forward : index_select_strided(Tensor x, int64_t index, int axis) -> Tensor(out)
Expand Down
25 changes: 18 additions & 7 deletions paddle/phi/api/yaml/generator/api_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,23 @@

import collections
import re
from typing import List

PREFIX_TENSOR_NAME = 'input_'
PREFIX_META_TENSOR_NAME = 'meta_'


def parse_plain_list(s: str, sep=",") -> List[str]:
"""Copy from `paddle/fluid/operators/generator/parse_utils.py`"""
if sep == ",":
patten = re.compile(r',(?![^{]*\})') # support "int[] a={1,2}"
items = re.split(patten, s.strip())
items = [x.strip() for x in items]
return items
else:
return [item.strip() for item in s.strip().split(sep)]


class BaseAPI:
def __init__(self, api_item_yaml):
self.api = self.get_api_name(api_item_yaml)
Expand Down Expand Up @@ -367,14 +379,13 @@ def parse_data_transform(self, api_item_yaml):
data_transform = {'skip_transform': [], 'support_trans_dtype': []}
if 'data_transform' in api_item_yaml:
if 'skip_transform' in api_item_yaml['data_transform']:
data_transform['skip_transform'] = api_item_yaml[
'data_transform'
]['skip_transform']
data_transform['skip_transform'] = parse_plain_list(
api_item_yaml['data_transform']['skip_transform']
)
if 'support_trans_dtype' in api_item_yaml['data_transform']:
data_transform['support_trans_dtype'] = api_item_yaml[
'data_transform'
]['support_trans_dtype']

data_transform['support_trans_dtype'] = parse_plain_list(
api_item_yaml['data_transform']['support_trans_dtype']
)
return data_transform

# Override by child class
Expand Down
4 changes: 4 additions & 0 deletions paddle/phi/api/yaml/ops.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -1238,6 +1238,8 @@
func : index_sample
data_type : x
backward : index_sample_grad
data_transform :
skip_transform : index

- op : index_select
args : (Tensor x, Tensor index, int axis = 0)
Expand All @@ -1248,6 +1250,8 @@
func : index_select
data_type : x
backward : index_select_grad
data_transform :
skip_transform : index

- op : index_select_strided
args : (Tensor x, int64_t index, int axis = 0)
Expand Down
4 changes: 3 additions & 1 deletion paddle/phi/kernels/cpu/index_sample_grad_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -100,4 +100,6 @@ PD_REGISTER_KERNEL(index_sample_grad,
float,
double,
int,
int64_t) {}
int64_t,
phi::dtype::complex<float>,
phi::dtype::complex<double>) {}
4 changes: 3 additions & 1 deletion paddle/phi/kernels/cpu/index_sample_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -115,4 +115,6 @@ PD_REGISTER_KERNEL(index_sample,
float,
double,
int,
int64_t) {}
int64_t,
phi::dtype::complex<float>,
phi::dtype::complex<double>) {}
2 changes: 2 additions & 0 deletions paddle/phi/kernels/cpu/index_select_grad_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -60,5 +60,7 @@ PD_REGISTER_KERNEL(index_select_grad,
float,
double,
phi::dtype::bfloat16,
phi::dtype::complex<float>,
phi::dtype::complex<double>,
int,
int64_t) {}
2 changes: 2 additions & 0 deletions paddle/phi/kernels/cpu/index_select_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -59,5 +59,7 @@ PD_REGISTER_KERNEL(index_select,
float,
double,
phi::dtype::bfloat16,
phi::dtype::complex<float>,
phi::dtype::complex<double>,
int,
int64_t) {}
4 changes: 3 additions & 1 deletion paddle/phi/kernels/gpu/index_sample_grad_kernel.cu
100755 → 100644
Original file line number Diff line number Diff line change
Expand Up @@ -135,4 +135,6 @@ PD_REGISTER_KERNEL(index_sample_grad,
float,
double,
int,
int64_t) {}
int64_t,
phi::dtype::complex<float>,
phi::dtype::complex<double>) {}
4 changes: 3 additions & 1 deletion paddle/phi/kernels/gpu/index_sample_kernel.cu
100755 → 100644
Original file line number Diff line number Diff line change
Expand Up @@ -108,4 +108,6 @@ PD_REGISTER_KERNEL(index_sample,
float,
double,
int,
int64_t) {}
int64_t,
phi::dtype::complex<float>,
phi::dtype::complex<double>) {}
2 changes: 2 additions & 0 deletions paddle/phi/kernels/gpu/index_select_grad_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -132,5 +132,7 @@ PD_REGISTER_KERNEL(index_select_grad,
double,
phi::dtype::float16,
phi::dtype::bfloat16,
phi::dtype::complex<float>,
phi::dtype::complex<double>,
int,
int64_t) {}
2 changes: 2 additions & 0 deletions paddle/phi/kernels/gpu/index_select_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -85,5 +85,7 @@ PD_REGISTER_KERNEL(index_select,
double,
phi::dtype::float16,
phi::dtype::bfloat16,
phi::dtype::complex<float>,
phi::dtype::complex<double>,
int,
int64_t) {}
26 changes: 22 additions & 4 deletions python/paddle/tensor/search.py
Original file line number Diff line number Diff line change
Expand Up @@ -319,7 +319,7 @@ def index_select(x, index, axis=0, name=None):
size as the length of ``index``; other dimensions have the same size as in the ``x`` tensor.
Args:
x (Tensor): The input Tensor to be operated. The data of ``x`` can be one of float16, float32, float64, int32, int64.
x (Tensor): The input Tensor to be operated. The data of ``x`` can be one of float16, float32, float64, int32, int64, complex64 and complex128.
index (Tensor): The 1-D Tensor containing the indices to index. The data type of ``index`` must be int32 or int64.
axis (int, optional): The dimension in which we index. Default: if None, the ``axis`` is 0.
name (str, optional): For details, please refer to :ref:`api_guide_Name`. Generally, no setting is required. Default: None.
Expand Down Expand Up @@ -353,7 +353,16 @@ def index_select(x, index, axis=0, name=None):
check_variable_and_dtype(
x,
'x',
['uint16', 'float16', 'float32', 'float64', 'int32', 'int64'],
[
'uint16',
'float16',
'float32',
'float64',
'int32',
'int64',
'complex64',
'complex128',
],
'paddle.tensor.search.index_select',
)
check_variable_and_dtype(
Expand Down Expand Up @@ -771,7 +780,7 @@ def index_sample(x, index):
Args:
x (Tensor): The source input tensor with 2-D shape. Supported data type is
int32, int64, bfloat16, float16, float32, float64.
int32, int64, bfloat16, float16, float32, float64, complex64, complex128.
index (Tensor): The index input tensor with 2-D shape, first dimension should be same with X.
Data type is int32 or int64.
Expand Down Expand Up @@ -826,7 +835,16 @@ def index_sample(x, index):
check_variable_and_dtype(
x,
'x',
['uint16', 'float16', 'float32', 'float64', 'int32', 'int64'],
[
'uint16',
'float16',
'float32',
'float64',
'int32',
'int64',
'complex64',
'complex128',
],
'paddle.tensor.search.index_sample',
)
check_variable_and_dtype(
Expand Down
27 changes: 27 additions & 0 deletions test/legacy_test/test_index_sample_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,11 @@ def setUp(self):
self.python_api = paddle.index_sample
self.config()
xnp = np.random.random(self.x_shape).astype(self.x_type)
if self.x_type == np.complex64 or self.x_type == np.complex128:
xnp = (
np.random.random(self.x_shape)
+ 1j * np.random.random(self.x_shape)
).astype(self.x_type)
indexnp = np.random.randint(
low=0, high=self.x_shape[1], size=self.index_shape
).astype(self.index_type)
Expand Down Expand Up @@ -122,6 +127,28 @@ def config(self):
self.index_type = "int64"


class TestIndexSampleComplex64(TestIndexSampleOp):
def config(self):
"""
For complex64 x type
"""
self.x_shape = (10, 128)
self.x_type = np.complex64
self.index_shape = (10, 64)
self.index_type = "int64"


class TestIndexSampleComplex128(TestIndexSampleOp):
def config(self):
"""
For complex64 x type
"""
self.x_shape = (10, 128)
self.x_type = np.complex128
self.index_shape = (10, 64)
self.index_type = "int64"


@unittest.skipIf(
not core.is_compiled_with_cuda()
or not core.is_bfloat16_supported(core.CUDAPlace(0)),
Expand Down
33 changes: 31 additions & 2 deletions test/legacy_test/test_index_select_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,11 @@ def setUp(self):
low=0, high=self.x_shape[self.dim], size=self.index_size
)
x_np = np.random.random(self.x_shape).astype(self.x_type)
if self.dtype == np.complex64 or self.dtype == np.complex128:
x_np = (
np.random.random(self.x_shape)
+ 1j * np.random.random(self.x_shape)
).astype(self.x_type)
self.inputs = {'X': x_np, 'Index': index_np}
self.attrs = {'dim': self.dim}
outer_loop = np.prod(self.x_shape[: self.dim])
Expand All @@ -60,10 +65,16 @@ def init_dtype_type(self):
self.index_size = 100

def test_check_output(self):
self.check_output(check_prim=True)
if self.x_type == np.complex64 or self.x_type == np.complex128:
self.check_output(check_prim=False)
else:
self.check_output(check_prim=True)

def test_check_grad_normal(self):
self.check_grad(['X'], 'Out', check_prim=True)
if self.x_type == np.complex64 or self.x_type == np.complex128:
self.check_grad(['X'], 'Out', check_prim=False)
else:
self.check_grad(['X'], 'Out', check_prim=True)


class TestIndexSelectOpCase2(TestIndexSelectOp):
Expand Down Expand Up @@ -146,6 +157,24 @@ def test_check_grad_normal(self):
self.check_grad_with_place(place, ['X'], 'Out', check_prim=True)


class TestIndexSelectComplex64(TestIndexSelectOp):
def init_dtype_type(self):
self.x_type = np.complex64
self.index_type = np.int32
self.dim = -2
self.x_shape = (10, 10, 4, 10)
self.index_size = 10


class TestIndexSelectComplex128(TestIndexSelectOp):
def init_dtype_type(self):
self.x_type = np.complex128
self.index_type = np.int32
self.dim = -2
self.x_shape = (10, 10, 4, 10)
self.index_size = 10


class TestIndexSelectAPI(unittest.TestCase):
def input_data(self):
self.data_x = np.array(
Expand Down

0 comments on commit 0b60839

Please sign in to comment.