Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Cherry pick seq2seq api from #19820 #20555

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
58 changes: 46 additions & 12 deletions paddle/fluid/API.spec

Large diffs are not rendered by default.

6 changes: 4 additions & 2 deletions paddle/fluid/operators/assign_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -154,10 +154,12 @@ REGISTER_OPERATOR(assign, ops::AssignOp, ops::AssignGradMaker,
ops::AssignOpProtoMaker, ops::AssignOpInplaceInferer);
REGISTER_OP_CPU_KERNEL_FUNCTOR(assign, float, ops::AssignKernel, double,
ops::AssignKernel, int, ops::AssignKernel,
int64_t, ops::AssignKernel);
int64_t, ops::AssignKernel, bool,
ops::AssignKernel);

#ifdef PADDLE_WITH_CUDA
REGISTER_OP_CUDA_KERNEL_FUNCTOR(assign, float, ops::AssignKernel, double,
ops::AssignKernel, int, ops::AssignKernel,
int64_t, ops::AssignKernel);
int64_t, ops::AssignKernel, bool,
ops::AssignKernel);
#endif
5 changes: 4 additions & 1 deletion paddle/fluid/operators/expand_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -226,8 +226,11 @@ REGISTER_OP_CPU_KERNEL(
expand, ops::ExpandKernel<paddle::platform::CPUDeviceContext, float>,
ops::ExpandKernel<paddle::platform::CPUDeviceContext, double>,
ops::ExpandKernel<paddle::platform::CPUDeviceContext, int>,
ops::ExpandKernel<paddle::platform::CPUDeviceContext, int64_t>,
ops::ExpandKernel<paddle::platform::CPUDeviceContext, bool>);
REGISTER_OP_CPU_KERNEL(
expand_grad,
ops::ExpandGradKernel<paddle::platform::CPUDeviceContext, float>,
ops::ExpandGradKernel<paddle::platform::CPUDeviceContext, double>);
ops::ExpandGradKernel<paddle::platform::CPUDeviceContext, double>,
ops::ExpandGradKernel<paddle::platform::CPUDeviceContext, int>,
ops::ExpandGradKernel<paddle::platform::CPUDeviceContext, int64_t>);
5 changes: 4 additions & 1 deletion paddle/fluid/operators/expand_op.cu
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,11 @@ REGISTER_OP_CUDA_KERNEL(
expand, ops::ExpandKernel<paddle::platform::CUDADeviceContext, float>,
ops::ExpandKernel<paddle::platform::CUDADeviceContext, double>,
ops::ExpandKernel<paddle::platform::CUDADeviceContext, int>,
ops::ExpandKernel<paddle::platform::CUDADeviceContext, int64_t>,
ops::ExpandKernel<paddle::platform::CUDADeviceContext, bool>);
REGISTER_OP_CUDA_KERNEL(
expand_grad,
ops::ExpandGradKernel<paddle::platform::CUDADeviceContext, float>,
ops::ExpandGradKernel<paddle::platform::CUDADeviceContext, double>);
ops::ExpandGradKernel<paddle::platform::CUDADeviceContext, double>,
ops::ExpandGradKernel<paddle::platform::CUDADeviceContext, int>,
ops::ExpandGradKernel<paddle::platform::CUDADeviceContext, int64_t>);
9 changes: 8 additions & 1 deletion paddle/fluid/operators/fill_constant_batch_size_like_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,11 @@ class FillConstantBatchSizeLikeOpMaker : public BatchSizeLikeOpMaker {
.SetDefault(framework::proto::VarType::FP32);
AddAttr<float>("value", "default 0. The value to be filled")
.SetDefault(0.0f);
AddAttr<bool>("force_cpu",
"(bool, default false) Force fill output variable to cpu "
"memory. Otherwise, fill output variable to the running "
"device")
.SetDefault(false);
AddComment(R"DOC(
This function creates a tensor of specified *shape*, *dtype* and batch size,
and initializes this with a constant supplied in *value*. The batch size is
Expand Down Expand Up @@ -65,4 +70,6 @@ REGISTER_OP_CPU_KERNEL(
ops::FillConstantBatchSizeLikeOpKernel<paddle::platform::CPUDeviceContext,
int>,
ops::FillConstantBatchSizeLikeOpKernel<paddle::platform::CPUDeviceContext,
int64_t>);
int64_t>,
ops::FillConstantBatchSizeLikeOpKernel<paddle::platform::CPUDeviceContext,
bool>);
Original file line number Diff line number Diff line change
Expand Up @@ -25,4 +25,6 @@ REGISTER_OP_CUDA_KERNEL(
ops::FillConstantBatchSizeLikeOpKernel<paddle::platform::CUDADeviceContext,
int>,
ops::FillConstantBatchSizeLikeOpKernel<paddle::platform::CUDADeviceContext,
int64_t>);
int64_t>,
ops::FillConstantBatchSizeLikeOpKernel<paddle::platform::CUDADeviceContext,
bool>);
19 changes: 14 additions & 5 deletions paddle/fluid/operators/fill_constant_batch_size_like_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,11 @@ template <typename DeviceContext, typename T>
class FillConstantBatchSizeLikeOpKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto data_type =
static_cast<framework::proto::VarType::Type>(ctx.Attr<int>("dtype"));
auto value = ctx.Attr<float>("value");
auto force_cpu = ctx.Attr<bool>("force_cpu");

