Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support logical functors and refine tensor packing function #33089

Merged
merged 18 commits into from Jun 4, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
72 changes: 64 additions & 8 deletions paddle/fluid/operators/controlflow/logical_op.cu
Expand Up @@ -13,12 +13,68 @@ See the License for the specific language governing permissions and
limitations under the License. */

#include "paddle/fluid/operators/controlflow/logical_op.h"
#include "paddle/fluid/operators/elementwise/elementwise_op_broadcast.cu.h"

REGISTER_BINARY_LOGICAL_KERNEL(logical_and, CUDA,
paddle::operators::LogicalAndFunctor);
REGISTER_BINARY_LOGICAL_KERNEL(logical_or, CUDA,
paddle::operators::LogicalOrFunctor);
REGISTER_UNARY_LOGICAL_KERNEL(logical_not, CUDA,
paddle::operators::LogicalNotFunctor);
REGISTER_BINARY_LOGICAL_KERNEL(logical_xor, CUDA,
paddle::operators::LogicalXorFunctor);
namespace ops = paddle::operators;
namespace plat = paddle::platform;

namespace paddle {
namespace operators {

#define LOGICAL_BINARY_FUNCTOR(func_name, op) \
template <typename T> \
struct func_name { \
using ELEMENT_TYPE = T; \
HOSTDEVICE bool operator()(const T* args) const { \
return args[0] op args[1]; \
} \
};

LOGICAL_BINARY_FUNCTOR(CudaOrFunctor, ||)
LOGICAL_BINARY_FUNCTOR(CudaAndFunctor, &&)
LOGICAL_BINARY_FUNCTOR(CudaXorFunctor, ^)
#undef LOGICAL_BINARY_FUNCTOR

template <typename T>
struct CudaNotFunctor {
using ELEMENT_TYPE = T;
HOSTDEVICE bool operator()(const T* args) const { return !args[0]; }
};

template <typename Functor>
class BinaryLogicalOpKernel<platform::CUDADeviceContext, Functor>
: public framework::OpKernel<typename Functor::ELEMENT_TYPE> {
public:
using InT = typename Functor::ELEMENT_TYPE;
using OutT = bool;
void Compute(const framework::ExecutionContext& ctx) const override {
auto functor = Functor();
std::vector<const framework::Tensor*> ins;
std::vector<framework::Tensor*> outs;
const auto& cuda_ctx =
ctx.template device_context<platform::CUDADeviceContext>();
int axis = PackTensorsIntoVector<OutT>(ctx, &ins, &outs);

if (ins.size() == 1) {
LaunchElementwiseCudaKernel<ElementwiseType::kUnary, InT, OutT>(
cuda_ctx, ins, &outs, axis, functor);
} else {
LaunchElementwiseCudaKernel<ElementwiseType::kBinary, InT, OutT>(
cuda_ctx, ins, &outs, axis, functor);
}
}
};

} // namespace operators
} // namespace paddle

#define REGISTER_LOGICAL_CUDA_KERNEL(op_name, func) \
REGISTER_OP_CUDA_KERNEL( \
op_name, \
ops::BinaryLogicalOpKernel<plat::CUDADeviceContext, ops::func<bool>>);

REGISTER_LOGICAL_CUDA_KERNEL(logical_or, CudaOrFunctor)
REGISTER_LOGICAL_CUDA_KERNEL(logical_and, CudaAndFunctor)
REGISTER_LOGICAL_CUDA_KERNEL(logical_xor, CudaXorFunctor)
REGISTER_LOGICAL_CUDA_KERNEL(logical_not, CudaNotFunctor)
#undef REGISTER_LOGICAL_CUDA_KERNEL
1 change: 0 additions & 1 deletion paddle/fluid/operators/elementwise/elementwise_max_op.cu
Expand Up @@ -13,7 +13,6 @@ See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/elementwise/elementwise_max_op.h"
#include "paddle/fluid/operators/elementwise/elementwise_op_broadcast.cu.h"
#include "paddle/fluid/operators/elementwise/elementwise_op_function.cu.h"

