Skip to content

Commit

Permalink
[API 2.0] Add api paddle.reshape(x, shape, name) (#26338)
Browse files Browse the repository at this point in the history
(1) Add api paddle.reshape(x, shape, name); 

(2) Use Tensor replaces Variable. test=develop
  • Loading branch information
liym27 committed Aug 20, 2020
1 parent cefbb35 commit adba432
Show file tree
Hide file tree
Showing 3 changed files with 173 additions and 55 deletions.
17 changes: 8 additions & 9 deletions python/paddle/fluid/layers/nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -5965,7 +5965,6 @@ def reshape(x, shape, actual_shape=None, act=None, inplace=False, name=None):
"""
:alias_main: paddle.reshape
:alias: paddle.reshape,paddle.tensor.reshape,paddle.tensor.manipulation.reshape
:old_api: paddle.fluid.layers.reshape

This operator changes the shape of ``x`` without changing its data.

Expand Down Expand Up @@ -6008,14 +6007,14 @@ def reshape(x, shape, actual_shape=None, act=None, inplace=False, name=None):
The parameter ``actual_shape`` will be deprecated in the future and only use ``shape`` instead to represent the target shape.

Args:
x(Variable): A ``Tensor`` or ``LoDTensor`` . The data type is ``float32``, ``float64``, ``int32`` or ``int64``.
shape(list|tuple|Variable): Define the target shape. At most one dimension of the target shape can be -1.
x(Tensor): An N-D Tensor. The data type is ``float32``, ``float64``, ``int32`` or ``int64``.
shape(list|tuple|Tensor): Define the target shape. At most one dimension of the target shape can be -1.
The data type is ``int32`` . If ``shape`` is a list or tuple, the elements of it should be integers or Tensors with shape [1].
If ``shape`` is an Variable, it should be an 1-D Tensor .
If ``shape`` is an Tensor, it should be an 1-D Tensor .
actual_shape(variable, optional): An 1-D ``Tensor`` or ``LoDTensor`` . The data type is ``int32`` . If provided, reshape
according to this given shape rather than ``shape`` specifying shape.
That is to say ``actual_shape`` has a higher priority
than ``shape(list|tuple)`` but not ``shape(Variable)``. \
than ``shape(list|tuple)`` but not ``shape(Tensor)``. \
This argument ``actual_shape`` will be removed in a future version. \
Instructions for updating: ``actual_shape`` will be removed in future versions and replaced by ``shape``.
act (str, optional): The non-linear activation to be applied to the reshaped input. Default None.
Expand All @@ -6027,10 +6026,10 @@ def reshape(x, shape, actual_shape=None, act=None, inplace=False, name=None):
For more information, please refer to :ref:`api_guide_Name` .

Returns:
Variable: A ``Tensor`` or ``LoDTensor``. The data type is same as ``x``. It is a new tensor variable if ``inplace`` is ``False``, otherwise it is ``x``. If ``act`` is None, return the reshaped tensor variable, otherwise return the activated tensor variable.
Tensor: A reshaped Tensor with the same data type as ``x``. It is a new tensor variable if ``inplace`` is ``False``, otherwise it is ``x``. If ``act`` is None, return the reshaped tensor variable, otherwise return the activated tensor variable.

Raises:
TypeError: If actual_shape is neither Variable nor None.
TypeError: If actual_shape is neither Tensor nor None.
ValueError: If more than one elements of ``shape`` is -1.
ValueError: If the element of ``shape`` is 0, the corresponding dimension should be less than or equal to the dimension of ``x``.
ValueError: If the elements in ``shape`` is negative except -1.
Expand All @@ -6041,15 +6040,15 @@ def reshape(x, shape, actual_shape=None, act=None, inplace=False, name=None):
import paddle.fluid as fluid

# example 1:
# attr shape is a list which doesn't contain tensor Variable.
# attr shape is a list which doesn't contain Tensors.
data_1 = fluid.data(
name='data_1', shape=[2, 4, 6], dtype='float32')
reshaped_1 = fluid.layers.reshape(
x=data_1, shape=[-1, 0, 3, 2], inplace=True)
# the shape of reshaped_1 is [2,4,3,2].

# example 2:
# attr shape is a list which contains tensor Variable.
# attr shape is a list which contains Tensors.
data_2 = fluid.layers.fill_constant([2,25], "int32", 3)
dim = fluid.layers.fill_constant([1], "int32", 5)
reshaped_2 = fluid.layers.reshape(data_2, shape=[dim, 10])
Expand Down
127 changes: 84 additions & 43 deletions python/paddle/fluid/tests/unittests/test_reshape_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import numpy as np

from op_test import OpTest
import paddle
import paddle.fluid as fluid
from paddle.fluid import compiler, Program, program_guard

Expand Down Expand Up @@ -227,35 +228,43 @@ def init_dtype(self):

# Test python API
class TestReshapeAPI(unittest.TestCase):
# situation 1: have shape( list, no tensor), no actual shape(Tensor)
def test_1(self):
def _set_paddle_api(self):
self.fill_constant = paddle.fill_constant
self.data = paddle.data
self.reshape = paddle.reshape
self.to_tensor = paddle.to_tensor

def _set_fluid_api(self):
self.fill_constant = fluid.layers.fill_constant
self.data = fluid.data
self.reshape = fluid.layers.reshape

def _test_api(self):
input = np.random.random([2, 25]).astype("float32")
shape = [2, 5, 5]
positive_five = fluid.layers.fill_constant([1], "int32", 5)
x = fluid.layers.data(
name="x", shape=[2, 25], append_batch_size=False, dtype="float32")
main_prog = Program()
with program_guard(main_prog, Program()):
positive_five = self.fill_constant([1], "int32", 5)
x = self.data(name="x", shape=[2, 25], dtype="float32")

actual_shape = fluid.layers.data(
name="shape",
shape=[1, 3],
append_batch_size=False,
dtype="float32")
actual_shape = self.data(name="shape", shape=[3], dtype="int32")

# situation 1: have shape( list, no tensor), no actual shape(Tensor)
out_1 = fluid.layers.reshape(x, shape)
# situation 1: have shape( list, no tensor), no actual shape(Tensor)
out_1 = self.reshape(x, shape)

# situation 2: have shape(list, no tensor), have actual shape(Tensor)
out_2 = fluid.layers.reshape(x, shape=shape, actual_shape=actual_shape)
# situation 2: have shape(list, no tensor), have actual shape(Tensor)
out_2 = fluid.layers.reshape(
x, shape=shape, actual_shape=actual_shape)

# Situation 3: have shape(list, have tensor), no actual shape(Tensor)
out_3 = fluid.layers.reshape(x, shape=[positive_five, 10])
# Situation 3: have shape(list, have tensor), no actual shape(Tensor)
out_3 = self.reshape(x, shape=[positive_five, 10])

# Situation 4: have shape(Tensor), no actual shape(Tensor)
out_4 = fluid.layers.reshape(x, shape=actual_shape)
# Situation 4: have shape(Tensor), no actual shape(Tensor)
out_4 = self.reshape(x, shape=actual_shape)

exe = fluid.Executor(place=fluid.CPUPlace())
res_1, res_2, res_3, res_4 = exe.run(
fluid.default_main_program(),
main_prog,
feed={"x": input,
"shape": np.array([2, 5, 5]).astype("int32")},
fetch_list=[out_1, out_2, out_3, out_4])
Expand All @@ -265,76 +274,108 @@ def test_1(self):
assert np.array_equal(res_3, input.reshape([5, 10]))
assert np.array_equal(res_4, input.reshape(shape))

def test_paddle_api(self):
self._set_paddle_api()
self._test_api()

def test_fluid_api(self):
self._set_fluid_api()
self._test_api()

def test_imperative(self):
self._set_paddle_api()
input = np.random.random([2, 25]).astype("float32")
shape = [2, 5, 5]
with fluid.dygraph.guard():
x = self.to_tensor(input)
positive_five = self.fill_constant([1], "int32", 5)

out_1 = self.reshape(x, shape)

out_2 = self.reshape(x, shape=[positive_five, 10])

shape_tensor = self.to_tensor(np.array([2, 5, 5]).astype("int32"))
out_3 = self.reshape(x, shape=shape_tensor)

assert np.array_equal(out_1.numpy(), input.reshape(shape))
assert np.array_equal(out_2.numpy(), input.reshape([5, 10]))
assert np.array_equal(out_3.numpy(), input.reshape(shape))


# Test Input Error
class TestReshapeOpError(unittest.TestCase):
def test_errors(self):
def _set_paddle_api(self):
self.data = paddle.data
self.reshape = paddle.reshape

def _set_fluid_api(self):
self.data = fluid.data
self.reshape = fluid.layers.reshape

def _test_errors(self):
with program_guard(Program(), Program()):
# The x type of reshape_op must be Variable.
def test_x_type():
x1 = fluid.create_lod_tensor(
np.array([[-1]]), [[1]], fluid.CPUPlace())
fluid.layers.reshape(x1, shape=[1])
self.reshape(x1, shape=[1])

self.assertRaises(TypeError, test_x_type)

# The x dtype of reshape_op must be float16, float32, float64, int32 or int64.
def test_x_dtype():
x2 = fluid.layers.data(
name="x2",
shape=[2, 25],
append_batch_size=False,
dtype="bool")
fluid.layers.reshape(x2, shape=[2, 5, 5])
x2 = self.data(name="x2", shape=[2, 25], dtype="bool")
self.reshape(x2, shape=[2, 5, 5])

self.assertRaises(TypeError, test_x_dtype)

def test_x_dtype_float16():
x_float16 = fluid.layers.data(
name="x_float16",
shape=[2, 25],
append_batch_size=False,
dtype="float16")
fluid.layers.reshape(x_float16, shape=[2, 5, 5])
x_float16 = self.data(
name="x_float16", shape=[2, 25], dtype="float16")
self.reshape(x_float16, shape=[2, 5, 5])

test_x_dtype_float16()

x3 = fluid.layers.data(
name="x3",
shape=[2, 25],
append_batch_size=False,
dtype="float32")
x3 = self.data(name="x3", shape=[2, 25], dtype="float32")

# The argument shape's type of reshape_op must be list, tuple or Variable.
def test_shape_type():
fluid.layers.reshape(x3, shape=1)
self.reshape(x3, shape=1)

self.assertRaises(TypeError, test_shape_type)

# The argument actual_shape's type of reshape_op must be Variable or None.
def test_actual_shape_type():
fluid.layers.reshape(x3, shape=[25, 2], actual_shape=1)
self.reshape(x3, shape=[25, 2], actual_shape=1)

self.assertRaises(TypeError, test_actual_shape_type)

# The argument shape have more than one -1.
def test_shape_1():
fluid.layers.reshape(x3, shape=[-1, -1, 5])
self.reshape(x3, shape=[-1, -1, 5])

self.assertRaises(AssertionError, test_shape_1)

# The argument shape have element 0 whose index exceed the input dimension.
def test_shape_2():
fluid.layers.reshape(x3, [2, 5, 5, 0])
self.reshape(x3, [2, 5, 5, 0])

self.assertRaises(AssertionError, test_shape_2)

# The argument shape have more than one negative value.
def test_shape_3():
fluid.layers.reshape(x3, [-1, -2, 5])
self.reshape(x3, [-1, -2, 5])

self.assertRaises(AssertionError, test_shape_3)

def test_paddle_api_error(self):
self._set_paddle_api()
self._test_errors()

def test_fluid_api_error(self):
self._set_fluid_api()
self._test_errors()


if __name__ == "__main__":
unittest.main()
84 changes: 81 additions & 3 deletions python/paddle/tensor/manipulation.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

from __future__ import print_function

from ..fluid.layers import core, reshape
from ..fluid.layers import core
from ..fluid.layer_helper import LayerHelper
from ..fluid.framework import Variable, OpProtoHolder, in_dygraph_mode, convert_np_dtype_to_dtype_
from ..fluid.data_feeder import convert_dtype, check_variable_and_dtype, check_type, check_dtype
Expand All @@ -23,7 +23,7 @@
import numpy as np
# TODO: define functions to manipulate a tensor
from ..fluid.layers import cast #DEFINE_ALIAS
from ..fluid.layers import reshape #DEFINE_ALIAS
from ..fluid.layers import expand_as #DEFINE_ALIAS
from ..fluid.layers import scatter #DEFINE_ALIAS
from ..fluid.layers import slice #DEFINE_ALIAS
from ..fluid.layers import strided_slice #DEFINE_ALIAS
Expand Down Expand Up @@ -377,7 +377,7 @@ def roll(x, shifts, axis=None, name=None):
outputs={'Out': out},
attrs={'axis': axis,
'shifts': shifts})
out = reshape(out, shape=origin_shape, inplace=True)
out = layers.reshape(out, shape=origin_shape, inplace=True)
return out


Expand Down Expand Up @@ -1048,3 +1048,81 @@ def get_attr_expand_shape(list_expand_shape):


broadcast_to = expand


def reshape(x, shape, name=None):
"""
:alias_main: paddle.reshape
:alias: paddle.reshape,paddle.tensor.reshape,paddle.tensor.manipulation.reshape
This operator changes the shape of ``x`` without changing its data.
Some tricks exist when specifying the target shape.
1. -1 means the value of this dimension is inferred from the total element
number of x and remaining dimensions. Thus one and only one dimension can
be set -1.
2. 0 means the actual dimension value is going to be copied from the
corresponding dimension of x. The index of 0s in shape can not exceed
the dimension of x.
Here are some examples to explain it.
1. Given a 3-D tensor x with a shape [2, 4, 6], and the target shape
is [6, 8], the reshape operator will transform x into a 2-D tensor with
shape [6, 8] and leaving x's data unchanged.
2. Given a 3-D tensor x with a shape [2, 4, 6], and the target shape
specified is [2, 3, -1, 2], the reshape operator will transform x into a
4-D tensor with shape [2, 3, 4, 2] and leaving x's data unchanged. In this
case, one dimension of the target shape is set to -1, the value of this
dimension is inferred from the total element number of x and remaining
dimensions.
3. Given a 3-D tensor x with a shape [2, 4, 6], and the target shape
is [-1, 0, 3, 2], the reshape operator will transform x into a 4-D tensor
with shape [2, 4, 3, 2] and leaving x's data unchanged. In this case,
besides -1, 0 means the actual dimension value is going to be copied from
the corresponding dimension of x.
Args:
x(Tensor): An N-D Tensor. The data type is ``float32``, ``float64``, ``int32`` or ``int64``.
shape(list|tuple|Tensor): Define the target shape. At most one dimension of the target shape can be -1.
The data type is ``int32`` . If ``shape`` is a list or tuple, the elements of it should be integers or Tensors with shape [1].
If ``shape`` is an Tensor, it should be an 1-D Tensor .
name(str, optional): The default value is None. Normally there is no need for user to set this property.
For more information, please refer to :ref:`api_guide_Name` .
Returns:
Tensor: A reshaped Tensor with the same data type as ``x``.
Raises:
ValueError: If more than one elements of ``shape`` is -1.
ValueError: If the element of ``shape`` is 0, the corresponding dimension should be less than or equal to the dimension of ``x``.
ValueError: If the elements in ``shape`` is negative except -1.
Examples:
.. code-block:: python
import numpy as np
import paddle
paddle.disable_static()
data = np.random.random([2, 4, 6]).astype("float32")
x = paddle.to_tensor(data)
positive_four = paddle.fill_constant([1], "int32", 4)
out_1 = paddle.reshape(x, [-1, 0, 3, 2])
# the shape of out_1 is [2,4,3,2].
out_2 = paddle.reshape(x, shape=[positive_four, 12])
# the shape of out_2 is [4, 12].
shape_tensor = paddle.to_tensor(np.array([8, 6]).astype("int32"))
out_3 = paddle.reshape(x, shape=shape_tensor)
# the shape of out_2 is [8, 6].
"""
return paddle.fluid.layers.reshape(x=x, shape=shape, name=name)

0 comments on commit adba432

Please sign in to comment.