Skip to content

Commit

Permalink
【prim】Slice grad (#50771)
Browse files Browse the repository at this point in the history
* support prim test in OpTest

* fix cmake

* fix op test

* fix test_input_spec

* disable cinn in reduce_sum unit test

* add bfloat16 dtype for sum

* add approve rules

* polish code

* add clear jit program function

* convert grad out from tensor to numpy

* remove unnecessary code

* add only_prim flag

* fix flag

* fix op test

* add attr

* fix optest comp inplace error

* fix op test

* fix op test with guard

* add initialization of check_comp flag

* fix comp inplace error in op test

* rename check_comp with check_prim and add bfloat16 dtype convert

* rename comp_op_type to prim_op_type

* rename comp to prim

* remove useless code

* skip ci check for only prim

* add no_grad_vars and grad_outputs in prim test

* fix var_dict

* fix op test for only_prim

* fix dy2static bugs

* polish some code

* temp

* modify op test

* except cinn test

* modify bfp16

* modify pad grad

* add pad_grad dtype

* start cinn part

---------

Co-authored-by: Charles-hit <wanghao107@baidu.com>
  • Loading branch information
xiaoguoguo626807 and Charles-hit committed Feb 24, 2023
1 parent 0d956e1 commit f6dea80
Show file tree
Hide file tree
Showing 9 changed files with 179 additions and 23 deletions.
31 changes: 31 additions & 0 deletions paddle/fluid/operators/slice_op.cc
Expand Up @@ -18,6 +18,9 @@ limitations under the License. */
#include <vector>

#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/prim/api/composite_backward/composite_backward_api.h"
#include "paddle/fluid/prim/utils/static/composite_grad_desc_maker.h"
#include "paddle/fluid/prim/utils/static/desc_tensor.h"
#include "paddle/phi/kernels/funcs/slice_utils.h"

namespace paddle {
Expand Down Expand Up @@ -409,6 +412,34 @@ class SliceOpGradMaker : public framework::SingleGradOpMaker<T> {
}
};

class SliceCompositeGradOpMaker : public prim::CompositeGradOpMakerBase {
using prim::CompositeGradOpMakerBase::CompositeGradOpMakerBase;

public:
void Apply() override {
paddle::experimental::Tensor input = this->GetSingleForwardInput("Input");
paddle::experimental::Tensor out_grad = this->GetSingleOutputGrad("Out");
paddle::experimental::Tensor input_grad = this->GetSingleInputGrad("Input");

auto dx_ptr = this->GetOutputPtr(&input_grad);
std::string dx_name = this->GetOutputName(input_grad);
auto axes = this->Attr<std::vector<int64_t>>("axes");
auto starts = this->Attr<std::vector<int64_t>>("starts");
auto ends = this->Attr<std::vector<int64_t>>("ends");
auto infer_flags = this->Attr<std::vector<int64_t>>("infer_flags");
auto decrease_axis = this->Attr<std::vector<int64_t>>("decrease_axis");
VLOG(6) << "Runing slice_grad composite func";
prim::slice_grad<prim::DescTensor>(input,
out_grad,
axes,
paddle::experimental::IntArray(starts),
paddle::experimental::IntArray(ends),
infer_flags,
decrease_axis,
dx_ptr);
this->RecoverOutputName(input_grad, dx_name);
}
};
template <typename T>
class SliceDoubleOpGradMaker : public framework::SingleGradOpMaker<T> {
public:
Expand Down
2 changes: 2 additions & 0 deletions paddle/fluid/prim/api/api.yaml
Expand Up @@ -25,3 +25,5 @@
- scatter_nd_add
- tile
- transpose
- subtract
- pad
61 changes: 61 additions & 0 deletions paddle/fluid/prim/api/composite_backward/composite_backward_api.h
Expand Up @@ -323,5 +323,66 @@ void exp_grad(const Tensor& out, const Tensor& out_grad, Tensor* x_grad) {
}
}

template <typename T>
void slice_grad(const Tensor& input,
const Tensor& out_grad,
const std::vector<int64_t>& axes,
const IntArray& starts,
const IntArray& ends,
const std::vector<int64_t>& infer_flags,
const std::vector<int64_t>& decrease_axis,
Tensor* input_grad) {
if (input_grad) {
size_t rank = input.dims().size();
auto out_dims = out_grad.dims();
auto in_dims = input.dims();

auto decrease_size = decrease_axis.size();
if (decrease_size > 0) {
if (decrease_size == static_cast<size_t>(in_dims.size())) {
// all dims decrease
out_dims = phi::make_ddim(std::vector<int>(decrease_size, 1));
} else {
std::vector<int> origin_out_shape(out_dims.size() + decrease_size, -1);
for (size_t i = 0; i < decrease_size; ++i) {
origin_out_shape[decrease_axis[i]] = 1;
}

int index = 0;
for (size_t i = 0; i < origin_out_shape.size(); ++i) {
if (origin_out_shape[i] == -1) {
origin_out_shape[i] = out_dims[index];
++index;
}
}
out_dims = phi::make_ddim(origin_out_shape);
}
}

std::vector<int> offsets(rank, 0);
std::vector<int> extents(rank, 0);
for (size_t i = 0; i < rank; ++i) {
offsets[i] = 0;
extents[i] = out_dims[i];
}

for (size_t i = 0; i < axes.size(); ++i) {
int axis = axes[i];
int64_t start = starts[i] < 0 ? (starts[i] + in_dims[axis]) : starts[i];
start = std::max(start, static_cast<int64_t>(0));
offsets[axis] = start;
}

std::vector<int> paddings;
for (size_t i = 0; i < rank; ++i) {
paddings.push_back(offsets[i]);
paddings.push_back((in_dims[i] - out_dims[i]) - offsets[i]);
}

auto out_tmp = pad<T>(out_grad, paddings, 0.0);
set_output<T>(out_tmp, input_grad);
}
}

} // namespace prim
} // namespace paddle
1 change: 1 addition & 0 deletions paddle/phi/api/yaml/legacy_backward.yaml
Expand Up @@ -1172,6 +1172,7 @@
param : [input]
kernel :
func : slice_grad
composite: slice_grad(input, out_grad, axes, starts, ends, infer_flags, decrease_axis)
backward : slice_double_grad
no_need_buffer : input

