Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add paddle.gather for API2.0 #26455

Merged
merged 19 commits into from
Aug 23, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
129 changes: 129 additions & 0 deletions paddle/fluid/operators/gather.cu.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ limitations under the License. */
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/framework/tensor.h"
#include "paddle/fluid/memory/malloc.h"
#include "paddle/fluid/operators/math/math_function.h"
#include "paddle/fluid/platform/cuda_primitives.h"
#include "paddle/fluid/platform/place.h"

Expand Down Expand Up @@ -158,5 +159,133 @@ void GPUGatherNd(const framework::ExecutionContext& context,
end_size);
}

template <typename T, typename U>
__global__ void GatherGPUKernel(const T* input, const U* index, T* out,
int outer_dim_size, int inner_dim_size,
int out_index_dim_size,
int input_index_dim_size, int size) {
int idx = blockDim.x * blockIdx.x + threadIdx.x;
for (; idx < size; idx += blockDim.x * gridDim.x) {
int inner_dim_index = idx / (outer_dim_size * out_index_dim_size);
int next_idx = idx % (outer_dim_size * out_index_dim_size);
int index_dim_index = next_idx / (outer_dim_size);
int out_dim_index = next_idx % outer_dim_size;
int input_index =
inner_dim_index * (outer_dim_size * input_index_dim_size) +
index[index_dim_index] * outer_dim_size + out_dim_index;
out[idx] = input[input_index];
}
}

template <typename T, typename U>
__global__ void GatherGradGPUKernel(const T* input, const U* index, T* out,
int outer_dim_size, int inner_dim_size,
int input_index_dim_size,
int out_index_dim_size, int size) {
int idx = blockDim.x * blockIdx.x + threadIdx.x;
for (; idx < size; idx += blockDim.x * gridDim.x) {
int inner_dim_index = idx / (outer_dim_size * input_index_dim_size);
int next_idx = idx % (outer_dim_size * input_index_dim_size);
int index_dim_index = next_idx / (outer_dim_size);
int out_dim_index = next_idx % outer_dim_size;
int out_index = inner_dim_index * (outer_dim_size * out_index_dim_size) +
index[index_dim_index] * outer_dim_size + out_dim_index;
paddle::platform::CudaAtomicAdd(out + out_index, *(input + idx));
}
}

template <typename T, typename U, typename V>
void GatherV2CUDAFunction(const Tensor* input, const Tensor* index,
const Tensor* axis, Tensor* out,
const paddle::platform::Place& place,
const framework::ExecutionContext& ctx) {
int axis_size = axis->numel();
int index_size = index->numel();
int input_size = input->numel();
auto input_dim = input->dims();
auto* input_data = input->data<T>();
auto* index_data = index->data<U>();

if (input->numel() == 0) return;
PADDLE_ENFORCE_EQ(axis_size, 1,
platform::errors::InvalidArgument(
"Axis size should be 1, but received %d", axis_size));
Tensor cpu_axis;
framework::TensorCopy(*axis, platform::CPUPlace(), &cpu_axis);
int axis_index = cpu_axis.data<V>()[0];
int index_dim_size = input_dim[axis_index];

int inner_dim_size = 1;
int outer_dim_size = 1;
std::vector<int> out_dim_vec;

for (int i = 0; i < axis_index; i++) {
inner_dim_size *= input_dim[i];
out_dim_vec.push_back(input_dim[i]);
}
out_dim_vec.push_back(index_size);
for (int i = axis_index + 1; i < input_dim.size(); i++) {
outer_dim_size *= input_dim[i];
out_dim_vec.push_back(input_dim[i]);
}
auto out_dim = framework::make_ddim(out_dim_vec);

out->Resize(out_dim);
auto* out_data = out->mutable_data<T>(place);
int out_size = out->numel();

int threads = 512;
int grid = (out_size + threads - 1) / threads;
auto stream = ctx.cuda_device_context().stream();
GatherGPUKernel<T, U><<<grid, threads, 0, stream>>>(
input_data, index_data, out_data, outer_dim_size, inner_dim_size,
index_size, index_dim_size, out_size);
}