namespace ops = paddle::operators;

Expand Down
1 change: 0 additions & 1 deletion paddle/fluid/operators/elementwise/elementwise_min_op.cu
Expand Up @@ -13,7 +13,6 @@ See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/elementwise/elementwise_min_op.h"
#include "paddle/fluid/operators/elementwise/elementwise_op_broadcast.cu.h"
#include "paddle/fluid/operators/elementwise/elementwise_op_function.cu.h"

namespace ops = paddle::operators;

Expand Down
43 changes: 2 additions & 41 deletions paddle/fluid/operators/elementwise/elementwise_mul_op.cu
Expand Up @@ -36,52 +36,13 @@ class ElementwiseMulKernel<platform::CUDADeviceContext, T>
: public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
int axis = -1;
auto x_var = ctx.InputVar("X");
PADDLE_ENFORCE_NOT_NULL(
x_var, platform::errors::InvalidArgument(
"Cannot get input Variable X, Variable name = %s.",
ctx.InputName("X")));
auto* y = ctx.Input<framework::LoDTensor>("Y");

framework::Tensor x, *z;
framework::Tensor x_for_selectedrows;
std::vector<const framework::Tensor*> ins;
std::vector<framework::Tensor*> outs;
const auto& cuda_ctx =
ctx.template device_context<platform::CUDADeviceContext>();

if (x_var->IsType<framework::LoDTensor>()) {
x = x_var->Get<framework::LoDTensor>();
z = ctx.Output<framework::LoDTensor>("Out");
axis = PackTensorsIntoVector<T>(ctx, &ins, &outs);
} else if (x_var->IsType<framework::SelectedRows>()) {
PADDLE_ENFORCE_EQ(y->dims().size() == 1 && y->dims()[0] == 1, true,
platform::errors::InvalidArgument(
"For elementwise_op, if X is Sparse, Y must be "
"scalar. But reveived the size of Y = %s.",
y->dims().size()));
auto& x_sele = x_var->Get<framework::SelectedRows>();
auto out_sele = ctx.Output<framework::SelectedRows>("Out");
x = x_sele.value();
out_sele->set_rows(x_sele.rows());
out_sele->set_height(x_sele.height());
out_sele->mutable_value()->Resize(x_sele.value().dims());
out_sele->mutable_value()->mutable_data(ctx.GetPlace(), x.type());
z = ctx.Output<framework::SelectedRows>("Out")->mutable_value();
z->mutable_data<T>(ctx.GetPlace());
outs.emplace_back(z);
ins.emplace_back(&x);
ins.emplace_back(y);

axis = ctx.HasAttr("axis") ? ctx.Attr<int>("axis") : -1;
axis = axis == -1 ? std::abs(y->dims().size() - x.dims().size()) : axis;
} else {
PADDLE_THROW(platform::errors::InvalidArgument(
"X's type[%s] is not supported by elementwise_op. X's type should be "
"LoDTensor or SelectedRows.",
framework::ToTypeName(x_var->Type())));
}

