Skip to content

Commit

Permalink
Implemented the hardShrink activation (#4653)
Browse files Browse the repository at this point in the history
* Implemented the hardShrink activation

* Fixing the unit test
  • Loading branch information
kavyasrinet committed Oct 11, 2017
1 parent 6604d7c commit 1397e17
Show file tree
Hide file tree
Showing 3 changed files with 76 additions and 3 deletions.
21 changes: 21 additions & 0 deletions paddle/operators/activation_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,24 @@ class TanhShrinkOpMaker : public framework::OpProtoAndCheckerMaker {
}
};

template <typename AttrType>
class HardShrinkOpMaker : public framework::OpProtoAndCheckerMaker {
public:
HardShrinkOpMaker(framework::OpProto *proto,
framework::OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "Input of HardShrink operator");
AddOutput("Y", "Output of HardShrink operator");
AddComment(
"HardShrink activation operator, "
"hard_shrink(x) = x if x > lambda"
"hard_shrink(x) = x if x < -lambda"
"hard_shrink(x) = 0 otherwise");
AddAttr<AttrType>("threshold", "The value of threshold for HardShrink")
.SetDefault(static_cast<AttrType>(0.5));
}
};

class SqrtOpMaker : public framework::OpProtoAndCheckerMaker {
public:
SqrtOpMaker(framework::OpProto *proto, framework::OpAttrChecker *op_checker)
Expand Down Expand Up @@ -357,6 +375,9 @@ REGISTER_OP(pow, ops::ActivationOp, ops::PowOpMaker<float>, pow_grad,
REGISTER_OP(stanh, ops::ActivationOp, ops::STanhOpMaker<float>, stanh_grad,
ops::ActivationOpGrad);

REGISTER_OP(hard_shrink, ops::ActivationOp, ops::HardShrinkOpMaker<float>,
hard_shrink_grad, ops::ActivationOpGrad);

#define REGISTER_ACTIVATION_CPU_KERNEL(act_type, functor, grad_functor) \
REGISTER_OP_CPU_KERNEL( \
act_type, \
Expand Down
38 changes: 35 additions & 3 deletions paddle/operators/activation_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -199,6 +199,39 @@ struct TanhShrinkGradFunctor : public BaseActivationFunctor<T> {
}
};

// tanhshrink(x) = x - tanh(x)
// where tanh(x) = (exp(x) - exp(-x)) / (exp(x) + exp(-x))
template <typename T>
struct HardShrinkFunctor : public BaseActivationFunctor<T> {
float threshold;

typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
return {{"threshold", &threshold}};
}
template <typename Device, typename X, typename Y>
void operator()(Device d, X x, Y y) const {
auto temp1 = (x < (threshold * -1)).template cast<T>().eval();
auto temp2 = (x > threshold).template cast<T>().eval();
y.device(d) = x * (temp1 + temp2);
}
};

template <typename T>
struct HardShrinkGradFunctor : public BaseActivationFunctor<T> {
float threshold;

typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
return {{"threshold", &threshold}};
}

template <typename Device, typename X, typename Y, typename dY, typename dX>
void operator()(Device d, X x, Y y, dY dy, dX dx) const {
auto temp1 = (x < (threshold * -1)).template cast<T>().eval();
auto temp2 = (x > threshold).template cast<T>().eval();
dx.device(d) = dy * (temp1 + temp2).template cast<T>();
}
};

// softshrink(x) = x - lambda, if x > lambda; x + lambda, if x < lambda; 0
// otherwise
template <typename T>
Expand Down Expand Up @@ -351,8 +384,6 @@ template <typename T>
struct Relu6Functor : public BaseActivationFunctor<T> {
float threshold;

// NOTE: Explicit hides the `BaseActivationFunctor<T>::GetAttrs`
// not polymorphism for speed.
typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
return {{"threshold", &threshold}};
}
Expand Down Expand Up @@ -555,4 +586,5 @@ struct STanhGradFunctor : public BaseActivationFunctor<T> {
__macro(relu6, Relu6Functor, Relu6GradFunctor); \
__macro(leaky_relu, LeakyReluFunctor, LeakyReluGradFunctor); \
__macro(tanh_shrink, TanhShrinkFunctor, TanhShrinkGradFunctor); \
__macro(elu, ELUFunctor, ELUGradFunctor)
__macro(elu, ELUFunctor, ELUGradFunctor); \
__macro(hard_shrink, HardShrinkFunctor, HardShrinkGradFunctor)
20 changes: 20 additions & 0 deletions python/paddle/v2/framework/tests/test_activation_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,26 @@ def test_check_grad(self):
self.check_grad(['X'], 'Y', max_relative_error=0.008)


class TestHardShrink(OpTest):
def setUp(self):
self.op_type = "hard_shrink"
x = np.random.uniform(-1, 1, [4, 4]).astype("float32")
threshold = 0.5

self.inputs = {'X': x}
self.attrs = {'lambda': threshold}

t = np.copy(x)
t[(t >= -threshold) & (t <= threshold)] = 0
self.outputs = {'Y': t}

def test_check_output(self):
self.check_output()

def test_check_grad(self):
self.check_grad(['X'], 'Y', max_relative_error=0.005)


class TestSoftShrink(OpTest):
def setUp(self):
self.op_type = "softshrink"
Expand Down

0 comments on commit 1397e17

Please sign in to comment.