auto* out = ctx.Output<framework::Tensor>("Out");
auto* in = ctx.Input<framework::LoDTensor>("Input");
if (in->lod().size() && ctx.Attr<int>("input_dim_idx") == 0) {
Expand All @@ -32,12 +37,16 @@ class FillConstantBatchSizeLikeOpKernel : public framework::OpKernel<T> {
odims[output_dim_idx] = static_cast<int>(in->lod().back().size()) - 1;
out->mutable_data<T>(odims, ctx.GetPlace());
}
out->mutable_data<T>(ctx.GetPlace());
auto value = ctx.Attr<float>("value");

math::SetConstant<DeviceContext, T> setter;
setter(ctx.template device_context<DeviceContext>(), out,
static_cast<T>(value));
if (force_cpu) {
out->mutable_data(platform::CPUPlace(), data_type);
} else {
out->mutable_data(ctx.GetPlace(), data_type);
}

platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance();
auto& dev_ctx = *pool.Get(ctx.GetPlace());
math::set_constant(dev_ctx, out, value);
}
};

Expand Down
1 change: 1 addition & 0 deletions paddle/fluid/operators/fill_constant_op.cu.cc
Original file line number Diff line number Diff line change
Expand Up @@ -19,4 +19,5 @@ REGISTER_OP_CUDA_KERNEL(fill_constant, ops::FillConstantKernel<float>,
ops::FillConstantKernel<double>,
ops::FillConstantKernel<int64_t>,
ops::FillConstantKernel<int>,
ops::FillConstantKernel<bool>,
ops::FillConstantKernel<paddle::platform::float16>);
11 changes: 8 additions & 3 deletions paddle/fluid/operators/gather_nd_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -60,8 +60,13 @@ class GatherNdOp : public framework::OperatorWithKernel {
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(ctx.Input<Tensor>("X")->type(),
ctx.device_context());
auto* x = ctx.Input<Tensor>("X");
const auto& x_type = x->type();
return framework::OpKernelType(
x_type,
x_type == framework::proto::VarType::BOOL
? x->place() // to be consistent with compare and logical ops
: ctx.device_context().GetPlace());
}
};

