Skip to content

Commit

Permalink
modify flip test=develop (#25312)
Browse files Browse the repository at this point in the history
According to paddle 2.0 standard
1, change flip api attr name 'dim' to 'axis'.
2, support empty axis
3, change example code to imperative mode.
  • Loading branch information
yaoxuefeng6 committed Jul 14, 2020
1 parent f8eccb0 commit 5d3766f
Show file tree
Hide file tree
Showing 7 changed files with 94 additions and 73 deletions.
86 changes: 46 additions & 40 deletions paddle/fluid/operators/flip_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -36,46 +36,52 @@ class FlipOp : public framework::OperatorWithKernel {
platform::errors::NotFound(
"Output(Out) of FlipOp should not be null."));
auto x_dims = ctx->GetInputDim("X");
auto flip_dims = ctx->Attrs().Get<std::vector<int>>("dims");
auto flip_dims = ctx->Attrs().Get<std::vector<int>>("axis");
size_t flip_dims_size = flip_dims.size();

// check if dims axis within range
auto min_max_d = std::minmax_element(flip_dims.begin(), flip_dims.end());
PADDLE_ENFORCE_LT(*min_max_d.first, x_dims.size(),
platform::errors::InvalidArgument(
"min(dims) should be less than the input tensor X's "
"dimensions of FlipOp. But received min(dims) = %d, "
"X's dimensions = %d, X's shape = [%s]",
*min_max_d.first, x_dims.size(), x_dims));
PADDLE_ENFORCE_GE(
*min_max_d.first, x_dims.size() * -1,
platform::errors::InvalidArgument(
"min(dims) should be greater than or equal to the input tensor X's "
"dimensions of FlipOp times -1. But received min(dims) = %d, X's "
"dimensions = %d, X's shape = [%s]",
*min_max_d.first, x_dims.size() * -1, x_dims));
PADDLE_ENFORCE_LT(*min_max_d.second, x_dims.size(),
platform::errors::InvalidArgument(
"max(dims) should be less than the input tensor X's "
"dimensions of FlipOp. But received max(dims) = %d, "
"X's dimensions = %d, X's shape = [%s]",
*min_max_d.second, x_dims.size(), x_dims));
PADDLE_ENFORCE_GE(
*min_max_d.second, x_dims.size() * -1,
platform::errors::InvalidArgument(
"max(dims) should be greater than or equal to the input tensor X's "
"dimensions of FlipOp times -1. But received max(dims) = %d, X's "
"dimensions = %d, X's shape = [%s]",
*min_max_d.second, x_dims.size() * -1, x_dims));

// check duplicates in dims
flip_dims.erase(std::unique(flip_dims.begin(), flip_dims.end()),
flip_dims.end());
PADDLE_ENFORCE_EQ(flip_dims.size(), flip_dims_size,
platform::errors::InvalidArgument(
"dims has duplicates, original flip dims size=%d, "
"but unique flip dims size=%d.)",
flip_dims_size, flip_dims.size()));
if (flip_dims_size > 0) {
// check if dims axis within range
auto min_max_d = std::minmax_element(flip_dims.begin(), flip_dims.end());
PADDLE_ENFORCE_LT(
*min_max_d.first, x_dims.size(),
platform::errors::InvalidArgument(
"min(axes) should be less than the input tensor X's "
"axes of FlipOp. But received min(axes) = %d, "
"X's axes = %d, X's shape = [%s]",
*min_max_d.first, x_dims.size(), x_dims));
PADDLE_ENFORCE_GE(*min_max_d.first, x_dims.size() * -1,
platform::errors::InvalidArgument(
"min(axes) should be greater than or equal to the "
"input tensor X's "
"axes of FlipOp times -1. But received "
"min(axes) = %d, X's "
"axes = %d, X's shape = [%s]",
*min_max_d.first, x_dims.size() * -1, x_dims));
PADDLE_ENFORCE_LT(
*min_max_d.second, x_dims.size(),
platform::errors::InvalidArgument(
"max(axes) should be less than the input tensor X's "
"axes of FlipOp. But received max(axes) = %d, "
"X's axes = %d, X's shape = [%s]",
*min_max_d.second, x_dims.size(), x_dims));
PADDLE_ENFORCE_GE(*min_max_d.second, x_dims.size() * -1,
platform::errors::InvalidArgument(
"max(axes) should be greater than or equal to the "
"input tensor X's "
"axes of FlipOp times -1. But received "
"max(axes) = %d, X's "
"axes = %d, X's shape = [%s]",
*min_max_d.second, x_dims.size() * -1, x_dims));

// check duplicates in dims
flip_dims.erase(std::unique(flip_dims.begin(), flip_dims.end()),
flip_dims.end());
PADDLE_ENFORCE_EQ(flip_dims.size(), flip_dims_size,
platform::errors::InvalidArgument(
"axes has duplicates, original flip axes size=%d, "
"but unique flip axes size=%d.)",
flip_dims_size, flip_dims.size()));
}

VLOG(3) << "flip operator x.shape=" << x_dims;

Expand Down Expand Up @@ -104,10 +110,10 @@ class FlipOpMaker : public framework::OpProtoAndCheckerMaker {
void Make() override {
AddInput("X", "(Tensor), The input tensor of flip op.");
AddOutput("Out", "(Tensor), The output tensor of flip op.");
AddAttr<std::vector<int>>("dims", "The axes to flip on.");
AddAttr<std::vector<int>>("axis", "The axes to flip on.");
AddComment(R"DOC(
Flip Operator.
Reverse the order of a n-D tensor along given axis in dims.
Reverse the order of a n-D tensor along given axis in axes.
)DOC");
}
};
Expand Down
2 changes: 1 addition & 1 deletion paddle/fluid/operators/flip_op.cu
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ class FlipKernel<platform::CUDADeviceContext, T>
Tensor* out = ctx.Output<Tensor>("Out");
auto* in_data = x->data<T>();
auto* out_data = out->mutable_data<T>(ctx.GetPlace());
auto flip_dims = ctx.template Attr<std::vector<int>>("dims");
auto flip_dims = ctx.template Attr<std::vector<int>>("axis");

const int flip_dims_size = static_cast<int>(flip_dims.size());
auto x_dims = x->dims();
Expand Down
2 changes: 1 addition & 1 deletion paddle/fluid/operators/flip_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ class FlipKernel<platform::CPUDeviceContext, T>
void Compute(const framework::ExecutionContext& ctx) const override {
const Tensor* x = ctx.Input<Tensor>("X");
Tensor* out = ctx.Output<Tensor>("Out");
auto flip_dims = ctx.template Attr<std::vector<int>>("dims");
auto flip_dims = ctx.template Attr<std::vector<int>>("axis");

auto x_dims = x->dims();
const int total_dims = x_dims.size();
Expand Down
2 changes: 1 addition & 1 deletion python/paddle/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@
from .tensor.manipulation import gather #DEFINE_ALIAS
from .tensor.manipulation import gather_nd #DEFINE_ALIAS
from .tensor.manipulation import reshape #DEFINE_ALIAS
from .tensor.manipulation import reverse #DEFINE_ALIAS
from .tensor.manipulation import flip as reverse #DEFINE_ALIAS
from .tensor.manipulation import scatter #DEFINE_ALIAS
from .tensor.manipulation import scatter_nd_add #DEFINE_ALIAS
from .tensor.manipulation import scatter_nd #DEFINE_ALIAS
Expand Down
30 changes: 21 additions & 9 deletions python/paddle/fluid/tests/unittests/test_flip.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,9 +30,9 @@ def test_static_graph(self):
startup_program = fluid.Program()
train_program = fluid.Program()
with fluid.program_guard(train_program, startup_program):
dims = [0]
axis = [0]
input = fluid.data(name='input', dtype='float32', shape=[2, 3])
output = paddle.flip(input, dims)
output = paddle.flip(input, axis)
place = fluid.CPUPlace()
if fluid.core.is_compiled_with_cuda():
place = fluid.CUDAPlace(0)
Expand Down Expand Up @@ -68,7 +68,7 @@ def setUp(self):
self.outputs = {'Out': self.calc_ref_res()}

def init_attrs(self):
self.attrs = {"dims": self.dims}
self.attrs = {"axis": self.axis}

def test_check_output(self):
self.check_output()
Expand All @@ -78,37 +78,49 @@ def test_check_grad(self):

def init_test_case(self):
self.in_shape = (6, 4, 2, 3)
self.dims = [0, 1]
self.axis = [0, 1]

def calc_ref_res(self):
res = self.inputs['X']
for axis in self.dims:
for axis in self.axis:
res = np.flip(res, axis)
return res


class TestFlipOpAxis1(TestFlipOp):
def init_test_case(self):
self.in_shape = (2, 4, 4)
self.dims = [0]
self.axis = [0]


class TestFlipOpAxis2(TestFlipOp):
def init_test_case(self):
self.in_shape = (4, 4, 6, 3)
self.dims = [0, 2]
self.axis = [0, 2]


class TestFlipOpAxis3(TestFlipOp):
def init_test_case(self):
self.in_shape = (4, 3, 1)
self.dims = [0, 1, 2]
self.axis = [0, 1, 2]


class TestFlipOpAxis4(TestFlipOp):
def init_test_case(self):
self.in_shape = (6, 4, 2, 2)
self.dims = [0, 1, 2, 3]
self.axis = [0, 1, 2, 3]


class TestFlipOpEmptyAxis(TestFlipOp):
def init_test_case(self):
self.in_shape = (6, 4, 2, 2)
self.axis = []


class TestFlipOpNegAxis(TestFlipOp):
def init_test_case(self):
self.in_shape = (6, 4, 2, 2)
self.axis = [-1]


if __name__ == "__main__":
Expand Down
2 changes: 1 addition & 1 deletion python/paddle/tensor/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@
from .manipulation import gather #DEFINE_ALIAS
from .manipulation import gather_nd #DEFINE_ALIAS
from .manipulation import reshape #DEFINE_ALIAS
from .manipulation import reverse #DEFINE_ALIAS
from .manipulation import flip as reverse #DEFINE_ALIAS
from .manipulation import scatter #DEFINE_ALIAS
from .manipulation import scatter_nd_add #DEFINE_ALIAS
from .manipulation import scatter_nd #DEFINE_ALIAS
Expand Down
43 changes: 23 additions & 20 deletions python/paddle/tensor/manipulation.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@
from ..fluid.layers import expand_as #DEFINE_ALIAS
from ..fluid.layers import flatten #DEFINE_ALIAS
from ..fluid.layers import reshape #DEFINE_ALIAS
from ..fluid.layers import reverse #DEFINE_ALIAS
from ..fluid.layers import scatter #DEFINE_ALIAS
from ..fluid.layers import slice #DEFINE_ALIAS
from ..fluid.layers import strided_slice #DEFINE_ALIAS
Expand All @@ -51,59 +50,63 @@
]


def flip(input, dims, name=None):
def flip(x, axis, name=None):
"""
:alias_main: paddle.flip
:alias: paddle.flip,paddle.tensor.flip,paddle.tensor.manipulation.flip
Reverse the order of a n-D tensor along given axis in dims.
Reverse the order of a n-D tensor along given axis in axis.
Args:
input (Variable): A Tensor(or LoDTensor) with shape :math:`[N_1, N_2,..., N_k]` . The data type of the input Tensor
x (Variable): A Tensor(or LoDTensor) with shape :math:`[N_1, N_2,..., N_k]` . The data type of the input Tensor x
should be float32, float64, int32, int64, bool.
dims (list): The axis to flip on.
axis (list): The axis(axes) to flip on. Negative indices for indexing from the end are accepted.
name (str, optional): The default value is None. Normally there is no need for user to set this property.
For more information, please refer to :ref:`api_guide_Name` .
Returns:
Variable: Tensor or LoDTensor calculated by flip layer. The data type is same with input.
Variable: Tensor or LoDTensor calculated by flip layer. The data type is same with input x.
Examples:
.. code-block:: python
import paddle
import paddle.fluid as fluid
import numpy as np
input = fluid.data(name="x", shape=[-1, 2, 2], dtype='float32')
output = paddle.flip(input, dims=[0, 1])
exe = fluid.Executor(fluid.CPUPlace())
exe.run(fluid.default_startup_program())
img = np.arange(12).reshape((3,2,2)).astype(np.float32)
res = exe.run(fluid.default_main_program(), feed={'x':img}, fetch_list=[output])
print(res) # [[[10,11][8, 9]],[[6, 7],[4, 5]] [[2, 3],[0, 1]]]
paddle.enable_imperative()
image_shape=(3, 2, 2)
x = np.arange(image_shape[0] * image_shape[1] * image_shape[2]).reshape(image_shape)
x = x.astype('float32')
img = paddle.imperative.to_variable(x)
out = paddle.flip(img, [0,1])
print(out) # [[[10,11][8, 9]],[[6, 7],[4, 5]] [[2, 3],[0, 1]]]
"""
helper = LayerHelper("flip", **locals())
check_type(input, 'X', (Variable), 'flip')
dtype = helper.input_dtype()
check_type(x, 'X', (Variable), 'flip')
dtype = helper.input_dtype('x')
check_dtype(dtype, 'X',
['float16', 'float32', 'float64', 'int32', 'int64', 'bool'],
'flip')
check_type(dims, 'dims', (list, tuple), 'flip')
assert len(dims) > 0, 'len(dims) must be greater than 0.'
check_type(axis, 'axis', (list, tuple), 'flip')
if name is None:
out = helper.create_variable_for_type_inference(dtype)
else:
out = helper.create_variable(name=name, dtype=dtype, persistable=False)

helper.append_op(
type="flip",
inputs={"X": input},
inputs={"X": x},
outputs={"Out": out},
attrs={"dims": dims})
attrs={"axis": axis})
return out


reverse = flip #DEFINE_ALIAS


def roll(x, shifts, axis=None, name=None):
"""
:alias_main: paddle.roll
Expand Down

0 comments on commit 5d3766f

Please sign in to comment.