Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
fix for amalgamation build failure
Browse files Browse the repository at this point in the history
  • Loading branch information
Hao Jin committed Mar 21, 2018
1 parent 8b952a4 commit fb0c6c2
Show file tree
Hide file tree
Showing 3 changed files with 29 additions and 9 deletions.
32 changes: 25 additions & 7 deletions src/operator/leaky_relu-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,10 @@ class LeakyReLUOp : public Operator {
out = out_data[leakyrelu::kOut].get_with_shape<xpu, 3, DType>(dshape, s);
switch (param_.act_type) {
case leakyrelu::kLeakyReLU: {
Assign(out, req[leakyrelu::kOut], F<mshadow_op::xelu>(data, DType(param_.slope)));
MXNET_ASSIGN_REQ_SWITCH(req[leakyrelu::kOut], Req, {
mxnet_op::Kernel<mxnet_op::op_with_req<mxnet::op::mshadow_op::xelu, Req>, xpu>::Launch(
s, out.size(0) * out.size(1) * out.size(2), out.dptr_, data.dptr_, DType(param_.slope));
});
break;
}
case leakyrelu::kPReLU: {
Expand Down Expand Up @@ -146,12 +149,19 @@ class LeakyReLUOp : public Operator {
});
} else {
const float slope = (param_.lower_bound + param_.upper_bound) / 2.0f;
Assign(out, req[leakyrelu::kOut], F<mshadow_op::xelu>(data, DType(slope)));
MXNET_ASSIGN_REQ_SWITCH(req[leakyrelu::kOut], Req, {
mxnet_op::Kernel<mxnet_op::op_with_req<mxnet::op::mshadow_op::xelu, Req>, xpu>::Launch(
s, out.size(0) * out.size(1) * out.size(2), out.dptr_, data.dptr_, DType(slope));
});
}
break;
}
case leakyrelu::kELU: {
Assign(out, req[leakyrelu::kOut], F<mshadow_op::elu>(data, DType(param_.slope)));
MXNET_ASSIGN_REQ_SWITCH(req[leakyrelu::kOut], Req, {
mxnet_op::Kernel<mxnet_op::op_with_req<mxnet::op::mshadow_op::elu, Req>, xpu>::Launch(
s, out.size(0) * out.size(1) * out.size(2), out.dptr_, data.dptr_,
DType(param_.slope));
});
break;
}
default:
Expand Down Expand Up @@ -194,8 +204,12 @@ class LeakyReLUOp : public Operator {
}
switch (param_.act_type) {
case leakyrelu::kLeakyReLU: {
Assign(gdata, req[leakyrelu::kData],
F<mshadow_op::xelu_grad>(output, DType(param_.slope)) * grad);
MXNET_ASSIGN_REQ_SWITCH(req[leakyrelu::kData], Req, {
mxnet_op::Kernel<mxnet_op::op_with_req<
mxnet_op::backward_grad_tuned<mxnet::op::mshadow_op::xelu_grad>, Req>, xpu>::Launch(
s, gdata.size(0) * gdata.size(1) * gdata.size(2), gdata.dptr_, grad.dptr_,
output.dptr_, DType(param_.slope));
});
break;
}
case leakyrelu::kPReLU: {
Expand Down Expand Up @@ -223,8 +237,12 @@ class LeakyReLUOp : public Operator {
break;
}
case leakyrelu::kELU: {
Assign(gdata, req[leakyrelu::kData],
F<mshadow_op::elu_grad>(output, DType(param_.slope)) * grad);
MXNET_ASSIGN_REQ_SWITCH(req[leakyrelu::kData], Req, {
mxnet_op::Kernel<mxnet_op::op_with_req<
mxnet_op::backward_grad_tuned<mxnet::op::mshadow_op::elu_grad>, Req>, xpu>::Launch(
s, gdata.size(0) * gdata.size(1) * gdata.size(2), gdata.dptr_, grad.dptr_,
output.dptr_, DType(param_.slope));
});
break;
}
default:
Expand Down
4 changes: 2 additions & 2 deletions src/operator/mshadow_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -131,8 +131,8 @@ MXNET_BINARY_MATH_OP_NC(xelu, a > DType(0) ? a :

MXNET_BINARY_MATH_OP_NC(xelu_grad, a > DType(0) ? DType(1) : b);

MXNET_BINARY_MATH_OP(elu, a > DType(0) ? math::id(a) :
math::id(b) * math::expm1(a));
MXNET_BINARY_MATH_OP_NC(elu, a > DType(0) ? a :
DType(math::id(b) * math::expm1(a)));

MXNET_BINARY_MATH_OP_NC(elu_grad, a > DType(0) ? DType(1) : DType(b + a));

Expand Down
2 changes: 2 additions & 0 deletions src/operator/operator_tune.cc
Original file line number Diff line number Diff line change
Expand Up @@ -315,10 +315,12 @@ IMPLEMENT_BINARY_WORKLOAD_BWD(mxnet::op::mshadow_op::right); // NOLINT()
IMPLEMENT_BINARY_WORKLOAD_FWD(mxnet::op::mshadow_op::power); // NOLINT()
IMPLEMENT_BINARY_WORKLOAD_FWD(mxnet::op::mshadow_op::rpower); // NOLINT()
IMPLEMENT_BINARY_WORKLOAD_FWD(mxnet::op::mshadow_op::xelu); // NOLINT()
IMPLEMENT_BINARY_WORKLOAD_FWD(mxnet::op::mshadow_op::elu); // NOLINT()
IMPLEMENT_BINARY_WORKLOAD_BWD(mxnet::op::mshadow_op::power_grad); // NOLINT()
IMPLEMENT_BINARY_WORKLOAD_BWD(mxnet::op::mshadow_op::rpower_grad); // NOLINT()
IMPLEMENT_BINARY_WORKLOAD_BWD(mxnet::op::mshadow_op::power_rgrad); // NOLINT()
IMPLEMENT_BINARY_WORKLOAD_BWD(mxnet::op::mshadow_op::xelu_grad); // NOLINT()
IMPLEMENT_BINARY_WORKLOAD_BWD(mxnet::op::mshadow_op::elu_grad); // NOLINT()
IMPLEMENT_BINARY_WORKLOAD_FWD(mxnet::op::mshadow_op::maximum); // NOLINT()
IMPLEMENT_BINARY_WORKLOAD_FWD(mxnet::op::mshadow_op::minimum); // NOLINT()
IMPLEMENT_BINARY_WORKLOAD_FWD(mxnet::op::mshadow_op::hypot); // NOLINT()
Expand Down

0 comments on commit fb0c6c2

Please sign in to comment.