Skip to content

Commit

Permalink
CompareOp's kernel device type is decided by input tensor place
Browse files Browse the repository at this point in the history
CompareOp can run on CPU even other operators are running on GPU, since
opeatations like comparing control flags should be performed only on CPU
  • Loading branch information
reyoung committed Nov 8, 2017
1 parent c365c61 commit 3187451
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 14 deletions.
36 changes: 26 additions & 10 deletions paddle/operators/compare_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

#include "paddle/operators/compare_op.h"
#include "paddle/framework/op_registry.h"

namespace paddle {
namespace operators {
template <typename OpComment>
Expand Down Expand Up @@ -61,19 +62,34 @@ class CompareOpInferShape : public framework::InferShapeBase {
}
};

class CompareOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;

protected:
framework::OpKernelType GetKernelType(
const framework::ExecutionContext &ctx) const override {
framework::OpKernelType kt = OperatorWithKernel::GetKernelType(ctx);
// CompareOp kernel's device type is decided by input tensor place
kt.place_ = ctx.Input<framework::LoDTensor>("X")->place();
return kt;
}
};

} // namespace operators
} // namespace paddle

#define REGISTER_LOGICAL_OP(op_type, _equation) \
struct _##op_type##Comment { \
static char type[]; \
static char equation[]; \
}; \
char _##op_type##Comment::type[]{#op_type}; \
char _##op_type##Comment::equation[]{_equation}; \
REGISTER_OP_WITH_KERNEL( \
op_type, ::paddle::operators::CompareOpProtoMaker<_##op_type##Comment>, \
::paddle::operators::CompareOpInferShape<_##op_type##Comment>, \
#define REGISTER_LOGICAL_OP(op_type, _equation) \
struct _##op_type##Comment { \
static char type[]; \
static char equation[]; \
}; \
char _##op_type##Comment::type[]{#op_type}; \
char _##op_type##Comment::equation[]{_equation}; \
REGISTER_OPERATOR( \
op_type, ::paddle::operators::CompareOp, \
::paddle::operators::CompareOpProtoMaker<_##op_type##Comment>, \
::paddle::operators::CompareOpInferShape<_##op_type##Comment>, \
::paddle::framework::EmptyGradOpMaker);

REGISTER_LOGICAL_OP(less_than, "Out = X < Y");
Expand Down
4 changes: 0 additions & 4 deletions paddle/platform/transform.h
Original file line number Diff line number Diff line change
Expand Up @@ -49,8 +49,6 @@ struct Transform<platform::CPUPlace> {
template <typename InputIter, typename OutputIter, typename UnaryOperation>
void operator()(const DeviceContext& context, InputIter first, InputIter last,
OutputIter result, UnaryOperation op) {
auto place = context.GetPlace();
PADDLE_ENFORCE(is_cpu_place(place), "It must use CPU place.");
std::transform(first, last, result, op);
}

Expand All @@ -59,8 +57,6 @@ struct Transform<platform::CPUPlace> {
void operator()(const DeviceContext& context, InputIter1 first1,
InputIter1 last1, InputIter2 first2, OutputIter result,
BinaryOperation op) {
auto place = context.GetPlace();
PADDLE_ENFORCE(is_cpu_place(place), "It must use CPU place.");
std::transform(first1, last1, first2, result, op);
}
};
Expand Down

0 comments on commit 3187451

Please sign in to comment.