Skip to content

Commit

Permalink
add backward gradient computation for op argsort (#22203)
Browse files Browse the repository at this point in the history
* add backward gradient computation for op argsort test=developo

* use pre-commit test=develop
  • Loading branch information
FlyingQianMM committed Jan 10, 2020
1 parent 46189b1 commit 443a713
Show file tree
Hide file tree
Showing 4 changed files with 287 additions and 6 deletions.
55 changes: 50 additions & 5 deletions paddle/fluid/operators/argsort_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License. */

#include "paddle/fluid/operators/argsort_op.h"
#include <memory>

namespace paddle {
namespace operators {
Expand All @@ -21,7 +22,7 @@ class ArgsortOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;

void InferShape(framework::InferShapeContext *ctx) const override {
void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("X"),
"Input(X) of ArgsortOp should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("Out"),
Expand Down Expand Up @@ -49,6 +50,24 @@ class ArgsortOp : public framework::OperatorWithKernel {
}
};

class ArgsortGradOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;

void InferShape(framework::InferShapeContext* ctx) const override {
ctx->SetOutputDim(framework::GradVarName("X"), ctx->GetInputDim("X"));
ctx->ShareLoD("X", /*-->*/ framework::GradVarName("X"));
}

protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(OperatorWithKernel::IndicateVarDataType(
ctx, framework::GradVarName("Out")),
ctx.device_context());
}
};

class ArgsortOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
Expand Down Expand Up @@ -83,16 +102,42 @@ Output(Indices) gives the sorted order along the given axis Attr(axis).
}
};

template <typename T>
class ArgsortGradOpMaker : public framework::SingleGradOpMaker<T> {
public:
using framework::SingleGradOpMaker<T>::SingleGradOpMaker;

protected:
std::unique_ptr<T> Apply() const override {
std::unique_ptr<T> op(new T());
op->SetType("argsort_grad");
op->SetInput("Indices", this->Output("Indices"));
op->SetInput("X", this->Input("X"));
op->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out"));
op->SetOutput(framework::GradVarName("X"), this->InputGrad("X"));
op->SetAttrMap(this->Attrs());
return op;
}
};

DECLARE_NO_NEED_BUFFER_VARS_INFERENCE(ArgsortGradNoNeedBufferVarInference, "X");

} // namespace operators
} // namespace paddle

namespace ops = paddle::operators;
REGISTER_OPERATOR(
argsort, ops::ArgsortOp, ops::ArgsortOpMaker,
paddle::framework::EmptyGradOpMaker<paddle::framework::OpDesc>,
paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>);
REGISTER_OPERATOR(argsort, ops::ArgsortOp, ops::ArgsortOpMaker,
ops::ArgsortGradOpMaker<paddle::framework::OpDesc>,
ops::ArgsortGradOpMaker<paddle::imperative::OpBase>);
REGISTER_OPERATOR(argsort_grad, ops::ArgsortGradOp,
ops::ArgsortGradNoNeedBufferVarInference);
REGISTER_OP_CPU_KERNEL(argsort,
ops::ArgsortKernel<paddle::platform::CPUPlace, float>,
ops::ArgsortKernel<paddle::platform::CPUPlace, double>,
ops::ArgsortKernel<paddle::platform::CPUPlace, int>,
ops::ArgsortKernel<paddle::platform::CPUPlace, int64_t>);
REGISTER_OP_CPU_KERNEL(
argsort_grad, ops::ArgsortGradientKernel<paddle::platform::CPUPlace, float>,
ops::ArgsortGradientKernel<paddle::platform::CPUPlace, double>,
ops::ArgsortGradientKernel<paddle::platform::CPUPlace, int>,
ops::ArgsortGradientKernel<paddle::platform::CPUPlace, int64_t>);
123 changes: 123 additions & 0 deletions paddle/fluid/operators/argsort_op.cu
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,19 @@ static __global__ void FillIndex(T* indices, T num_rows, T num_cols) {
}
}

template <typename T, typename IndType>
static __global__ void FillGrad(const T* dO, const IndType* indices, T* dX,
IndType num_rows, IndType num_cols) {
int col_id = threadIdx.x;
int row_id = blockIdx.x;

for (IndType j = row_id; j < num_rows; j += gridDim.x) {
for (IndType i = col_id; i < num_cols; i += blockDim.x) {
dX[j * num_cols + indices[j * num_cols + i]] = dO[j * num_cols + i];
}
}
}

// Sort by flag descending, True: descending. False: Ascending.
// Default is false.
template <typename T, typename IndType>
Expand Down Expand Up @@ -160,6 +173,35 @@ void ArgFullSort(const platform::CUDADeviceContext& ctx, const Tensor* input,
temp_storage_bytes, cudaGetErrorString(err));
}

