-
Notifications
You must be signed in to change notification settings - Fork 5.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
【PaddlePaddle Hackathon 3 No.16】为 Paddle 新增 API paddle.take #44741
Changes from 25 commits
982d01e
69b0a3e
b07c062
09d2836
c8482f6
0665e50
c5a9e16
10b41c4
6852760
9649b87
ec1cfd7
6806a8f
27b6943
5d32c52
b35d831
cc2f4f4
c4161f2
aaee858
ca2604f
5979d5f
7b3fc1d
668964d
64b688a
cdd1080
eca0483
4ca5c41
9fb6896
7fd6c85
046ff44
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,220 @@ | ||
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
from __future__ import print_function | ||
|
||
import unittest | ||
import numpy as np | ||
import paddle | ||
import paddle.fluid as fluid | ||
import paddle.fluid.core as core | ||
from paddle.fluid import Program, program_guard | ||
|
||
|
||
class TestTakeAPI(unittest.TestCase): | ||
|
||
def set_mode(self): | ||
self.mode = 'raise' | ||
|
||
def set_dtype(self): | ||
self.input_dtype = 'float64' | ||
self.index_dtype = 'int64' | ||
|
||
def set_input(self): | ||
self.input_shape = [3, 4] | ||
self.index_shape = [2, 3] | ||
self.input_np = np.arange(0, 12).reshape(self.input_shape).astype( | ||
self.input_dtype) | ||
self.index_np = np.arange(-4, 2).reshape(self.index_shape).astype( | ||
self.index_dtype) | ||
|
||
def setUp(self): | ||
self.set_mode() | ||
self.set_dtype() | ||
self.set_input() | ||
self.place = fluid.CUDAPlace( | ||
0) if core.is_compiled_with_cuda() else fluid.CPUPlace() | ||
|
||
def test_static_graph(self): | ||
paddle.enable_static() | ||
startup_program = Program() | ||
train_program = Program() | ||
with program_guard(startup_program, train_program): | ||
x = fluid.data(name='input', | ||
dtype=self.input_dtype, | ||
shape=self.input_shape) | ||
index = fluid.data(name='index', | ||
dtype=self.index_dtype, | ||
shape=self.index_shape) | ||
out = paddle.take(x, index, mode=self.mode) | ||
|
||
exe = fluid.Executor(self.place) | ||
st_result = exe.run(fluid.default_main_program(), | ||
feed={ | ||
'input': self.input_np, | ||
'index': self.index_np | ||
}, | ||
fetch_list=out) | ||
np.testing.assert_allclose( | ||
st_result[0], | ||
np.take(self.input_np, self.index_np, mode=self.mode)) | ||
|
||
def test_dygraph(self): | ||
paddle.disable_static(self.place) | ||
x = paddle.to_tensor(self.input_np) | ||
index = paddle.to_tensor(self.index_np) | ||
dy_result = paddle.take(x, index, mode=self.mode) | ||
np.testing.assert_allclose( | ||
np.take(self.input_np, self.index_np, mode=self.mode), | ||
dy_result.numpy()) | ||
|
||
|
||
class TestTakeInt32(TestTakeAPI): | ||
"""Test take API with data type int32""" | ||
|
||
def set_dtype(self): | ||
self.input_dtype = 'int32' | ||
self.index_dtype = 'int64' | ||
|
||
|
||
class TestTakeInt64(TestTakeAPI): | ||
"""Test take API with data type int64""" | ||
|
||
def set_dtype(self): | ||
self.input_dtype = 'int64' | ||
self.index_dtype = 'int64' | ||
|
||
|
||
class TestTakeFloat32(TestTakeAPI): | ||
"""Test take API with data type float32""" | ||
|
||
def set_dtype(self): | ||
self.input_dtype = 'float32' | ||
self.index_dtype = 'int64' | ||
|
||
|
||
class TestTakeTypeError(TestTakeAPI): | ||
"""Test take Type Error""" | ||
|
||
def test_static_type_error(self): | ||
"""Argument 'index' must be Tensor""" | ||
paddle.enable_static() | ||
with program_guard(Program()): | ||
x = fluid.data(name='input', | ||
dtype=self.input_dtype, | ||
shape=self.input_shape) | ||
self.assertRaises(TypeError, paddle.take, x, self.index_np, | ||
self.mode) | ||
|
||
def test_dygraph_type_error(self): | ||
paddle.disable_static(self.place) | ||
x = paddle.to_tensor(self.input_np) | ||
self.assertRaises(TypeError, paddle.take, x, self.index_np, self.mode) | ||
|
||
def test_static_dtype_error(self): | ||
"""Data type of argument 'index' must be in [paddle.int32, paddle.int64]""" | ||
paddle.enable_static() | ||
with program_guard(Program()): | ||
x = fluid.data(name='input', | ||
dtype='float64', | ||
shape=self.input_shape) | ||
index = fluid.data(name='index', | ||
dtype='float32', | ||
shape=self.index_shape) | ||
self.assertRaises(TypeError, paddle.take, x, index, self.mode) | ||
|
||
def test_dygraph_dtype_error(self): | ||
paddle.disable_static(self.place) | ||
x = paddle.to_tensor(self.input_np) | ||
index = paddle.to_tensor(self.index_np, dtype='float32') | ||
self.assertRaises(TypeError, paddle.take, x, index, self.mode) | ||
|
||
|
||
class TestTakeModeRaise(unittest.TestCase): | ||
"""Test take index out of range error""" | ||
|
||
def set_mode(self): | ||
self.mode = 'raise' | ||
|
||
def set_dtype(self): | ||
self.input_dtype = 'float64' | ||
self.index_dtype = 'int64' | ||
|
||
def set_input(self): | ||
self.input_shape = [3, 4] | ||
self.index_shape = [5, 8] | ||
self.input_np = np.arange(0, 12).reshape(self.input_shape).astype( | ||
self.input_dtype) | ||
self.index_np = np.arange(-20, 20).reshape(self.index_shape).astype( | ||
self.index_dtype) # Both ends of the index are out of bounds | ||
|
||
def setUp(self): | ||
self.set_mode() | ||
self.set_dtype() | ||
self.set_input() | ||
self.place = fluid.CUDAPlace( | ||
0) if core.is_compiled_with_cuda() else fluid.CPUPlace() | ||
|
||
def test_static_index_error(self): | ||
"""When the index is out of range, | ||
an error is reported directly through `paddle.index_select`""" | ||
paddle.enable_static() | ||
with program_guard(Program()): | ||
x = fluid.data(name='input', | ||
dtype=self.input_dtype, | ||
shape=self.input_shape) | ||
index = fluid.data(name='index', | ||
dtype=self.index_dtype, | ||
shape=self.index_shape) | ||
self.assertRaises(ValueError, paddle.index_select, x, index) | ||
|
||
def test_dygraph_index_error(self): | ||
paddle.disable_static(self.place) | ||
x = paddle.to_tensor(self.input_np) | ||
index = paddle.to_tensor(self.index_np, dtype=self.index_dtype) | ||
self.assertRaises(ValueError, paddle.index_select, x, index) | ||
|
||
|
||
class TestTakeModeWrap(TestTakeAPI): | ||
"""Test take index out of range mode""" | ||
|
||
def set_mode(self): | ||
self.mode = 'wrap' | ||
|
||
def set_input(self): | ||
self.input_shape = [3, 4] | ||
self.index_shape = [5, 8] | ||
self.input_np = np.arange(0, 12).reshape(self.input_shape).astype( | ||
self.input_dtype) | ||
self.index_np = np.arange(-20, 20).reshape(self.index_shape).astype( | ||
self.index_dtype) # Both ends of the index are out of bounds | ||
|
||
|
||
class TestTakeModeClip(TestTakeAPI): | ||
"""Test take index out of range mode""" | ||
|
||
def set_mode(self): | ||
self.mode = 'clip' | ||
|
||
def set_input(self): | ||
self.input_shape = [3, 4] | ||
self.index_shape = [5, 8] | ||
self.input_np = np.arange(0, 12).reshape(self.input_shape).astype( | ||
self.input_dtype) | ||
self.index_np = np.arange(-20, 20).reshape(self.index_shape).astype( | ||
self.index_dtype) # Both ends of the index are out of bounds | ||
|
||
|
||
if __name__ == "__main__": | ||
unittest.main() |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -4735,7 +4735,6 @@ def frac(x, name=None): | |
type="trunc", inputs=inputs, attrs=attrs, outputs={"Out": y}) | ||
return _elementwise_op(LayerHelper(op_type, **locals())) | ||
|
||
|
||
def sgn(x, name=None): | ||
""" | ||
For complex tensor, this API returns a new tensor whose elements have the same angles as the corresponding | ||
|
@@ -4776,3 +4775,107 @@ def sgn(x, name=None): | |
return paddle.as_complex(output) | ||
else: | ||
return paddle.sign(x) | ||
|
||
def take(x, index, mode='raise', name=None): | ||
""" | ||
Returns a new tensor with the elements of input tensor x at the given index. | ||
The input tensor is treated as if it were viewed as a 1-D tensor. | ||
The result takes the same shape as the index. | ||
|
||
Args: | ||
x (Tensor): An N-D Tensor, its data type should be int32, int64, float32, float64. | ||
index (Tensor): An N-D Tensor, its data type should be int32, int64. | ||
mode (str, optional): Specifies how out-of-bounds index will behave. the candicates are ``'raise'``, ``'wrap'`` and ``'clip'``. | ||
|
||
- ``'raise'``: raise an error (default); | ||
- ``'wrap'``: wrap around; | ||
- ``'clip'``: clip to the range. ``'clip'`` mode means that all indices that are too large are replaced by the index that addresses the last element. Note that this disables indexing with negative numbers. | ||
|
||
name (str, optional): Name for the operation (optional, default is None). For more information, please refer to :ref:`api_guide_Name`. | ||
|
||
Returns: | ||
Tensor, Tensor with the same shape as index, the data type is the same with input. | ||
|
||
Examples: | ||
.. code-block:: python | ||
|
||
import paddle | ||
|
||
x_int = paddle.arange(0, 12).reshape([3, 4]) | ||
x_float = x_int.astype(paddle.float64) | ||
|
||
idx_pos = paddle.arange(4, 10).reshape([2, 3]) # positive index | ||
idx_neg = paddle.arange(-2, 4).reshape([2, 3]) # negative index | ||
idx_err = paddle.arange(-2, 13).reshape([3, 5]) # index out of range | ||
|
||
paddle.take(x_int, idx_pos) | ||
# Tensor(shape=[2, 3], dtype=int64, place=Place(cpu), stop_gradient=True, | ||
# [[4, 5, 6], | ||
# [7, 8, 9]]) | ||
|
||
paddle.take(x_int, idx_neg) | ||
# Tensor(shape=[2, 3], dtype=int64, place=Place(cpu), stop_gradient=True, | ||
# [[10, 11, 0 ], | ||
# [1 , 2 , 3 ]]) | ||
|
||
paddle.take(x_float, idx_pos) | ||
# Tensor(shape=[2, 3], dtype=float64, place=Place(cpu), stop_gradient=True, | ||
# [[4., 5., 6.], | ||
# [7., 8., 9.]]) | ||
|
||
x_int.take(idx_pos) | ||
# Tensor(shape=[2, 3], dtype=int64, place=Place(cpu), stop_gradient=True, | ||
# [[4, 5, 6], | ||
# [7, 8, 9]]) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 示例可增加一个negative index和float类型的input There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Done |
||
|
||
paddle.take(x_int, idx_err, mode='wrap') | ||
# Tensor(shape=[3, 5], dtype=int32, place=Place(cpu), stop_gradient=True, | ||
# [[10, 11, 0 , 1 , 2 ], | ||
# [3 , 4 , 5 , 6 , 7 ], | ||
# [8 , 9 , 10, 11, 0 ]]) | ||
|
||
paddle.take(x_int, idx_err, mode='clip') | ||
# Tensor(shape=[3, 5], dtype=int32, place=Place(cpu), stop_gradient=True, | ||
# [[0 , 0 , 0 , 1 , 2 ], | ||
# [3 , 4 , 5 , 6 , 7 ], | ||
# [8 , 9 , 10, 11, 11]]) | ||
|
||
""" | ||
if mode not in ['raise', 'wrap', 'clip']: | ||
raise ValueError( | ||
"'mode' in 'take' should be 'raise', 'wrap', 'clip', but received {}.".format(mode)) | ||
|
||
if paddle.in_dynamic_mode(): | ||
if not isinstance(index, (paddle.Tensor, Variable)): | ||
raise TypeError( | ||
"The type of 'index' must be Tensor, but got {}".format(type(index))) | ||
if index.dtype not in [paddle.int32, paddle.int64]: | ||
raise TypeError( | ||
"The data type of 'index' must be one of ['int32', 'int64'], but got {}".format( | ||
index.dtype)) | ||
|
||
else: | ||
check_variable_and_dtype(index, 'index', ['int32', 'int64'], 'take') | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. index索引越界时需要报错 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Done |
||
|
||
input_1d = x.flatten() | ||
index_1d = index.flatten() | ||
max_index = input_1d.shape[-1] | ||
|
||
if mode == 'raise': | ||
# This processing enables 'take' to handle negative indexes within the correct range. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 可以补充下注释,negative indexes可以enable,但越界的索引会在下面的index_select报错 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. THX,Done |
||
# Negative indexes can be enabled, | ||
# but out-of-range indexes will report an error in the following paddle.index_select | ||
index_1d = paddle.where(index_1d < 0, index_1d % max_index, index_1d) | ||
elif mode == 'wrap': | ||
# The out of range indices are constrained by taking the remainder. | ||
index_1d = paddle.where(index_1d < 0, | ||
index_1d % max_index, index_1d) | ||
index_1d = paddle.where(index_1d >= max_index, | ||
index_1d % max_index, index_1d) | ||
elif mode == 'clip': | ||
# 'clip' mode disables indexing with negative numbers. | ||
index_1d = clip(index_1d, 0, max_index - 1) | ||
|
||
out = input_1d.index_select(index_1d).reshape(index.shape) | ||
|
||
return out |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
the name of parameter needs to be consistent with rfc,
input
in rfc whilex
here, andmode
is not in rfc.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@jeff41404 根据之前的修改意见 PaddlePaddle/community#186 (review) 更新过RFC:PaddlePaddle/community#217
参数的名字按照新的RFC内容进行修改的。
@S-HuaBomb 请先修改完RFC的评审意见吧。
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
rfc is still old now, should update and merge rfc first
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is the modified RFC PaddlePaddle/community#217 with instructions added