forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
hard_sigmoid_op.cc
154 lines (121 loc) · 3.76 KB
/
hard_sigmoid_op.cc
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
#include "caffe2/operators/hard_sigmoid_op.h"
#include <algorithm>
#include <functional>
#include <string>
#include "caffe2/utils/eigen_utils.h"
namespace caffe2 {
template <>
template <typename T>
bool HardSigmoidFunctor<CPUContext>::
operator()(const int N, const T* X, T* Y, CPUContext* /* context */) const {
EigenVectorArrayMap<T>(Y, N) =
(ConstEigenVectorArrayMap<T>(X, N) * T(alpha) + T(beta))
.cwiseMin(T(1))
.cwiseMax(T(0));
return true;
}
template <>
template <typename T>
bool HardSigmoidGradientFunctor<CPUContext>::Forward(
const std::vector<int>& Y_dims,
const std::vector<int>& /* dY_dims */,
const T* Y,
const T* dY,
T* dX,
CPUContext* /* context */) const {
const int size = std::accumulate(
Y_dims.cbegin(), Y_dims.cend(), 1, std::multiplies<int>());
ConstEigenVectorArrayMap<T> Y_arr(Y, size);
EigenVectorArrayMap<T>(dX, size) =
(Y_arr > T(0) && Y_arr < T(1))
.select(ConstEigenVectorArrayMap<T>(dY, size) * alpha, T(0));
return true;
}
namespace {
OpSchema::Cost CostInferenceForHardSigmoid(
const OperatorDef& def,
const vector<TensorShape>& in) {
struct OpSchema::Cost cost = PointwiseCostInference<4>(def, in);
cost.params_bytes = 0;
return cost;
}
} // namespace
REGISTER_CPU_OPERATOR(
HardSigmoid,
UnaryElementwiseWithArgsOp<
TensorTypes<float>,
CPUContext,
HardSigmoidFunctor<CPUContext>>);
REGISTER_CPU_OPERATOR(
HardSigmoidGradient,
BinaryElementwiseWithArgsOp<
TensorTypes<float>,
CPUContext,
HardSigmoidGradientFunctor<CPUContext>>);
// Input: X, output: Y
OPERATOR_SCHEMA(HardSigmoid)
.NumInputs(1)
.NumOutputs(1)
.AllowInplace({{0, 0}})
.CostInferenceFunction(CostInferenceForHardSigmoid)
.IdenticalTypeAndShape()
.SetDoc(R"DOC(
Applies hard sigmoid operation to the input data element-wise.
The HardSigmoid operation takes one input $X$, produces one output $Y$, and is defined as:
$$Y = max(0,min(1,x * alpha + beta))$$
Github Links:
- https://github.com/pytorch/pytorch/blob/master/caffe2/operators/hard_sigmoid_op.h
- https://github.com/pytorch/pytorch/blob/master/caffe2/operators/hard_sigmoid_op.cc
<details>
<summary> <b>Example</b> </summary>
**Code**
```
workspace.ResetWorkspace()
op = core.CreateOperator(
"HardSigmoid",
["X"],
["Y"],
alpha = 0.2,
beta = 0.5,
)
workspace.FeedBlob("X", np.random.randn(5).astype(np.float32))
print("input:", workspace.FetchBlob("X"))
workspace.RunOperatorOnce(op)
print("sigmoid:", workspace.FetchBlob("Y"))
```
**Result**
```
input: [ 1.5744036 0.31632107 1.7842269 1.4450722 -2.1726978 ]
hard_sigmoid: [ 0.81488073, 0.56326419, 0.85684538, 0.78901446, 0.06546044]
```
</details>
)DOC")
.Arg("alpha", "float: the slope of the function. Defaults to 0.2")
.Arg("beta", "float: the bias value of the function. Defaults to 0.5")
.Input(0, "X", "1D input tensor")
.Output(0, "Y", "1D output tensor with same shape as input")
.InheritOnnxSchema();
// Input: Y, dY, output: dX
OPERATOR_SCHEMA(HardSigmoidGradient)
.NumInputs(2)
.NumOutputs(1)
.AllowInplace({{1, 0}})
.SetDoc(R"DOC(
HardSigmoidGradient takes both Y and dY as well as an argument alpha and uses
this to update dX according to the chain rule and derivatives of the hard
sigmoid function.
)DOC");
namespace {
class GetHardSigmoidGradient : public GradientMakerBase {
using GradientMakerBase::GradientMakerBase;
std::vector<OperatorDef> GetGradientDefs() override {
return SingleGradientDef(
def_.type() + "Gradient",
"",
std::vector<std::string>{O(0), GO(0)},
std::vector<std::string>{GI(0)});
}
};
} // namespace
REGISTER_GRADIENT(HardSigmoid, GetHardSigmoidGradient);
} // namespace caffe2