template <typename T, typename U, typename V>
void GatherV2GradCUDAFunction(const Tensor* input, const Tensor* index,
const Tensor* axis, Tensor* out,
const paddle::platform::Place& place,
const framework::ExecutionContext& ctx) {
auto* index_data = index->data<U>();

int axis_size = axis->numel();
int index_size = index->numel();
int input_size = input->numel();
auto input_dim = input->dims();
auto* input_data = input->data<T>();

if (input->numel() == 0) return;
PADDLE_ENFORCE_EQ(axis_size, 1,
platform::errors::InvalidArgument(
"Axis size should be 1, but received %d", axis_size));
Tensor cpu_axis;
framework::TensorCopy(*axis, platform::CPUPlace(), &cpu_axis);
int axis_index = cpu_axis.data<V>()[0];
int input_index_dim_size = input_dim[axis_index];

int inner_dim_size = 1;
int outer_dim_size = 1;

for (int i = 0; i < axis_index; i++) {
inner_dim_size *= input_dim[i];
}
for (int i = axis_index + 1; i < input_dim.size(); i++) {
outer_dim_size *= input_dim[i];
}

auto* out_data = out->mutable_data<T>(place);
auto* dev_ctx = platform::DeviceContextPool::Instance().Get(place);
auto out_dim = out->dims();
int out_index_dim_size = out_dim[axis_index];
operators::math::set_constant(*dev_ctx, out, 0.0);

int threads = 512;
int grid = (input_size + threads - 1) / threads;
auto stream = ctx.cuda_device_context().stream();
GatherGradGPUKernel<T, U><<<grid, threads, 0, stream>>>(
input_data, index_data, out_data, outer_dim_size, inner_dim_size,
input_index_dim_size, out_index_dim_size, input_size);
}
} // namespace operators
} // namespace paddle
107 changes: 107 additions & 0 deletions paddle/fluid/operators/gather.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,12 @@ limitations under the License. */
#pragma once
#include <memory.h>
#include <cstring>
#include <vector>

#include "paddle/fluid/framework/ddim.h"
#include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/tensor.h"
#include "paddle/fluid/operators/math/math_function.h"
#include "paddle/fluid/platform/place.h"

namespace paddle {
Expand Down Expand Up @@ -124,5 +126,110 @@ void CPUGatherNd(const platform::DeviceContext& ctx, const Tensor& input,
}
}

template <typename T, typename U, typename V>
void GatherV2Function(const Tensor* input, const Tensor* index,
const Tensor* axis, Tensor* out,
const paddle::platform::Place& place) {
auto* axis_data = axis->data<V>();
auto* index_data = index->data<U>();

int axis_size = axis->numel();
int index_size = index->numel();
int input_size = input->numel();
auto input_dim = input->dims();
auto* input_data = input->data<T>();

if (input->numel() == 0) return;
PADDLE_ENFORCE_EQ(axis_size, 1,
platform::errors::InvalidArgument(
"Axis size should be 1, but received %d", axis_size));
int axis_index = axis_data[0];

int input_index_dim_size = input_dim[axis_index];
for (int i = 0; i < index_size; i++) {
PADDLE_ENFORCE_LT(index_data[i], input_index_dim_size,
platform::errors::InvalidArgument(
"The element of Index must be less than the size of "
"input dim size of axis which is %d, but received "
"index element which is %d in the %d index.",
input_index_dim_size, index_data[i], i));
}

int inner_dim_size = 1;
int outer_dim_size = 1;
std::vector<int> out_dim_vec;

for (int i = 0; i < axis_index; i++) {
inner_dim_size *= input_dim[i];
out_dim_vec.push_back(input_dim[i]);
}
out_dim_vec.push_back(index_size);
for (int i = axis_index + 1; i < input_dim.size(); i++) {
outer_dim_size *= input_dim[i];
out_dim_vec.push_back(input_dim[i]);
}
auto out_dim = framework::make_ddim(out_dim_vec);

out->Resize(out_dim);
auto* out_data = out->mutable_data<T>(place);

int out_index = 0;
for (int i = 0; i < inner_dim_size; i++) {
for (int j = 0; j < index_size; j++) {
for (int k = 0; k < outer_dim_size; k++) {
int index = k + index_data[j] * outer_dim_size +
(i * input_size / inner_dim_size);
out_data[out_index] = input_data[index];
out_index++;
}
}
}
}

