Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add backward gradient computation for op argsort #22203

Merged
merged 2 commits into from
Jan 10, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
55 changes: 50 additions & 5 deletions paddle/fluid/operators/argsort_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License. */

#include "paddle/fluid/operators/argsort_op.h"
#include <memory>

namespace paddle {
namespace operators {
Expand All @@ -21,7 +22,7 @@ class ArgsortOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;

void InferShape(framework::InferShapeContext *ctx) const override {
void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("X"),
"Input(X) of ArgsortOp should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("Out"),
Expand Down Expand Up @@ -49,6 +50,24 @@ class ArgsortOp : public framework::OperatorWithKernel {
}
};

class ArgsortGradOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;

void InferShape(framework::InferShapeContext* ctx) const override {
ctx->SetOutputDim(framework::GradVarName("X"), ctx->GetInputDim("X"));
ctx->ShareLoD("X", /*-->*/ framework::GradVarName("X"));
}

protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(OperatorWithKernel::IndicateVarDataType(
ctx, framework::GradVarName("Out")),
ctx.device_context());
}
};

class ArgsortOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
Expand Down Expand Up @@ -83,16 +102,42 @@ Output(Indices) gives the sorted order along the given axis Attr(axis).
}
};

template <typename T>
class ArgsortGradOpMaker : public framework::SingleGradOpMaker<T> {
public:
using framework::SingleGradOpMaker<T>::SingleGradOpMaker;

protected:
std::unique_ptr<T> Apply() const override {
std::unique_ptr<T> op(new T());
op->SetType("argsort_grad");
op->SetInput("Indices", this->Output("Indices"));
op->SetInput("X", this->Input("X"));
op->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out"));
op->SetOutput(framework::GradVarName("X"), this->InputGrad("X"));
op->SetAttrMap(this->Attrs());
return op;
}
};

DECLARE_NO_NEED_BUFFER_VARS_INFERENCE(ArgsortGradNoNeedBufferVarInference, "X");

} // namespace operators
} // namespace paddle

namespace ops = paddle::operators;
REGISTER_OPERATOR(
argsort, ops::ArgsortOp, ops::ArgsortOpMaker,
paddle::framework::EmptyGradOpMaker<paddle::framework::OpDesc>,
paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>);
REGISTER_OPERATOR(argsort, ops::ArgsortOp, ops::ArgsortOpMaker,
ops::ArgsortGradOpMaker<paddle::framework::OpDesc>,
ops::ArgsortGradOpMaker<paddle::imperative::OpBase>);
REGISTER_OPERATOR(argsort_grad, ops::ArgsortGradOp,
ops::ArgsortGradNoNeedBufferVarInference);
REGISTER_OP_CPU_KERNEL(argsort,
ops::ArgsortKernel<paddle::platform::CPUPlace, float>,
ops::ArgsortKernel<paddle::platform::CPUPlace, double>,
ops::ArgsortKernel<paddle::platform::CPUPlace, int>,
ops::ArgsortKernel<paddle::platform::CPUPlace, int64_t>);
REGISTER_OP_CPU_KERNEL(
argsort_grad, ops::ArgsortGradientKernel<paddle::platform::CPUPlace, float>,
ops::ArgsortGradientKernel<paddle::platform::CPUPlace, double>,
ops::ArgsortGradientKernel<paddle::platform::CPUPlace, int>,
ops::ArgsortGradientKernel<paddle::platform::CPUPlace, int64_t>);
123 changes: 123 additions & 0 deletions paddle/fluid/operators/argsort_op.cu
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,19 @@ static __global__ void FillIndex(T* indices, T num_rows, T num_cols) {
}
}

template <typename T, typename IndType>
static __global__ void FillGrad(const T* dO, const IndType* indices, T* dX,
IndType num_rows, IndType num_cols) {
int col_id = threadIdx.x;
int row_id = blockIdx.x;

for (IndType j = row_id; j < num_rows; j += gridDim.x) {
for (IndType i = col_id; i < num_cols; i += blockDim.x) {
dX[j * num_cols + indices[j * num_cols + i]] = dO[j * num_cols + i];
}
}
}

// Sort by flag descending, True: descending. False: Ascending.
// Default is false.
template <typename T, typename IndType>
Expand Down Expand Up @@ -160,6 +173,35 @@ void ArgFullSort(const platform::CUDADeviceContext& ctx, const Tensor* input,
temp_storage_bytes, cudaGetErrorString(err));
}

