From fb0c6c2ddf6ab0113598a0c5b7b9e1f16213e46a Mon Sep 17 00:00:00 2001 From: Hao Jin Date: Wed, 21 Mar 2018 16:17:40 +0000 Subject: [PATCH] fix for amalgamation build failure --- src/operator/leaky_relu-inl.h | 32 +++++++++++++++++++++++++------- src/operator/mshadow_op.h | 4 ++-- src/operator/operator_tune.cc | 2 ++ 3 files changed, 29 insertions(+), 9 deletions(-) diff --git a/src/operator/leaky_relu-inl.h b/src/operator/leaky_relu-inl.h index b5f534fb73bd..6158432844e4 100644 --- a/src/operator/leaky_relu-inl.h +++ b/src/operator/leaky_relu-inl.h @@ -106,7 +106,10 @@ class LeakyReLUOp : public Operator { out = out_data[leakyrelu::kOut].get_with_shape(dshape, s); switch (param_.act_type) { case leakyrelu::kLeakyReLU: { - Assign(out, req[leakyrelu::kOut], F(data, DType(param_.slope))); + MXNET_ASSIGN_REQ_SWITCH(req[leakyrelu::kOut], Req, { + mxnet_op::Kernel, xpu>::Launch( + s, out.size(0) * out.size(1) * out.size(2), out.dptr_, data.dptr_, DType(param_.slope)); + }); break; } case leakyrelu::kPReLU: { @@ -146,12 +149,19 @@ class LeakyReLUOp : public Operator { }); } else { const float slope = (param_.lower_bound + param_.upper_bound) / 2.0f; - Assign(out, req[leakyrelu::kOut], F(data, DType(slope))); + MXNET_ASSIGN_REQ_SWITCH(req[leakyrelu::kOut], Req, { + mxnet_op::Kernel, xpu>::Launch( + s, out.size(0) * out.size(1) * out.size(2), out.dptr_, data.dptr_, DType(slope)); + }); } break; } case leakyrelu::kELU: { - Assign(out, req[leakyrelu::kOut], F(data, DType(param_.slope))); + MXNET_ASSIGN_REQ_SWITCH(req[leakyrelu::kOut], Req, { + mxnet_op::Kernel, xpu>::Launch( + s, out.size(0) * out.size(1) * out.size(2), out.dptr_, data.dptr_, + DType(param_.slope)); + }); break; } default: @@ -194,8 +204,12 @@ class LeakyReLUOp : public Operator { } switch (param_.act_type) { case leakyrelu::kLeakyReLU: { - Assign(gdata, req[leakyrelu::kData], - F(output, DType(param_.slope)) * grad); + MXNET_ASSIGN_REQ_SWITCH(req[leakyrelu::kData], Req, { + mxnet_op::Kernel, Req>, xpu>::Launch( + s, gdata.size(0) * gdata.size(1) * gdata.size(2), gdata.dptr_, grad.dptr_, + output.dptr_, DType(param_.slope)); + }); break; } case leakyrelu::kPReLU: { @@ -223,8 +237,12 @@ class LeakyReLUOp : public Operator { break; } case leakyrelu::kELU: { - Assign(gdata, req[leakyrelu::kData], - F(output, DType(param_.slope)) * grad); + MXNET_ASSIGN_REQ_SWITCH(req[leakyrelu::kData], Req, { + mxnet_op::Kernel, Req>, xpu>::Launch( + s, gdata.size(0) * gdata.size(1) * gdata.size(2), gdata.dptr_, grad.dptr_, + output.dptr_, DType(param_.slope)); + }); break; } default: diff --git a/src/operator/mshadow_op.h b/src/operator/mshadow_op.h index bdd15e0f357e..5606c64369ad 100644 --- a/src/operator/mshadow_op.h +++ b/src/operator/mshadow_op.h @@ -131,8 +131,8 @@ MXNET_BINARY_MATH_OP_NC(xelu, a > DType(0) ? a : MXNET_BINARY_MATH_OP_NC(xelu_grad, a > DType(0) ? DType(1) : b); -MXNET_BINARY_MATH_OP(elu, a > DType(0) ? math::id(a) : - math::id(b) * math::expm1(a)); +MXNET_BINARY_MATH_OP_NC(elu, a > DType(0) ? a : + DType(math::id(b) * math::expm1(a))); MXNET_BINARY_MATH_OP_NC(elu_grad, a > DType(0) ? DType(1) : DType(b + a)); diff --git a/src/operator/operator_tune.cc b/src/operator/operator_tune.cc index ef7c94b67238..c48d83a3be87 100644 --- a/src/operator/operator_tune.cc +++ b/src/operator/operator_tune.cc @@ -315,10 +315,12 @@ IMPLEMENT_BINARY_WORKLOAD_BWD(mxnet::op::mshadow_op::right); // NOLINT() IMPLEMENT_BINARY_WORKLOAD_FWD(mxnet::op::mshadow_op::power); // NOLINT() IMPLEMENT_BINARY_WORKLOAD_FWD(mxnet::op::mshadow_op::rpower); // NOLINT() IMPLEMENT_BINARY_WORKLOAD_FWD(mxnet::op::mshadow_op::xelu); // NOLINT() +IMPLEMENT_BINARY_WORKLOAD_FWD(mxnet::op::mshadow_op::elu); // NOLINT() IMPLEMENT_BINARY_WORKLOAD_BWD(mxnet::op::mshadow_op::power_grad); // NOLINT() IMPLEMENT_BINARY_WORKLOAD_BWD(mxnet::op::mshadow_op::rpower_grad); // NOLINT() IMPLEMENT_BINARY_WORKLOAD_BWD(mxnet::op::mshadow_op::power_rgrad); // NOLINT() IMPLEMENT_BINARY_WORKLOAD_BWD(mxnet::op::mshadow_op::xelu_grad); // NOLINT() +IMPLEMENT_BINARY_WORKLOAD_BWD(mxnet::op::mshadow_op::elu_grad); // NOLINT() IMPLEMENT_BINARY_WORKLOAD_FWD(mxnet::op::mshadow_op::maximum); // NOLINT() IMPLEMENT_BINARY_WORKLOAD_FWD(mxnet::op::mshadow_op::minimum); // NOLINT() IMPLEMENT_BINARY_WORKLOAD_FWD(mxnet::op::mshadow_op::hypot); // NOLINT()