Expand Down
5 changes: 4 additions & 1 deletion paddle/phi/kernels/cpu/pad_grad_kernel.cc
Expand Up @@ -24,5 +24,8 @@ PD_REGISTER_KERNEL(pad_grad,
phi::PadGradKernel,
float,
double,
int,
int64_t,
phi::dtype::complex<float>,
phi::dtype::complex<double>) {}
phi::dtype::complex<double>,
phi::dtype::bfloat16) {}
3 changes: 2 additions & 1 deletion paddle/phi/kernels/cpu/pad_kernel.cc
Expand Up @@ -27,4 +27,5 @@ PD_REGISTER_KERNEL(pad,
int,
int64_t,
phi::dtype::complex<float>,
phi::dtype::complex<double>) {}
phi::dtype::complex<double>,
phi::dtype::bfloat16) {}
3 changes: 2 additions & 1 deletion python/paddle/fluid/tests/unittests/CMakeLists.txt
Expand Up @@ -1202,7 +1202,8 @@ if($ENV{USE_STANDALONE_EXECUTOR})
PROPERTIES ENVIRONMENT FLAGS_USE_STANDALONE_EXECUTOR=0)
endif()

set(TEST_CINN_OPS test_softmax_op test_expand_v2_op test_reduce_op)
set(TEST_CINN_OPS test_softmax_op test_expand_v2_op test_reduce_op
test_slice_op)

foreach(TEST_CINN_OPS ${TEST_CINN_OPS})
if(WITH_CINN)
Expand Down

0 comments on commit f6dea80

Please sign in to comment.