Skip to content

Commit

Permalink
optimize elementwise_max_grad using new interfaces (#37906)
Browse files Browse the repository at this point in the history
* init elem_max_grad op

* optimize code and reply review comments

* ternary functors

* apply new reduce func

* move functor to .h

* multi-outputs init

* rearrange code

* modifed functors

* optimizer code

* pass nullptr

* revert the last change as seg fault occurs

* optimize code

* remove inplace

* remove comments
  • Loading branch information
AshburnLee committed Jan 12, 2022
1 parent 5f5f626 commit 4a64ca1
Show file tree
Hide file tree
Showing 3 changed files with 80 additions and 6 deletions.
27 changes: 27 additions & 0 deletions paddle/fluid/operators/elementwise/elementwise_functor.h
Original file line number Diff line number Diff line change
Expand Up @@ -301,5 +301,32 @@ struct MulGradXYFunctor<Complex<InT>, Complex<OutT>> {
}
};

// Ternary compare
template <typename T>
struct MaxGradXFunctor {
inline HOSTDEVICE T operator()(const T& x, const T& y, const T& dout) const {
return dout * static_cast<T>(x > y);
}
};
template <typename T>
struct MaxGradYFunctor {
inline HOSTDEVICE T operator()(const T& x, const T& y, const T& dout) const {
return dout * static_cast<T>(x <= y);
}
};

template <typename InT, typename OutT>
struct MaxGradXYFunctor {
inline HOSTDEVICE paddle::framework::Array<OutT, 2> operator()(
const InT& x, const InT& y, const InT& dout) {
paddle::framework::Array<OutT, 2> outs;
// dx = dout * (x > y)
outs[0] = static_cast<OutT>(dout * static_cast<InT>(x > y));
// dy = dout * (x <= y)
outs[1] = static_cast<OutT>(dout * static_cast<InT>(x <= y));
return outs;
}
};

} // namespace operators
} // namespace paddle
30 changes: 28 additions & 2 deletions paddle/fluid/operators/elementwise/elementwise_max_op.cu
Original file line number Diff line number Diff line change
Expand Up @@ -24,15 +24,41 @@ class ElementwiseMaxKernel<platform::CUDADeviceContext, T>
void Compute(const framework::ExecutionContext& ctx) const override {
std::vector<const framework::Tensor*> ins;
std::vector<framework::Tensor*> outs;
const auto& cuda_ctx =
const auto& dev_ctx =
ctx.template device_context<platform::CUDADeviceContext>();

int axis = PackTensorsIntoVector<T>(ctx, &ins, &outs);
LaunchElementwiseCudaKernel<ElementwiseType::kBinary, T, T>(
cuda_ctx, ins, &outs, axis, MaxFunctor<T>());
dev_ctx, ins, &outs, axis, MaxFunctor<T>());
}
};

template <typename DeviceContext, typename T>
typename std::enable_if<
std::is_same<DeviceContext, platform::CUDADeviceContext>::value>::type
ElementwiseMaxGrad(const framework::ExecutionContext& ctx,
const framework::Tensor* x, const framework::Tensor* y,
const framework::Tensor* out, const framework::Tensor* dout,
framework::Tensor* dx, framework::Tensor* dy) {
int axis = ctx.Attr<int>("axis");
const auto& dev_ctx =
ctx.template device_context<platform::CUDADeviceContext>();
const auto place = ctx.GetPlace();
if (dx != nullptr && dy != nullptr) {
std::vector<const framework::Tensor*> ins = {x, y, dout};
GetGradXAndYOut<ElementwiseType::kTernary, T>(
dev_ctx, place, axis, ins, dout, dx, dy, MaxGradXYFunctor<T, T>());
} else if (dx != nullptr && dy == nullptr) {
std::vector<const framework::Tensor*> ins = {x, y, dout};
GetGradXOrYOut<ElementwiseType::kTernary, T>(
dev_ctx, place, axis, ins, dout, dx, MaxGradXFunctor<T>());
} else if (dx == nullptr && dy != nullptr) {
std::vector<const framework::Tensor*> ins = {x, y, dout};
GetGradXOrYOut<ElementwiseType::kTernary, T>(
dev_ctx, place, axis, ins, dout, dy, MaxGradYFunctor<T>());
}
}

} // namespace operators
} // namespace paddle

Expand Down
29 changes: 25 additions & 4 deletions paddle/fluid/operators/elementwise/elementwise_max_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,28 @@ struct MaxGradDy {
}
};

template <typename DeviceContext, typename T>
typename std::enable_if<
std::is_same<DeviceContext, platform::CPUDeviceContext>::value>::type
ElementwiseMaxGrad(const framework::ExecutionContext& ctx,
const framework::Tensor* x, const framework::Tensor* y,
const framework::Tensor* out, const framework::Tensor* dout,
framework::Tensor* dx, framework::Tensor* dy) {
int axis = ctx.Attr<int>("axis");
ElemwiseGradCompute<DeviceContext, T, MaxGradDx<T>, MaxGradDy<T>>(
ctx, *x, *y, *out, *dout, axis, dx, dy, MaxGradDx<T>(), MaxGradDy<T>());
}

#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
template <typename DeviceContext, typename T>
typename std::enable_if<
std::is_same<DeviceContext, platform::CUDADeviceContext>::value>::type
ElementwiseMaxGrad(const framework::ExecutionContext& ctx,
const framework::Tensor* x, const framework::Tensor* y,
const framework::Tensor* out, const framework::Tensor* dout,
framework::Tensor* dx, framework::Tensor* dy);
#endif

template <typename DeviceContext, typename T>
class ElementwiseMaxGradKernel : public ElemwiseGradKernel<T> {
public:
Expand All @@ -74,12 +96,11 @@ class ElementwiseMaxGradKernel : public ElemwiseGradKernel<T> {
auto* x = ctx.Input<Tensor>("X");
auto* y = ctx.Input<Tensor>("Y");
auto* dout = ctx.Input<Tensor>(framework::GradVarName("Out"));
auto* out = dout; // out is not necessary
auto* dx = ctx.Output<Tensor>(framework::GradVarName("X"));
auto* dy = ctx.Output<Tensor>(framework::GradVarName("Y"));
auto* out = dout; // Fake out, not used
int axis = ctx.Attr<int>("axis");
ElemwiseGradCompute<DeviceContext, T, MaxGradDx<T>, MaxGradDy<T>>(
ctx, *x, *y, *out, *dout, axis, dx, dy, MaxGradDx<T>(), MaxGradDy<T>());

ElementwiseMaxGrad<DeviceContext, T>(ctx, x, y, out, dout, dx, dy);
}
};

Expand Down

0 comments on commit 4a64ca1

Please sign in to comment.