template <typename T, typename IndType>
void ArgFullAssign(const platform::CUDADeviceContext& ctx, const Tensor* dO,
const Tensor* indices, Tensor* dX, const IndType num_rows,
const IndType num_cols) {
auto cu_stream = ctx.stream();

auto ComputeBlockSize = [](IndType col) {
if (col > 512)
return 1024;
else if (col > 256 && col <= 512)
return 512;
else if (col > 128 && col <= 256)
return 256;
else if (col > 64 && col <= 128)
return 128;
else
return 64;
};

int block_size = ComputeBlockSize(num_cols);

int maxGridDimX = ctx.GetCUDAMaxGridDimSize().x;
// actually, int num_rows < max_grid_size
int grid_size = num_rows < maxGridDimX ? num_rows : maxGridDimX;
FillGrad<<<grid_size, block_size, 0, cu_stream>>>(
dO->data<T>(), indices->data<IndType>(), dX->data<T>(), num_rows,
num_cols);
}

template <typename T>
class ArgsortOpCUDAKernel : public framework::OpKernel<T> {
public:
Expand Down Expand Up @@ -234,6 +276,81 @@ class ArgsortOpCUDAKernel : public framework::OpKernel<T> {
}
};

template <typename T>
class ArgsortGradOpCUDAKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto* indices = ctx.Input<Tensor>("Indices");
auto* dX = ctx.Output<Tensor>(framework::GradVarName("X"));
auto* dO = ctx.Input<Tensor>(framework::GradVarName("Out"));
int axis = ctx.Attr<int>("axis");

dX->mutable_data<T>(ctx.GetPlace());
auto dxt = framework::EigenVector<T>::Flatten(*dX);
auto& place = *ctx.template device_context<platform::CUDADeviceContext>()
.eigen_device();
dxt.device(place) = dxt.constant(static_cast<T>(0));
if (dO->numel() == 0) return;

auto in_dims = indices->dims();
axis = (axis < 0) ? (in_dims.size() + axis) : axis;

int64_t numel = indices->numel();

// Special case for full sort, speedup ~190x.
if (axis == -1 || axis + 1 == in_dims.size()) {
const int64_t input_height = framework::product(
framework::slice_ddim(in_dims, 0, in_dims.size() - 1));
const int64_t input_width = in_dims[in_dims.size() - 1];
const auto& dev_ctx = ctx.cuda_device_context();
ArgFullAssign<T, int64_t>(dev_ctx, dO, indices, dX, input_height,
input_width);
} else {
// if not full sort, do transpose first
std::vector<int> trans;
for (int i = 0; i < axis; i++) {
trans.push_back(i);
}
trans.push_back(in_dims.size() - 1);
for (int i = axis + 1; i < in_dims.size() - 1; i++) {
trans.push_back(i);
}
trans.push_back(axis);
framework::DDim trans_dims(in_dims);
for (int i = 0; i < trans.size(); i++) {
trans_dims[i] = in_dims[trans[i]];
}

Tensor trans_dO;
trans_dO.mutable_data<T>(trans_dims, ctx.GetPlace());
Tensor trans_ind;
trans_ind.mutable_data<int64_t>(trans_dims, ctx.GetPlace());
int ndims = trans.size();
const auto& dev_ctx = ctx.cuda_device_context();
// Do transpose
TransCompute<platform::CUDADeviceContext, T>(ndims, dev_ctx, *dO,
&trans_dO, trans);
TransCompute<platform::CUDADeviceContext, int64_t>(
ndims, dev_ctx, *indices, &trans_ind, trans);

const int64_t input_height = framework::product(
framework::slice_ddim(trans_dims, 0, trans_dims.size() - 1));
const int64_t input_width = trans_dims[trans_dims.size() - 1];

Tensor tmp_out;
tmp_out.mutable_data<T>(trans_dims, ctx.GetPlace());

ArgFullAssign<T, int64_t>(dev_ctx, &trans_dO, &trans_ind, &tmp_out,
input_height, input_width);

// transpose back
TransCompute<platform::CUDADeviceContext, T>(ndims, dev_ctx, tmp_out, dX,
trans);
return;
}
}
};

} // namespace operators
} // namespace paddle

Expand All @@ -243,3 +360,9 @@ REGISTER_OP_CUDA_KERNEL(
paddle::operators::ArgsortOpCUDAKernel<int>,
paddle::operators::ArgsortOpCUDAKernel<int64_t>,
paddle::operators::ArgsortOpCUDAKernel<paddle::platform::float16>);
REGISTER_OP_CUDA_KERNEL(
argsort_grad, paddle::operators::ArgsortGradOpCUDAKernel<float>,
paddle::operators::ArgsortGradOpCUDAKernel<double>,
paddle::operators::ArgsortGradOpCUDAKernel<int>,
paddle::operators::ArgsortGradOpCUDAKernel<int64_t>,
paddle::operators::ArgsortGradOpCUDAKernel<paddle::platform::float16>);
97 changes: 97 additions & 0 deletions paddle/fluid/operators/argsort_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,31 @@ static void FullSort(Type input_height, Type input_width, int input_dim,
}
}
}

