Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

modify roll test=develop #25321

Merged
merged 4 commits into from
Jul 13, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions paddle/fluid/operators/roll_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ class RollOp : public framework::OperatorWithKernel {
platform::errors::InvalidArgument(
"Output(Out) of RollOp should not be null."));

auto dims = ctx->Attrs().Get<std::vector<int64_t>>("dims");
auto dims = ctx->Attrs().Get<std::vector<int64_t>>("axis");
auto shifts = ctx->Attrs().Get<std::vector<int64_t>>("shifts");

PADDLE_ENFORCE_EQ(dims.size(), shifts.size(),
Expand Down Expand Up @@ -92,7 +92,7 @@ class RollOpMaker : public framework::OpProtoAndCheckerMaker {
"of the tensor are shifted.")
.SetDefault({});
AddAttr<std::vector<int64_t>>(
"dims",
"axis",
"Axis along which to roll. It must have the same size "
"with shifts.")
.SetDefault({});
Expand Down
8 changes: 4 additions & 4 deletions paddle/fluid/operators/roll_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ class RollKernel : public framework::OpKernel<T> {
auto& input = input_var->Get<LoDTensor>();
auto* output = output_var->GetMutable<LoDTensor>();
std::vector<int64_t> shifts = context.Attr<std::vector<int64_t>>("shifts");
std::vector<int64_t> dims = context.Attr<std::vector<int64_t>>("dims");
std::vector<int64_t> dims = context.Attr<std::vector<int64_t>>("axis");

std::vector<T> out_vec;
TensorToVector(input, context.device_context(), &out_vec);
Expand All @@ -94,8 +94,8 @@ class RollKernel : public framework::OpKernel<T> {
PADDLE_ENFORCE_EQ(
dims[i] < input_dim.size() && dims[i] >= (0 - input_dim.size()), true,
platform::errors::OutOfRange(
"Attr(dims[%d]) is out of range, It's expected "
"to be in range of [-%d, %d]. But received Attr(dims[%d]) = %d.",
"Attr(axis[%d]) is out of range, It's expected "
"to be in range of [-%d, %d]. But received Attr(axis[%d]) = %d.",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Python端也增加对axis范围的检查,防止打印出来call stack。

i, input_dim.size(), input_dim.size() - 1, i, dims[i]));
shift_along_dim(out_vec.data(), input_dim, dims[i], shifts[i]);
}
Expand All @@ -114,7 +114,7 @@ class RollGradKernel : public framework::OpKernel<T> {
auto& input = input_var->Get<LoDTensor>();
auto* output = output_var->GetMutable<LoDTensor>();
std::vector<int64_t> shifts = context.Attr<std::vector<int64_t>>("shifts");
std::vector<int64_t> dims = context.Attr<std::vector<int64_t>>("dims");
std::vector<int64_t> dims = context.Attr<std::vector<int64_t>>("axis");

std::vector<T> out_vec;
TensorToVector(input, context.device_context(), &out_vec);
Expand Down
26 changes: 20 additions & 6 deletions python/paddle/fluid/tests/unittests/test_roll_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,17 +28,17 @@ def setUp(self):
self.op_type = "roll"
self.init_dtype_type()
self.inputs = {'X': np.random.random(self.x_shape).astype(self.dtype)}
self.attrs = {'shifts': self.shifts, 'dims': self.dims}
self.attrs = {'shifts': self.shifts, 'axis': self.axis}
self.outputs = {
'Out': np.roll(self.inputs['X'], self.attrs['shifts'],
self.attrs['dims'])
self.attrs['axis'])
}

def init_dtype_type(self):
self.dtype = np.float64
self.x_shape = (100, 4, 5)
self.shifts = [101, -1]
self.dims = [0, -2]
self.axis = [0, -2]

def test_check_output(self):
self.check_output()
Expand All @@ -52,7 +52,7 @@ def init_dtype_type(self):
self.dtype = np.float32
self.x_shape = (100, 10, 5)
self.shifts = [8, -1]
self.dims = [-1, -2]
self.axis = [-1, -2]


class TestRollAPI(unittest.TestCase):
Expand All @@ -78,7 +78,7 @@ def test_roll_op_api(self):
# case 2:
with program_guard(Program(), Program()):
x = fluid.layers.data(name='x', shape=[-1, 3])
z = paddle.roll(x, shifts=1, dims=0)
z = paddle.roll(x, shifts=1, axis=0)
exe = fluid.Executor(fluid.CPUPlace())
res, = exe.run(feed={'x': self.data_x},
fetch_list=[z.name],
Expand All @@ -101,12 +101,26 @@ def test_dygraph_api(self):
# case 2:
with fluid.dygraph.guard():
x = fluid.dygraph.to_variable(self.data_x)
z = paddle.roll(x, shifts=1, dims=0)
z = paddle.roll(x, shifts=1, axis=0)
np_z = z.numpy()
expect_out = np.array([[7.0, 8.0, 9.0], [1.0, 2.0, 3.0],
[4.0, 5.0, 6.0]])
self.assertTrue(np.allclose(expect_out, np_z))

def test_roll_op_false(self):
self.input_data()

def test_axis_out_range():
with program_guard(Program(), Program()):
x = fluid.layers.data(name='x', shape=[-1, 3])
z = paddle.roll(x, shifts=1, axis=10)
exe = fluid.Executor(fluid.CPUPlace())
res, = exe.run(feed={'x': self.data_x},
fetch_list=[z.name],
return_numpy=False)

self.assertRaises(ValueError, test_axis_out_range)


if __name__ == "__main__":
unittest.main()
79 changes: 44 additions & 35 deletions python/paddle/tensor/manipulation.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,23 +104,24 @@ def flip(input, dims, name=None):
return out


def roll(input, shifts, dims=None):
def roll(x, shifts, axis=None, name=None):
"""
Copy link
Contributor

@jzhang533 jzhang533 Jul 8, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it should be : paddle.tensor.roll(x, shifts, axis=None, name=None)
according to the latest argument convention.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok

:alias_main: paddle.roll
:alias: paddle.roll,paddle.tensor.roll,paddle.tensor.manipulation.roll

Roll the `input` tensor along the given dimension(s). Elements that are shifted beyond
the last position are re-introduced at the first position. If a dimension is not specified,
Roll the `x` tensor along the given axis(axes). With specific 'shifts', Elements that
roll beyond the last position are re-introduced at the first according to 'shifts'.
If a axis is not specified,
the tensor will be flattened before rolling and then restored to the original shape.

Args:
input (Variable): The input tensor variable.
x (Variable): The x tensor variable as input.
shifts (int|list|tuple): The number of places by which the elements
of the `input` tensor are shifted.
dims (int|list|tuple|None): Dimentions along which to roll.
of the `x` tensor are shifted.
axis (int|list|tuple|None): axis(axes) along which to roll.

Returns:
Variable: A Tensor with same data type as `input`.
Variable: A Tensor with same data type as `x`.

Examples:
.. code-block:: python
Expand All @@ -131,48 +132,56 @@ def roll(input, shifts, dims=None):
data = np.array([[1.0, 2.0, 3.0],
[4.0, 5.0, 6.0],
[7.0, 8.0, 9.0]])
with fluid.dygraph.guard():
x = fluid.dygraph.to_variable(data)
out_z1 = paddle.roll(x, shifts=1)
print(out_z1.numpy())
#[[9. 1. 2.]
# [3. 4. 5.]
# [6. 7. 8.]]
out_z2 = paddle.roll(x, shifts=1, dims=0)
print(out_z2.numpy())
#[[7. 8. 9.]
# [1. 2. 3.]
# [4. 5. 6.]]
paddle.enable_imperative()
x = paddle.imperative.to_variable(data)
out_z1 = paddle.roll(x, shifts=1)
print(out_z1.numpy())
#[[9. 1. 2.]
# [3. 4. 5.]
# [6. 7. 8.]]
out_z2 = paddle.roll(x, shifts=1, axis=0)
print(out_z2.numpy())
#[[7. 8. 9.]
# [1. 2. 3.]
# [4. 5. 6.]]
"""
helper = LayerHelper("roll", **locals())
origin_shape = input.shape
origin_shape = x.shape
if type(shifts) == int:
shifts = [shifts]
if type(dims) == int:
dims = [dims]

if dims:
check_type(dims, 'dims', (list, tuple), 'roll')
if type(axis) == int:
axis = [axis]

len_origin_shape = len(origin_shape)
if axis:
for i in range(len(axis)):
if axis[i] >= len_origin_shape or axis[i] < -len_origin_shape:
raise ValueError(
"axis is out of range, it should be in range [{}, {}), but received {}".
format(-len_origin_shape, len_origin_shape, axis))

if axis:
check_type(axis, 'axis', (list, tuple), 'roll')
check_type(shifts, 'shifts', (list, tuple), 'roll')

if in_dygraph_mode():
if dims is None:
input = core.ops.reshape(input, 'shape', [-1, 1])
dims = [0]
out = core.ops.roll(input, 'dims', dims, 'shifts', shifts)
if axis is None:
x = core.ops.reshape(x, 'shape', [-1, 1])
axis = [0]
out = core.ops.roll(x, 'axis', axis, 'shifts', shifts)
return core.ops.reshape(out, 'shape', origin_shape)

out = helper.create_variable_for_type_inference(input.dtype)
out = helper.create_variable_for_type_inference(x.dtype)

if dims is None:
input = reshape(input, shape=[-1, 1])
dims = [0]
if axis is None:
x = reshape(x, shape=[-1, 1])
axis = [0]

helper.append_op(
type='roll',
inputs={'X': input},
inputs={'X': x},
outputs={'Out': out},
attrs={'dims': dims,
attrs={'axis': axis,
'shifts': shifts})
out = reshape(out, shape=origin_shape, inplace=True)
return out
Expand Down