Expand Down Expand Up @@ -173,7 +178,7 @@ REGISTER_OPERATOR(gather_nd_grad, ops::GatherNdGradOp,
REGISTER_OP_CPU_KERNEL(gather_nd, ops::GatherNdOpKernel<float>,
ops::GatherNdOpKernel<double>,
ops::GatherNdOpKernel<int64_t>,
ops::GatherNdOpKernel<int>,
ops::GatherNdOpKernel<int>, ops::GatherNdOpKernel<bool>,
ops::GatherNdOpKernel<uint8_t>);

REGISTER_OP_CPU_KERNEL(gather_nd_grad, ops::GatherNdGradOpKernel<float>,
Expand Down
1 change: 1 addition & 0 deletions paddle/fluid/operators/gather_nd_op.cu
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,7 @@ REGISTER_OP_CUDA_KERNEL(gather_nd, ops::GatherNdOpCUDAKernel<CUDA, float>,
ops::GatherNdOpCUDAKernel<CUDA, double>,
ops::GatherNdOpCUDAKernel<CUDA, int64_t>,
ops::GatherNdOpCUDAKernel<CUDA, int>,
ops::GatherNdOpCUDAKernel<CUDA, bool>,
ops::GatherNdOpCUDAKernel<CUDA, plat::float16>);

REGISTER_OP_CUDA_KERNEL(gather_nd_grad,
Expand Down
78 changes: 78 additions & 0 deletions paddle/fluid/operators/gather_tree_op.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#include "paddle/fluid/operators/gather_tree_op.h"

namespace paddle {
namespace operators {

class GatherTreeOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;

void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("Ids"),
"Input(Ids) of GatherTreeOp should not be null.");
PADDLE_ENFORCE(ctx->HasInput("Parents"),
"Input(Parents) of GatherTreeOp should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("Out"),
"Output(Out) of GatherTreeOp should not be null.");

auto ids_dims = ctx->GetInputDim("Ids");
auto parents_dims = ctx->GetInputDim("Parents");
PADDLE_ENFORCE(ids_dims == parents_dims,
"The shape of Input(Parents) must be same with the shape of "
"Input(Ids).");
ctx->SetOutputDim("Out", ids_dims);
}

protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(ctx.Input<Tensor>("Ids")->type(),
ctx.device_context());
}
};

class GatherTreeOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddInput("Ids",
"The Tensor with shape [length, batch_size, beam_size] containing "
"the selected ids of all time steps.");
AddInput("Parents",
"The Tensor has the same shape as Ids and contains the parents "
"corresponding to selected ids when searching among beams.");
AddOutput(
"Out",
"A Tensor with shape [length, batch_size, beam_size] containing the "
"full sequences. The sequences is collected by backtracing from the "
"last time step of Ids.");
AddComment(R"DOC(
GatherTree Operator.

Backtrace from the last time step and generate the full sequences by collecting beam search
selected ids.

)DOC");
}
};

} // namespace operators
} // namespace paddle

namespace ops = paddle::operators;
REGISTER_OPERATOR(gather_tree, ops::GatherTreeOp, ops::GatherTreeOpMaker);
REGISTER_OP_CPU_KERNEL(gather_tree, ops::GatherTreeOpKernel<int32_t>,
ops::GatherTreeOpKernel<int64_t>);
80 changes: 80 additions & 0 deletions paddle/fluid/operators/gather_tree_op.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#include <algorithm>
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/gather_tree_op.h"

