From 3d1bf95691824890f8df9cb5b5abcf87347b91c8 Mon Sep 17 00:00:00 2001 From: wangchaochaohu Date: Tue, 18 Aug 2020 04:04:11 +0000 Subject: [PATCH 01/17] refine the doc --- python/paddle/tensor/manipulation.py | 16 ++++------------ 1 file changed, 4 insertions(+), 12 deletions(-) diff --git a/python/paddle/tensor/manipulation.py b/python/paddle/tensor/manipulation.py index 9e2b7286ba677..09e232a73a695 100644 --- a/python/paddle/tensor/manipulation.py +++ b/python/paddle/tensor/manipulation.py @@ -658,10 +658,8 @@ def unsqueeze(x, axis, name=None): return layers.unsqueeze(x, axis, name) -def gather(input, index, overwrite=True): +def gather(x, index, axis=None, name=None): """ - :alias_main: paddle.gather - :alias: paddle.gather,paddle.tensor.gather,paddle.tensor.manipulation.gather **Gather Layer** @@ -689,19 +687,13 @@ def gather(input, index, overwrite=True): Out = [[3, 4], [5, 6]] Args: - input (Variable): The source input tensor with rank>=1. Supported data type is + input (Tensor): The source input tensor with rank>=1. Supported data type is int32, int64, float32, float64 and uint8 (only for CPU), float16 (only for GPU). - index (Variable): The index input tensor with rank=1. Data type is int32 or int64. - overwrite (bool, optional): The mode that updating the grad when has same index. - If True, use the overwrite mode to update the grad of the same index, - if False, use the accumulate mode to update the grad of the same index. - Default value is True. - - + index (Tensor): The index input tensor with rank=1. Data type is int32 or int64. Returns: - output (Variable): The output is a tensor with the same rank as input. + output (Tensor): The output is a tensor with the same rank as input. Examples: From 5351674d319b8a0c027279f50cd45683f5aca6b5 Mon Sep 17 00:00:00 2001 From: wangchaochaohu Date: Tue, 18 Aug 2020 12:45:23 +0000 Subject: [PATCH 02/17] refine the code test=develop --- paddle/fluid/operators/gather_v2_op.cc | 125 +++++++++++++++++++++++ paddle/fluid/operators/gather_v2_op.cu | 117 ++++++++++++++++++++++ paddle/fluid/operators/gather_v2_op.h | 131 +++++++++++++++++++++++++ 3 files changed, 373 insertions(+) create mode 100644 paddle/fluid/operators/gather_v2_op.cc create mode 100644 paddle/fluid/operators/gather_v2_op.cu create mode 100644 paddle/fluid/operators/gather_v2_op.h diff --git a/paddle/fluid/operators/gather_v2_op.cc b/paddle/fluid/operators/gather_v2_op.cc new file mode 100644 index 0000000000000..620760671f3d7 --- /dev/null +++ b/paddle/fluid/operators/gather_v2_op.cc @@ -0,0 +1,125 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/operators/gather_v2_op.h" +#include +#include +#include +#include "paddle/fluid/framework/ddim.h" + +namespace paddle { +namespace operators { + +class GatherV2Op : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + void InferShape(framework::InferShapeContext* ctx) const override { + PADDLE_ENFORCE_EQ(ctx->HasInput("X"), true, + platform::errors::InvalidArgument( + "Input(X) of GatherOp should not be null.")); + PADDLE_ENFORCE_EQ(ctx->HasInput("Index"), true, + platform::errors::InvalidArgument( + "Input(Index) of GatherOp should not be null.")); + PADDLE_ENFORCE_EQ(ctx->HasInput("Axis"), true, + platform::errors::InvalidArgument( + "Input(Axis) of GatherOp should not be null.")); + PADDLE_ENFORCE_EQ(ctx->HasOutput("Y"), true, + platform::errors::InvalidArgument( + "Output(Y) of GatherOp should not be null.")); + + auto index_dims = ctx->GetInputDim("Index"); + PADDLE_ENFORCE(index_dims.size() == 1 || + (index_dims.size() == 2 && index_dims[1] == 1)); + } + + protected: + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext& ctx) const override { + return framework::OpKernelType( + OperatorWithKernel::IndicateVarDataType(ctx, "X"), + ctx.device_context()); + } +}; + +class GatherV2GradOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + void InferShape(framework::InferShapeContext* ctx) const override { + ctx->SetOutputDim(framework::GradVarName("X"), ctx->GetInputDim("X")); + ctx->ShareLoD("X", /*-->*/ framework::GradVarName("X")); + } + + protected: + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext& ctx) const override { + return framework::OpKernelType(OperatorWithKernel::IndicateVarDataType( + ctx, framework::GradVarName("Y")), + ctx.device_context()); + } +}; + +class GatherV2OpMaker : public framework::OpProtoAndCheckerMaker { + public: + void Make() override { + AddInput("X", "The source input of gather op"); + AddInput("Index", "The index input of gather op"); + AddOutput("Y", "The output of gather op"); + AddInput("axis", + "The Tensor which contains the axis that we do gather operation."); + AddComment(R"DOC( +Y is obtained by gathering entries of the axis dimension +of X indexed by Index and concatenate them together. +)DOC"); + } +}; + +template +class GatherV2GradOpMaker : public framework::SingleGradOpMaker { + public: + using framework::SingleGradOpMaker::SingleGradOpMaker; + + protected: + void Apply(GradOpPtr op) const override { + op->SetType("gather_v2_grad"); + op->SetInput("Index", this->Input("Index")); + op->SetInput("X", this->Input("X")); + op->SetInput(framework::GradVarName("Y"), this->OutputGrad("Y")); + op->SetOutput(framework::GradVarName("X"), this->InputGrad("X")); + op->SetAttrMap(this->Attrs()); + } +}; + +DECLARE_NO_NEED_BUFFER_VARS_INFERER(GatherV2GradNoNeedBufferVarInferer, "X"); + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +REGISTER_OPERATOR(gather_v2, ops::GatherV2Op, ops::GatherV2OpMaker, + ops::GatherV2GradOpMaker, + ops::GatherV2GradOpMaker); +REGISTER_OPERATOR(gather_v2_grad, ops::GatherV2GradOp, + ops::GatherV2GradNoNeedBufferVarInferer); +REGISTER_OP_CPU_KERNEL(gather_v2, ops::GatherV2OpKernel, + ops::GatherV2OpKernel, + ops::GatherV2OpKernel, + ops::GatherV2OpKernel, + ops::GatherV2OpKernel); +REGISTER_OP_CPU_KERNEL(gather_grad_v2, ops::GatherV2GradientOpKernel, + ops::GatherV2GradientOpKernel, + ops::GatherV2GradientOpKernel, + ops::GatherV2GradientOpKernel, + ops::GatherV2GradientOpKernel); diff --git a/paddle/fluid/operators/gather_v2_op.cu b/paddle/fluid/operators/gather_v2_op.cu new file mode 100644 index 0000000000000..26a97f810d448 --- /dev/null +++ b/paddle/fluid/operators/gather_v2_op.cu @@ -0,0 +1,117 @@ +/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/framework/eigen.h" +#include "paddle/fluid/operators/gather_v2_op.h" + +namespace paddle { +namespace operators { + +template +class GatherV2OpCUDAKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext &ctx) const override { + /* + PADDLE_ENFORCE_EQ(platform::is_gpu_place(ctx.GetPlace()), true, + platform::errors::PreconditionNotMet( + "This kernel only runs on GPU device.")); + auto *x = ctx.Input("X"); + auto *index = ctx.Input("Index"); + auto *output = ctx.Output("Out"); + + output->mutable_data(ctx.GetPlace()); + if (x->numel() == 0) return; + const auto &index_type = index->type(); + bool index_type_match = index_type == framework::proto::VarType::INT32 || + index_type == framework::proto::VarType::INT64; + PADDLE_ENFORCE_EQ(index_type_match, true, + platform::errors::InvalidArgument( + "Index holds the wrong type, it holds [%s], but " + "desires to be [%s] or [%s].", + paddle::framework::DataTypeToString(index_type), + paddle::framework::DataTypeToString( + framework::proto::VarType::INT32), + paddle::framework::DataTypeToString( + framework::proto::VarType::INT64))); + if (index_type == framework::proto::VarType::INT32) { + GPUGatherNd(ctx, *x, *index, output); + } else if (index_type == framework::proto::VarType::INT64) { + GPUGatherNd(ctx, *x, *index, output); + } + */ + } +}; + +template +class GatherV2GradOpCUDAKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext &ctx) const override { + /* + PADDLE_ENFORCE_EQ(platform::is_gpu_place(ctx.GetPlace()), true, + platform::errors::PreconditionNotMet( + "This kernel only runs on GPU device.")); + auto *index = ctx.Input("Index"); + auto *dX = ctx.Output(framework::GradVarName("X")); + auto *dO = ctx.Input(framework::GradVarName("Out")); + + dX->mutable_data(ctx.GetPlace()); + auto dxt = framework::EigenVector::Flatten(*dX); + auto &place = *ctx.template device_context() + .eigen_device(); + dxt.device(place) = dxt.constant(static_cast(0)); + if (dO->numel() == 0) return; + + const auto &index_type = index->type(); + bool index_type_match = index_type == framework::proto::VarType::INT32 || + index_type == framework::proto::VarType::INT64; + + PADDLE_ENFORCE_EQ(index_type_match, true, + platform::errors::InvalidArgument( + "Index holds the wrong type, it holds [%s]," + "but desires to be [%s] or [%s].", + paddle::framework::DataTypeToString(index_type), + paddle::framework::DataTypeToString( + framework::proto::VarType::INT32), + paddle::framework::DataTypeToString( + framework::proto::VarType::INT64))); + + if (index_type == framework::proto::VarType::INT32) { + GPUScatterNdAdd(ctx, *dO, *index, dX); + } else if (index_type == framework::proto::VarType::INT64) { + GPUScatterNdAdd(ctx, *dO, *index, dX); + } + + */ + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +namespace plat = paddle::platform; +using CUDA = paddle::platform::CUDADeviceContext; +REGISTER_OP_CUDA_KERNEL(gather_v2, ops::GatherV2OpCUDAKernel, + ops::GatherV2OpCUDAKernel, + ops::GatherV2OpCUDAKernel, + ops::GatherV2OpCUDAKernel, + ops::GatherV2OpCUDAKernel, + ops::GatherV2OpCUDAKernel); + +REGISTER_OP_CUDA_KERNEL(gather_v2_grad, + ops::GatherV2GradOpCUDAKernel, + ops::GatherV2GradOpCUDAKernel, + ops::GatherV2GradOpCUDAKernel, + ops::GatherV2GradOpCUDAKernel, + ops::GatherV2GradOpCUDAKernel); diff --git a/paddle/fluid/operators/gather_v2_op.h b/paddle/fluid/operators/gather_v2_op.h new file mode 100644 index 0000000000000..2a5bcec118210 --- /dev/null +++ b/paddle/fluid/operators/gather_v2_op.h @@ -0,0 +1,131 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once +#include +#include "paddle/fluid/framework/eigen.h" +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/operators/gather.h" + +namespace paddle { +namespace operators { + +using Tensor = framework::Tensor; + +template +class GatherV2OpKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + auto* input = ctx.Input("X"); + auto* index = ctx.Input("Index"); + auto* axis = ctx.Input("Axis"); + auto* out = ctx.Output("Y"); + auto* axis_data = axis->data(); + auto* index_data = index->data(); + auto* input_data = input->data(); + + int axis_size = axis->numel(); + int index_size = index->numel(); + int input_size = input->numel(); + auto input_dim = input->dims(); + if (input->numel() == 0) return; + PADDLE_ENFORCE_EQ(axis_size, 1, + platform::errors::InvalidArgument( + "Axis size should be 1, but received %d", axis_size)); + int axis_index = axis; + auto index_dim_size = input_dim[axis]; + PADDLE_ENFORCE_LE( + index_size, index_dim_size, + platform::errors::InvalidArgument( + "The size that index should be less equal than the dim size of " + "input," + "but received index size:%d, the dim size of input %d.", + axis_size, index_dim_size)); + + int inner_dim_size = 1; + int outer_dim_size = 1; + std::vector out_dim_vec = {input_dim_size}; + + for (int i = 0; i < axis_index; i++) { + inner_dim_size *= input_dim[i]; + } + for (int i = axis_index + 1; i < input_dim.size(); i++) { + outer_dim_size *= input_dim[i]; + out_dim_vec.push_back(input_dim[i]); + } + auto out_dim = framework::make_ddim(out_dim_vec); + + out->Resize(out_dim); + auto* out_data = out->mutable_data(ctx.GetPlace()); + + for (int i = 0; i < inner_dim_size; i++) { + for (int j = 0; j < outer_dim_size) { + } + } + } +}; + +template +class GatherV2GradientOpKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + /* + PADDLE_ENFORCE_EQ( + platform::is_cpu_place(ctx.GetPlace()), true, + platform::errors::PreconditionNotMet("This kernel only runs on CPU.")); + + auto *index = ctx.Input("Index"); + auto *dX = ctx.Output(framework::GradVarName("X")); + auto *dO = ctx.Input(framework::GradVarName("Out")); + + dX->mutable_data(ctx.GetPlace()); + auto dxt = framework::EigenVector::Flatten(*dX); + auto &place = *ctx.template device_context() + .eigen_device(); + dxt.device(place) = dxt.constant(static_cast(0)); + if (dO->numel() == 0) return; + bool overwrite = ctx.Attr("overwrite"); + + const auto &index_type = index->type(); + bool index_type_match = index_type == framework::proto::VarType::INT32 || + index_type == framework::proto::VarType::INT64; + PADDLE_ENFORCE_EQ(index_type_match, true, + platform::errors::InvalidArgument( + "Index holds the wrong type, it holds [%s]," + "but desires to be [%s] or [%s].", + paddle::framework::DataTypeToString(index_type), + paddle::framework::DataTypeToString( + framework::proto::VarType::INT32), + paddle::framework::DataTypeToString( + framework::proto::VarType::INT64))); + if (index_type == framework::proto::VarType::INT32) { + if (overwrite) { + ScatterAssign(ctx.device_context(), *dO, *index, dX); + } else { + ScatterAssignAdd(ctx, *dO, *index, dX); + } + } else if (index_type == framework::proto::VarType::INT64) { + if (overwrite) { + ScatterAssign(ctx.device_context(), *dO, *index, dX); + } else { + ScatterAssignAdd(ctx, *dO, *index, dX); + } + } + + */ + } +}; + +} // namespace operators +} // namespace paddle From 57e14dc666a77010ad0a0c491c3b4e997a4e3bfb Mon Sep 17 00:00:00 2001 From: wangchaochaohu Date: Tue, 18 Aug 2020 14:14:52 +0000 Subject: [PATCH 03/17] refine the code test=develop --- paddle/fluid/operators/gather_v2_op.h | 13 +++++++++---- 1 file changed, 9 insertions(+), 4 deletions(-) diff --git a/paddle/fluid/operators/gather_v2_op.h b/paddle/fluid/operators/gather_v2_op.h index 2a5bcec118210..bbdc202d8f1b4 100644 --- a/paddle/fluid/operators/gather_v2_op.h +++ b/paddle/fluid/operators/gather_v2_op.h @@ -43,8 +43,8 @@ class GatherV2OpKernel : public framework::OpKernel { PADDLE_ENFORCE_EQ(axis_size, 1, platform::errors::InvalidArgument( "Axis size should be 1, but received %d", axis_size)); - int axis_index = axis; - auto index_dim_size = input_dim[axis]; + int axis_index = axis_data[0]; + int index_dim_size = input_dim[axis_index]; PADDLE_ENFORCE_LE( index_size, index_dim_size, platform::errors::InvalidArgument( @@ -55,7 +55,7 @@ class GatherV2OpKernel : public framework::OpKernel { int inner_dim_size = 1; int outer_dim_size = 1; - std::vector out_dim_vec = {input_dim_size}; + std::vector out_dim_vec{index_dim_size}; for (int i = 0; i < axis_index; i++) { inner_dim_size *= input_dim[i]; @@ -70,7 +70,12 @@ class GatherV2OpKernel : public framework::OpKernel { auto* out_data = out->mutable_data(ctx.GetPlace()); for (int i = 0; i < inner_dim_size; i++) { - for (int j = 0; j < outer_dim_size) { + for (int j = 0; j < index_size; j++) { + for (int k = 0; k < outer_dim_size; k++) { + int index = k + index_data[j] * outer_dim_size + + (i * input_size / inner_dim_size); + out_data[i] = input_data[index]; + } } } } From b14cbe1a32ef83b24ecd6fc0c4729ff87d6406a8 Mon Sep 17 00:00:00 2001 From: wangchaochaohu Date: Tue, 18 Aug 2020 16:15:58 +0000 Subject: [PATCH 04/17] refine the code test=develop --- paddle/fluid/operators/gather_v2_op.cc | 4 +- paddle/fluid/operators/gather_v2_op.h | 171 ++++++++++++------------- 2 files changed, 81 insertions(+), 94 deletions(-) diff --git a/paddle/fluid/operators/gather_v2_op.cc b/paddle/fluid/operators/gather_v2_op.cc index 620760671f3d7..c965edfd26f9f 100644 --- a/paddle/fluid/operators/gather_v2_op.cc +++ b/paddle/fluid/operators/gather_v2_op.cc @@ -77,7 +77,7 @@ class GatherV2OpMaker : public framework::OpProtoAndCheckerMaker { AddInput("X", "The source input of gather op"); AddInput("Index", "The index input of gather op"); AddOutput("Y", "The output of gather op"); - AddInput("axis", + AddInput("Axis", "The Tensor which contains the axis that we do gather operation."); AddComment(R"DOC( Y is obtained by gathering entries of the axis dimension @@ -96,9 +96,9 @@ class GatherV2GradOpMaker : public framework::SingleGradOpMaker { op->SetType("gather_v2_grad"); op->SetInput("Index", this->Input("Index")); op->SetInput("X", this->Input("X")); + op->SetInput("Axis", this->Input("Axis")); op->SetInput(framework::GradVarName("Y"), this->OutputGrad("Y")); op->SetOutput(framework::GradVarName("X"), this->InputGrad("X")); - op->SetAttrMap(this->Attrs()); } }; diff --git a/paddle/fluid/operators/gather_v2_op.h b/paddle/fluid/operators/gather_v2_op.h index bbdc202d8f1b4..0da4429875d93 100644 --- a/paddle/fluid/operators/gather_v2_op.h +++ b/paddle/fluid/operators/gather_v2_op.h @@ -14,6 +14,7 @@ limitations under the License. */ #pragma once #include +#include "paddle/fluid/framework/ddim.h" #include "paddle/fluid/framework/eigen.h" #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/operators/gather.h" @@ -22,61 +23,87 @@ namespace paddle { namespace operators { using Tensor = framework::Tensor; +template +void GatherV2Function(const Tensor* input, const Tensor* index, + const Tensor* axis, Tensor* out, + const paddle::platform::Place& place) { + auto* axis_data = axis->data(); + auto* index_data = index->data(); + + int axis_size = axis->numel(); + int index_size = index->numel(); + int input_size = input->numel(); + auto input_dim = input->dims(); + auto* input_data = input->data(); + + if (input->numel() == 0) return; + PADDLE_ENFORCE_EQ(axis_size, 1, + platform::errors::InvalidArgument( + "Axis size should be 1, but received %d", axis_size)); + int axis_index = axis_data[0]; + int index_dim_size = input_dim[axis_index]; + PADDLE_ENFORCE_LE( + index_size, index_dim_size, + platform::errors::InvalidArgument( + "The size that index should be less equal than the dim size of " + "input," + "but received index size:%d, the dim size of input %d.", + axis_size, index_dim_size)); + + int inner_dim_size = 1; + int outer_dim_size = 1; + std::vector out_dim_vec{index_dim_size}; + + for (int i = 0; i < axis_index; i++) { + inner_dim_size *= input_dim[i]; + } + for (int i = axis_index + 1; i < input_dim.size(); i++) { + outer_dim_size *= input_dim[i]; + out_dim_vec.push_back(input_dim[i]); + } + auto out_dim = framework::make_ddim(out_dim_vec); + + out->Resize(out_dim); + auto* out_data = out->mutable_data(place); + + for (int i = 0; i < inner_dim_size; i++) { + for (int j = 0; j < index_size; j++) { + for (int k = 0; k < outer_dim_size; k++) { + int index = k + index_data[j] * outer_dim_size + + (i * input_size / inner_dim_size); + out_data[i] = input_data[index]; + } + } + } +} template class GatherV2OpKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& ctx) const override { - auto* input = ctx.Input("X"); - auto* index = ctx.Input("Index"); - auto* axis = ctx.Input("Axis"); - auto* out = ctx.Output("Y"); - auto* axis_data = axis->data(); - auto* index_data = index->data(); - auto* input_data = input->data(); - - int axis_size = axis->numel(); - int index_size = index->numel(); - int input_size = input->numel(); - auto input_dim = input->dims(); - if (input->numel() == 0) return; - PADDLE_ENFORCE_EQ(axis_size, 1, - platform::errors::InvalidArgument( - "Axis size should be 1, but received %d", axis_size)); - int axis_index = axis_data[0]; - int index_dim_size = input_dim[axis_index]; - PADDLE_ENFORCE_LE( - index_size, index_dim_size, - platform::errors::InvalidArgument( - "The size that index should be less equal than the dim size of " - "input," - "but received index size:%d, the dim size of input %d.", - axis_size, index_dim_size)); - - int inner_dim_size = 1; - int outer_dim_size = 1; - std::vector out_dim_vec{index_dim_size}; - - for (int i = 0; i < axis_index; i++) { - inner_dim_size *= input_dim[i]; + const Tensor* input = ctx.Input("X"); + const Tensor* index = ctx.Input("Index"); + const Tensor* axis = ctx.Input("Axis"); + Tensor* out = ctx.Output("Y"); + + const auto& index_type = index->type(); + const auto& axis_type = axis->type(); + auto place = ctx.GetPlace(); + if (index_type == framework::proto::VarType::INT32 && + axis_type == framework::proto::VarType::INT32) { + GatherV2Function(input, index, axis, out, place); } - for (int i = axis_index + 1; i < input_dim.size(); i++) { - outer_dim_size *= input_dim[i]; - out_dim_vec.push_back(input_dim[i]); + if (index_type == framework::proto::VarType::INT32 && + axis_type == framework::proto::VarType::INT64) { + GatherV2Function(input, index, axis, out, place); } - auto out_dim = framework::make_ddim(out_dim_vec); - - out->Resize(out_dim); - auto* out_data = out->mutable_data(ctx.GetPlace()); - - for (int i = 0; i < inner_dim_size; i++) { - for (int j = 0; j < index_size; j++) { - for (int k = 0; k < outer_dim_size; k++) { - int index = k + index_data[j] * outer_dim_size + - (i * input_size / inner_dim_size); - out_data[i] = input_data[index]; - } - } + if (index_type == framework::proto::VarType::INT64 && + axis_type == framework::proto::VarType::INT32) { + GatherV2Function(input, index, axis, out, place); + } + if (index_type == framework::proto::VarType::INT64 && + axis_type == framework::proto::VarType::INT64) { + GatherV2Function(input, index, axis, out, place); } } }; @@ -85,50 +112,10 @@ template class GatherV2GradientOpKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& ctx) const override { - /* - PADDLE_ENFORCE_EQ( - platform::is_cpu_place(ctx.GetPlace()), true, - platform::errors::PreconditionNotMet("This kernel only runs on CPU.")); - - auto *index = ctx.Input("Index"); - auto *dX = ctx.Output(framework::GradVarName("X")); - auto *dO = ctx.Input(framework::GradVarName("Out")); - - dX->mutable_data(ctx.GetPlace()); - auto dxt = framework::EigenVector::Flatten(*dX); - auto &place = *ctx.template device_context() - .eigen_device(); - dxt.device(place) = dxt.constant(static_cast(0)); - if (dO->numel() == 0) return; - bool overwrite = ctx.Attr("overwrite"); - - const auto &index_type = index->type(); - bool index_type_match = index_type == framework::proto::VarType::INT32 || - index_type == framework::proto::VarType::INT64; - PADDLE_ENFORCE_EQ(index_type_match, true, - platform::errors::InvalidArgument( - "Index holds the wrong type, it holds [%s]," - "but desires to be [%s] or [%s].", - paddle::framework::DataTypeToString(index_type), - paddle::framework::DataTypeToString( - framework::proto::VarType::INT32), - paddle::framework::DataTypeToString( - framework::proto::VarType::INT64))); - if (index_type == framework::proto::VarType::INT32) { - if (overwrite) { - ScatterAssign(ctx.device_context(), *dO, *index, dX); - } else { - ScatterAssignAdd(ctx, *dO, *index, dX); - } - } else if (index_type == framework::proto::VarType::INT64) { - if (overwrite) { - ScatterAssign(ctx.device_context(), *dO, *index, dX); - } else { - ScatterAssignAdd(ctx, *dO, *index, dX); - } - } - - */ + auto* index = ctx.Input("Index"); + auto* axis = ctx.Input("Axis"); + auto* out = ctx.Output(framework::GradVarName("X")); + auto* input = ctx.Input(framework::GradVarName("Y")); } }; From d9daeefc07f2dd487e7f349a9551848f09d81570 Mon Sep 17 00:00:00 2001 From: wangchaochaohu Date: Wed, 19 Aug 2020 07:39:28 +0000 Subject: [PATCH 05/17] add cuda implementation for gather op --- paddle/fluid/operators/gather_v2_op.cu | 260 +++++++++++++++++++------ paddle/fluid/operators/gather_v2_op.h | 81 +++++++- 2 files changed, 274 insertions(+), 67 deletions(-) diff --git a/paddle/fluid/operators/gather_v2_op.cu b/paddle/fluid/operators/gather_v2_op.cu index 26a97f810d448..9b7bea327e070 100644 --- a/paddle/fluid/operators/gather_v2_op.cu +++ b/paddle/fluid/operators/gather_v2_op.cu @@ -13,86 +13,218 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "paddle/fluid/framework/eigen.h" +#include "paddle/fluid/operators/gather.cu.h" +#include "paddle/fluid/operators/gather_op.h" #include "paddle/fluid/operators/gather_v2_op.h" +#include "paddle/fluid/operators/scatter.cu.h" namespace paddle { namespace operators { +using Tensor = framework::Tensor; + +template +__global__ void GatherGPUKernel(const T* input, const U* index, T* out, + int outer_dim_size, int inner_dim_size, + int index_dim_size, int size) { + int idx = blockDim.x * blockIdx.x + threadIdx.x; + for (; idx < size; idx += blockDim.x * gridDim.x) { + int inner_dim_index = idx / (outer_dim_size * index_dim_size); + int out_dim_index = idx % outer_dim_size; + int input_dim_index = idx / outer_dim_size; + int input_index = inner_dim_index * (outer_dim_size * index_dim_size) + + index[input_dim_index] * outer_dim_size + out_dim_index; + out[idx] = input[input_index]; + } +} + +template +__global__ void GatherGradGPUKernel(const T* input, const U* index, T* out, + int outer_dim_size, int inner_dim_size, + int index_dim_size, int size) { + int idx = blockDim.x * blockIdx.x + threadIdx.x; + for (; idx < size; idx += blockDim.x * gridDim.x) { + int inner_dim_index = idx / (outer_dim_size * index_dim_size); + int out_dim_index = idx % outer_dim_size; + int input_dim_index = idx / outer_dim_size; + int out_index = inner_dim_index * (outer_dim_size * index_dim_size) + + index[input_dim_index] * outer_dim_size + out_dim_index; + paddle::platform::CudaAtomicAdd(out + out_index, *(input + idx)); + } +} + +template +void GatherV2CUDAFunction(const Tensor* input, const Tensor* index, + const Tensor* axis, Tensor* out, + const paddle::platform::Place& place, + const framework::ExecutionContext& ctx) { + int axis_size = axis->numel(); + int index_size = index->numel(); + int input_size = input->numel(); + auto input_dim = input->dims(); + auto* input_data = input->data(); + auto* index_data = index->data(); + + if (input->numel() == 0) return; + PADDLE_ENFORCE_EQ(axis_size, 1, + platform::errors::InvalidArgument( + "Axis size should be 1, but received %d", axis_size)); + Tensor cpu_axis; + framework::TensorCopy(*axis, platform::CPUPlace(), &cpu_axis); + int axis_index = cpu_axis.data()[0]; + int index_dim_size = input_dim[axis_index]; + PADDLE_ENFORCE_LE( + index_size, index_dim_size, + platform::errors::InvalidArgument( + "The size that index should be less equal than the dim size of " + "input," + "but received index size:%d, the dim size of input %d.", + axis_size, index_dim_size)); + + int inner_dim_size = 1; + int outer_dim_size = 1; + std::vector out_dim_vec{index_dim_size}; + + for (int i = 0; i < axis_index; i++) { + inner_dim_size *= input_dim[i]; + } + for (int i = axis_index + 1; i < input_dim.size(); i++) { + outer_dim_size *= input_dim[i]; + out_dim_vec.push_back(input_dim[i]); + } + auto out_dim = framework::make_ddim(out_dim_vec); + + out->Resize(out_dim); + auto* out_data = out->mutable_data(place); + int out_size = out->numel(); + + int threads = 512; + int grid = (out_size + threads - 1) / threads; + auto stream = ctx.cuda_device_context().stream(); + GatherGPUKernel<<>>( + input_data, index_data, out_data, outer_dim_size, inner_dim_size, + index_dim_size, out_size); +} + +template +void GatherV2GradCUDAFunction(const Tensor* input, const Tensor* index, + const Tensor* axis, Tensor* out, + const paddle::platform::Place& place, + const framework::ExecutionContext& ctx) { + auto* axis_data = axis->data(); + auto* index_data = index->data(); + + int axis_size = axis->numel(); + int index_size = index->numel(); + int input_size = input->numel(); + auto input_dim = input->dims(); + auto* input_data = input->data(); + + if (input->numel() == 0) return; + PADDLE_ENFORCE_EQ(axis_size, 1, + platform::errors::InvalidArgument( + "Axis size should be 1, but received %d", axis_size)); + Tensor cpu_axis; + framework::TensorCopy(*axis, platform::CPUPlace(), &cpu_axis); + int axis_index = cpu_axis.data()[0]; + int index_dim_size = input_dim[axis_index]; + PADDLE_ENFORCE_LE( + index_size, index_dim_size, + platform::errors::InvalidArgument( + "The size that index should be less equal than the dim size of " + "input," + "but received index size:%d, the dim size of input %d.", + axis_size, index_dim_size)); + + int inner_dim_size = 1; + int outer_dim_size = 1; + + for (int i = 0; i < axis_index; i++) { + inner_dim_size *= input_dim[i]; + } + for (int i = axis_index + 1; i < input_dim.size(); i++) { + outer_dim_size *= input_dim[i]; + } + + auto* out_data = out->mutable_data(place); + auto* dev_ctx = platform::DeviceContextPool::Instance().Get(place); + operators::math::set_constant(*dev_ctx, out, 0.0); + + int threads = 512; + int grid = (input_size + threads - 1) / threads; + auto stream = ctx.cuda_device_context().stream(); + GatherGradGPUKernel<<>>( + input_data, index_data, out_data, outer_dim_size, inner_dim_size, + index_dim_size, input_size); +} + template class GatherV2OpCUDAKernel : public framework::OpKernel { public: - void Compute(const framework::ExecutionContext &ctx) const override { - /* - PADDLE_ENFORCE_EQ(platform::is_gpu_place(ctx.GetPlace()), true, - platform::errors::PreconditionNotMet( - "This kernel only runs on GPU device.")); - auto *x = ctx.Input("X"); - auto *index = ctx.Input("Index"); - auto *output = ctx.Output("Out"); - - output->mutable_data(ctx.GetPlace()); - if (x->numel() == 0) return; - const auto &index_type = index->type(); - bool index_type_match = index_type == framework::proto::VarType::INT32 || - index_type == framework::proto::VarType::INT64; - PADDLE_ENFORCE_EQ(index_type_match, true, - platform::errors::InvalidArgument( - "Index holds the wrong type, it holds [%s], but " - "desires to be [%s] or [%s].", - paddle::framework::DataTypeToString(index_type), - paddle::framework::DataTypeToString( - framework::proto::VarType::INT32), - paddle::framework::DataTypeToString( - framework::proto::VarType::INT64))); - if (index_type == framework::proto::VarType::INT32) { - GPUGatherNd(ctx, *x, *index, output); - } else if (index_type == framework::proto::VarType::INT64) { - GPUGatherNd(ctx, *x, *index, output); + void Compute(const framework::ExecutionContext& ctx) const override { + const Tensor* input = ctx.Input("X"); + const Tensor* index = ctx.Input("Index"); + const Tensor* axis = ctx.Input("Axis"); + Tensor* out = ctx.Output("Y"); + + const auto& index_type = index->type(); + const auto& axis_type = axis->type(); + auto place = ctx.GetPlace(); + if (index_type == framework::proto::VarType::INT32 && + axis_type == framework::proto::VarType::INT32) { + GatherV2CUDAFunction(input, index, axis, out, place, + ctx); + } + if (index_type == framework::proto::VarType::INT32 && + axis_type == framework::proto::VarType::INT64) { + GatherV2CUDAFunction(input, index, axis, out, place, + ctx); + } + if (index_type == framework::proto::VarType::INT64 && + axis_type == framework::proto::VarType::INT32) { + GatherV2CUDAFunction(input, index, axis, out, place, + ctx); + } + if (index_type == framework::proto::VarType::INT64 && + axis_type == framework::proto::VarType::INT64) { + GatherV2CUDAFunction(input, index, axis, out, place, + ctx); } - */ } }; template class GatherV2GradOpCUDAKernel : public framework::OpKernel { public: - void Compute(const framework::ExecutionContext &ctx) const override { - /* - PADDLE_ENFORCE_EQ(platform::is_gpu_place(ctx.GetPlace()), true, - platform::errors::PreconditionNotMet( - "This kernel only runs on GPU device.")); - auto *index = ctx.Input("Index"); - auto *dX = ctx.Output(framework::GradVarName("X")); - auto *dO = ctx.Input(framework::GradVarName("Out")); - - dX->mutable_data(ctx.GetPlace()); - auto dxt = framework::EigenVector::Flatten(*dX); - auto &place = *ctx.template device_context() - .eigen_device(); - dxt.device(place) = dxt.constant(static_cast(0)); - if (dO->numel() == 0) return; - - const auto &index_type = index->type(); - bool index_type_match = index_type == framework::proto::VarType::INT32 || - index_type == framework::proto::VarType::INT64; - - PADDLE_ENFORCE_EQ(index_type_match, true, - platform::errors::InvalidArgument( - "Index holds the wrong type, it holds [%s]," - "but desires to be [%s] or [%s].", - paddle::framework::DataTypeToString(index_type), - paddle::framework::DataTypeToString( - framework::proto::VarType::INT32), - paddle::framework::DataTypeToString( - framework::proto::VarType::INT64))); - - if (index_type == framework::proto::VarType::INT32) { - GPUScatterNdAdd(ctx, *dO, *index, dX); - } else if (index_type == framework::proto::VarType::INT64) { - GPUScatterNdAdd(ctx, *dO, *index, dX); + void Compute(const framework::ExecutionContext& ctx) const override { + const Tensor* input = ctx.Input("X"); + const Tensor* index = ctx.Input("Index"); + const Tensor* axis = ctx.Input("Axis"); + Tensor* out = ctx.Output("Y"); + + const auto& index_type = index->type(); + const auto& axis_type = axis->type(); + auto place = ctx.GetPlace(); + if (index_type == framework::proto::VarType::INT32 && + axis_type == framework::proto::VarType::INT32) { + GatherV2GradCUDAFunction(input, index, axis, out, + place, ctx); + } + if (index_type == framework::proto::VarType::INT32 && + axis_type == framework::proto::VarType::INT64) { + GatherV2GradCUDAFunction(input, index, axis, out, + place, ctx); + } + if (index_type == framework::proto::VarType::INT64 && + axis_type == framework::proto::VarType::INT32) { + GatherV2GradCUDAFunction(input, index, axis, out, + place, ctx); + } + if (index_type == framework::proto::VarType::INT64 && + axis_type == framework::proto::VarType::INT64) { + GatherV2GradCUDAFunction(input, index, axis, out, + place, ctx); } - - */ } }; diff --git a/paddle/fluid/operators/gather_v2_op.h b/paddle/fluid/operators/gather_v2_op.h index 0da4429875d93..3ae8178db5012 100644 --- a/paddle/fluid/operators/gather_v2_op.h +++ b/paddle/fluid/operators/gather_v2_op.h @@ -18,17 +18,19 @@ limitations under the License. */ #include "paddle/fluid/framework/eigen.h" #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/operators/gather.h" +#include "paddle/fluid/operators/math/math_function.h" namespace paddle { namespace operators { using Tensor = framework::Tensor; + template void GatherV2Function(const Tensor* input, const Tensor* index, const Tensor* axis, Tensor* out, const paddle::platform::Place& place) { - auto* axis_data = axis->data(); - auto* index_data = index->data(); + auto* axis_data = axis->data(); + auto* index_data = index->data(); int axis_size = axis->numel(); int index_size = index->numel(); @@ -66,12 +68,14 @@ void GatherV2Function(const Tensor* input, const Tensor* index, out->Resize(out_dim); auto* out_data = out->mutable_data(place); + int out_index = 0; for (int i = 0; i < inner_dim_size; i++) { for (int j = 0; j < index_size; j++) { for (int k = 0; k < outer_dim_size; k++) { int index = k + index_data[j] * outer_dim_size + (i * input_size / inner_dim_size); - out_data[i] = input_data[index]; + out_data[out_index] = input_data[index]; + out_index++; } } } @@ -108,6 +112,57 @@ class GatherV2OpKernel : public framework::OpKernel { } }; +template +void GatherV2GradFunction(const Tensor* input, const Tensor* index, + const Tensor* axis, Tensor* out, + const paddle::platform::Place& place) { + auto* axis_data = axis->data(); + auto* index_data = index->data(); + + int axis_size = axis->numel(); + int index_size = index->numel(); + auto input_dim = input->dims(); + auto* input_data = input->data(); + + if (input->numel() == 0) return; + PADDLE_ENFORCE_EQ(axis_size, 1, + platform::errors::InvalidArgument( + "Axis size should be 1, but received %d", axis_size)); + int axis_index = axis_data[0]; + int index_dim_size = input_dim[axis_index]; + PADDLE_ENFORCE_LE( + index_size, index_dim_size, + platform::errors::InvalidArgument( + "The size that index should be less equal than the dim size of " + "input," + "but received index size:%d, the dim size of input %d.", + axis_size, index_dim_size)); + + int inner_dim_size = 1; + int outer_dim_size = 1; + + for (int i = 0; i < axis_index; i++) { + inner_dim_size *= input_dim[i]; + } + for (int i = axis_index + 1; i < input_dim.size(); i++) { + outer_dim_size *= input_dim[i]; + } + + auto* out_data = out->mutable_data(place); + auto* dev_ctx = platform::DeviceContextPool::Instance().Get(place); + operators::math::set_constant(*dev_ctx, out, 0.0); + + for (int i = 0; i < index_size; i++) { + for (int j = 0; j < inner_dim_size; j++) { + for (int k = 0; k < outer_dim_size; k++) { + int index = k + index_data[i] * outer_dim_size + + j * outer_dim_size * index_dim_size; + out_data[index] += input_data[i]; + } + } + } +} + template class GatherV2GradientOpKernel : public framework::OpKernel { public: @@ -116,6 +171,26 @@ class GatherV2GradientOpKernel : public framework::OpKernel { auto* axis = ctx.Input("Axis"); auto* out = ctx.Output(framework::GradVarName("X")); auto* input = ctx.Input(framework::GradVarName("Y")); + + const auto& index_type = index->type(); + const auto& axis_type = axis->type(); + auto place = ctx.GetPlace(); + if (index_type == framework::proto::VarType::INT32 && + axis_type == framework::proto::VarType::INT32) { + GatherV2GradFunction(input, index, axis, out, place); + } + if (index_type == framework::proto::VarType::INT32 && + axis_type == framework::proto::VarType::INT64) { + GatherV2GradFunction(input, index, axis, out, place); + } + if (index_type == framework::proto::VarType::INT64 && + axis_type == framework::proto::VarType::INT32) { + GatherV2GradFunction(input, index, axis, out, place); + } + if (index_type == framework::proto::VarType::INT64 && + axis_type == framework::proto::VarType::INT64) { + GatherV2GradFunction(input, index, axis, out, place); + } } }; From 5ebfb10e37f5651a6fc7b84bb755c0ad5f3ccf95 Mon Sep 17 00:00:00 2001 From: wangchaochaohu Date: Wed, 19 Aug 2020 09:50:48 +0000 Subject: [PATCH 06/17] refine the code test=develop --- paddle/fluid/operators/gather_v2_op.h | 28 ++++----- .../tests/unittests/test_gather_v2_op.py | 62 +++++++++++++++++++ 2 files changed, 76 insertions(+), 14 deletions(-) create mode 100644 python/paddle/fluid/tests/unittests/test_gather_v2_op.py diff --git a/paddle/fluid/operators/gather_v2_op.h b/paddle/fluid/operators/gather_v2_op.h index 3ae8178db5012..df4f7636dc39a 100644 --- a/paddle/fluid/operators/gather_v2_op.h +++ b/paddle/fluid/operators/gather_v2_op.h @@ -26,19 +26,19 @@ namespace operators { using Tensor = framework::Tensor; template -void GatherV2Function(const Tensor* input, const Tensor* index, - const Tensor* axis, Tensor* out, +void GatherV2Function(const Tensor& input, const Tensor& index, + const Tensor& axis, Tensor* out, const paddle::platform::Place& place) { - auto* axis_data = axis->data(); - auto* index_data = index->data(); + auto* axis_data = axis.data(); + auto* index_data = index.data(); - int axis_size = axis->numel(); - int index_size = index->numel(); - int input_size = input->numel(); - auto input_dim = input->dims(); - auto* input_data = input->data(); + int axis_size = axis.numel(); + int index_size = index.numel(); + int input_size = input.numel(); + auto input_dim = input.dims(); + auto* input_data = input.data(); - if (input->numel() == 0) return; + if (input.numel() == 0) return; PADDLE_ENFORCE_EQ(axis_size, 1, platform::errors::InvalidArgument( "Axis size should be 1, but received %d", axis_size)); @@ -95,19 +95,19 @@ class GatherV2OpKernel : public framework::OpKernel { auto place = ctx.GetPlace(); if (index_type == framework::proto::VarType::INT32 && axis_type == framework::proto::VarType::INT32) { - GatherV2Function(input, index, axis, out, place); + GatherV2Function(*input, *index, *axis, out, place); } if (index_type == framework::proto::VarType::INT32 && axis_type == framework::proto::VarType::INT64) { - GatherV2Function(input, index, axis, out, place); + GatherV2Function(*input, *index, *axis, out, place); } if (index_type == framework::proto::VarType::INT64 && axis_type == framework::proto::VarType::INT32) { - GatherV2Function(input, index, axis, out, place); + GatherV2Function(*input, *index, *axis, out, place); } if (index_type == framework::proto::VarType::INT64 && axis_type == framework::proto::VarType::INT64) { - GatherV2Function(input, index, axis, out, place); + GatherV2Function(*input, *index, *axis, out, place); } } }; diff --git a/python/paddle/fluid/tests/unittests/test_gather_v2_op.py b/python/paddle/fluid/tests/unittests/test_gather_v2_op.py new file mode 100644 index 0000000000000..8716b821c930d --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_gather_v2_op.py @@ -0,0 +1,62 @@ +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import unittest +import numpy as np +from op_test import OpTest +import paddle +import paddle.fluid as fluid + + +def gather_numpy(x, axis, index): + result = x[:, index, :] + return result + + +class TestGatherOp(OpTest): + def setUp(self): + self.op_type = "gather_v2" + self.config() + xnp = np.random.random(self.x_shape).astype(self.x_type) + axis_np = np.array(self.axis).astype(self.index_type) + index_np = np.array(self.index).astype(self.index_type) + self.inputs = {'X': xnp, 'Index': index_np, 'Axis': axis_np} + out = gather_numpy(xnp, axis_np, index_np) + print(out.shape) + self.outputs = {'Y': out} + + def test_check_output(self): + self.check_output_with_place(paddle.CPUPlace()) + + """ + def test_check_grad(self): + self.check_grad(['X'], 'Out') + """ + + def config(self): + """ + For multi-dimension input + """ + self.x_shape = (10, 20, 10) + self.x_type = "float64" + self.index = [1, 3, 5] + self.index_type = "int32" + self.axis = [1] + self.axis_type = "int32" + + +if __name__ == "__main__": + unittest.main() From e0d751c7ff75c1e52bc56f9136d00a4622e48591 Mon Sep 17 00:00:00 2001 From: wangchaochaohu Date: Wed, 19 Aug 2020 09:59:05 +0000 Subject: [PATCH 07/17] refine the code test=develop --- paddle/fluid/operators/gather_v2_op.cc | 3 +++ 1 file changed, 3 insertions(+) diff --git a/paddle/fluid/operators/gather_v2_op.cc b/paddle/fluid/operators/gather_v2_op.cc index c965edfd26f9f..2cd22e2101cae 100644 --- a/paddle/fluid/operators/gather_v2_op.cc +++ b/paddle/fluid/operators/gather_v2_op.cc @@ -42,6 +42,9 @@ class GatherV2Op : public framework::OperatorWithKernel { auto index_dims = ctx->GetInputDim("Index"); PADDLE_ENFORCE(index_dims.size() == 1 || (index_dims.size() == 2 && index_dims[1] == 1)); + framework::DDim output_dims(ctx->GetInputDim("X")); + ctx->SetOutputDim("Out", output_dims); + ctx->ShareLoD("X", /*->*/ "Out"); } protected: From 0d5aacd247e96a3fc08fa8fb26cd5c33dd54b057 Mon Sep 17 00:00:00 2001 From: wangchaochaohu Date: Wed, 19 Aug 2020 16:01:45 +0000 Subject: [PATCH 08/17] refine --- paddle/fluid/operators/gather_v2_op.cc | 4 +-- paddle/fluid/operators/gather_v2_op.cu | 6 ++-- paddle/fluid/operators/gather_v2_op.h | 32 ++++++++++--------- .../tests/unittests/test_gather_v2_op.py | 2 +- 4 files changed, 24 insertions(+), 20 deletions(-) diff --git a/paddle/fluid/operators/gather_v2_op.cc b/paddle/fluid/operators/gather_v2_op.cc index 2cd22e2101cae..656f6d2190d1e 100644 --- a/paddle/fluid/operators/gather_v2_op.cc +++ b/paddle/fluid/operators/gather_v2_op.cc @@ -43,8 +43,8 @@ class GatherV2Op : public framework::OperatorWithKernel { PADDLE_ENFORCE(index_dims.size() == 1 || (index_dims.size() == 2 && index_dims[1] == 1)); framework::DDim output_dims(ctx->GetInputDim("X")); - ctx->SetOutputDim("Out", output_dims); - ctx->ShareLoD("X", /*->*/ "Out"); + ctx->SetOutputDim("Y", output_dims); + ctx->ShareLoD("X", /*->*/ "Y"); } protected: diff --git a/paddle/fluid/operators/gather_v2_op.cu b/paddle/fluid/operators/gather_v2_op.cu index 9b7bea327e070..4035fe4404e51 100644 --- a/paddle/fluid/operators/gather_v2_op.cu +++ b/paddle/fluid/operators/gather_v2_op.cu @@ -34,7 +34,7 @@ __global__ void GatherGPUKernel(const T* input, const U* index, T* out, int input_dim_index = idx / outer_dim_size; int input_index = inner_dim_index * (outer_dim_size * index_dim_size) + index[input_dim_index] * outer_dim_size + out_dim_index; - out[idx] = input[input_index]; + out[idx] = input[0]; } } @@ -83,11 +83,13 @@ void GatherV2CUDAFunction(const Tensor* input, const Tensor* index, int inner_dim_size = 1; int outer_dim_size = 1; - std::vector out_dim_vec{index_dim_size}; + std::vector out_dim_vec; for (int i = 0; i < axis_index; i++) { inner_dim_size *= input_dim[i]; + out_dim_vec.push_back(input_dim[i]); } + out_dim_vec.push_back(index_size); for (int i = axis_index + 1; i < input_dim.size(); i++) { outer_dim_size *= input_dim[i]; out_dim_vec.push_back(input_dim[i]); diff --git a/paddle/fluid/operators/gather_v2_op.h b/paddle/fluid/operators/gather_v2_op.h index df4f7636dc39a..a84de07292d85 100644 --- a/paddle/fluid/operators/gather_v2_op.h +++ b/paddle/fluid/operators/gather_v2_op.h @@ -26,19 +26,19 @@ namespace operators { using Tensor = framework::Tensor; template -void GatherV2Function(const Tensor& input, const Tensor& index, - const Tensor& axis, Tensor* out, +void GatherV2Function(const Tensor* input, const Tensor* index, + const Tensor* axis, Tensor* out, const paddle::platform::Place& place) { - auto* axis_data = axis.data(); - auto* index_data = index.data(); + auto* axis_data = axis->data(); + auto* index_data = index->data(); - int axis_size = axis.numel(); - int index_size = index.numel(); - int input_size = input.numel(); - auto input_dim = input.dims(); - auto* input_data = input.data(); + int axis_size = axis->numel(); + int index_size = index->numel(); + int input_size = input->numel(); + auto input_dim = input->dims(); + auto* input_data = input->data(); - if (input.numel() == 0) return; + if (input->numel() == 0) return; PADDLE_ENFORCE_EQ(axis_size, 1, platform::errors::InvalidArgument( "Axis size should be 1, but received %d", axis_size)); @@ -54,11 +54,13 @@ void GatherV2Function(const Tensor& input, const Tensor& index, int inner_dim_size = 1; int outer_dim_size = 1; - std::vector out_dim_vec{index_dim_size}; + std::vector out_dim_vec; for (int i = 0; i < axis_index; i++) { inner_dim_size *= input_dim[i]; + out_dim_vec.push_back(input_dim[i]); } + out_dim_vec.push_back(index_size); for (int i = axis_index + 1; i < input_dim.size(); i++) { outer_dim_size *= input_dim[i]; out_dim_vec.push_back(input_dim[i]); @@ -95,19 +97,19 @@ class GatherV2OpKernel : public framework::OpKernel { auto place = ctx.GetPlace(); if (index_type == framework::proto::VarType::INT32 && axis_type == framework::proto::VarType::INT32) { - GatherV2Function(*input, *index, *axis, out, place); + GatherV2Function(input, index, axis, out, place); } if (index_type == framework::proto::VarType::INT32 && axis_type == framework::proto::VarType::INT64) { - GatherV2Function(*input, *index, *axis, out, place); + GatherV2Function(input, index, axis, out, place); } if (index_type == framework::proto::VarType::INT64 && axis_type == framework::proto::VarType::INT32) { - GatherV2Function(*input, *index, *axis, out, place); + GatherV2Function(input, index, axis, out, place); } if (index_type == framework::proto::VarType::INT64 && axis_type == framework::proto::VarType::INT64) { - GatherV2Function(*input, *index, *axis, out, place); + GatherV2Function(input, index, axis, out, place); } } }; diff --git a/python/paddle/fluid/tests/unittests/test_gather_v2_op.py b/python/paddle/fluid/tests/unittests/test_gather_v2_op.py index 8716b821c930d..53cbc31e6e9e6 100644 --- a/python/paddle/fluid/tests/unittests/test_gather_v2_op.py +++ b/python/paddle/fluid/tests/unittests/test_gather_v2_op.py @@ -39,7 +39,7 @@ def setUp(self): self.outputs = {'Y': out} def test_check_output(self): - self.check_output_with_place(paddle.CPUPlace()) + self.check_output() """ def test_check_grad(self): From 8bf07d96a425af67410ef58998b8afb4df08b0b5 Mon Sep 17 00:00:00 2001 From: wangchaochaohu Date: Thu, 20 Aug 2020 06:00:02 +0000 Subject: [PATCH 09/17] refine --- paddle/fluid/operators/gather_v2_op.cc | 2 +- paddle/fluid/operators/gather_v2_op.cu | 17 ++++++++------ paddle/fluid/operators/gather_v2_op.h | 22 +++++++------------ .../tests/unittests/test_gather_v2_op.py | 6 ++--- 4 files changed, 21 insertions(+), 26 deletions(-) diff --git a/paddle/fluid/operators/gather_v2_op.cc b/paddle/fluid/operators/gather_v2_op.cc index 656f6d2190d1e..257401f31fe5c 100644 --- a/paddle/fluid/operators/gather_v2_op.cc +++ b/paddle/fluid/operators/gather_v2_op.cc @@ -121,7 +121,7 @@ REGISTER_OP_CPU_KERNEL(gather_v2, ops::GatherV2OpKernel, ops::GatherV2OpKernel, ops::GatherV2OpKernel, ops::GatherV2OpKernel); -REGISTER_OP_CPU_KERNEL(gather_grad_v2, ops::GatherV2GradientOpKernel, +REGISTER_OP_CPU_KERNEL(gather_v2_grad, ops::GatherV2GradientOpKernel, ops::GatherV2GradientOpKernel, ops::GatherV2GradientOpKernel, ops::GatherV2GradientOpKernel, diff --git a/paddle/fluid/operators/gather_v2_op.cu b/paddle/fluid/operators/gather_v2_op.cu index 4035fe4404e51..b24f649c8403b 100644 --- a/paddle/fluid/operators/gather_v2_op.cu +++ b/paddle/fluid/operators/gather_v2_op.cu @@ -26,15 +26,18 @@ using Tensor = framework::Tensor; template __global__ void GatherGPUKernel(const T* input, const U* index, T* out, int outer_dim_size, int inner_dim_size, - int index_dim_size, int size) { + int index_dim_size, int input_index_dim_size, + int size) { int idx = blockDim.x * blockIdx.x + threadIdx.x; for (; idx < size; idx += blockDim.x * gridDim.x) { int inner_dim_index = idx / (outer_dim_size * index_dim_size); - int out_dim_index = idx % outer_dim_size; - int input_dim_index = idx / outer_dim_size; - int input_index = inner_dim_index * (outer_dim_size * index_dim_size) + - index[input_dim_index] * outer_dim_size + out_dim_index; - out[idx] = input[0]; + int next_idx = idx % (outer_dim_size * index_dim_size); + int index_dim_index = next_idx / (outer_dim_size); + int out_dim_index = next_idx % outer_dim_size; + int input_index = + inner_dim_index * (outer_dim_size * input_index_dim_size) + + index[index_dim_index] * outer_dim_size + out_dim_index; + out[idx] = input[input_index]; } } @@ -105,7 +108,7 @@ void GatherV2CUDAFunction(const Tensor* input, const Tensor* index, auto stream = ctx.cuda_device_context().stream(); GatherGPUKernel<<>>( input_data, index_data, out_data, outer_dim_size, inner_dim_size, - index_dim_size, out_size); + index_size, index_dim_size, out_size); } template diff --git a/paddle/fluid/operators/gather_v2_op.h b/paddle/fluid/operators/gather_v2_op.h index a84de07292d85..68ac128bb7dcd 100644 --- a/paddle/fluid/operators/gather_v2_op.h +++ b/paddle/fluid/operators/gather_v2_op.h @@ -122,7 +122,6 @@ void GatherV2GradFunction(const Tensor* input, const Tensor* index, auto* index_data = index->data(); int axis_size = axis->numel(); - int index_size = index->numel(); auto input_dim = input->dims(); auto* input_data = input->data(); @@ -131,14 +130,7 @@ void GatherV2GradFunction(const Tensor* input, const Tensor* index, platform::errors::InvalidArgument( "Axis size should be 1, but received %d", axis_size)); int axis_index = axis_data[0]; - int index_dim_size = input_dim[axis_index]; - PADDLE_ENFORCE_LE( - index_size, index_dim_size, - platform::errors::InvalidArgument( - "The size that index should be less equal than the dim size of " - "input," - "but received index size:%d, the dim size of input %d.", - axis_size, index_dim_size)); + int input_index_dim_size = input_dim[axis_index]; int inner_dim_size = 1; int outer_dim_size = 1; @@ -152,14 +144,16 @@ void GatherV2GradFunction(const Tensor* input, const Tensor* index, auto* out_data = out->mutable_data(place); auto* dev_ctx = platform::DeviceContextPool::Instance().Get(place); + auto out_dim = out->dims(); + int out_index_dim_size = out_dim[axis_index]; operators::math::set_constant(*dev_ctx, out, 0.0); - for (int i = 0; i < index_size; i++) { - for (int j = 0; j < inner_dim_size; j++) { + for (int i = 0; i < inner_dim_size; i++) { + for (int j = 0; j < input_index_dim_size; j++) { for (int k = 0; k < outer_dim_size; k++) { - int index = k + index_data[i] * outer_dim_size + - j * outer_dim_size * index_dim_size; - out_data[index] += input_data[i]; + int index = k + index_data[j] * outer_dim_size + + i * outer_dim_size * out_index_dim_size; + out_data[index] += input_data[j * outer_dim_size + k]; } } } diff --git a/python/paddle/fluid/tests/unittests/test_gather_v2_op.py b/python/paddle/fluid/tests/unittests/test_gather_v2_op.py index 53cbc31e6e9e6..cbfd31063ca51 100644 --- a/python/paddle/fluid/tests/unittests/test_gather_v2_op.py +++ b/python/paddle/fluid/tests/unittests/test_gather_v2_op.py @@ -41,16 +41,14 @@ def setUp(self): def test_check_output(self): self.check_output() - """ def test_check_grad(self): - self.check_grad(['X'], 'Out') - """ + self.check_grad(['X'], 'Y') def config(self): """ For multi-dimension input """ - self.x_shape = (10, 20, 10) + self.x_shape = (3, 20, 3) self.x_type = "float64" self.index = [1, 3, 5] self.index_type = "int32" From b13eea488df36eac7fbc97ace4f5542aa8e3f8b7 Mon Sep 17 00:00:00 2001 From: wangchaochaohu Date: Thu, 20 Aug 2020 09:30:59 +0000 Subject: [PATCH 10/17] refine the code --- paddle/fluid/operators/gather_v2_op.cu | 44 ++++++++++++-------------- paddle/fluid/operators/gather_v2_op.h | 2 +- 2 files changed, 21 insertions(+), 25 deletions(-) diff --git a/paddle/fluid/operators/gather_v2_op.cu b/paddle/fluid/operators/gather_v2_op.cu index b24f649c8403b..6a1dc871f97ac 100644 --- a/paddle/fluid/operators/gather_v2_op.cu +++ b/paddle/fluid/operators/gather_v2_op.cu @@ -26,12 +26,12 @@ using Tensor = framework::Tensor; template __global__ void GatherGPUKernel(const T* input, const U* index, T* out, int outer_dim_size, int inner_dim_size, - int index_dim_size, int input_index_dim_size, - int size) { + int out_index_dim_size, + int input_index_dim_size, int size) { int idx = blockDim.x * blockIdx.x + threadIdx.x; for (; idx < size; idx += blockDim.x * gridDim.x) { - int inner_dim_index = idx / (outer_dim_size * index_dim_size); - int next_idx = idx % (outer_dim_size * index_dim_size); + int inner_dim_index = idx / (outer_dim_size * out_index_dim_size); + int next_idx = idx % (outer_dim_size * out_index_dim_size); int index_dim_index = next_idx / (outer_dim_size); int out_dim_index = next_idx % outer_dim_size; int input_index = @@ -44,14 +44,16 @@ __global__ void GatherGPUKernel(const T* input, const U* index, T* out, template __global__ void GatherGradGPUKernel(const T* input, const U* index, T* out, int outer_dim_size, int inner_dim_size, - int index_dim_size, int size) { + int input_index_dim_size, + int out_index_dim_size, int size) { int idx = blockDim.x * blockIdx.x + threadIdx.x; for (; idx < size; idx += blockDim.x * gridDim.x) { - int inner_dim_index = idx / (outer_dim_size * index_dim_size); - int out_dim_index = idx % outer_dim_size; - int input_dim_index = idx / outer_dim_size; - int out_index = inner_dim_index * (outer_dim_size * index_dim_size) + - index[input_dim_index] * outer_dim_size + out_dim_index; + int inner_dim_index = idx / (outer_dim_size * input_index_dim_size); + int next_idx = idx % (outer_dim_size * input_index_dim_size); + int index_dim_index = next_idx / (outer_dim_size); + int out_dim_index = next_idx % outer_dim_size; + int out_index = inner_dim_index * (outer_dim_size * out_index_dim_size) + + index[index_dim_index] * outer_dim_size + out_dim_index; paddle::platform::CudaAtomicAdd(out + out_index, *(input + idx)); } } @@ -116,7 +118,6 @@ void GatherV2GradCUDAFunction(const Tensor* input, const Tensor* index, const Tensor* axis, Tensor* out, const paddle::platform::Place& place, const framework::ExecutionContext& ctx) { - auto* axis_data = axis->data(); auto* index_data = index->data(); int axis_size = axis->numel(); @@ -132,14 +133,7 @@ void GatherV2GradCUDAFunction(const Tensor* input, const Tensor* index, Tensor cpu_axis; framework::TensorCopy(*axis, platform::CPUPlace(), &cpu_axis); int axis_index = cpu_axis.data()[0]; - int index_dim_size = input_dim[axis_index]; - PADDLE_ENFORCE_LE( - index_size, index_dim_size, - platform::errors::InvalidArgument( - "The size that index should be less equal than the dim size of " - "input," - "but received index size:%d, the dim size of input %d.", - axis_size, index_dim_size)); + int input_index_dim_size = input_dim[axis_index]; int inner_dim_size = 1; int outer_dim_size = 1; @@ -153,24 +147,26 @@ void GatherV2GradCUDAFunction(const Tensor* input, const Tensor* index, auto* out_data = out->mutable_data(place); auto* dev_ctx = platform::DeviceContextPool::Instance().Get(place); - operators::math::set_constant(*dev_ctx, out, 0.0); + auto out_dim = out->dims(); + int out_index_dim_size = out_dim[axis_index]; + // operators::math::set_constant(*dev_ctx, out, 0.0); int threads = 512; int grid = (input_size + threads - 1) / threads; auto stream = ctx.cuda_device_context().stream(); GatherGradGPUKernel<<>>( input_data, index_data, out_data, outer_dim_size, inner_dim_size, - index_dim_size, input_size); + input_index_dim_size, out_index_dim_size, input_size); } template class GatherV2OpCUDAKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& ctx) const override { - const Tensor* input = ctx.Input("X"); const Tensor* index = ctx.Input("Index"); const Tensor* axis = ctx.Input("Axis"); Tensor* out = ctx.Output("Y"); + const Tensor* input = ctx.Input("X"); const auto& index_type = index->type(); const auto& axis_type = axis->type(); @@ -202,10 +198,10 @@ template class GatherV2GradOpCUDAKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& ctx) const override { - const Tensor* input = ctx.Input("X"); const Tensor* index = ctx.Input("Index"); const Tensor* axis = ctx.Input("Axis"); - Tensor* out = ctx.Output("Y"); + auto* out = ctx.Output(framework::GradVarName("X")); + auto* input = ctx.Input(framework::GradVarName("Y")); const auto& index_type = index->type(); const auto& axis_type = axis->type(); diff --git a/paddle/fluid/operators/gather_v2_op.h b/paddle/fluid/operators/gather_v2_op.h index 68ac128bb7dcd..bdee5eeec2434 100644 --- a/paddle/fluid/operators/gather_v2_op.h +++ b/paddle/fluid/operators/gather_v2_op.h @@ -87,10 +87,10 @@ template class GatherV2OpKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& ctx) const override { - const Tensor* input = ctx.Input("X"); const Tensor* index = ctx.Input("Index"); const Tensor* axis = ctx.Input("Axis"); Tensor* out = ctx.Output("Y"); + const Tensor* input = ctx.Input("X"); const auto& index_type = index->type(); const auto& axis_type = axis->type(); From b1e68b0b886d1be8e403c99a1185e81010c48b10 Mon Sep 17 00:00:00 2001 From: wangchaochaohu Date: Thu, 20 Aug 2020 09:48:33 +0000 Subject: [PATCH 11/17] refine the code test=develop --- python/paddle/fluid/tests/unittests/test_gather_v2_op.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/python/paddle/fluid/tests/unittests/test_gather_v2_op.py b/python/paddle/fluid/tests/unittests/test_gather_v2_op.py index cbfd31063ca51..03e10afb2effc 100644 --- a/python/paddle/fluid/tests/unittests/test_gather_v2_op.py +++ b/python/paddle/fluid/tests/unittests/test_gather_v2_op.py @@ -34,8 +34,8 @@ def setUp(self): axis_np = np.array(self.axis).astype(self.index_type) index_np = np.array(self.index).astype(self.index_type) self.inputs = {'X': xnp, 'Index': index_np, 'Axis': axis_np} - out = gather_numpy(xnp, axis_np, index_np) - print(out.shape) + out = xnp[:, self.index, :] + print(out) self.outputs = {'Y': out} def test_check_output(self): @@ -48,9 +48,9 @@ def config(self): """ For multi-dimension input """ - self.x_shape = (3, 20, 3) + self.x_shape = (3, 88, 3) self.x_type = "float64" - self.index = [1, 3, 5] + self.index = [1, 1, 1] self.index_type = "int32" self.axis = [1] self.axis_type = "int32" From 8faf356b673bfe524712c375e90e3dce52f8d5d9 Mon Sep 17 00:00:00 2001 From: wangchaochaohu Date: Thu, 20 Aug 2020 14:42:57 +0000 Subject: [PATCH 12/17] refine --- paddle/fluid/operators/gather.cu.h | 129 +++++++++ paddle/fluid/operators/gather.h | 107 ++++++++ paddle/fluid/operators/gather_op.cc | 5 + paddle/fluid/operators/gather_op.cu | 55 ++++ paddle/fluid/operators/gather_op.h | 48 ++++ paddle/fluid/operators/gather_v2_op.cc | 128 --------- paddle/fluid/operators/gather_v2_op.cu | 250 ------------------ paddle/fluid/operators/gather_v2_op.h | 194 -------------- .../fluid/tests/unittests/test_gather_op.py | 62 +++++ python/paddle/tensor/manipulation.py | 10 +- 10 files changed, 412 insertions(+), 576 deletions(-) delete mode 100644 paddle/fluid/operators/gather_v2_op.cc delete mode 100644 paddle/fluid/operators/gather_v2_op.cu delete mode 100644 paddle/fluid/operators/gather_v2_op.h diff --git a/paddle/fluid/operators/gather.cu.h b/paddle/fluid/operators/gather.cu.h index f59d46ec79bd0..c4bdd9e439c54 100644 --- a/paddle/fluid/operators/gather.cu.h +++ b/paddle/fluid/operators/gather.cu.h @@ -18,6 +18,7 @@ limitations under the License. */ #include "paddle/fluid/framework/operator.h" #include "paddle/fluid/framework/tensor.h" #include "paddle/fluid/memory/malloc.h" +#include "paddle/fluid/operators/math/math_function.h" #include "paddle/fluid/platform/cuda_primitives.h" #include "paddle/fluid/platform/place.h" @@ -158,5 +159,133 @@ void GPUGatherNd(const framework::ExecutionContext& context, end_size); } +template +__global__ void GatherGPUKernel(const T* input, const U* index, T* out, + int outer_dim_size, int inner_dim_size, + int out_index_dim_size, + int input_index_dim_size, int size) { + int idx = blockDim.x * blockIdx.x + threadIdx.x; + for (; idx < size; idx += blockDim.x * gridDim.x) { + int inner_dim_index = idx / (outer_dim_size * out_index_dim_size); + int next_idx = idx % (outer_dim_size * out_index_dim_size); + int index_dim_index = next_idx / (outer_dim_size); + int out_dim_index = next_idx % outer_dim_size; + int input_index = + inner_dim_index * (outer_dim_size * input_index_dim_size) + + index[index_dim_index] * outer_dim_size + out_dim_index; + out[idx] = input[input_index]; + } +} + +template +__global__ void GatherGradGPUKernel(const T* input, const U* index, T* out, + int outer_dim_size, int inner_dim_size, + int input_index_dim_size, + int out_index_dim_size, int size) { + int idx = blockDim.x * blockIdx.x + threadIdx.x; + for (; idx < size; idx += blockDim.x * gridDim.x) { + int inner_dim_index = idx / (outer_dim_size * input_index_dim_size); + int next_idx = idx % (outer_dim_size * input_index_dim_size); + int index_dim_index = next_idx / (outer_dim_size); + int out_dim_index = next_idx % outer_dim_size; + int out_index = inner_dim_index * (outer_dim_size * out_index_dim_size) + + index[index_dim_index] * outer_dim_size + out_dim_index; + paddle::platform::CudaAtomicAdd(out + out_index, *(input + idx)); + } +} + +template +void GatherV2CUDAFunction(const Tensor* input, const Tensor* index, + const Tensor* axis, Tensor* out, + const paddle::platform::Place& place, + const framework::ExecutionContext& ctx) { + int axis_size = axis->numel(); + int index_size = index->numel(); + int input_size = input->numel(); + auto input_dim = input->dims(); + auto* input_data = input->data(); + auto* index_data = index->data(); + + if (input->numel() == 0) return; + PADDLE_ENFORCE_EQ(axis_size, 1, + platform::errors::InvalidArgument( + "Axis size should be 1, but received %d", axis_size)); + Tensor cpu_axis; + framework::TensorCopy(*axis, platform::CPUPlace(), &cpu_axis); + int axis_index = cpu_axis.data()[0]; + int index_dim_size = input_dim[axis_index]; + + int inner_dim_size = 1; + int outer_dim_size = 1; + std::vector out_dim_vec; + + for (int i = 0; i < axis_index; i++) { + inner_dim_size *= input_dim[i]; + out_dim_vec.push_back(input_dim[i]); + } + out_dim_vec.push_back(index_size); + for (int i = axis_index + 1; i < input_dim.size(); i++) { + outer_dim_size *= input_dim[i]; + out_dim_vec.push_back(input_dim[i]); + } + auto out_dim = framework::make_ddim(out_dim_vec); + + out->Resize(out_dim); + auto* out_data = out->mutable_data(place); + int out_size = out->numel(); + + int threads = 512; + int grid = (out_size + threads - 1) / threads; + auto stream = ctx.cuda_device_context().stream(); + GatherGPUKernel<<>>( + input_data, index_data, out_data, outer_dim_size, inner_dim_size, + index_size, index_dim_size, out_size); +} + +template +void GatherV2GradCUDAFunction(const Tensor* input, const Tensor* index, + const Tensor* axis, Tensor* out, + const paddle::platform::Place& place, + const framework::ExecutionContext& ctx) { + auto* index_data = index->data(); + + int axis_size = axis->numel(); + int index_size = index->numel(); + int input_size = input->numel(); + auto input_dim = input->dims(); + auto* input_data = input->data(); + + if (input->numel() == 0) return; + PADDLE_ENFORCE_EQ(axis_size, 1, + platform::errors::InvalidArgument( + "Axis size should be 1, but received %d", axis_size)); + Tensor cpu_axis; + framework::TensorCopy(*axis, platform::CPUPlace(), &cpu_axis); + int axis_index = cpu_axis.data()[0]; + int input_index_dim_size = input_dim[axis_index]; + + int inner_dim_size = 1; + int outer_dim_size = 1; + + for (int i = 0; i < axis_index; i++) { + inner_dim_size *= input_dim[i]; + } + for (int i = axis_index + 1; i < input_dim.size(); i++) { + outer_dim_size *= input_dim[i]; + } + + auto* out_data = out->mutable_data(place); + auto* dev_ctx = platform::DeviceContextPool::Instance().Get(place); + auto out_dim = out->dims(); + int out_index_dim_size = out_dim[axis_index]; + operators::math::set_constant(*dev_ctx, out, 0.0); + + int threads = 512; + int grid = (input_size + threads - 1) / threads; + auto stream = ctx.cuda_device_context().stream(); + GatherGradGPUKernel<<>>( + input_data, index_data, out_data, outer_dim_size, inner_dim_size, + input_index_dim_size, out_index_dim_size, input_size); +} } // namespace operators } // namespace paddle diff --git a/paddle/fluid/operators/gather.h b/paddle/fluid/operators/gather.h index f5a7bffe47453..c12a3b8adc978 100644 --- a/paddle/fluid/operators/gather.h +++ b/paddle/fluid/operators/gather.h @@ -15,10 +15,12 @@ limitations under the License. */ #pragma once #include #include +#include #include "paddle/fluid/framework/ddim.h" #include "paddle/fluid/framework/eigen.h" #include "paddle/fluid/framework/tensor.h" +#include "paddle/fluid/operators/math/math_function.h" #include "paddle/fluid/platform/place.h" namespace paddle { @@ -124,5 +126,110 @@ void CPUGatherNd(const platform::DeviceContext& ctx, const Tensor& input, } } +template +void GatherV2Function(const Tensor* input, const Tensor* index, + const Tensor* axis, Tensor* out, + const paddle::platform::Place& place) { + auto* axis_data = axis->data(); + auto* index_data = index->data(); + + int axis_size = axis->numel(); + int index_size = index->numel(); + int input_size = input->numel(); + auto input_dim = input->dims(); + auto* input_data = input->data(); + + if (input->numel() == 0) return; + PADDLE_ENFORCE_EQ(axis_size, 1, + platform::errors::InvalidArgument( + "Axis size should be 1, but received %d", axis_size)); + int axis_index = axis_data[0]; + + int input_index_dim_size = input_dim[axis_index]; + for (int i = 0; i < index_size; i++) { + PADDLE_ENFORCE_LT(index_data[i], input_index_dim_size, + platform::errors::InvalidArgument( + "The element of Index must be less than the size of " + "input dim size of axis which is %d, but received " + "index element which is %d in the %d index.", + input_index_dim_size, index_data[i], i)); + } + + int inner_dim_size = 1; + int outer_dim_size = 1; + std::vector out_dim_vec; + + for (int i = 0; i < axis_index; i++) { + inner_dim_size *= input_dim[i]; + out_dim_vec.push_back(input_dim[i]); + } + out_dim_vec.push_back(index_size); + for (int i = axis_index + 1; i < input_dim.size(); i++) { + outer_dim_size *= input_dim[i]; + out_dim_vec.push_back(input_dim[i]); + } + auto out_dim = framework::make_ddim(out_dim_vec); + + out->Resize(out_dim); + auto* out_data = out->mutable_data(place); + + int out_index = 0; + for (int i = 0; i < inner_dim_size; i++) { + for (int j = 0; j < index_size; j++) { + for (int k = 0; k < outer_dim_size; k++) { + int index = k + index_data[j] * outer_dim_size + + (i * input_size / inner_dim_size); + out_data[out_index] = input_data[index]; + out_index++; + } + } + } +} + +template +void GatherV2GradFunction(const Tensor* input, const Tensor* index, + const Tensor* axis, Tensor* out, + const paddle::platform::Place& place) { + auto* axis_data = axis->data(); + auto* index_data = index->data(); + + int axis_size = axis->numel(); + auto input_dim = input->dims(); + auto* input_data = input->data(); + + if (input->numel() == 0) return; + PADDLE_ENFORCE_EQ(axis_size, 1, + platform::errors::InvalidArgument( + "Axis size should be 1, but received %d", axis_size)); + int axis_index = axis_data[0]; + int input_index_dim_size = input_dim[axis_index]; + + int inner_dim_size = 1; + int outer_dim_size = 1; + + for (int i = 0; i < axis_index; i++) { + inner_dim_size *= input_dim[i]; + } + for (int i = axis_index + 1; i < input_dim.size(); i++) { + outer_dim_size *= input_dim[i]; + } + + auto* out_data = out->mutable_data(place); + auto* dev_ctx = platform::DeviceContextPool::Instance().Get(place); + auto out_dim = out->dims(); + int out_index_dim_size = out_dim[axis_index]; + operators::math::set_constant(*dev_ctx, out, 0.0); + + for (int i = 0; i < inner_dim_size; i++) { + for (int j = 0; j < input_index_dim_size; j++) { + for (int k = 0; k < outer_dim_size; k++) { + int index = k + index_data[j] * outer_dim_size + + i * outer_dim_size * out_index_dim_size; + out_data[index] += input_data[j * outer_dim_size + k]; + } + } + } +} + } // namespace operators } // namespace paddle diff --git a/paddle/fluid/operators/gather_op.cc b/paddle/fluid/operators/gather_op.cc index 6a3abaa600281..8a3450d1df97a 100644 --- a/paddle/fluid/operators/gather_op.cc +++ b/paddle/fluid/operators/gather_op.cc @@ -78,6 +78,9 @@ class GatherOpMaker : public framework::OpProtoAndCheckerMaker { void Make() override { AddInput("X", "The source input of gather op"); AddInput("Index", "The index input of gather op"); + AddInput("Axis", + "The Tensor which contains the axis that we do gather operation.") + .AsDispensable(); AddOutput("Out", "The output of gather op"); AddAttr( "overwrite", @@ -120,6 +123,8 @@ class GatherGradOpMaker : public framework::SingleGradOpMaker { void Apply(GradOpPtr op) const override { op->SetType("gather_grad"); op->SetInput("Index", this->Input("Index")); + op->SetInput("Axis", this->Input("Axis")); + op->SetInput("X", this->Input("X")); op->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out")); op->SetOutput(framework::GradVarName("X"), this->InputGrad("X")); diff --git a/paddle/fluid/operators/gather_op.cu b/paddle/fluid/operators/gather_op.cu index 5bef547c0542b..37fbfb21f60a0 100644 --- a/paddle/fluid/operators/gather_op.cu +++ b/paddle/fluid/operators/gather_op.cu @@ -31,6 +31,33 @@ class GatherOpCUDAKernel : public framework::OpKernel { auto *index = ctx.Input("Index"); auto *output = ctx.Output("Out"); + if (ctx.HasInput("Axis")) { + const Tensor *axis = ctx.Input("Axis"); + const auto &index_type = index->type(); + const auto &axis_type = axis->type(); + auto place = ctx.GetPlace(); + if (index_type == framework::proto::VarType::INT32 && + axis_type == framework::proto::VarType::INT32) { + GatherV2CUDAFunction(x, index, axis, output, place, + ctx); + } + if (index_type == framework::proto::VarType::INT32 && + axis_type == framework::proto::VarType::INT64) { + GatherV2CUDAFunction(x, index, axis, output, place, + ctx); + } + if (index_type == framework::proto::VarType::INT64 && + axis_type == framework::proto::VarType::INT32) { + GatherV2CUDAFunction(x, index, axis, output, place, + ctx); + } + if (index_type == framework::proto::VarType::INT64 && + axis_type == framework::proto::VarType::INT64) { + GatherV2CUDAFunction(x, index, axis, output, place, + ctx); + } + return; + } output->mutable_data(ctx.GetPlace()); if (x->numel() == 0) return; const auto &index_type = index->type(); @@ -64,6 +91,34 @@ class GatherGradOpCUDAKernel : public framework::OpKernel { auto *dX = ctx.Output(framework::GradVarName("X")); auto *dO = ctx.Input(framework::GradVarName("Out")); + if (ctx.HasInput("Axis")) { + const Tensor *axis = ctx.Input("Axis"); + const auto &index_type = index->type(); + const auto &axis_type = axis->type(); + auto place = ctx.GetPlace(); + if (index_type == framework::proto::VarType::INT32 && + axis_type == framework::proto::VarType::INT32) { + GatherV2GradCUDAFunction(dO, index, axis, dX, + place, ctx); + } + if (index_type == framework::proto::VarType::INT32 && + axis_type == framework::proto::VarType::INT64) { + GatherV2GradCUDAFunction(dO, index, axis, dX, + place, ctx); + } + if (index_type == framework::proto::VarType::INT64 && + axis_type == framework::proto::VarType::INT32) { + GatherV2GradCUDAFunction(dO, index, axis, dX, + place, ctx); + } + if (index_type == framework::proto::VarType::INT64 && + axis_type == framework::proto::VarType::INT64) { + GatherV2GradCUDAFunction(dO, index, axis, dX, + place, ctx); + } + return; + } + dX->mutable_data(ctx.GetPlace()); auto dxt = framework::EigenVector::Flatten(*dX); auto &place = *ctx.template device_context() diff --git a/paddle/fluid/operators/gather_op.h b/paddle/fluid/operators/gather_op.h index e4ce13ca8fc0b..8ec0d6ce0b69c 100644 --- a/paddle/fluid/operators/gather_op.h +++ b/paddle/fluid/operators/gather_op.h @@ -35,6 +35,30 @@ class GatherOpKernel : public framework::OpKernel { auto *index = ctx.Input("Index"); auto *output = ctx.Output("Out"); + if (ctx.HasInput("Axis")) { + const Tensor *axis = ctx.Input("Axis"); + const auto &index_type = index->type(); + const auto &axis_type = axis->type(); + auto place = ctx.GetPlace(); + if (index_type == framework::proto::VarType::INT32 && + axis_type == framework::proto::VarType::INT32) { + GatherV2Function(x, index, axis, output, place); + } + if (index_type == framework::proto::VarType::INT32 && + axis_type == framework::proto::VarType::INT64) { + GatherV2Function(x, index, axis, output, place); + } + if (index_type == framework::proto::VarType::INT64 && + axis_type == framework::proto::VarType::INT32) { + GatherV2Function(x, index, axis, output, place); + } + if (index_type == framework::proto::VarType::INT64 && + axis_type == framework::proto::VarType::INT64) { + GatherV2Function(x, index, axis, output, place); + } + return; + } + output->mutable_data(ctx.GetPlace()); if (x->numel() == 0) return; @@ -70,6 +94,30 @@ class GatherGradientOpKernel : public framework::OpKernel { auto *dX = ctx.Output(framework::GradVarName("X")); auto *dO = ctx.Input(framework::GradVarName("Out")); + if (ctx.HasInput("Axis")) { + const Tensor *axis = ctx.Input("Axis"); + const auto &index_type = index->type(); + const auto &axis_type = axis->type(); + auto place = ctx.GetPlace(); + if (index_type == framework::proto::VarType::INT32 && + axis_type == framework::proto::VarType::INT32) { + GatherV2GradFunction(dO, index, axis, dX, place); + } + if (index_type == framework::proto::VarType::INT32 && + axis_type == framework::proto::VarType::INT64) { + GatherV2GradFunction(dO, index, axis, dX, place); + } + if (index_type == framework::proto::VarType::INT64 && + axis_type == framework::proto::VarType::INT32) { + GatherV2GradFunction(dO, index, axis, dX, place); + } + if (index_type == framework::proto::VarType::INT64 && + axis_type == framework::proto::VarType::INT64) { + GatherV2GradFunction(dO, index, axis, dX, place); + } + return; + } + dX->mutable_data(ctx.GetPlace()); auto dxt = framework::EigenVector::Flatten(*dX); auto &place = *ctx.template device_context() diff --git a/paddle/fluid/operators/gather_v2_op.cc b/paddle/fluid/operators/gather_v2_op.cc deleted file mode 100644 index 257401f31fe5c..0000000000000 --- a/paddle/fluid/operators/gather_v2_op.cc +++ /dev/null @@ -1,128 +0,0 @@ -/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#include "paddle/fluid/operators/gather_v2_op.h" -#include -#include -#include -#include "paddle/fluid/framework/ddim.h" - -namespace paddle { -namespace operators { - -class GatherV2Op : public framework::OperatorWithKernel { - public: - using framework::OperatorWithKernel::OperatorWithKernel; - - void InferShape(framework::InferShapeContext* ctx) const override { - PADDLE_ENFORCE_EQ(ctx->HasInput("X"), true, - platform::errors::InvalidArgument( - "Input(X) of GatherOp should not be null.")); - PADDLE_ENFORCE_EQ(ctx->HasInput("Index"), true, - platform::errors::InvalidArgument( - "Input(Index) of GatherOp should not be null.")); - PADDLE_ENFORCE_EQ(ctx->HasInput("Axis"), true, - platform::errors::InvalidArgument( - "Input(Axis) of GatherOp should not be null.")); - PADDLE_ENFORCE_EQ(ctx->HasOutput("Y"), true, - platform::errors::InvalidArgument( - "Output(Y) of GatherOp should not be null.")); - - auto index_dims = ctx->GetInputDim("Index"); - PADDLE_ENFORCE(index_dims.size() == 1 || - (index_dims.size() == 2 && index_dims[1] == 1)); - framework::DDim output_dims(ctx->GetInputDim("X")); - ctx->SetOutputDim("Y", output_dims); - ctx->ShareLoD("X", /*->*/ "Y"); - } - - protected: - framework::OpKernelType GetExpectedKernelType( - const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType( - OperatorWithKernel::IndicateVarDataType(ctx, "X"), - ctx.device_context()); - } -}; - -class GatherV2GradOp : public framework::OperatorWithKernel { - public: - using framework::OperatorWithKernel::OperatorWithKernel; - - void InferShape(framework::InferShapeContext* ctx) const override { - ctx->SetOutputDim(framework::GradVarName("X"), ctx->GetInputDim("X")); - ctx->ShareLoD("X", /*-->*/ framework::GradVarName("X")); - } - - protected: - framework::OpKernelType GetExpectedKernelType( - const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType(OperatorWithKernel::IndicateVarDataType( - ctx, framework::GradVarName("Y")), - ctx.device_context()); - } -}; - -class GatherV2OpMaker : public framework::OpProtoAndCheckerMaker { - public: - void Make() override { - AddInput("X", "The source input of gather op"); - AddInput("Index", "The index input of gather op"); - AddOutput("Y", "The output of gather op"); - AddInput("Axis", - "The Tensor which contains the axis that we do gather operation."); - AddComment(R"DOC( -Y is obtained by gathering entries of the axis dimension -of X indexed by Index and concatenate them together. -)DOC"); - } -}; - -template -class GatherV2GradOpMaker : public framework::SingleGradOpMaker { - public: - using framework::SingleGradOpMaker::SingleGradOpMaker; - - protected: - void Apply(GradOpPtr op) const override { - op->SetType("gather_v2_grad"); - op->SetInput("Index", this->Input("Index")); - op->SetInput("X", this->Input("X")); - op->SetInput("Axis", this->Input("Axis")); - op->SetInput(framework::GradVarName("Y"), this->OutputGrad("Y")); - op->SetOutput(framework::GradVarName("X"), this->InputGrad("X")); - } -}; - -DECLARE_NO_NEED_BUFFER_VARS_INFERER(GatherV2GradNoNeedBufferVarInferer, "X"); - -} // namespace operators -} // namespace paddle - -namespace ops = paddle::operators; -REGISTER_OPERATOR(gather_v2, ops::GatherV2Op, ops::GatherV2OpMaker, - ops::GatherV2GradOpMaker, - ops::GatherV2GradOpMaker); -REGISTER_OPERATOR(gather_v2_grad, ops::GatherV2GradOp, - ops::GatherV2GradNoNeedBufferVarInferer); -REGISTER_OP_CPU_KERNEL(gather_v2, ops::GatherV2OpKernel, - ops::GatherV2OpKernel, - ops::GatherV2OpKernel, - ops::GatherV2OpKernel, - ops::GatherV2OpKernel); -REGISTER_OP_CPU_KERNEL(gather_v2_grad, ops::GatherV2GradientOpKernel, - ops::GatherV2GradientOpKernel, - ops::GatherV2GradientOpKernel, - ops::GatherV2GradientOpKernel, - ops::GatherV2GradientOpKernel); diff --git a/paddle/fluid/operators/gather_v2_op.cu b/paddle/fluid/operators/gather_v2_op.cu deleted file mode 100644 index 6a1dc871f97ac..0000000000000 --- a/paddle/fluid/operators/gather_v2_op.cu +++ /dev/null @@ -1,250 +0,0 @@ -/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#include "paddle/fluid/framework/eigen.h" -#include "paddle/fluid/operators/gather.cu.h" -#include "paddle/fluid/operators/gather_op.h" -#include "paddle/fluid/operators/gather_v2_op.h" -#include "paddle/fluid/operators/scatter.cu.h" - -namespace paddle { -namespace operators { - -using Tensor = framework::Tensor; - -template -__global__ void GatherGPUKernel(const T* input, const U* index, T* out, - int outer_dim_size, int inner_dim_size, - int out_index_dim_size, - int input_index_dim_size, int size) { - int idx = blockDim.x * blockIdx.x + threadIdx.x; - for (; idx < size; idx += blockDim.x * gridDim.x) { - int inner_dim_index = idx / (outer_dim_size * out_index_dim_size); - int next_idx = idx % (outer_dim_size * out_index_dim_size); - int index_dim_index = next_idx / (outer_dim_size); - int out_dim_index = next_idx % outer_dim_size; - int input_index = - inner_dim_index * (outer_dim_size * input_index_dim_size) + - index[index_dim_index] * outer_dim_size + out_dim_index; - out[idx] = input[input_index]; - } -} - -template -__global__ void GatherGradGPUKernel(const T* input, const U* index, T* out, - int outer_dim_size, int inner_dim_size, - int input_index_dim_size, - int out_index_dim_size, int size) { - int idx = blockDim.x * blockIdx.x + threadIdx.x; - for (; idx < size; idx += blockDim.x * gridDim.x) { - int inner_dim_index = idx / (outer_dim_size * input_index_dim_size); - int next_idx = idx % (outer_dim_size * input_index_dim_size); - int index_dim_index = next_idx / (outer_dim_size); - int out_dim_index = next_idx % outer_dim_size; - int out_index = inner_dim_index * (outer_dim_size * out_index_dim_size) + - index[index_dim_index] * outer_dim_size + out_dim_index; - paddle::platform::CudaAtomicAdd(out + out_index, *(input + idx)); - } -} - -template -void GatherV2CUDAFunction(const Tensor* input, const Tensor* index, - const Tensor* axis, Tensor* out, - const paddle::platform::Place& place, - const framework::ExecutionContext& ctx) { - int axis_size = axis->numel(); - int index_size = index->numel(); - int input_size = input->numel(); - auto input_dim = input->dims(); - auto* input_data = input->data(); - auto* index_data = index->data(); - - if (input->numel() == 0) return; - PADDLE_ENFORCE_EQ(axis_size, 1, - platform::errors::InvalidArgument( - "Axis size should be 1, but received %d", axis_size)); - Tensor cpu_axis; - framework::TensorCopy(*axis, platform::CPUPlace(), &cpu_axis); - int axis_index = cpu_axis.data()[0]; - int index_dim_size = input_dim[axis_index]; - PADDLE_ENFORCE_LE( - index_size, index_dim_size, - platform::errors::InvalidArgument( - "The size that index should be less equal than the dim size of " - "input," - "but received index size:%d, the dim size of input %d.", - axis_size, index_dim_size)); - - int inner_dim_size = 1; - int outer_dim_size = 1; - std::vector out_dim_vec; - - for (int i = 0; i < axis_index; i++) { - inner_dim_size *= input_dim[i]; - out_dim_vec.push_back(input_dim[i]); - } - out_dim_vec.push_back(index_size); - for (int i = axis_index + 1; i < input_dim.size(); i++) { - outer_dim_size *= input_dim[i]; - out_dim_vec.push_back(input_dim[i]); - } - auto out_dim = framework::make_ddim(out_dim_vec); - - out->Resize(out_dim); - auto* out_data = out->mutable_data(place); - int out_size = out->numel(); - - int threads = 512; - int grid = (out_size + threads - 1) / threads; - auto stream = ctx.cuda_device_context().stream(); - GatherGPUKernel<<>>( - input_data, index_data, out_data, outer_dim_size, inner_dim_size, - index_size, index_dim_size, out_size); -} - -template -void GatherV2GradCUDAFunction(const Tensor* input, const Tensor* index, - const Tensor* axis, Tensor* out, - const paddle::platform::Place& place, - const framework::ExecutionContext& ctx) { - auto* index_data = index->data(); - - int axis_size = axis->numel(); - int index_size = index->numel(); - int input_size = input->numel(); - auto input_dim = input->dims(); - auto* input_data = input->data(); - - if (input->numel() == 0) return; - PADDLE_ENFORCE_EQ(axis_size, 1, - platform::errors::InvalidArgument( - "Axis size should be 1, but received %d", axis_size)); - Tensor cpu_axis; - framework::TensorCopy(*axis, platform::CPUPlace(), &cpu_axis); - int axis_index = cpu_axis.data()[0]; - int input_index_dim_size = input_dim[axis_index]; - - int inner_dim_size = 1; - int outer_dim_size = 1; - - for (int i = 0; i < axis_index; i++) { - inner_dim_size *= input_dim[i]; - } - for (int i = axis_index + 1; i < input_dim.size(); i++) { - outer_dim_size *= input_dim[i]; - } - - auto* out_data = out->mutable_data(place); - auto* dev_ctx = platform::DeviceContextPool::Instance().Get(place); - auto out_dim = out->dims(); - int out_index_dim_size = out_dim[axis_index]; - // operators::math::set_constant(*dev_ctx, out, 0.0); - - int threads = 512; - int grid = (input_size + threads - 1) / threads; - auto stream = ctx.cuda_device_context().stream(); - GatherGradGPUKernel<<>>( - input_data, index_data, out_data, outer_dim_size, inner_dim_size, - input_index_dim_size, out_index_dim_size, input_size); -} - -template -class GatherV2OpCUDAKernel : public framework::OpKernel { - public: - void Compute(const framework::ExecutionContext& ctx) const override { - const Tensor* index = ctx.Input("Index"); - const Tensor* axis = ctx.Input("Axis"); - Tensor* out = ctx.Output("Y"); - const Tensor* input = ctx.Input("X"); - - const auto& index_type = index->type(); - const auto& axis_type = axis->type(); - auto place = ctx.GetPlace(); - if (index_type == framework::proto::VarType::INT32 && - axis_type == framework::proto::VarType::INT32) { - GatherV2CUDAFunction(input, index, axis, out, place, - ctx); - } - if (index_type == framework::proto::VarType::INT32 && - axis_type == framework::proto::VarType::INT64) { - GatherV2CUDAFunction(input, index, axis, out, place, - ctx); - } - if (index_type == framework::proto::VarType::INT64 && - axis_type == framework::proto::VarType::INT32) { - GatherV2CUDAFunction(input, index, axis, out, place, - ctx); - } - if (index_type == framework::proto::VarType::INT64 && - axis_type == framework::proto::VarType::INT64) { - GatherV2CUDAFunction(input, index, axis, out, place, - ctx); - } - } -}; - -template -class GatherV2GradOpCUDAKernel : public framework::OpKernel { - public: - void Compute(const framework::ExecutionContext& ctx) const override { - const Tensor* index = ctx.Input("Index"); - const Tensor* axis = ctx.Input("Axis"); - auto* out = ctx.Output(framework::GradVarName("X")); - auto* input = ctx.Input(framework::GradVarName("Y")); - - const auto& index_type = index->type(); - const auto& axis_type = axis->type(); - auto place = ctx.GetPlace(); - if (index_type == framework::proto::VarType::INT32 && - axis_type == framework::proto::VarType::INT32) { - GatherV2GradCUDAFunction(input, index, axis, out, - place, ctx); - } - if (index_type == framework::proto::VarType::INT32 && - axis_type == framework::proto::VarType::INT64) { - GatherV2GradCUDAFunction(input, index, axis, out, - place, ctx); - } - if (index_type == framework::proto::VarType::INT64 && - axis_type == framework::proto::VarType::INT32) { - GatherV2GradCUDAFunction(input, index, axis, out, - place, ctx); - } - if (index_type == framework::proto::VarType::INT64 && - axis_type == framework::proto::VarType::INT64) { - GatherV2GradCUDAFunction(input, index, axis, out, - place, ctx); - } - } -}; - -} // namespace operators -} // namespace paddle - -namespace ops = paddle::operators; -namespace plat = paddle::platform; -using CUDA = paddle::platform::CUDADeviceContext; -REGISTER_OP_CUDA_KERNEL(gather_v2, ops::GatherV2OpCUDAKernel, - ops::GatherV2OpCUDAKernel, - ops::GatherV2OpCUDAKernel, - ops::GatherV2OpCUDAKernel, - ops::GatherV2OpCUDAKernel, - ops::GatherV2OpCUDAKernel); - -REGISTER_OP_CUDA_KERNEL(gather_v2_grad, - ops::GatherV2GradOpCUDAKernel, - ops::GatherV2GradOpCUDAKernel, - ops::GatherV2GradOpCUDAKernel, - ops::GatherV2GradOpCUDAKernel, - ops::GatherV2GradOpCUDAKernel); diff --git a/paddle/fluid/operators/gather_v2_op.h b/paddle/fluid/operators/gather_v2_op.h deleted file mode 100644 index bdee5eeec2434..0000000000000 --- a/paddle/fluid/operators/gather_v2_op.h +++ /dev/null @@ -1,194 +0,0 @@ -/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#pragma once -#include -#include "paddle/fluid/framework/ddim.h" -#include "paddle/fluid/framework/eigen.h" -#include "paddle/fluid/framework/op_registry.h" -#include "paddle/fluid/operators/gather.h" -#include "paddle/fluid/operators/math/math_function.h" - -namespace paddle { -namespace operators { - -using Tensor = framework::Tensor; - -template -void GatherV2Function(const Tensor* input, const Tensor* index, - const Tensor* axis, Tensor* out, - const paddle::platform::Place& place) { - auto* axis_data = axis->data(); - auto* index_data = index->data(); - - int axis_size = axis->numel(); - int index_size = index->numel(); - int input_size = input->numel(); - auto input_dim = input->dims(); - auto* input_data = input->data(); - - if (input->numel() == 0) return; - PADDLE_ENFORCE_EQ(axis_size, 1, - platform::errors::InvalidArgument( - "Axis size should be 1, but received %d", axis_size)); - int axis_index = axis_data[0]; - int index_dim_size = input_dim[axis_index]; - PADDLE_ENFORCE_LE( - index_size, index_dim_size, - platform::errors::InvalidArgument( - "The size that index should be less equal than the dim size of " - "input," - "but received index size:%d, the dim size of input %d.", - axis_size, index_dim_size)); - - int inner_dim_size = 1; - int outer_dim_size = 1; - std::vector out_dim_vec; - - for (int i = 0; i < axis_index; i++) { - inner_dim_size *= input_dim[i]; - out_dim_vec.push_back(input_dim[i]); - } - out_dim_vec.push_back(index_size); - for (int i = axis_index + 1; i < input_dim.size(); i++) { - outer_dim_size *= input_dim[i]; - out_dim_vec.push_back(input_dim[i]); - } - auto out_dim = framework::make_ddim(out_dim_vec); - - out->Resize(out_dim); - auto* out_data = out->mutable_data(place); - - int out_index = 0; - for (int i = 0; i < inner_dim_size; i++) { - for (int j = 0; j < index_size; j++) { - for (int k = 0; k < outer_dim_size; k++) { - int index = k + index_data[j] * outer_dim_size + - (i * input_size / inner_dim_size); - out_data[out_index] = input_data[index]; - out_index++; - } - } - } -} - -template -class GatherV2OpKernel : public framework::OpKernel { - public: - void Compute(const framework::ExecutionContext& ctx) const override { - const Tensor* index = ctx.Input("Index"); - const Tensor* axis = ctx.Input("Axis"); - Tensor* out = ctx.Output("Y"); - const Tensor* input = ctx.Input("X"); - - const auto& index_type = index->type(); - const auto& axis_type = axis->type(); - auto place = ctx.GetPlace(); - if (index_type == framework::proto::VarType::INT32 && - axis_type == framework::proto::VarType::INT32) { - GatherV2Function(input, index, axis, out, place); - } - if (index_type == framework::proto::VarType::INT32 && - axis_type == framework::proto::VarType::INT64) { - GatherV2Function(input, index, axis, out, place); - } - if (index_type == framework::proto::VarType::INT64 && - axis_type == framework::proto::VarType::INT32) { - GatherV2Function(input, index, axis, out, place); - } - if (index_type == framework::proto::VarType::INT64 && - axis_type == framework::proto::VarType::INT64) { - GatherV2Function(input, index, axis, out, place); - } - } -}; - -template -void GatherV2GradFunction(const Tensor* input, const Tensor* index, - const Tensor* axis, Tensor* out, - const paddle::platform::Place& place) { - auto* axis_data = axis->data(); - auto* index_data = index->data(); - - int axis_size = axis->numel(); - auto input_dim = input->dims(); - auto* input_data = input->data(); - - if (input->numel() == 0) return; - PADDLE_ENFORCE_EQ(axis_size, 1, - platform::errors::InvalidArgument( - "Axis size should be 1, but received %d", axis_size)); - int axis_index = axis_data[0]; - int input_index_dim_size = input_dim[axis_index]; - - int inner_dim_size = 1; - int outer_dim_size = 1; - - for (int i = 0; i < axis_index; i++) { - inner_dim_size *= input_dim[i]; - } - for (int i = axis_index + 1; i < input_dim.size(); i++) { - outer_dim_size *= input_dim[i]; - } - - auto* out_data = out->mutable_data(place); - auto* dev_ctx = platform::DeviceContextPool::Instance().Get(place); - auto out_dim = out->dims(); - int out_index_dim_size = out_dim[axis_index]; - operators::math::set_constant(*dev_ctx, out, 0.0); - - for (int i = 0; i < inner_dim_size; i++) { - for (int j = 0; j < input_index_dim_size; j++) { - for (int k = 0; k < outer_dim_size; k++) { - int index = k + index_data[j] * outer_dim_size + - i * outer_dim_size * out_index_dim_size; - out_data[index] += input_data[j * outer_dim_size + k]; - } - } - } -} - -template -class GatherV2GradientOpKernel : public framework::OpKernel { - public: - void Compute(const framework::ExecutionContext& ctx) const override { - auto* index = ctx.Input("Index"); - auto* axis = ctx.Input("Axis"); - auto* out = ctx.Output(framework::GradVarName("X")); - auto* input = ctx.Input(framework::GradVarName("Y")); - - const auto& index_type = index->type(); - const auto& axis_type = axis->type(); - auto place = ctx.GetPlace(); - if (index_type == framework::proto::VarType::INT32 && - axis_type == framework::proto::VarType::INT32) { - GatherV2GradFunction(input, index, axis, out, place); - } - if (index_type == framework::proto::VarType::INT32 && - axis_type == framework::proto::VarType::INT64) { - GatherV2GradFunction(input, index, axis, out, place); - } - if (index_type == framework::proto::VarType::INT64 && - axis_type == framework::proto::VarType::INT32) { - GatherV2GradFunction(input, index, axis, out, place); - } - if (index_type == framework::proto::VarType::INT64 && - axis_type == framework::proto::VarType::INT64) { - GatherV2GradFunction(input, index, axis, out, place); - } - } -}; - -} // namespace operators -} // namespace paddle diff --git a/python/paddle/fluid/tests/unittests/test_gather_op.py b/python/paddle/fluid/tests/unittests/test_gather_op.py index f8763e731eeed..468fe14592750 100644 --- a/python/paddle/fluid/tests/unittests/test_gather_op.py +++ b/python/paddle/fluid/tests/unittests/test_gather_op.py @@ -21,6 +21,13 @@ import paddle.fluid as fluid +def gather_numpy(x, index, axis): + x_transpose = np.swapaxes(x, 0, axis) + tmp_gather = x_transpose[index, ...] + gather = np.swapaxes(tmp_gather, 0, axis) + return gather + + class TestGatherOp(OpTest): def setUp(self): self.op_type = "gather" @@ -108,6 +115,61 @@ def config(self): self.index_type = "int32" +class TestGatherOp1(OpTest): + def setUp(self): + self.op_type = "gather" + self.config() + xnp = np.random.random(self.x_shape).astype(self.x_type) + axis_np = np.array(self.axis).astype(self.index_type) + index_np = np.array(self.index).astype(self.index_type) + out = gather_numpy(xnp, index_np, axis_np[0]) + self.inputs = {'X': xnp, 'Index': index_np, 'Axis': axis_np} + self.outputs = {'Out': out} + + def test_check_output(self): + self.check_output() + + def test_check_grad(self): + self.check_grad(['X'], 'Out') + + def config(self): + """ + For multi-dimension input + """ + self.x_shape = (3, 88, 3) + self.x_type = "float64" + self.index = [1, 3, 5] + self.index_type = "int32" + self.axis = [1] + self.axis_type = "int32" + + +class TestGatherOp2(TestGatherOp1): + def config(self): + """ + For multi-dimension input + """ + self.x_shape = (10, 88, 10) + self.x_type = "float64" + self.index = [1, 3, 5] + self.index_type = "int64" + self.axis = [0] + self.axis_type = "int32" + + +class TestGatherOp2(TestGatherOp1): + def config(self): + """ + For multi-dimension input + """ + self.x_shape = (3, 88, 10) + self.x_type = "float64" + self.index = [1, 1, 1, 1, 1, 1, 1, 1, 1, 1] + self.index_type = "int64" + self.axis = [0] + self.axis_type = "int32" + + class API_TestGather(unittest.TestCase): def test_out(self): with fluid.program_guard(fluid.Program(), fluid.Program()): diff --git a/python/paddle/tensor/manipulation.py b/python/paddle/tensor/manipulation.py index 09e232a73a695..76409cabb64af 100644 --- a/python/paddle/tensor/manipulation.py +++ b/python/paddle/tensor/manipulation.py @@ -712,15 +712,17 @@ def gather(x, index, axis=None, name=None): output = paddle.gather(input, index) # expected output: [[1,2],[3,4]] """ + if in_dygraph_mode(): + return core.ops.gather(x, index, axis) helper = LayerHelper('gather', **locals()) dtype = helper.input_dtype() out = helper.create_variable_for_type_inference(dtype) helper.append_op( type="gather", - inputs={"X": input, - "Index": index}, - outputs={"Out": out}, - attrs={'overwrite': overwrite}) + inputs={"X": x, + "Index": index, + "Axis": axis}, + outputs={"Out": out}) return out From b5f24fe270ced697dbe30d4a797be617fcd337d4 Mon Sep 17 00:00:00 2001 From: wangchaochaohu Date: Thu, 20 Aug 2020 16:06:35 +0000 Subject: [PATCH 13/17] refine the code --- .../fluid/tests/unittests/test_gather_op.py | 60 +++++++++++++++---- python/paddle/tensor/manipulation.py | 49 ++++++++++----- 2 files changed, 80 insertions(+), 29 deletions(-) diff --git a/python/paddle/fluid/tests/unittests/test_gather_op.py b/python/paddle/fluid/tests/unittests/test_gather_op.py index 468fe14592750..f52dff7681ce3 100644 --- a/python/paddle/fluid/tests/unittests/test_gather_op.py +++ b/python/paddle/fluid/tests/unittests/test_gather_op.py @@ -157,12 +157,25 @@ def config(self): self.axis_type = "int32" -class TestGatherOp2(TestGatherOp1): +class TestGatherOp3(TestGatherOp1): def config(self): """ For multi-dimension input """ - self.x_shape = (3, 88, 10) + self.x_shape = (10, 88, 10) + self.x_type = "float64" + self.index = [1, 3, 5] + self.index_type = "int64" + self.axis = [2] + self.axis_type = "int32" + + +class TestGatherOp4(TestGatherOp1): + def config(self): + """ + For multi-dimension input + """ + self.x_shape = (3, 100, 10) self.x_type = "float64" self.index = [1, 1, 1, 1, 1, 1, 1, 1, 1, 1] self.index_type = "int64" @@ -171,11 +184,11 @@ def config(self): class API_TestGather(unittest.TestCase): - def test_out(self): + def test_out1(self): with fluid.program_guard(fluid.Program(), fluid.Program()): data1 = fluid.layers.data('data1', shape=[-1, 2], dtype='float64') - index = fluid.layers.data('index', shape=[-1, 1], dtype='float64') - out = paddle.gather(data1, index) + index = fluid.layers.data('index', shape=[-1, 1], dtype='int32') + out = paddle.fluid.layers.gather(data1, index) place = fluid.CPUPlace() exe = fluid.Executor(place) input = np.array([[1, 2], [3, 4], [5, 6]]) @@ -186,18 +199,39 @@ def test_out(self): expected_output = np.array([[3, 4], [5, 6]]) self.assertTrue(np.allclose(result, expected_output)) + def test_out2(self): + with paddle.static.program_guard(paddle.static.Program(), + paddle.static.Program()): + x = paddle.data('x', shape=[-1, 2], dtype='float64') + index = paddle.data('index', shape=[-1, 1], dtype='int32') + axis = paddle.data('axis', shape=[1], dtype='int32') + out = paddle.gather(x, index, axis) + place = paddle.CPUPlace() + exe = paddle.static.Executor(place) + x_np = np.array([[1, 2], [3, 4], [5, 6]]).astype('float64') + index_np = np.array([1, 1]).astype('int32') + axis_np = np.array([1]).astype('int32') + result, = exe.run( + feed={"x": x_np, + "index": index_np, + 'axis': axis_np}, + fetch_list=[out]) + expected_output = gather_numpy(x_np, index_np, axis_np) + self.assertTrue(np.allclose(result, expected_output)) + class API_TestDygraphGather(unittest.TestCase): def test_out(self): - with fluid.dygraph.guard(): - input_1 = np.array([[1, 2], [3, 4], [5, 6]]) - index_1 = np.array([1, 2]) - input = fluid.dygraph.to_variable(input_1) - index = fluid.dygraph.to_variable(index_1) - output = paddle.fluid.layers.gather(input, index) - output_np = output.numpy() - expected_output = np.array([[3, 4], [5, 6]]) + paddle.disable_static() + input_1 = np.array([[1, 2], [3, 4], [5, 6]]) + index_1 = np.array([1, 2]) + input = paddle.to_tensor(input_1) + index = paddle.to_tensor(index_1) + output = paddle.fluid.layers.gather(input, index) + output_np = output.numpy() + expected_output = np.array([[3, 4], [5, 6]]) self.assertTrue(np.allclose(output_np, expected_output)) + paddle.enable_static() if __name__ == "__main__": diff --git a/python/paddle/tensor/manipulation.py b/python/paddle/tensor/manipulation.py index 76409cabb64af..d66bcb078f4b4 100644 --- a/python/paddle/tensor/manipulation.py +++ b/python/paddle/tensor/manipulation.py @@ -676,24 +676,28 @@ def gather(x, index, axis=None, name=None): Given: - X = [[1, 2], + x = [[1, 2], [3, 4], [5, 6]] - Index = [1, 2] + index = [1, 2] + axis=[0] Then: - Out = [[3, 4], + out = [[3, 4], [5, 6]] Args: - input (Tensor): The source input tensor with rank>=1. Supported data type is + x (Tensor): The source input tensor with rank>=1. Supported data type is int32, int64, float32, float64 and uint8 (only for CPU), float16 (only for GPU). index (Tensor): The index input tensor with rank=1. Data type is int32 or int64. + axis (Tensor|int, optional): The axis of input to be gathered, it's can be int or a Tensor with data type is int32 or int64. Default: if None, the axis is 0. + name (str, optional): The default value is None. Normally there is no need for user to set this property. + For more information, please refer to :ref:`api_guide_Name` . Returns: - output (Tensor): The output is a tensor with the same rank as input. + output (Tensor): The output is a tensor with the same rank as ``x``. Examples: @@ -701,19 +705,32 @@ def gather(x, index, axis=None, name=None): import numpy as np import paddle - import paddle.fluid as fluid - - with fluid.dygraph.guard(): - input_1 = np.array([[1,2],[3,4],[5,6]]) - index_1 = np.array([0,1]) - input = fluid.dygraph.to_variable(input_1) - index = fluid.dygraph.to_variable(index_1) - output = paddle.gather(input, index) - # expected output: [[1,2],[3,4]] + paddle.disable_static() + input_1 = np.array([[1,2],[3,4],[5,6]]) + index_1 = np.array([0,1]) + input = fluid.to_tensor(input_1) + index = fluid.to_tensor(index_1) + output = paddle.gather(input, index, axis=0) + # expected output: [[1,2],[3,4]] """ + if axis is None: + axis = 0 + axis_tensor = axis + if not isinstance(axis, Variable): + axis_tensor = fill_constant(shape=[1], dtype='int64', value=axis) if in_dygraph_mode(): - return core.ops.gather(x, index, axis) + return core.ops.gather(x, index, axis_tensor) + + check_variable_and_dtype( + x, 'x', ['float16', 'float32', 'float64', 'int32', 'int64', 'uint8'], + 'gather') + check_variable_and_dtype(index, 'index', ['int32', 'int64'], 'gather') + if isinstance(axis, Variable): + check_variable_and_dtype(axis, 'axis', ['int32', 'int64'], 'gather') + else: + check_type(axis, 'axis', (int), 'gather') + helper = LayerHelper('gather', **locals()) dtype = helper.input_dtype() out = helper.create_variable_for_type_inference(dtype) @@ -721,7 +738,7 @@ def gather(x, index, axis=None, name=None): type="gather", inputs={"X": x, "Index": index, - "Axis": axis}, + "Axis": axis_tensor}, outputs={"Out": out}) return out From 9525621c9732a2f424253a7483c59d5751c91d84 Mon Sep 17 00:00:00 2001 From: wangchaochaohu Date: Fri, 21 Aug 2020 03:25:17 +0000 Subject: [PATCH 14/17] delete unsed test test=develop --- .../tests/unittests/test_gather_v2_op.py | 60 ------------------- 1 file changed, 60 deletions(-) delete mode 100644 python/paddle/fluid/tests/unittests/test_gather_v2_op.py diff --git a/python/paddle/fluid/tests/unittests/test_gather_v2_op.py b/python/paddle/fluid/tests/unittests/test_gather_v2_op.py deleted file mode 100644 index 03e10afb2effc..0000000000000 --- a/python/paddle/fluid/tests/unittests/test_gather_v2_op.py +++ /dev/null @@ -1,60 +0,0 @@ -# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from __future__ import print_function - -import unittest -import numpy as np -from op_test import OpTest -import paddle -import paddle.fluid as fluid - - -def gather_numpy(x, axis, index): - result = x[:, index, :] - return result - - -class TestGatherOp(OpTest): - def setUp(self): - self.op_type = "gather_v2" - self.config() - xnp = np.random.random(self.x_shape).astype(self.x_type) - axis_np = np.array(self.axis).astype(self.index_type) - index_np = np.array(self.index).astype(self.index_type) - self.inputs = {'X': xnp, 'Index': index_np, 'Axis': axis_np} - out = xnp[:, self.index, :] - print(out) - self.outputs = {'Y': out} - - def test_check_output(self): - self.check_output() - - def test_check_grad(self): - self.check_grad(['X'], 'Y') - - def config(self): - """ - For multi-dimension input - """ - self.x_shape = (3, 88, 3) - self.x_type = "float64" - self.index = [1, 1, 1] - self.index_type = "int32" - self.axis = [1] - self.axis_type = "int32" - - -if __name__ == "__main__": - unittest.main() From a3ec71a52e79fb3f93e22e3f47e3f28214cce329 Mon Sep 17 00:00:00 2001 From: wangchaochaohu Date: Fri, 21 Aug 2020 06:16:09 +0000 Subject: [PATCH 15/17] refine the code and doc --- paddle/fluid/pybind/op_function_generator.cc | 1 + python/paddle/fluid/layers/nn.py | 19 ++++++++++++++----- .../fluid/tests/unittests/test_gather_op.py | 14 +++++++++++++- python/paddle/tensor/manipulation.py | 5 +++++ 4 files changed, 33 insertions(+), 6 deletions(-) diff --git a/paddle/fluid/pybind/op_function_generator.cc b/paddle/fluid/pybind/op_function_generator.cc index 93ba9feedf95b..5507ab813c3e2 100644 --- a/paddle/fluid/pybind/op_function_generator.cc +++ b/paddle/fluid/pybind/op_function_generator.cc @@ -41,6 +41,7 @@ std::map> op_ins_map = { {"fake_quantize_dequantize_moving_average_abs_max", {"X", "InScale", "InAccum", "InState"}}, {"nll_loss", {"X", "Label", "Weight"}}, + {"gather", {"X", "Index", "Axis"}}, }; // NOTE(zhiqiu): Like op_ins_map. diff --git a/python/paddle/fluid/layers/nn.py b/python/paddle/fluid/layers/nn.py index 446510121e72a..cb378c3babfe4 100644 --- a/python/paddle/fluid/layers/nn.py +++ b/python/paddle/fluid/layers/nn.py @@ -8168,19 +8168,21 @@ def gather(input, index, overwrite=True): [5, 6]] Args: - input (Variable): The source input tensor with rank>=1. Supported data type is + input (Temspr): The source input tensor with rank>=1. Supported data type is int32, int64, float32, float64 and uint8 (only for CPU), float16 (only for GPU). - index (Variable): The index input tensor with rank=1. Data type is int32 or int64. + index (Tensor): The index input tensor with rank=1. Data type is int32 or int64. overwrite (bool, optional): The mode that updating the grad when has same index. If True, use the overwrite mode to update the grad of the same index, if False, use the accumulate mode to update the grad of the same index. Default value is True. - - Returns: - output (Variable): The output is a tensor with the same rank as input. + output (Tensor): The output is a tensor with the same rank as input. + + Raises: + TypeError: ``x`` must be a Tensor and the data type of ``x`` must to be one of float16, float32, float64, int32, int64, uint8. + TypeError: ``index`` must be a Tensor and the data type of ``index`` must be int32 or int64. Examples: @@ -8191,6 +8193,13 @@ def gather(input, index, overwrite=True): index = fluid.data(name='index', shape=[-1, 1], dtype='int32') output = fluid.layers.gather(x, index) """ + if in_dygraph_mode(): + return core.ops.gather(input, index, None) + + check_variable_and_dtype( + input, 'x', + ['float16', 'float32', 'float64', 'int32', 'int64', 'uint8'], 'gather') + check_variable_and_dtype(index, 'index', ['int32', 'int64'], 'gather') helper = LayerHelper('gather', **locals()) dtype = helper.input_dtype() out = helper.create_variable_for_type_inference(dtype) diff --git a/python/paddle/fluid/tests/unittests/test_gather_op.py b/python/paddle/fluid/tests/unittests/test_gather_op.py index f52dff7681ce3..e60da20e88ee9 100644 --- a/python/paddle/fluid/tests/unittests/test_gather_op.py +++ b/python/paddle/fluid/tests/unittests/test_gather_op.py @@ -221,7 +221,7 @@ def test_out2(self): class API_TestDygraphGather(unittest.TestCase): - def test_out(self): + def test_out1(self): paddle.disable_static() input_1 = np.array([[1, 2], [3, 4], [5, 6]]) index_1 = np.array([1, 2]) @@ -233,6 +233,18 @@ def test_out(self): self.assertTrue(np.allclose(output_np, expected_output)) paddle.enable_static() + def test_out12(self): + paddle.disable_static() + input_1 = np.array([[1, 2], [3, 4], [5, 6]]) + index_1 = np.array([1, 2]) + x = paddle.to_tensor(input_1) + index = paddle.to_tensor(index_1) + output = paddle.gather(x, index, axis=0) + output_np = output.numpy() + expected_output = gather_numpy(input_1, index_1, axis=0) + self.assertTrue(np.allclose(output_np, expected_output)) + paddle.enable_static() + if __name__ == "__main__": unittest.main() diff --git a/python/paddle/tensor/manipulation.py b/python/paddle/tensor/manipulation.py index d66bcb078f4b4..a53a060f664bd 100644 --- a/python/paddle/tensor/manipulation.py +++ b/python/paddle/tensor/manipulation.py @@ -698,6 +698,11 @@ def gather(x, index, axis=None, name=None): Returns: output (Tensor): The output is a tensor with the same rank as ``x``. + + Raises: + TypeError: ``x`` must be a Tensor and the data type of ``x`` must to be one of float16, float32, float64, int32, int64, uint8. + TypeError: ``index`` must be a Tensor and the data type of ``index`` must be int32 or int64. + TypeError: ``axis`` must be a Tensor or int and the data type of ``index`` must be int32 or int64 when it's a Tensor. Examples: From 77a1d7e91dc902154fe997285dbd3d9cae0bd872 Mon Sep 17 00:00:00 2001 From: wangchaochaohu Date: Fri, 21 Aug 2020 06:47:10 +0000 Subject: [PATCH 16/17] refine the code test=develop --- .../fluid/tests/unittests/test_gather_op.py | 52 +++++++++++++++++++ 1 file changed, 52 insertions(+) diff --git a/python/paddle/fluid/tests/unittests/test_gather_op.py b/python/paddle/fluid/tests/unittests/test_gather_op.py index e60da20e88ee9..1f6e522d2668b 100644 --- a/python/paddle/fluid/tests/unittests/test_gather_op.py +++ b/python/paddle/fluid/tests/unittests/test_gather_op.py @@ -246,5 +246,57 @@ def test_out12(self): paddle.enable_static() +class TestGathertError(unittest.TestCase): + def test_error1(self): + with paddle.static.program_guard(paddle.static.Program(), + paddle.static.Program()): + + shape = [8, 9, 6] + x = paddle.data(shape=shape, dtype='int8', name='x') + axis = paddle.data(shape=[1], dtype='float32', name='axis') + index = paddle.data(shape=shape, dtype='int32', name='index') + index_float = paddle.data( + shape=shape, dtype='float32', name='index_float') + + def test_x_type(): + paddle.gather(x, index) + + self.assertRaises(TypeError, test_x_type) + + def test_index_type(): + paddle.gather(x, index_float) + + self.assertRaises(TypeError, test_index_type) + + def test_axis_dtype(): + paddle.gather(x, index, axis=1.11) + + self.assertRaises(TypeError, test_axis_dtype) + + def test_axis_dtype(): + paddle.gather(x, index, axis=axis) + + self.assertRaises(TypeError, test_axis_dtype) + + def test_error2(self): + with fluid.program_guard(fluid.Program(), fluid.Program()): + + shape = [8, 9, 6] + x = fluid.data(shape=shape, dtype='int8', name='x') + index = fluid.data(shape=shape, dtype='int32', name='mask') + index_float = fluid.data( + shape=shape, dtype='float32', name='index_float') + + def test_x_type(): + paddle.fluid.layers.gather(x, index) + + self.assertRaises(TypeError, test_x_type) + + def test_index_type(): + paddle.fluid.layers.gather(x, index_float) + + self.assertRaises(TypeError, test_index_type) + + if __name__ == "__main__": unittest.main() From 3092e2458dc1906fe00f46d321f3b5624d290291 Mon Sep 17 00:00:00 2001 From: wangchaochaohu Date: Fri, 21 Aug 2020 12:27:33 +0000 Subject: [PATCH 17/17] refine the doc test=develop --- python/paddle/fluid/layers/nn.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/paddle/fluid/layers/nn.py b/python/paddle/fluid/layers/nn.py index 4d12429843818..688a62f3b4794 100644 --- a/python/paddle/fluid/layers/nn.py +++ b/python/paddle/fluid/layers/nn.py @@ -8179,7 +8179,7 @@ def gather(input, index, overwrite=True): [5, 6]] Args: - input (Temspr): The source input tensor with rank>=1. Supported data type is + input (Tensor): The source input tensor with rank>=1. Supported data type is int32, int64, float32, float64 and uint8 (only for CPU), float16 (only for GPU). index (Tensor): The index input tensor with rank=1. Data type is int32 or int64.