template <typename T, typename IndType>
void ArgFullAssign(const platform::CUDADeviceContext& ctx, const Tensor* dO,
const Tensor* indices, Tensor* dX, const IndType num_rows,
const IndType num_cols) {
auto cu_stream = ctx.stream();

auto ComputeBlockSize = [](IndType col) {
if (col > 512)
return 1024;
else if (col > 256 && col <= 512)
return 512;
else if (col > 128 && col <= 256)
return 256;
else if (col > 64 && col <= 128)
return 128;
else
return 64;
};

int block_size = ComputeBlockSize(num_cols);

int maxGridDimX = ctx.GetCUDAMaxGridDimSize().x;
// actually, int num_rows < max_grid_size
int grid_size = num_rows < maxGridDimX ? num_rows : maxGridDimX;
FillGrad<<<grid_size, block_size, 0, cu_stream>>>(
dO->data<T>(), indices->data<IndType>(), dX->data<T>(), num_rows,
num_cols);
}

template <typename T>
class ArgsortOpCUDAKernel : public framework::OpKernel<T> {
public:
Expand Down Expand Up @@ -234,6 +276,81 @@ class ArgsortOpCUDAKernel : public framework::OpKernel<T> {
}
};

template <typename T>
class ArgsortGradOpCUDAKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto* indices = ctx.Input<Tensor>("Indices");
auto* dX = ctx.Output<Tensor>(framework::GradVarName("X"));
auto* dO = ctx.Input<Tensor>(framework::GradVarName("Out"));
int axis = ctx.Attr<int>("axis");

dX->mutable_data<T>(ctx.GetPlace());
auto dxt = framework::EigenVector<T>::Flatten(*dX);
auto& place = *ctx.template device_context<platform::CUDADeviceContext>()
.eigen_device();
dxt.device(place) = dxt.constant(static_cast<T>(0));
if (dO->numel() == 0) return;

auto in_dims = indices->dims();
axis = (axis < 0) ? (in_dims.size() + axis) : axis;

int64_t numel = indices->numel();

// Special case for full sort, speedup ~190x.
if (axis == -1 || axis + 1 == in_dims.size()) {
const int64_t input_height = framework::product(
framework::slice_ddim(in_dims, 0, in_dims.size() - 1));
const int64_t input_width = in_dims[in_dims.size() - 1];
const auto& dev_ctx = ctx.cuda_device_context();
ArgFullAssign<T, int64_t>(dev_ctx, dO, indices, dX, input_height,
input_width);
} else {
// if not full sort, do transpose first
std::vector<int> trans;
for (int i = 0; i < axis; i++) {
trans.push_back(i);
}
trans.push_back(in_dims.size() - 1);
for (int i = axis + 1; i < in_dims.size() - 1; i++) {
trans.push_back(i);
}
trans.push_back(axis);
framework::DDim trans_dims(in_dims);
for (int i = 0; i < trans.size(); i++) {
trans_dims[i] = in_dims[trans[i]];
}

Tensor trans_dO;
trans_dO.mutable_data<T>(trans_dims, ctx.GetPlace());
Tensor trans_ind;
trans_ind.mutable_data<int64_t>(trans_dims, ctx.GetPlace());
int ndims = trans.size();
const auto& dev_ctx = ctx.cuda_device_context();
// Do transpose
TransCompute<platform::CUDADeviceContext, T>(ndims, dev_ctx, *dO,
&trans_dO, trans);
TransCompute<platform::CUDADeviceContext, int64_t>(
ndims, dev_ctx, *indices, &trans_ind, trans);

const int64_t input_height = framework::product(
framework::slice_ddim(trans_dims, 0, trans_dims.size() - 1));
const int64_t input_width = trans_dims[trans_dims.size() - 1];

Tensor tmp_out;
tmp_out.mutable_data<T>(trans_dims, ctx.GetPlace());

ArgFullAssign<T, int64_t>(dev_ctx, &trans_dO, &trans_ind, &tmp_out,
input_height, input_width);

// transpose back
TransCompute<platform::CUDADeviceContext, T>(ndims, dev_ctx, tmp_out, dX,
trans);
return;
}
}
};

} // namespace operators
} // namespace paddle

Expand All @@ -243,3 +360,9 @@ REGISTER_OP_CUDA_KERNEL(
paddle::operators::ArgsortOpCUDAKernel<int>,
paddle::operators::ArgsortOpCUDAKernel<int64_t>,
paddle::operators::ArgsortOpCUDAKernel<paddle::platform::float16>);
REGISTER_OP_CUDA_KERNEL(
argsort_grad, paddle::operators::ArgsortGradOpCUDAKernel<float>,
paddle::operators::ArgsortGradOpCUDAKernel<double>,
paddle::operators::ArgsortGradOpCUDAKernel<int>,
paddle::operators::ArgsortGradOpCUDAKernel<int64_t>,
paddle::operators::ArgsortGradOpCUDAKernel<paddle::platform::float16>);
97 changes: 97 additions & 0 deletions paddle/fluid/operators/argsort_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,31 @@ static void FullSort(Type input_height, Type input_width, int input_dim,
}
}
}

