From 5d3766ff3d391ca3990ee1a05186f19122c3daec Mon Sep 17 00:00:00 2001 From: yaoxuefeng Date: Tue, 14 Jul 2020 20:07:56 +0800 Subject: [PATCH] modify flip test=develop (#25312) According to paddle 2.0 standard 1, change flip api attr name 'dim' to 'axis'. 2, support empty axis 3, change example code to imperative mode. --- paddle/fluid/operators/flip_op.cc | 86 ++++++++++--------- paddle/fluid/operators/flip_op.cu | 2 +- paddle/fluid/operators/flip_op.h | 2 +- python/paddle/__init__.py | 2 +- .../paddle/fluid/tests/unittests/test_flip.py | 30 +++++-- python/paddle/tensor/__init__.py | 2 +- python/paddle/tensor/manipulation.py | 43 +++++----- 7 files changed, 94 insertions(+), 73 deletions(-) diff --git a/paddle/fluid/operators/flip_op.cc b/paddle/fluid/operators/flip_op.cc index 7d0df5ffbd894..fc17657594b7a 100644 --- a/paddle/fluid/operators/flip_op.cc +++ b/paddle/fluid/operators/flip_op.cc @@ -36,46 +36,52 @@ class FlipOp : public framework::OperatorWithKernel { platform::errors::NotFound( "Output(Out) of FlipOp should not be null.")); auto x_dims = ctx->GetInputDim("X"); - auto flip_dims = ctx->Attrs().Get>("dims"); + auto flip_dims = ctx->Attrs().Get>("axis"); size_t flip_dims_size = flip_dims.size(); - // check if dims axis within range - auto min_max_d = std::minmax_element(flip_dims.begin(), flip_dims.end()); - PADDLE_ENFORCE_LT(*min_max_d.first, x_dims.size(), - platform::errors::InvalidArgument( - "min(dims) should be less than the input tensor X's " - "dimensions of FlipOp. But received min(dims) = %d, " - "X's dimensions = %d, X's shape = [%s]", - *min_max_d.first, x_dims.size(), x_dims)); - PADDLE_ENFORCE_GE( - *min_max_d.first, x_dims.size() * -1, - platform::errors::InvalidArgument( - "min(dims) should be greater than or equal to the input tensor X's " - "dimensions of FlipOp times -1. But received min(dims) = %d, X's " - "dimensions = %d, X's shape = [%s]", - *min_max_d.first, x_dims.size() * -1, x_dims)); - PADDLE_ENFORCE_LT(*min_max_d.second, x_dims.size(), - platform::errors::InvalidArgument( - "max(dims) should be less than the input tensor X's " - "dimensions of FlipOp. But received max(dims) = %d, " - "X's dimensions = %d, X's shape = [%s]", - *min_max_d.second, x_dims.size(), x_dims)); - PADDLE_ENFORCE_GE( - *min_max_d.second, x_dims.size() * -1, - platform::errors::InvalidArgument( - "max(dims) should be greater than or equal to the input tensor X's " - "dimensions of FlipOp times -1. But received max(dims) = %d, X's " - "dimensions = %d, X's shape = [%s]", - *min_max_d.second, x_dims.size() * -1, x_dims)); - - // check duplicates in dims - flip_dims.erase(std::unique(flip_dims.begin(), flip_dims.end()), - flip_dims.end()); - PADDLE_ENFORCE_EQ(flip_dims.size(), flip_dims_size, - platform::errors::InvalidArgument( - "dims has duplicates, original flip dims size=%d, " - "but unique flip dims size=%d.)", - flip_dims_size, flip_dims.size())); + if (flip_dims_size > 0) { + // check if dims axis within range + auto min_max_d = std::minmax_element(flip_dims.begin(), flip_dims.end()); + PADDLE_ENFORCE_LT( + *min_max_d.first, x_dims.size(), + platform::errors::InvalidArgument( + "min(axes) should be less than the input tensor X's " + "axes of FlipOp. But received min(axes) = %d, " + "X's axes = %d, X's shape = [%s]", + *min_max_d.first, x_dims.size(), x_dims)); + PADDLE_ENFORCE_GE(*min_max_d.first, x_dims.size() * -1, + platform::errors::InvalidArgument( + "min(axes) should be greater than or equal to the " + "input tensor X's " + "axes of FlipOp times -1. But received " + "min(axes) = %d, X's " + "axes = %d, X's shape = [%s]", + *min_max_d.first, x_dims.size() * -1, x_dims)); + PADDLE_ENFORCE_LT( + *min_max_d.second, x_dims.size(), + platform::errors::InvalidArgument( + "max(axes) should be less than the input tensor X's " + "axes of FlipOp. But received max(axes) = %d, " + "X's axes = %d, X's shape = [%s]", + *min_max_d.second, x_dims.size(), x_dims)); + PADDLE_ENFORCE_GE(*min_max_d.second, x_dims.size() * -1, + platform::errors::InvalidArgument( + "max(axes) should be greater than or equal to the " + "input tensor X's " + "axes of FlipOp times -1. But received " + "max(axes) = %d, X's " + "axes = %d, X's shape = [%s]", + *min_max_d.second, x_dims.size() * -1, x_dims)); + + // check duplicates in dims + flip_dims.erase(std::unique(flip_dims.begin(), flip_dims.end()), + flip_dims.end()); + PADDLE_ENFORCE_EQ(flip_dims.size(), flip_dims_size, + platform::errors::InvalidArgument( + "axes has duplicates, original flip axes size=%d, " + "but unique flip axes size=%d.)", + flip_dims_size, flip_dims.size())); + } VLOG(3) << "flip operator x.shape=" << x_dims; @@ -104,10 +110,10 @@ class FlipOpMaker : public framework::OpProtoAndCheckerMaker { void Make() override { AddInput("X", "(Tensor), The input tensor of flip op."); AddOutput("Out", "(Tensor), The output tensor of flip op."); - AddAttr>("dims", "The axes to flip on."); + AddAttr>("axis", "The axes to flip on."); AddComment(R"DOC( Flip Operator. - Reverse the order of a n-D tensor along given axis in dims. + Reverse the order of a n-D tensor along given axis in axes. )DOC"); } }; diff --git a/paddle/fluid/operators/flip_op.cu b/paddle/fluid/operators/flip_op.cu index 41aae1e1f35a6..581a994ba84b5 100644 --- a/paddle/fluid/operators/flip_op.cu +++ b/paddle/fluid/operators/flip_op.cu @@ -81,7 +81,7 @@ class FlipKernel Tensor* out = ctx.Output("Out"); auto* in_data = x->data(); auto* out_data = out->mutable_data(ctx.GetPlace()); - auto flip_dims = ctx.template Attr>("dims"); + auto flip_dims = ctx.template Attr>("axis"); const int flip_dims_size = static_cast(flip_dims.size()); auto x_dims = x->dims(); diff --git a/paddle/fluid/operators/flip_op.h b/paddle/fluid/operators/flip_op.h index 73d73f5d0f2e0..b77827b782b1a 100644 --- a/paddle/fluid/operators/flip_op.h +++ b/paddle/fluid/operators/flip_op.h @@ -41,7 +41,7 @@ class FlipKernel void Compute(const framework::ExecutionContext& ctx) const override { const Tensor* x = ctx.Input("X"); Tensor* out = ctx.Output("Out"); - auto flip_dims = ctx.template Attr>("dims"); + auto flip_dims = ctx.template Attr>("axis"); auto x_dims = x->dims(); const int total_dims = x_dims.size(); diff --git a/python/paddle/__init__.py b/python/paddle/__init__.py index a2096a01ccdd9..2eed69c9df6be 100644 --- a/python/paddle/__init__.py +++ b/python/paddle/__init__.py @@ -108,7 +108,7 @@ from .tensor.manipulation import gather #DEFINE_ALIAS from .tensor.manipulation import gather_nd #DEFINE_ALIAS from .tensor.manipulation import reshape #DEFINE_ALIAS -from .tensor.manipulation import reverse #DEFINE_ALIAS +from .tensor.manipulation import flip as reverse #DEFINE_ALIAS from .tensor.manipulation import scatter #DEFINE_ALIAS from .tensor.manipulation import scatter_nd_add #DEFINE_ALIAS from .tensor.manipulation import scatter_nd #DEFINE_ALIAS diff --git a/python/paddle/fluid/tests/unittests/test_flip.py b/python/paddle/fluid/tests/unittests/test_flip.py index 77e416e5e6a73..6feee9ce57306 100644 --- a/python/paddle/fluid/tests/unittests/test_flip.py +++ b/python/paddle/fluid/tests/unittests/test_flip.py @@ -30,9 +30,9 @@ def test_static_graph(self): startup_program = fluid.Program() train_program = fluid.Program() with fluid.program_guard(train_program, startup_program): - dims = [0] + axis = [0] input = fluid.data(name='input', dtype='float32', shape=[2, 3]) - output = paddle.flip(input, dims) + output = paddle.flip(input, axis) place = fluid.CPUPlace() if fluid.core.is_compiled_with_cuda(): place = fluid.CUDAPlace(0) @@ -68,7 +68,7 @@ def setUp(self): self.outputs = {'Out': self.calc_ref_res()} def init_attrs(self): - self.attrs = {"dims": self.dims} + self.attrs = {"axis": self.axis} def test_check_output(self): self.check_output() @@ -78,11 +78,11 @@ def test_check_grad(self): def init_test_case(self): self.in_shape = (6, 4, 2, 3) - self.dims = [0, 1] + self.axis = [0, 1] def calc_ref_res(self): res = self.inputs['X'] - for axis in self.dims: + for axis in self.axis: res = np.flip(res, axis) return res @@ -90,25 +90,37 @@ def calc_ref_res(self): class TestFlipOpAxis1(TestFlipOp): def init_test_case(self): self.in_shape = (2, 4, 4) - self.dims = [0] + self.axis = [0] class TestFlipOpAxis2(TestFlipOp): def init_test_case(self): self.in_shape = (4, 4, 6, 3) - self.dims = [0, 2] + self.axis = [0, 2] class TestFlipOpAxis3(TestFlipOp): def init_test_case(self): self.in_shape = (4, 3, 1) - self.dims = [0, 1, 2] + self.axis = [0, 1, 2] class TestFlipOpAxis4(TestFlipOp): def init_test_case(self): self.in_shape = (6, 4, 2, 2) - self.dims = [0, 1, 2, 3] + self.axis = [0, 1, 2, 3] + + +class TestFlipOpEmptyAxis(TestFlipOp): + def init_test_case(self): + self.in_shape = (6, 4, 2, 2) + self.axis = [] + + +class TestFlipOpNegAxis(TestFlipOp): + def init_test_case(self): + self.in_shape = (6, 4, 2, 2) + self.axis = [-1] if __name__ == "__main__": diff --git a/python/paddle/tensor/__init__.py b/python/paddle/tensor/__init__.py index a96d112c8ea3b..62afe63471693 100644 --- a/python/paddle/tensor/__init__.py +++ b/python/paddle/tensor/__init__.py @@ -81,7 +81,7 @@ from .manipulation import gather #DEFINE_ALIAS from .manipulation import gather_nd #DEFINE_ALIAS from .manipulation import reshape #DEFINE_ALIAS -from .manipulation import reverse #DEFINE_ALIAS +from .manipulation import flip as reverse #DEFINE_ALIAS from .manipulation import scatter #DEFINE_ALIAS from .manipulation import scatter_nd_add #DEFINE_ALIAS from .manipulation import scatter_nd #DEFINE_ALIAS diff --git a/python/paddle/tensor/manipulation.py b/python/paddle/tensor/manipulation.py index f1fc40e3c2b9b..a98a07d3dbdcd 100644 --- a/python/paddle/tensor/manipulation.py +++ b/python/paddle/tensor/manipulation.py @@ -28,7 +28,6 @@ from ..fluid.layers import expand_as #DEFINE_ALIAS from ..fluid.layers import flatten #DEFINE_ALIAS from ..fluid.layers import reshape #DEFINE_ALIAS -from ..fluid.layers import reverse #DEFINE_ALIAS from ..fluid.layers import scatter #DEFINE_ALIAS from ..fluid.layers import slice #DEFINE_ALIAS from ..fluid.layers import strided_slice #DEFINE_ALIAS @@ -51,46 +50,47 @@ ] -def flip(input, dims, name=None): +def flip(x, axis, name=None): """ :alias_main: paddle.flip :alias: paddle.flip,paddle.tensor.flip,paddle.tensor.manipulation.flip - Reverse the order of a n-D tensor along given axis in dims. + Reverse the order of a n-D tensor along given axis in axis. Args: - input (Variable): A Tensor(or LoDTensor) with shape :math:`[N_1, N_2,..., N_k]` . The data type of the input Tensor + x (Variable): A Tensor(or LoDTensor) with shape :math:`[N_1, N_2,..., N_k]` . The data type of the input Tensor x should be float32, float64, int32, int64, bool. - dims (list): The axis to flip on. + axis (list): The axis(axes) to flip on. Negative indices for indexing from the end are accepted. name (str, optional): The default value is None. Normally there is no need for user to set this property. For more information, please refer to :ref:`api_guide_Name` . Returns: - Variable: Tensor or LoDTensor calculated by flip layer. The data type is same with input. + Variable: Tensor or LoDTensor calculated by flip layer. The data type is same with input x. Examples: .. code-block:: python import paddle - import paddle.fluid as fluid import numpy as np - input = fluid.data(name="x", shape=[-1, 2, 2], dtype='float32') - output = paddle.flip(input, dims=[0, 1]) - exe = fluid.Executor(fluid.CPUPlace()) - exe.run(fluid.default_startup_program()) - img = np.arange(12).reshape((3,2,2)).astype(np.float32) - res = exe.run(fluid.default_main_program(), feed={'x':img}, fetch_list=[output]) - print(res) # [[[10,11][8, 9]],[[6, 7],[4, 5]] [[2, 3],[0, 1]]] + + paddle.enable_imperative() + + image_shape=(3, 2, 2) + x = np.arange(image_shape[0] * image_shape[1] * image_shape[2]).reshape(image_shape) + x = x.astype('float32') + img = paddle.imperative.to_variable(x) + out = paddle.flip(img, [0,1]) + + print(out) # [[[10,11][8, 9]],[[6, 7],[4, 5]] [[2, 3],[0, 1]]] """ helper = LayerHelper("flip", **locals()) - check_type(input, 'X', (Variable), 'flip') - dtype = helper.input_dtype() + check_type(x, 'X', (Variable), 'flip') + dtype = helper.input_dtype('x') check_dtype(dtype, 'X', ['float16', 'float32', 'float64', 'int32', 'int64', 'bool'], 'flip') - check_type(dims, 'dims', (list, tuple), 'flip') - assert len(dims) > 0, 'len(dims) must be greater than 0.' + check_type(axis, 'axis', (list, tuple), 'flip') if name is None: out = helper.create_variable_for_type_inference(dtype) else: @@ -98,12 +98,15 @@ def flip(input, dims, name=None): helper.append_op( type="flip", - inputs={"X": input}, + inputs={"X": x}, outputs={"Out": out}, - attrs={"dims": dims}) + attrs={"axis": axis}) return out +reverse = flip #DEFINE_ALIAS + + def roll(x, shifts, axis=None, name=None): """ :alias_main: paddle.roll