Skip to content

Commit

Permalink
polish code
Browse files Browse the repository at this point in the history
  • Loading branch information
hbwx24 committed Jul 29, 2021
1 parent 39aa575 commit 3b6ad37
Show file tree
Hide file tree
Showing 4 changed files with 148 additions and 84 deletions.
13 changes: 13 additions & 0 deletions paddle/fluid/operators/strided_slice_op.cc
Expand Up @@ -163,6 +163,19 @@ class StridedSliceOp : public framework::OperatorWithKernel {
auto *in_var = ctx.InputVar("Input");
auto is_in_var_array = in_var->IsType<framework::LoDTensorArray>();
if (is_in_var_array) {
auto &tensor_array = in_var->Get<framework::LoDTensorArray>();
for (auto &tensor : tensor_array) {
if (!platform::is_cuda_pinned_place(tensor.place())) {
PADDLE_ENFORCE_EQ(
platform::is_same_place(tensor.place(),
ctx.device_context().GetPlace()),
true, platform::errors::InvalidArgument(
"Place of context is %s. Place of context is %s. They "
"are should be same, but reveived different place.",
string::to_string(ctx.device_context().GetPlace()),
string::to_string(tensor.place())));
}
}
return framework::OpKernelType(
OperatorWithKernel::IndicateVarDataType(ctx, "Input"),
ctx.device_context());
Expand Down
122 changes: 61 additions & 61 deletions paddle/fluid/operators/strided_slice_op.h
Expand Up @@ -127,6 +127,9 @@ static void StridedSliceFunctor(int64_t* starts, int64_t* ends,
if (!(ends[axis_index] == -1 &&
strides[axis_index] < 0)) { // skip None stop condition
ends[axis_index] = ends[axis_index] + axis_size;
if (ends[axis_index] < 0) {
ends[axis_index] = 0;
}
}
}
if (decrease_axis_affect) {
Expand All @@ -147,9 +150,8 @@ static void StridedSliceFunctor(int64_t* starts, int64_t* ends,
strides[axis_index] = -strides[axis_index];
if (starts[axis_index] > ends[axis_index]) {
// swap the reverse
auto end_dim = dims[axis_index] - 1 < starts[axis_index]
? dims[axis_index] - 1
: starts[axis_index];
auto end_dim = axis_size - 1 < starts[axis_index] ? axis_size - 1
: starts[axis_index];
auto offset = (end_dim - ends[axis_index]) % strides[axis_index];
offset = offset == 0 ? strides[axis_index] : offset;

Expand Down Expand Up @@ -378,33 +380,32 @@ class StridedSliceKernel : public framework::OpKernel<T> {
TensorCopy(in_tensor, context.GetPlace(), out_tensor);
}

return;
}
auto in = context.Input<framework::Tensor>("Input");
auto out = context.Output<framework::Tensor>("Out");
out->Resize(out_dims);
out->mutable_data<T>(context.GetPlace());
auto in_t =
framework::EigenTensor<T, D, Eigen::RowMajor, Eigen::DenseIndex>::From(
*in);
auto out_t =
framework::EigenTensor<T, D, Eigen::RowMajor, Eigen::DenseIndex>::From(
*out, out_dims);
if (need_reverse) {
framework::Tensor tmp;
tmp.mutable_data<T>(out_dims, context.GetPlace());
auto tmp_t = framework::EigenTensor<T, D, Eigen::RowMajor,
Eigen::DenseIndex>::From(tmp);
tmp_t.device(place) =
in_t.stridedSlice(starts_indices, ends_indices, strides_indices);
out_t.device(place) = tmp_t.reverse(reverse_axis);
} else {
out_t.device(place) =
in_t.stridedSlice(starts_indices, ends_indices, strides_indices);
}
auto in = context.Input<framework::Tensor>("Input");
auto out = context.Output<framework::Tensor>("Out");
out->Resize(out_dims);
out->mutable_data<T>(context.GetPlace());
auto in_t = framework::EigenTensor<T, D, Eigen::RowMajor,
Eigen::DenseIndex>::From(*in);
auto out_t =
framework::EigenTensor<T, D, Eigen::RowMajor,
Eigen::DenseIndex>::From(*out, out_dims);
if (need_reverse) {
framework::Tensor tmp;
tmp.mutable_data<T>(out_dims, context.GetPlace());
auto tmp_t = framework::EigenTensor<T, D, Eigen::RowMajor,
Eigen::DenseIndex>::From(tmp);
tmp_t.device(place) =
in_t.stridedSlice(starts_indices, ends_indices, strides_indices);
out_t.device(place) = tmp_t.reverse(reverse_axis);
} else {
out_t.device(place) =
in_t.stridedSlice(starts_indices, ends_indices, strides_indices);
}

if (decrease_axis.size() > 0) {
out->Resize(out_dims_origin);
if (decrease_axis.size() > 0) {
out->Resize(out_dims_origin);
}
}
}
};
Expand Down Expand Up @@ -453,11 +454,11 @@ class StridedSliceGradKernel : public framework::OpKernel<T> {
auto* out_var = context.OutputVar(framework::GradVarName("Input"));
bool is_out_var_array = out_var->IsType<LoDTensorArray>();
if (is_out_var_array) {
// Since the shape of `framework::GradVarName("Input")` of
// StridedSliceGrad
// cannot be calculated by `framework::GradVarName("Output")`,
// the dim of "Input" is used to calculate the output shape.
// when set it to inplace OP, there may be some problems.
// Note(weixin):Since the shape of `framework::GradVarName("Input")` of
// StridedSliceGrad cannot be calculated by
// `framework::GradVarName("Output")`, the dim of "Input" is used to
// calculate the output shape. when set it to inplace OP, there may be
// some problems.
const int64_t size =
context.Input<framework::LoDTensorArray>("Input")->size();

Expand Down Expand Up @@ -621,40 +622,39 @@ class StridedSliceGradKernel : public framework::OpKernel<T> {
set_zero(dev_ctx, d_out_tensor, static_cast<T>(0));
}
}
return;
}

auto* d_input =
context.Input<framework::Tensor>(framework::GradVarName("Out"));
auto* d_out =
context.Output<framework::Tensor>(framework::GradVarName("Input"));
} else {
auto* d_input =
context.Input<framework::Tensor>(framework::GradVarName("Out"));
auto* d_out =
context.Output<framework::Tensor>(framework::GradVarName("Input"));

d_out->mutable_data<T>(context.GetPlace());
d_out->mutable_data<T>(context.GetPlace());

math::SetConstant<DeviceContext, T> set_zero;
set_zero(dev_ctx, d_out, static_cast<T>(0));
math::SetConstant<DeviceContext, T> set_zero;
set_zero(dev_ctx, d_out, static_cast<T>(0));

auto in_dims = d_input->dims();
auto in_dims = d_input->dims();

auto in_t =
framework::EigenTensor<T, D, Eigen::RowMajor, Eigen::DenseIndex>::From(
*d_input);
auto out_t =
framework::EigenTensor<T, D, Eigen::RowMajor, Eigen::DenseIndex>::From(
*d_out, out_dims);
if (need_reverse) {
framework::Tensor reverse_input;
reverse_input.mutable_data<T>(in_dims, context.GetPlace());
auto reverse_in_t =
auto in_t = framework::EigenTensor<T, D, Eigen::RowMajor,
Eigen::DenseIndex>::From(*d_input);
auto out_t =
framework::EigenTensor<T, D, Eigen::RowMajor,
Eigen::DenseIndex>::From(reverse_input);

reverse_in_t.device(place) = in_t.reverse(reverse_axis);
out_t.stridedSlice(starts_indices, ends_indices, strides_indices)
.device(place) = reverse_in_t;
} else {
out_t.stridedSlice(starts_indices, ends_indices, strides_indices)
.device(place) = in_t;
Eigen::DenseIndex>::From(*d_out, out_dims);
if (need_reverse) {
framework::Tensor reverse_input;
reverse_input.mutable_data<T>(in_dims, context.GetPlace());
auto reverse_in_t =
framework::EigenTensor<T, D, Eigen::RowMajor,
Eigen::DenseIndex>::From(reverse_input);

reverse_in_t.device(place) = in_t.reverse(reverse_axis);
out_t.stridedSlice(starts_indices, ends_indices, strides_indices)
.device(place) = reverse_in_t;
} else {
out_t.stridedSlice(starts_indices, ends_indices, strides_indices)
.device(place) = in_t;
}
}
}
};
Expand Down
Expand Up @@ -177,7 +177,8 @@ def test_set_value_with_save(self):
output_spec=None)