template <typename T, typename U, typename V>
void GatherV2GradFunction(const Tensor* input, const Tensor* index,
const Tensor* axis, Tensor* out,
const paddle::platform::Place& place) {
auto* axis_data = axis->data<V>();
auto* index_data = index->data<U>();

int axis_size = axis->numel();
auto input_dim = input->dims();
auto* input_data = input->data<T>();

if (input->numel() == 0) return;
PADDLE_ENFORCE_EQ(axis_size, 1,
platform::errors::InvalidArgument(
"Axis size should be 1, but received %d", axis_size));
int axis_index = axis_data[0];
int input_index_dim_size = input_dim[axis_index];

int inner_dim_size = 1;
int outer_dim_size = 1;

for (int i = 0; i < axis_index; i++) {
inner_dim_size *= input_dim[i];
}
for (int i = axis_index + 1; i < input_dim.size(); i++) {
outer_dim_size *= input_dim[i];
}

auto* out_data = out->mutable_data<T>(place);
auto* dev_ctx = platform::DeviceContextPool::Instance().Get(place);
auto out_dim = out->dims();
int out_index_dim_size = out_dim[axis_index];
operators::math::set_constant(*dev_ctx, out, 0.0);

for (int i = 0; i < inner_dim_size; i++) {
for (int j = 0; j < input_index_dim_size; j++) {
for (int k = 0; k < outer_dim_size; k++) {
int index = k + index_data[j] * outer_dim_size +
i * outer_dim_size * out_index_dim_size;
out_data[index] += input_data[j * outer_dim_size + k];
}
}
}
}

} // namespace operators
} // namespace paddle
5 changes: 5 additions & 0 deletions paddle/fluid/operators/gather_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,9 @@ class GatherOpMaker : public framework::OpProtoAndCheckerMaker {
void Make() override {
AddInput("X", "The source input of gather op");
AddInput("Index", "The index input of gather op");
AddInput("Axis",
"The Tensor which contains the axis that we do gather operation.")
.AsDispensable();
AddOutput("Out", "The output of gather op");
AddAttr<bool>(
"overwrite",
Expand Down Expand Up @@ -120,6 +123,8 @@ class GatherGradOpMaker : public framework::SingleGradOpMaker<T> {
void Apply(GradOpPtr<T> op) const override {
op->SetType("gather_grad");
op->SetInput("Index", this->Input("Index"));
op->SetInput("Axis", this->Input("Axis"));

op->SetInput("X", this->Input("X"));
op->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out"));
op->SetOutput(framework::GradVarName("X"), this->InputGrad("X"));
Expand Down
55 changes: 55 additions & 0 deletions paddle/fluid/operators/gather_op.cu
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,33 @@ class GatherOpCUDAKernel : public framework::OpKernel<T> {
auto *index = ctx.Input<Tensor>("Index");
auto *output = ctx.Output<Tensor>("Out");

if (ctx.HasInput("Axis")) {
const Tensor *axis = ctx.Input<Tensor>("Axis");
const auto &index_type = index->type();
const auto &axis_type = axis->type();
auto place = ctx.GetPlace();
if (index_type == framework::proto::VarType::INT32 &&
axis_type == framework::proto::VarType::INT32) {
GatherV2CUDAFunction<T, int32_t, int32_t>(x, index, axis, output, place,
ctx);
}
if (index_type == framework::proto::VarType::INT32 &&
axis_type == framework::proto::VarType::INT64) {
GatherV2CUDAFunction<T, int32_t, int64_t>(x, index, axis, output, place,
ctx);
}
if (index_type == framework::proto::VarType::INT64 &&
axis_type == framework::proto::VarType::INT32) {
GatherV2CUDAFunction<T, int64_t, int32_t>(x, index, axis, output, place,
ctx);
}
if (index_type == framework::proto::VarType::INT64 &&
axis_type == framework::proto::VarType::INT64) {
GatherV2CUDAFunction<T, int64_t, int64_t>(x, index, axis, output, place,
ctx);
}
return;
}
output->mutable_data<T>(ctx.GetPlace());
if (x->numel() == 0) return;
const auto &index_type = index->type();
Expand Down Expand Up @@ -64,6 +91,34 @@ class GatherGradOpCUDAKernel : public framework::OpKernel<T> {
auto *dX = ctx.Output<Tensor>(framework::GradVarName("X"));
auto *dO = ctx.Input<Tensor>(framework::GradVarName("Out"));

if (ctx.HasInput("Axis")) {
const Tensor *axis = ctx.Input<Tensor>("Axis");
const auto &index_type = index->type();
const auto &axis_type = axis->type();
auto place = ctx.GetPlace();
if (index_type == framework::proto::VarType::INT32 &&
axis_type == framework::proto::VarType::INT32) {
GatherV2GradCUDAFunction<T, int32_t, int32_t>(dO, index, axis, dX,
place, ctx);
}
if (index_type == framework::proto::VarType::INT32 &&
axis_type == framework::proto::VarType::INT64) {
GatherV2GradCUDAFunction<T, int32_t, int64_t>(dO, index, axis, dX,
place, ctx);
}
if (index_type == framework::proto::VarType::INT64 &&
axis_type == framework::proto::VarType::INT32) {
GatherV2GradCUDAFunction<T, int64_t, int32_t>(dO, index, axis, dX,
place, ctx);
}
if (index_type == framework::proto::VarType::INT64 &&
axis_type == framework::proto::VarType::INT64) {
GatherV2GradCUDAFunction<T, int64_t, int64_t>(dO, index, axis, dX,
place, ctx);
}
return;
}

dX->mutable_data<T>(ctx.GetPlace());
auto dxt = framework::EigenVector<T>::Flatten(*dX);
auto &place = *ctx.template device_context<platform::CUDADeviceContext>()
Expand Down
Loading