int axis = PackTensorsIntoVector<T>(ctx, &ins, &outs, &x_for_selectedrows);
LaunchElementwiseCudaKernel<ElementwiseType::kBinary, T, T>(
cuda_ctx, ins, &outs, axis, CudaMulFunctor<T>());
}
Expand Down
Expand Up @@ -509,15 +509,21 @@ void LaunchElementwiseCudaKernel(
const platform::CUDADeviceContext &cuda_ctx,
const std::vector<const framework::Tensor *> &ins,
std::vector<framework::Tensor *> *outs, int axis, Functor func) {
std::vector<int> dims_size;
bool no_broadcast_flag = true;
for (auto *in : ins) {
no_broadcast_flag = ins[0]->dims() == in->dims();
dims_size.emplace_back(in->dims().size());
}

if (no_broadcast_flag) {
LaunchSameDimsElementwiseCudaKernel<ET, InT, OutT>(cuda_ctx, ins, outs,
func);
} else {
axis = axis == -1
? *std::max_element(dims_size.begin(), dims_size.end()) -
*std::min_element(dims_size.begin(), dims_size.end())
: axis;
LaunchBroadcastElementwiseCudaKernel<ET, InT, OutT>(cuda_ctx, ins, outs,
axis, func);
}
Expand Down
55 changes: 48 additions & 7 deletions paddle/fluid/operators/elementwise/elementwise_op_function.h
Expand Up @@ -61,25 +61,66 @@ namespace paddle {
namespace operators {

/*
* To pack the input and output tnesors into vector for
* LaunchElementwiseCudaKernel
* Pack input and output tensors into respective vectors with
* consideration of varible X`s class type.
* Input variable X is supported to be whether LoDTensor or
* SelectedRows class type in this package function, once X
* was SelectedRows type, a valid pointer x_for_selectedrows
* is excepted to be passed in from op kernel for acquisition
* of the valid address of LoDTensor created ahead in the function.
*/
template <typename OutT>
int PackTensorsIntoVector(const framework::ExecutionContext &ctx,
std::vector<const framework::Tensor *> *ins,
std::vector<framework::Tensor *> *outs) {
std::vector<framework::Tensor *> *outs,
framework::Tensor *x_for_selectedrows = nullptr) {
int axis = -1;
auto *x = ctx.Input<framework::LoDTensor>("X");
auto x_var = ctx.InputVar("X");
PADDLE_ENFORCE_NOT_NULL(
x_var, platform::errors::InvalidArgument(
"Unable to get input Variable X, Variable name is %s.\n",
ctx.InputName("X")));
auto *y = ctx.Input<framework::LoDTensor>("Y");
auto *z = ctx.Output<framework::LoDTensor>("Out");
framework::Tensor *z;

if (x_var->IsType<framework::LoDTensor>()) {
auto *x = ctx.Input<framework::LoDTensor>("X");
z = ctx.Output<framework::LoDTensor>("Out");
ins->emplace_back(x);
} else if (x_var->IsType<framework::SelectedRows>()) {
PADDLE_ENFORCE_EQ(y->dims().size() == 1 && y->dims()[0] == 1, true,
platform::errors::InvalidArgument(
"For elementwise_op, if X is Sparse, Y must be "
"scalar. But reveived the size of Y = %d.",
y->dims().size()));
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里要PADDLE_ENFORCE检查下x_ptr不为null

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

按照建议修改

PADDLE_ENFORCE_NOT_NULL(
x_for_selectedrows,
platform::errors::InvalidArgument(
"The parameter x_for_selectedrows is excepted to "
"be valid, once input varible X`s class type is "
"SelectedRows.\n"));
auto &x_sele = x_var->Get<framework::SelectedRows>();
auto out_sele = ctx.Output<framework::SelectedRows>("Out");
*x_for_selectedrows = x_sele.value();
out_sele->set_rows(x_sele.rows());
out_sele->set_height(x_sele.height());
out_sele->mutable_value()->Resize(x_sele.value().dims());
out_sele->mutable_value()->mutable_data(ctx.GetPlace(),
x_for_selectedrows->type());
z = ctx.Output<framework::SelectedRows>("Out")->mutable_value();
ins->emplace_back(x_for_selectedrows);
} else {
PADDLE_THROW(platform::errors::InvalidArgument(
"X's type[%s] is not supported by elementwise_op. X's type should be "
"LoDTensor or SelectedRows.",
framework::ToTypeName(x_var->Type())));
}
z->mutable_data<OutT>(ctx.GetPlace());
outs->emplace_back(z);
ins->emplace_back(x);

if (y != nullptr) {
ins->emplace_back(y);
axis = ctx.HasAttr("axis") ? ctx.Attr<int>("axis") : -1;
axis = axis == -1 ? std::abs(y->dims().size() - x->dims().size()) : axis;
}
return axis;
}
Expand Down