class TestSliceSupplementCase(unittest.TestCase):
class TestSliceSupplementSpecialCase(unittest.TestCase):
# unittest for slice index which abs(step)>0. eg: x[::2]
def test_static_slice_step(self):
paddle.enable_static()
array = np.arange(4**3).reshape((4, 4, 4)).astype('int64')
Expand Down Expand Up @@ -242,6 +243,20 @@ def test_compare_paddle_strided_slice_with_numpy(self):
np.array_equal(sl.numpy(), array[s2[0]:e2[0]:stride2[0], s2[1]:e2[
1]:stride2[1]]))

array = np.arange(6 * 7 * 8).reshape((6, 7, 8))
pt = paddle.to_tensor(array)
s2 = [7, -1]
e2 = [2, -5]
stride2 = [-2, -3]
sl = paddle.strided_slice(
pt, axes=[0, 2], starts=s2, ends=e2, strides=stride2)

array_slice = array[s2[0]:e2[0]:stride2[0], ::, s2[1]:e2[1]:stride2[1]]
self.assertTrue(
np.array_equal(sl.numpy(), array_slice),
msg="paddle.strided_slice:\n {} \n numpy slice:\n{}".format(
sl.numpy(), array_slice))


if __name__ == '__main__':
unittest.main()
80 changes: 58 additions & 22 deletions python/paddle/fluid/tests/unittests/test_strided_slice_op.py
Expand Up @@ -701,6 +701,49 @@ def create_case(self, net):
msg="dygraph graph result:\n{} \nstatic dygraph result:\n{}".format(
l1.numpy(), l2.numpy()))