template <typename T, typename Type>
static void FullAssign(Type input_height, Type input_width, int input_dim,
const framework::Tensor* input,
const framework::Tensor* indices, T* t_out) {
#ifdef PADDLE_WITH_MKLML
#pragma omp parallel for
#endif
for (Type i = 0; i < input_height; ++i) {
if (input_dim == 1) {
auto e_input = EigenVector<T>::Flatten(*input);
auto e_indices = EigenVector<Type>::Flatten(*indices);
for (Type j = 0; j < input_width; ++j) {
t_out[i * input_width + e_indices(j)] = e_input(e_indices(j));
}
} else {
auto e_input = EigenMatrix<T>::Reshape(*input, input_dim - 1);
auto e_indices = EigenMatrix<Type>::Reshape(*indices, input_dim - 1);
for (Type j = 0; j < input_width; ++j) {
t_out[i * input_width + e_indices(i, j)] = e_input(i, e_indices(i, j));
}
}
}
}

template <typename DeviceContext, typename T>
class ArgsortKernel : public framework::OpKernel<T> {
public:
Expand Down Expand Up @@ -142,5 +167,77 @@ class ArgsortKernel : public framework::OpKernel<T> {
}
};

template <typename DeviceContext, typename T>
class ArgsortGradientKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto* indices = ctx.Input<Tensor>("Indices");
auto* dX = ctx.Output<Tensor>(framework::GradVarName("X"));
auto* dO = ctx.Input<Tensor>(framework::GradVarName("Out"));
int axis = ctx.Attr<int>("axis");

auto in_dims = indices->dims();
axis = (axis < 0) ? (in_dims.size() + axis) : axis;

dX->mutable_data<T>(ctx.GetPlace());
auto dxt = framework::EigenVector<T>::Flatten(*dX);
auto& place = *ctx.template device_context<platform::CPUDeviceContext>()
.eigen_device();
dxt.device(place) = dxt.constant(static_cast<T>(0));
if (dO->numel() == 0) return;

// Do full assign
if (axis == -1 || axis + 1 == in_dims.size()) {
const int64_t input_height = framework::product(
framework::slice_ddim(in_dims, 0, in_dims.size() - 1));
const int64_t input_width = in_dims[in_dims.size() - 1];

FullAssign<T, int64_t>(input_height, input_width, in_dims.size(), dO,
indices, dX->data<T>());
} else {
// If not full assign do transpose
std::vector<int> trans;
for (int i = 0; i < axis; i++) {
trans.push_back(i);
}
trans.push_back(in_dims.size() - 1);
for (int i = axis + 1; i < in_dims.size() - 1; i++) {
trans.push_back(i);
}
trans.push_back(axis);
framework::DDim trans_dims(in_dims);
for (size_t i = 0; i < trans.size(); i++) {
trans_dims[i] = in_dims[trans[i]];
}

Tensor trans_dO;
trans_dO.mutable_data<T>(trans_dims, ctx.GetPlace());
Tensor trans_ind;
trans_ind.mutable_data<int64_t>(trans_dims, ctx.GetPlace());
int ndims = trans.size();
auto& dev_ctx = ctx.template device_context<platform::CPUDeviceContext>();
// Do transpose
TransCompute<platform::CPUDeviceContext, T>(ndims, dev_ctx, *dO,
&trans_dO, trans);
TransCompute<platform::CPUDeviceContext, int64_t>(
ndims, dev_ctx, *indices, &trans_ind, trans);

const int64_t input_height = framework::product(
framework::slice_ddim(trans_dims, 0, trans_dims.size() - 1));
const int64_t input_width = trans_dims[trans_dims.size() - 1];

Tensor tmp_out;
T* t_out = tmp_out.mutable_data<T>(trans_dims, ctx.GetPlace());

FullAssign<T, int64_t>(input_height, input_width, in_dims.size(),
&trans_dO, &trans_ind, t_out);

// transpose back
TransCompute<platform::CPUDeviceContext, T>(ndims, dev_ctx, tmp_out, dX,
trans);
}
}
};

} // namespace operators
} // namespace paddle
18 changes: 17 additions & 1 deletion python/paddle/fluid/tests/unittests/test_argsort_op.py
100755 → 100644
Original file line number Diff line number Diff line change
Expand Up @@ -48,14 +48,17 @@ def init_axis(self):
self.axis = -1

def init_datatype(self):
self.dtype = "float32"
self.dtype = "float64"

def init_direction(self):
self.descending = False

def test_check_output(self):
self.check_output()

def test_check_grad(self):
self.check_grad(['X'], 'Out')


class TestArgsortOpAxis0(TestArgsortOp):
def init_axis(self):
Expand Down Expand Up @@ -146,5 +149,18 @@ def init_direction(self):
self.descending = True


class TestArgsortOpFP32Axis(TestArgsortOp):
def init_datatype(self):
self.dtype = "float32"


class TestArgsortOpFP32DescendingAxis(TestArgsortOp):
def init_datatype(self):
self.dtype = "float32"

def init_direction(self):
self.descending = True


if __name__ == "__main__":
unittest.main()

0 comments on commit 443a713

Please sign in to comment.