Skip to content

Commit

Permalink
Merge pull request #5465 from reyoung/feature/compare_op_support_cpu
Browse files Browse the repository at this point in the history
CompareOp's kernel device type is decided by input tensor place
  • Loading branch information
reyoung committed Nov 8, 2017
2 parents 2a76b42 + 3187451 commit 7f22a6d
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 14 deletions.
36 changes: 26 additions & 10 deletions paddle/operators/compare_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

#include "paddle/operators/compare_op.h"
#include "paddle/framework/op_registry.h"

namespace paddle {
namespace operators {
template <typename OpComment>
Expand Down Expand Up @@ -61,19 +62,34 @@ class CompareOpInferShape : public framework::InferShapeBase {
}
};

class CompareOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;

protected:
framework::OpKernelType GetKernelType(
const framework::ExecutionContext &ctx) const override {
framework::OpKernelType kt = OperatorWithKernel::GetKernelType(ctx);
// CompareOp kernel's device type is decided by input tensor place
kt.place_ = ctx.Input<framework::LoDTensor>("X")->place();
return kt;
}
};

} // namespace operators
} // namespace paddle

#define REGISTER_LOGICAL_OP(op_type, _equation) \
struct _##op_type##Comment { \
static char type[]; \
static char equation[]; \
}; \
char _##op_type##Comment::type[]{#op_type}; \
char _##op_type##Comment::equation[]{_equation}; \
REGISTER_OP_WITH_KERNEL( \
op_type, ::paddle::operators::CompareOpProtoMaker<_##op_type##Comment>, \
::paddle::operators::CompareOpInferShape<_##op_type##Comment>, \
#define REGISTER_LOGICAL_OP(op_type, _equation) \
struct _##op_type##Comment { \
static char type[]; \
static char equation[]; \
}; \
char _##op_type##Comment::type[]{#op_type}; \
char _##op_type##Comment::equation[]{_equation}; \
REGISTER_OPERATOR( \
op_type, ::paddle::operators::CompareOp, \
::paddle::operators::CompareOpProtoMaker<_##op_type##Comment>, \
::paddle::operators::CompareOpInferShape<_##op_type##Comment>, \
::paddle::framework::EmptyGradOpMaker);

REGISTER_LOGICAL_OP(less_than, "Out = X < Y");
Expand Down
4 changes: 0 additions & 4 deletions paddle/platform/transform.h
Original file line number Diff line number Diff line change
Expand Up @@ -49,8 +49,6 @@ struct Transform<platform::CPUPlace> {
template <typename InputIter, typename OutputIter, typename UnaryOperation>
void operator()(const DeviceContext& context, InputIter first, InputIter last,
OutputIter result, UnaryOperation op) {
auto place = context.GetPlace();
PADDLE_ENFORCE(is_cpu_place(place), "It must use CPU place.");
std::transform(first, last, result, op);
}

Expand All @@ -59,8 +57,6 @@ struct Transform<platform::CPUPlace> {
void operator()(const DeviceContext& context, InputIter1 first1,
InputIter1 last1, InputIter2 first2, OutputIter result,
BinaryOperation op) {
auto place = context.GetPlace();
PADDLE_ENFORCE(is_cpu_place(place), "It must use CPU place.");
std::transform(first1, last1, first2, result, op);
}
};
Expand Down

0 comments on commit 7f22a6d

Please sign in to comment.