def test_strided_slice_tensor_array_cuda_pinned_place(self):
if paddle.device.is_compiled_with_cuda():
with paddle.fluid.dygraph.guard():

class Simple(paddle.nn.Layer):
def __init__(self):
super(Simple, self).__init__()

def forward(self, inps):
tensor_array = None
for i, tensor in enumerate(inps):
index = paddle.full(
shape=[1], dtype='int64', fill_value=i)
if tensor_array is None:
tensor_array = paddle.tensor.array_write(
tensor, i=index)
else:
paddle.tensor.array_write(
tensor, i=index, array=tensor_array)

array1 = paddle.concat(tensor_array)
array2 = paddle.concat(tensor_array[::-1])
return array1 + array2 * array2

net = Simple()
func = paddle.jit.to_static(net.forward)

inps1 = paddle.to_tensor(
np.random.randn(2, 10),
place=paddle.CUDAPinnedPlace(),
stop_gradient=False)
inps2 = paddle.to_tensor(
np.random.randn(2, 10),
place=paddle.CUDAPinnedPlace(),
stop_gradient=False)

self.assertTrue(inps1.place.is_cuda_pinned_place())
self.assertTrue(inps2.place.is_cuda_pinned_place())

result = func([inps1, inps2])

self.assertFalse(result.place.is_cuda_pinned_place())

def test_strided_slice_tensor_array(self):
class Net(ArrayLayer):
def array_slice(self, tensors):
Expand Down Expand Up @@ -854,28 +897,21 @@ def array_slice(self, tensors):

self.create_case(Net(input_size=112, array_size=13))

# TODO(weixin):Currently, the case that the start index is
# less than `-array_size` is not supported.
# The index parsed from the slice of the VarBase/Variable
# is processed before being passed to `strided_slice_op`.
# The slice may be processed uniformly, instead of
# processing separately for TensorArray\VarBase\Variable.
#
# class Net(ArrayLayer):
#
# def array_slice(self,tensors):
# return tensors[-60:20:3]
# self.create_case(Net(input_size=112,array_size=13))

# class Net(ArrayLayer):
# def array_slice(self, tensors):
# return tensors[-3:-60:-3]

# self.create_case(Net(input_size=112, array_size=13))

# class Net(ArrayLayer):
# def array_slice(self, tensors):
# return tensors[-1:-60:-3]
class Net(ArrayLayer):
def array_slice(self, tensors):
return tensors[-60:20:3]

self.create_case(Net(input_size=112, array_size=13))

class Net(ArrayLayer):
def array_slice(self, tensors):
return tensors[-3:-60:-3]

self.create_case(Net(input_size=112, array_size=13))

class Net(ArrayLayer):
def array_slice(self, tensors):
return tensors[-1:-60:-3]


if __name__ == "__main__":
Expand Down

0 comments on commit 3b6ad37

Please sign in to comment.