namespace paddle {
namespace operators {

#define CUDA_1D_KERNEL_LOOP(i, n) \
for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < (n); \
i += blockDim.x * gridDim.x)

template <typename T>
__global__ void GatherTree(const T *ids_data, const T *parents_data,
T *out_data, const int64_t max_length,
const int64_t batch_size, const int64_t beam_size) {
CUDA_1D_KERNEL_LOOP(i, batch_size * beam_size) {
int batch = i / beam_size;
int beam = i % beam_size;
auto idx =
(max_length - 1) * batch_size * beam_size + batch * beam_size + beam;
out_data[idx] = ids_data[idx];
auto parent = parents_data[idx];
for (int step = max_length - 2; step >= 0; step--) {
idx = step * batch_size * beam_size + batch * beam_size;
out_data[idx + beam] = ids_data[idx + parent];
parent = parents_data[idx + parent];
}
}
}

template <typename T>
class GatherTreeOpCUDAKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext &ctx) const override {
auto *ids = ctx.Input<Tensor>("Ids");
auto *parents = ctx.Input<Tensor>("Parents");
auto *out = ctx.Output<Tensor>("Out");

const auto *ids_data = ids->data<T>();
const auto *parents_data = parents->data<T>();
auto *out_data = out->mutable_data<T>(ctx.GetPlace());

auto &ids_dims = ids->dims();
int64_t max_length = ids_dims[0];
int64_t batch_size = ids_dims[1];
int64_t beam_size = ids_dims[2];

auto &dev_ctx = ctx.cuda_device_context();

const int block = 512;
int max_threads =
std::min(static_cast<int64_t>(dev_ctx.GetMaxPhysicalThreadCount()),
batch_size * beam_size);
const int grid = std::max(max_threads / block, 1);
GatherTree<<<grid, block>>>(ids_data, parents_data, out_data, max_length,
batch_size, beam_size);
}
};

} // namespace operators
} // namespace paddle

namespace ops = paddle::operators;

REGISTER_OP_CUDA_KERNEL(gather_tree, ops::GatherTreeOpCUDAKernel<int32_t>,
ops::GatherTreeOpCUDAKernel<int64_t>);
58 changes: 58 additions & 0 deletions paddle/fluid/operators/gather_tree_op.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#pragma once
#include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/op_registry.h"

namespace paddle {
namespace operators {

using Tensor = framework::Tensor;

template <typename T>
class GatherTreeOpKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext &ctx) const override {
auto *ids = ctx.Input<Tensor>("Ids");
auto *parents = ctx.Input<Tensor>("Parents");
auto *out = ctx.Output<Tensor>("Out");

const auto *ids_data = ids->data<T>();
const auto *parents_data = parents->data<T>();
auto *out_data = out->mutable_data<T>(ctx.GetPlace());

auto &ids_dims = ids->dims();
auto max_length = ids_dims[0];
auto batch_size = ids_dims[1];
auto beam_size = ids_dims[2];

for (int batch = 0; batch < batch_size; batch++) {
for (int beam = 0; beam < beam_size; beam++) {
auto idx = (max_length - 1) * batch_size * beam_size +
batch * beam_size + beam;
out_data[idx] = ids_data[idx];
auto parent = parents_data[idx];
for (int step = max_length - 2; step >= 0; step--) {
idx = step * batch_size * beam_size + batch * beam_size;
out_data[idx + beam] = ids_data[idx + parent];
parent = parents_data[idx + parent];
}
}
}
}
};

} // namespace operators
} // namespace paddle
4 changes: 3 additions & 1 deletion paddle/fluid/operators/reduce_ops/reduce_all_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,9 @@

#include "paddle/fluid/operators/reduce_ops/reduce_all_op.h"

REGISTER_REDUCE_OP_WITHOUT_GRAD(reduce_all);
// kernel's device type is decided by input tensor place, to be consistent with
// compare and logical ops
REGISTER_REDUCE_OP_WITHOUT_GRAD(reduce_all, UseInputPlace);
REGISTER_OP_CPU_KERNEL(reduce_all,
ops::ReduceKernel<paddle::platform::CPUDeviceContext,
bool, ops::AllFunctor>);
4 changes: 3 additions & 1 deletion paddle/fluid/operators/reduce_ops/reduce_any_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,9 @@

#include "paddle/fluid/operators/reduce_ops/reduce_any_op.h"

REGISTER_REDUCE_OP_WITHOUT_GRAD(reduce_any);
// kernel's device type is decided by input tensor place, to be consistent with
// compare and logical ops
REGISTER_REDUCE_OP_WITHOUT_GRAD(reduce_any, UseInputPlace);
REGISTER_OP_CPU_KERNEL(reduce_any,
ops::ReduceKernel<paddle::platform::CPUDeviceContext,
bool, ops::AnyFunctor>);