diff --git a/paddle/fluid/operators/gather.cu.h b/paddle/fluid/operators/gather.cu.h index f59d46ec79bd0..c4bdd9e439c54 100644 --- a/paddle/fluid/operators/gather.cu.h +++ b/paddle/fluid/operators/gather.cu.h @@ -18,6 +18,7 @@ limitations under the License. */ #include "paddle/fluid/framework/operator.h" #include "paddle/fluid/framework/tensor.h" #include "paddle/fluid/memory/malloc.h" +#include "paddle/fluid/operators/math/math_function.h" #include "paddle/fluid/platform/cuda_primitives.h" #include "paddle/fluid/platform/place.h" @@ -158,5 +159,133 @@ void GPUGatherNd(const framework::ExecutionContext& context, end_size); } +template +__global__ void GatherGPUKernel(const T* input, const U* index, T* out, + int outer_dim_size, int inner_dim_size, + int out_index_dim_size, + int input_index_dim_size, int size) { + int idx = blockDim.x * blockIdx.x + threadIdx.x; + for (; idx < size; idx += blockDim.x * gridDim.x) { + int inner_dim_index = idx / (outer_dim_size * out_index_dim_size); + int next_idx = idx % (outer_dim_size * out_index_dim_size); + int index_dim_index = next_idx / (outer_dim_size); + int out_dim_index = next_idx % outer_dim_size; + int input_index = + inner_dim_index * (outer_dim_size * input_index_dim_size) + + index[index_dim_index] * outer_dim_size + out_dim_index; + out[idx] = input[input_index]; + } +} + +template +__global__ void GatherGradGPUKernel(const T* input, const U* index, T* out, + int outer_dim_size, int inner_dim_size, + int input_index_dim_size, + int out_index_dim_size, int size) { + int idx = blockDim.x * blockIdx.x + threadIdx.x; + for (; idx < size; idx += blockDim.x * gridDim.x) { + int inner_dim_index = idx / (outer_dim_size * input_index_dim_size); + int next_idx = idx % (outer_dim_size * input_index_dim_size); + int index_dim_index = next_idx / (outer_dim_size); + int out_dim_index = next_idx % outer_dim_size; + int out_index = inner_dim_index * (outer_dim_size * out_index_dim_size) + + index[index_dim_index] * outer_dim_size + out_dim_index; + paddle::platform::CudaAtomicAdd(out + out_index, *(input + idx)); + } +} + +template +void GatherV2CUDAFunction(const Tensor* input, const Tensor* index, + const Tensor* axis, Tensor* out, + const paddle::platform::Place& place, + const framework::ExecutionContext& ctx) { + int axis_size = axis->numel(); + int index_size = index->numel(); + int input_size = input->numel(); + auto input_dim = input->dims(); + auto* input_data = input->data(); + auto* index_data = index->data(); + + if (input->numel() == 0) return; + PADDLE_ENFORCE_EQ(axis_size, 1, + platform::errors::InvalidArgument( + "Axis size should be 1, but received %d", axis_size)); + Tensor cpu_axis; + framework::TensorCopy(*axis, platform::CPUPlace(), &cpu_axis); + int axis_index = cpu_axis.data()[0]; + int index_dim_size = input_dim[axis_index]; + + int inner_dim_size = 1; + int outer_dim_size = 1; + std::vector out_dim_vec; + + for (int i = 0; i < axis_index; i++) { + inner_dim_size *= input_dim[i]; + out_dim_vec.push_back(input_dim[i]); + } + out_dim_vec.push_back(index_size); + for (int i = axis_index + 1; i < input_dim.size(); i++) { + outer_dim_size *= input_dim[i]; + out_dim_vec.push_back(input_dim[i]); + } + auto out_dim = framework::make_ddim(out_dim_vec); + + out->Resize(out_dim); + auto* out_data = out->mutable_data(place); + int out_size = out->numel(); + + int threads = 512; + int grid = (out_size + threads - 1) / threads; + auto stream = ctx.cuda_device_context().stream(); + GatherGPUKernel<<>>( + input_data, index_data, out_data, outer_dim_size, inner_dim_size, + index_size, index_dim_size, out_size); +} + +template +void GatherV2GradCUDAFunction(const Tensor* input, const Tensor* index, + const Tensor* axis, Tensor* out, + const paddle::platform::Place& place, + const framework::ExecutionContext& ctx) { + auto* index_data = index->data(); + + int axis_size = axis->numel(); + int index_size = index->numel(); + int input_size = input->numel(); + auto input_dim = input->dims(); + auto* input_data = input->data(); + + if (input->numel() == 0) return; + PADDLE_ENFORCE_EQ(axis_size, 1, + platform::errors::InvalidArgument( + "Axis size should be 1, but received %d", axis_size)); + Tensor cpu_axis; + framework::TensorCopy(*axis, platform::CPUPlace(), &cpu_axis); + int axis_index = cpu_axis.data()[0]; + int input_index_dim_size = input_dim[axis_index]; + + int inner_dim_size = 1; + int outer_dim_size = 1; + + for (int i = 0; i < axis_index; i++) { + inner_dim_size *= input_dim[i]; + } + for (int i = axis_index + 1; i < input_dim.size(); i++) { + outer_dim_size *= input_dim[i]; + } + + auto* out_data = out->mutable_data(place); + auto* dev_ctx = platform::DeviceContextPool::Instance().Get(place); + auto out_dim = out->dims(); + int out_index_dim_size = out_dim[axis_index]; + operators::math::set_constant(*dev_ctx, out, 0.0); + + int threads = 512; + int grid = (input_size + threads - 1) / threads; + auto stream = ctx.cuda_device_context().stream(); + GatherGradGPUKernel<<>>( + input_data, index_data, out_data, outer_dim_size, inner_dim_size, + input_index_dim_size, out_index_dim_size, input_size); +} } // namespace operators } // namespace paddle diff --git a/paddle/fluid/operators/gather.h b/paddle/fluid/operators/gather.h index f5a7bffe47453..c12a3b8adc978 100644 --- a/paddle/fluid/operators/gather.h +++ b/paddle/fluid/operators/gather.h @@ -15,10 +15,12 @@ limitations under the License. */ #pragma once #include #include +#include #include "paddle/fluid/framework/ddim.h" #include "paddle/fluid/framework/eigen.h" #include "paddle/fluid/framework/tensor.h" +#include "paddle/fluid/operators/math/math_function.h" #include "paddle/fluid/platform/place.h" namespace paddle { @@ -124,5 +126,110 @@ void CPUGatherNd(const platform::DeviceContext& ctx, const Tensor& input, } } +template +void GatherV2Function(const Tensor* input, const Tensor* index, + const Tensor* axis, Tensor* out, + const paddle::platform::Place& place) { + auto* axis_data = axis->data(); + auto* index_data = index->data(); + + int axis_size = axis->numel(); + int index_size = index->numel(); + int input_size = input->numel(); + auto input_dim = input->dims(); + auto* input_data = input->data(); + + if (input->numel() == 0) return; + PADDLE_ENFORCE_EQ(axis_size, 1, + platform::errors::InvalidArgument( + "Axis size should be 1, but received %d", axis_size)); + int axis_index = axis_data[0]; + + int input_index_dim_size = input_dim[axis_index]; + for (int i = 0; i < index_size; i++) { + PADDLE_ENFORCE_LT(index_data[i], input_index_dim_size, + platform::errors::InvalidArgument( + "The element of Index must be less than the size of " + "input dim size of axis which is %d, but received " + "index element which is %d in the %d index.", + input_index_dim_size, index_data[i], i)); + } + + int inner_dim_size = 1; + int outer_dim_size = 1; + std::vector out_dim_vec; + + for (int i = 0; i < axis_index; i++) { + inner_dim_size *= input_dim[i]; + out_dim_vec.push_back(input_dim[i]); + } + out_dim_vec.push_back(index_size); + for (int i = axis_index + 1; i < input_dim.size(); i++) { + outer_dim_size *= input_dim[i]; + out_dim_vec.push_back(input_dim[i]); + } + auto out_dim = framework::make_ddim(out_dim_vec); + + out->Resize(out_dim); + auto* out_data = out->mutable_data(place); + + int out_index = 0; + for (int i = 0; i < inner_dim_size; i++) { + for (int j = 0; j < index_size; j++) { + for (int k = 0; k < outer_dim_size; k++) { + int index = k + index_data[j] * outer_dim_size + + (i * input_size / inner_dim_size); + out_data[out_index] = input_data[index]; + out_index++; + } + } + } +} + +template +void GatherV2GradFunction(const Tensor* input, const Tensor* index, + const Tensor* axis, Tensor* out, + const paddle::platform::Place& place) { + auto* axis_data = axis->data(); + auto* index_data = index->data(); + + int axis_size = axis->numel(); + auto input_dim = input->dims(); + auto* input_data = input->data(); + + if (input->numel() == 0) return; + PADDLE_ENFORCE_EQ(axis_size, 1, + platform::errors::InvalidArgument( + "Axis size should be 1, but received %d", axis_size)); + int axis_index = axis_data[0]; + int input_index_dim_size = input_dim[axis_index]; + + int inner_dim_size = 1; + int outer_dim_size = 1; + + for (int i = 0; i < axis_index; i++) { + inner_dim_size *= input_dim[i]; + } + for (int i = axis_index + 1; i < input_dim.size(); i++) { + outer_dim_size *= input_dim[i]; + } + + auto* out_data = out->mutable_data(place); + auto* dev_ctx = platform::DeviceContextPool::Instance().Get(place); + auto out_dim = out->dims(); + int out_index_dim_size = out_dim[axis_index]; + operators::math::set_constant(*dev_ctx, out, 0.0); + + for (int i = 0; i < inner_dim_size; i++) { + for (int j = 0; j < input_index_dim_size; j++) { + for (int k = 0; k < outer_dim_size; k++) { + int index = k + index_data[j] * outer_dim_size + + i * outer_dim_size * out_index_dim_size; + out_data[index] += input_data[j * outer_dim_size + k]; + } + } + } +} + } // namespace operators } // namespace paddle diff --git a/paddle/fluid/operators/gather_op.cc b/paddle/fluid/operators/gather_op.cc index 6a3abaa600281..8a3450d1df97a 100644 --- a/paddle/fluid/operators/gather_op.cc +++ b/paddle/fluid/operators/gather_op.cc @@ -78,6 +78,9 @@ class GatherOpMaker : public framework::OpProtoAndCheckerMaker { void Make() override { AddInput("X", "The source input of gather op"); AddInput("Index", "The index input of gather op"); + AddInput("Axis", + "The Tensor which contains the axis that we do gather operation.") + .AsDispensable(); AddOutput("Out", "The output of gather op"); AddAttr( "overwrite", @@ -120,6 +123,8 @@ class GatherGradOpMaker : public framework::SingleGradOpMaker { void Apply(GradOpPtr op) const override { op->SetType("gather_grad"); op->SetInput("Index", this->Input("Index")); + op->SetInput("Axis", this->Input("Axis")); + op->SetInput("X", this->Input("X")); op->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out")); op->SetOutput(framework::GradVarName("X"), this->InputGrad("X")); diff --git a/paddle/fluid/operators/gather_op.cu b/paddle/fluid/operators/gather_op.cu index 5bef547c0542b..37fbfb21f60a0 100644 --- a/paddle/fluid/operators/gather_op.cu +++ b/paddle/fluid/operators/gather_op.cu @@ -31,6 +31,33 @@ class GatherOpCUDAKernel : public framework::OpKernel { auto *index = ctx.Input("Index"); auto *output = ctx.Output("Out"); + if (ctx.HasInput("Axis")) { + const Tensor *axis = ctx.Input("Axis"); + const auto &index_type = index->type(); + const auto &axis_type = axis->type(); + auto place = ctx.GetPlace(); + if (index_type == framework::proto::VarType::INT32 && + axis_type == framework::proto::VarType::INT32) { + GatherV2CUDAFunction(x, index, axis, output, place, + ctx); + } + if (index_type == framework::proto::VarType::INT32 && + axis_type == framework::proto::VarType::INT64) { + GatherV2CUDAFunction(x, index, axis, output, place, + ctx); + } + if (index_type == framework::proto::VarType::INT64 && + axis_type == framework::proto::VarType::INT32) { + GatherV2CUDAFunction(x, index, axis, output, place, + ctx); + } + if (index_type == framework::proto::VarType::INT64 && + axis_type == framework::proto::VarType::INT64) { + GatherV2CUDAFunction(x, index, axis, output, place, + ctx); + } + return; + } output->mutable_data(ctx.GetPlace()); if (x->numel() == 0) return; const auto &index_type = index->type(); @@ -64,6 +91,34 @@ class GatherGradOpCUDAKernel : public framework::OpKernel { auto *dX = ctx.Output(framework::GradVarName("X")); auto *dO = ctx.Input(framework::GradVarName("Out")); + if (ctx.HasInput("Axis")) { + const Tensor *axis = ctx.Input("Axis"); + const auto &index_type = index->type(); + const auto &axis_type = axis->type(); + auto place = ctx.GetPlace(); + if (index_type == framework::proto::VarType::INT32 && + axis_type == framework::proto::VarType::INT32) { + GatherV2GradCUDAFunction(dO, index, axis, dX, + place, ctx); + } + if (index_type == framework::proto::VarType::INT32 && + axis_type == framework::proto::VarType::INT64) { + GatherV2GradCUDAFunction(dO, index, axis, dX, + place, ctx); + } + if (index_type == framework::proto::VarType::INT64 && + axis_type == framework::proto::VarType::INT32) { + GatherV2GradCUDAFunction(dO, index, axis, dX, + place, ctx); + } + if (index_type == framework::proto::VarType::INT64 && + axis_type == framework::proto::VarType::INT64) { + GatherV2GradCUDAFunction(dO, index, axis, dX, + place, ctx); + } + return; + } + dX->mutable_data(ctx.GetPlace()); auto dxt = framework::EigenVector::Flatten(*dX); auto &place = *ctx.template device_context() diff --git a/paddle/fluid/operators/gather_op.h b/paddle/fluid/operators/gather_op.h index e4ce13ca8fc0b..8ec0d6ce0b69c 100644 --- a/paddle/fluid/operators/gather_op.h +++ b/paddle/fluid/operators/gather_op.h @@ -35,6 +35,30 @@ class GatherOpKernel : public framework::OpKernel { auto *index = ctx.Input("Index"); auto *output = ctx.Output("Out"); + if (ctx.HasInput("Axis")) { + const Tensor *axis = ctx.Input("Axis"); + const auto &index_type = index->type(); + const auto &axis_type = axis->type(); + auto place = ctx.GetPlace(); + if (index_type == framework::proto::VarType::INT32 && + axis_type == framework::proto::VarType::INT32) { + GatherV2Function(x, index, axis, output, place); + } + if (index_type == framework::proto::VarType::INT32 && + axis_type == framework::proto::VarType::INT64) { + GatherV2Function(x, index, axis, output, place); + } + if (index_type == framework::proto::VarType::INT64 && + axis_type == framework::proto::VarType::INT32) { + GatherV2Function(x, index, axis, output, place); + } + if (index_type == framework::proto::VarType::INT64 && + axis_type == framework::proto::VarType::INT64) { + GatherV2Function(x, index, axis, output, place); + } + return; + } + output->mutable_data(ctx.GetPlace()); if (x->numel() == 0) return; @@ -70,6 +94,30 @@ class GatherGradientOpKernel : public framework::OpKernel { auto *dX = ctx.Output(framework::GradVarName("X")); auto *dO = ctx.Input(framework::GradVarName("Out")); + if (ctx.HasInput("Axis")) { + const Tensor *axis = ctx.Input("Axis"); + const auto &index_type = index->type(); + const auto &axis_type = axis->type(); + auto place = ctx.GetPlace(); + if (index_type == framework::proto::VarType::INT32 && + axis_type == framework::proto::VarType::INT32) { + GatherV2GradFunction(dO, index, axis, dX, place); + } + if (index_type == framework::proto::VarType::INT32 && + axis_type == framework::proto::VarType::INT64) { + GatherV2GradFunction(dO, index, axis, dX, place); + } + if (index_type == framework::proto::VarType::INT64 && + axis_type == framework::proto::VarType::INT32) { + GatherV2GradFunction(dO, index, axis, dX, place); + } + if (index_type == framework::proto::VarType::INT64 && + axis_type == framework::proto::VarType::INT64) { + GatherV2GradFunction(dO, index, axis, dX, place); + } + return; + } + dX->mutable_data(ctx.GetPlace()); auto dxt = framework::EigenVector::Flatten(*dX); auto &place = *ctx.template device_context() diff --git a/paddle/fluid/pybind/op_function_generator.cc b/paddle/fluid/pybind/op_function_generator.cc index a770275bf08c0..70f6ed7fa2d5d 100644 --- a/paddle/fluid/pybind/op_function_generator.cc +++ b/paddle/fluid/pybind/op_function_generator.cc @@ -41,6 +41,7 @@ std::map> op_ins_map = { {"fake_quantize_dequantize_moving_average_abs_max", {"X", "InScale", "InAccum", "InState"}}, {"nll_loss", {"X", "Label", "Weight"}}, + {"gather", {"X", "Index", "Axis"}}, }; // NOTE(zhiqiu): Like op_ins_map. diff --git a/python/paddle/fluid/layers/nn.py b/python/paddle/fluid/layers/nn.py index 1f1439aa7cd60..b244781ac82b9 100644 --- a/python/paddle/fluid/layers/nn.py +++ b/python/paddle/fluid/layers/nn.py @@ -8179,19 +8179,21 @@ def gather(input, index, overwrite=True): [5, 6]] Args: - input (Variable): The source input tensor with rank>=1. Supported data type is + input (Tensor): The source input tensor with rank>=1. Supported data type is int32, int64, float32, float64 and uint8 (only for CPU), float16 (only for GPU). - index (Variable): The index input tensor with rank=1. Data type is int32 or int64. + index (Tensor): The index input tensor with rank=1. Data type is int32 or int64. overwrite (bool, optional): The mode that updating the grad when has same index. If True, use the overwrite mode to update the grad of the same index, if False, use the accumulate mode to update the grad of the same index. Default value is True. - - Returns: - output (Variable): The output is a tensor with the same rank as input. + output (Tensor): The output is a tensor with the same rank as input. + + Raises: + TypeError: ``x`` must be a Tensor and the data type of ``x`` must to be one of float16, float32, float64, int32, int64, uint8. + TypeError: ``index`` must be a Tensor and the data type of ``index`` must be int32 or int64. Examples: @@ -8202,6 +8204,13 @@ def gather(input, index, overwrite=True): index = fluid.data(name='index', shape=[-1, 1], dtype='int32') output = fluid.layers.gather(x, index) """ + if in_dygraph_mode(): + return core.ops.gather(input, index, None) + + check_variable_and_dtype( + input, 'x', + ['float16', 'float32', 'float64', 'int32', 'int64', 'uint8'], 'gather') + check_variable_and_dtype(index, 'index', ['int32', 'int64'], 'gather') helper = LayerHelper('gather', **locals()) dtype = helper.input_dtype() out = helper.create_variable_for_type_inference(dtype) diff --git a/python/paddle/fluid/tests/unittests/test_gather_op.py b/python/paddle/fluid/tests/unittests/test_gather_op.py index f8763e731eeed..1f6e522d2668b 100644 --- a/python/paddle/fluid/tests/unittests/test_gather_op.py +++ b/python/paddle/fluid/tests/unittests/test_gather_op.py @@ -21,6 +21,13 @@ import paddle.fluid as fluid +def gather_numpy(x, index, axis): + x_transpose = np.swapaxes(x, 0, axis) + tmp_gather = x_transpose[index, ...] + gather = np.swapaxes(tmp_gather, 0, axis) + return gather + + class TestGatherOp(OpTest): def setUp(self): self.op_type = "gather" @@ -108,12 +115,80 @@ def config(self): self.index_type = "int32" +class TestGatherOp1(OpTest): + def setUp(self): + self.op_type = "gather" + self.config() + xnp = np.random.random(self.x_shape).astype(self.x_type) + axis_np = np.array(self.axis).astype(self.index_type) + index_np = np.array(self.index).astype(self.index_type) + out = gather_numpy(xnp, index_np, axis_np[0]) + self.inputs = {'X': xnp, 'Index': index_np, 'Axis': axis_np} + self.outputs = {'Out': out} + + def test_check_output(self): + self.check_output() + + def test_check_grad(self): + self.check_grad(['X'], 'Out') + + def config(self): + """ + For multi-dimension input + """ + self.x_shape = (3, 88, 3) + self.x_type = "float64" + self.index = [1, 3, 5] + self.index_type = "int32" + self.axis = [1] + self.axis_type = "int32" + + +class TestGatherOp2(TestGatherOp1): + def config(self): + """ + For multi-dimension input + """ + self.x_shape = (10, 88, 10) + self.x_type = "float64" + self.index = [1, 3, 5] + self.index_type = "int64" + self.axis = [0] + self.axis_type = "int32" + + +class TestGatherOp3(TestGatherOp1): + def config(self): + """ + For multi-dimension input + """ + self.x_shape = (10, 88, 10) + self.x_type = "float64" + self.index = [1, 3, 5] + self.index_type = "int64" + self.axis = [2] + self.axis_type = "int32" + + +class TestGatherOp4(TestGatherOp1): + def config(self): + """ + For multi-dimension input + """ + self.x_shape = (3, 100, 10) + self.x_type = "float64" + self.index = [1, 1, 1, 1, 1, 1, 1, 1, 1, 1] + self.index_type = "int64" + self.axis = [0] + self.axis_type = "int32" + + class API_TestGather(unittest.TestCase): - def test_out(self): + def test_out1(self): with fluid.program_guard(fluid.Program(), fluid.Program()): data1 = fluid.layers.data('data1', shape=[-1, 2], dtype='float64') - index = fluid.layers.data('index', shape=[-1, 1], dtype='float64') - out = paddle.gather(data1, index) + index = fluid.layers.data('index', shape=[-1, 1], dtype='int32') + out = paddle.fluid.layers.gather(data1, index) place = fluid.CPUPlace() exe = fluid.Executor(place) input = np.array([[1, 2], [3, 4], [5, 6]]) @@ -124,18 +199,103 @@ def test_out(self): expected_output = np.array([[3, 4], [5, 6]]) self.assertTrue(np.allclose(result, expected_output)) + def test_out2(self): + with paddle.static.program_guard(paddle.static.Program(), + paddle.static.Program()): + x = paddle.data('x', shape=[-1, 2], dtype='float64') + index = paddle.data('index', shape=[-1, 1], dtype='int32') + axis = paddle.data('axis', shape=[1], dtype='int32') + out = paddle.gather(x, index, axis) + place = paddle.CPUPlace() + exe = paddle.static.Executor(place) + x_np = np.array([[1, 2], [3, 4], [5, 6]]).astype('float64') + index_np = np.array([1, 1]).astype('int32') + axis_np = np.array([1]).astype('int32') + result, = exe.run( + feed={"x": x_np, + "index": index_np, + 'axis': axis_np}, + fetch_list=[out]) + expected_output = gather_numpy(x_np, index_np, axis_np) + self.assertTrue(np.allclose(result, expected_output)) + class API_TestDygraphGather(unittest.TestCase): - def test_out(self): - with fluid.dygraph.guard(): - input_1 = np.array([[1, 2], [3, 4], [5, 6]]) - index_1 = np.array([1, 2]) - input = fluid.dygraph.to_variable(input_1) - index = fluid.dygraph.to_variable(index_1) - output = paddle.fluid.layers.gather(input, index) - output_np = output.numpy() - expected_output = np.array([[3, 4], [5, 6]]) + def test_out1(self): + paddle.disable_static() + input_1 = np.array([[1, 2], [3, 4], [5, 6]]) + index_1 = np.array([1, 2]) + input = paddle.to_tensor(input_1) + index = paddle.to_tensor(index_1) + output = paddle.fluid.layers.gather(input, index) + output_np = output.numpy() + expected_output = np.array([[3, 4], [5, 6]]) + self.assertTrue(np.allclose(output_np, expected_output)) + paddle.enable_static() + + def test_out12(self): + paddle.disable_static() + input_1 = np.array([[1, 2], [3, 4], [5, 6]]) + index_1 = np.array([1, 2]) + x = paddle.to_tensor(input_1) + index = paddle.to_tensor(index_1) + output = paddle.gather(x, index, axis=0) + output_np = output.numpy() + expected_output = gather_numpy(input_1, index_1, axis=0) self.assertTrue(np.allclose(output_np, expected_output)) + paddle.enable_static() + + +class TestGathertError(unittest.TestCase): + def test_error1(self): + with paddle.static.program_guard(paddle.static.Program(), + paddle.static.Program()): + + shape = [8, 9, 6] + x = paddle.data(shape=shape, dtype='int8', name='x') + axis = paddle.data(shape=[1], dtype='float32', name='axis') + index = paddle.data(shape=shape, dtype='int32', name='index') + index_float = paddle.data( + shape=shape, dtype='float32', name='index_float') + + def test_x_type(): + paddle.gather(x, index) + + self.assertRaises(TypeError, test_x_type) + + def test_index_type(): + paddle.gather(x, index_float) + + self.assertRaises(TypeError, test_index_type) + + def test_axis_dtype(): + paddle.gather(x, index, axis=1.11) + + self.assertRaises(TypeError, test_axis_dtype) + + def test_axis_dtype(): + paddle.gather(x, index, axis=axis) + + self.assertRaises(TypeError, test_axis_dtype) + + def test_error2(self): + with fluid.program_guard(fluid.Program(), fluid.Program()): + + shape = [8, 9, 6] + x = fluid.data(shape=shape, dtype='int8', name='x') + index = fluid.data(shape=shape, dtype='int32', name='mask') + index_float = fluid.data( + shape=shape, dtype='float32', name='index_float') + + def test_x_type(): + paddle.fluid.layers.gather(x, index) + + self.assertRaises(TypeError, test_x_type) + + def test_index_type(): + paddle.fluid.layers.gather(x, index_float) + + self.assertRaises(TypeError, test_index_type) if __name__ == "__main__": diff --git a/python/paddle/tensor/manipulation.py b/python/paddle/tensor/manipulation.py index 88342c89da75c..69fe5d79c541c 100644 --- a/python/paddle/tensor/manipulation.py +++ b/python/paddle/tensor/manipulation.py @@ -659,10 +659,8 @@ def unsqueeze(x, axis, name=None): return layers.unsqueeze(x, axis, name) -def gather(input, index, overwrite=True): +def gather(x, index, axis=None, name=None): """ - :alias_main: paddle.gather - :alias: paddle.gather,paddle.tensor.gather,paddle.tensor.manipulation.gather **Gather Layer** @@ -679,30 +677,33 @@ def gather(input, index, overwrite=True): Given: - X = [[1, 2], + x = [[1, 2], [3, 4], [5, 6]] - Index = [1, 2] + index = [1, 2] + axis=[0] Then: - Out = [[3, 4], + out = [[3, 4], [5, 6]] Args: - input (Variable): The source input tensor with rank>=1. Supported data type is + x (Tensor): The source input tensor with rank>=1. Supported data type is int32, int64, float32, float64 and uint8 (only for CPU), float16 (only for GPU). - index (Variable): The index input tensor with rank=1. Data type is int32 or int64. - overwrite (bool, optional): The mode that updating the grad when has same index. - If True, use the overwrite mode to update the grad of the same index, - if False, use the accumulate mode to update the grad of the same index. - Default value is True. - - + index (Tensor): The index input tensor with rank=1. Data type is int32 or int64. + axis (Tensor|int, optional): The axis of input to be gathered, it's can be int or a Tensor with data type is int32 or int64. Default: if None, the axis is 0. + name (str, optional): The default value is None. Normally there is no need for user to set this property. + For more information, please refer to :ref:`api_guide_Name` . Returns: - output (Variable): The output is a tensor with the same rank as input. + output (Tensor): The output is a tensor with the same rank as ``x``. + + Raises: + TypeError: ``x`` must be a Tensor and the data type of ``x`` must to be one of float16, float32, float64, int32, int64, uint8. + TypeError: ``index`` must be a Tensor and the data type of ``index`` must be int32 or int64. + TypeError: ``axis`` must be a Tensor or int and the data type of ``index`` must be int32 or int64 when it's a Tensor. Examples: @@ -710,26 +711,41 @@ def gather(input, index, overwrite=True): import numpy as np import paddle - import paddle.fluid as fluid - - with fluid.dygraph.guard(): - input_1 = np.array([[1,2],[3,4],[5,6]]) - index_1 = np.array([0,1]) - input = fluid.dygraph.to_variable(input_1) - index = fluid.dygraph.to_variable(index_1) - output = paddle.gather(input, index) - # expected output: [[1,2],[3,4]] + paddle.disable_static() + input_1 = np.array([[1,2],[3,4],[5,6]]) + index_1 = np.array([0,1]) + input = fluid.to_tensor(input_1) + index = fluid.to_tensor(index_1) + output = paddle.gather(input, index, axis=0) + # expected output: [[1,2],[3,4]] """ + if axis is None: + axis = 0 + axis_tensor = axis + if not isinstance(axis, Variable): + axis_tensor = fill_constant(shape=[1], dtype='int64', value=axis) + if in_dygraph_mode(): + return core.ops.gather(x, index, axis_tensor) + + check_variable_and_dtype( + x, 'x', ['float16', 'float32', 'float64', 'int32', 'int64', 'uint8'], + 'gather') + check_variable_and_dtype(index, 'index', ['int32', 'int64'], 'gather') + if isinstance(axis, Variable): + check_variable_and_dtype(axis, 'axis', ['int32', 'int64'], 'gather') + else: + check_type(axis, 'axis', (int), 'gather') + helper = LayerHelper('gather', **locals()) dtype = helper.input_dtype() out = helper.create_variable_for_type_inference(dtype) helper.append_op( type="gather", - inputs={"X": input, - "Index": index}, - outputs={"Out": out}, - attrs={'overwrite': overwrite}) + inputs={"X": x, + "Index": index, + "Axis": axis_tensor}, + outputs={"Out": out}) return out