template <typename T, typename Type>
static void FullAssign(Type input_height, Type input_width, int input_dim,
const framework::Tensor* input,
const framework::Tensor* indices, T* t_out) {
#ifdef PADDLE_WITH_MKLML
#pragma omp parallel for
#endif
for (Type i = 0; i < input_height; ++i) {
if (input_dim == 1) {
auto e_input = EigenVector<T>::Flatten(*input);
auto e_indices = EigenVector<Type>::Flatten(*indices);
for (Type j = 0; j < input_width; ++j) {
t_out[i * input_width + e_indices(j)] = e_input(e_indices(j));
}
} else {
auto e_input = EigenMatrix<T>::Reshape(*input, input_dim - 1);
auto e_indices = EigenMatrix<Type>::Reshape(*indices, input_dim - 1);
for (Type j = 0; j < input_width; ++j) {
t_out[i * input_width + e_indices(i, j)] = e_input(i, e_indices(i, j));
}
}
}
}

template <typename DeviceContext, typename T>
class ArgsortKernel : public framework::OpKernel<T> {
public:
Expand Down Expand Up @@ -142,5 +167,77 @@ class ArgsortKernel : public framework::OpKernel<T> {
}
};

template <typename DeviceContext, typename T>
class ArgsortGradientKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto* indices = ctx.Input<Tensor>("Indices");
auto* dX = ctx.Output<Tensor>(framework::GradVarName("X"));
auto* dO = ctx.Input<Tensor>(framework::GradVarName("Out"));
int axis = ctx.Attr<int>("axis");

auto in_dims = indices->dims();
axis = (axis < 0) ? (in_dims.size() + axis) : axis;

dX->mutable_data<T>(ctx.GetPlace());
auto dxt = framework::EigenVector<T>::Flatten(*dX);
auto& place = *ctx.template device_context<platform::CPUDeviceContext>()
.eigen_device();
dxt.device(place) = dxt.constant(static_cast<T>(0));
if (dO->numel() == 0) return;

// Do full assign
if (axis == -1 || axis + 1 == in_dims.size()) {
const int64_t input_height = framework::product(
framework::slice_ddim(in_dims, 0, in_dims.size() - 1));
const int64_t input_width = in_dims[in_dims.size() - 1];

FullAssign<T, int64_t>(input_height, input_width, in_dims.size(), dO,
indices, dX->data<T>());
} else {
// If not full assign do transpose
std::vector<int> trans;
for (int i = 0; i < axis; i++) {
trans.push_back(i);
}
trans.push_back(in_dims.size() - 1);
for (int i = axis + 1; i < in_dims.size() - 1; i++) {
trans.push_back(i);
}
trans.push_back(axis);
framework::DDim trans_dims(in_dims);
for (size_t i = 0; i < trans.size(); i++) {
trans_dims[i] = in_dims[trans[i]];
}

Tensor trans_dO;
trans_dO.mutable_data<T>(trans_dims, ctx.GetPlace());
Tensor trans_ind;
trans_ind.mutable_data<int64_t>(trans_dims, ctx.GetPlace());
int ndims = trans.size();
auto& dev_ctx = ctx.template device_context<platform::CPUDeviceContext>();
// Do transpose
TransCompute<platform::CPUDeviceContext, T>(ndims, dev_ctx, *dO,
&trans_dO, trans);
TransCompute<platform::CPUDeviceContext, int64_t>(
ndims, dev_ctx, *indices, &trans_ind, trans);

const int64_t input_height = framework::product(
framework::slice_ddim(trans_dims, 0, trans_dims.size() - 1));
const int64_t input_width = trans_dims[trans_dims.size() - 1];

Tensor tmp_out;
T* t_out = tmp_out.mutable_data<T>(trans_dims, ctx.GetPlace());

FullAssign<T, int64_t>(input_height, input_width, in_dims.size(),
&trans_dO, &trans_ind, t_out);

// transpose back
TransCompute<platform::CPUDeviceContext, T>(ndims, dev_ctx, tmp_out, dX,
trans);
}
}
};

} // namespace operators
} // namespace paddle
18 changes: 17 additions & 1 deletion python/paddle/fluid/tests/unittests/test_argsort_op.py
100755 → 100644
Original file line number Diff line number Diff line change
Expand Up @@ -48,14 +48,17 @@ def init_axis(self):
self.axis = -1

def init_datatype(self):
self.dtype = "float32"
self.dtype = "float64"

def init_direction(self):
self.descending = False

def test_check_output(self):
self.check_output()

def test_check_grad(self):
self.check_grad(['X'], 'Out')


class TestArgsortOpAxis0(TestArgsortOp):
def init_axis(self):
Expand Down Expand Up @@ -146,5 +149,18 @@ def init_direction(self):
self.descending = True


class TestArgsortOpFP32Axis(TestArgsortOp):
def init_datatype(self):
self.dtype = "float32"


class TestArgsortOpFP32DescendingAxis(TestArgsortOp):
def init_datatype(self):
self.dtype = "float32"

def init_direction(self):
self.descending = True


if __name__ == "__main__":
unittest.main()