forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
gru_unit_op.h
240 lines (205 loc) · 6.47 KB
/
gru_unit_op.h
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
#ifndef CAFFE2_OPERATORS_GRU_UNIT_OP_H_
#define CAFFE2_OPERATORS_GRU_UNIT_OP_H_
#include "caffe2/core/context.h"
#include "caffe2/core/operator.h"
#include "caffe2/utils/math.h"
namespace caffe2 {
namespace detail {
template <typename T>
inline T sigmoid(T x) {
return 1.0f / (1.0f + exp(-x));
}
template <typename T>
inline T host_tanh(T x) {
return 2.0f * sigmoid(2.0f * x) - 1.0f;
}
template <typename T, typename Context>
void GRUUnit(
int N,
int D,
int t,
const T* H_prev,
const T* X,
const int32_t* seqLengths,
bool drop_states,
T* H,
Context* /*context*/) {
for (int n = 0; n < N; ++n) {
const bool valid = seqLengths == nullptr || t < seqLengths[n];
for (int d = 0; d < D; ++d) {
if (!valid) {
if (drop_states) {
H[d] = 0;
} else {
H[d] = H_prev[d];
}
} else {
const T update = X[1 * D + d];
const T output = X[2 * D + d];
T sigmoid_update = sigmoid(update);
H[d] = H_prev[d] * sigmoid_update +
host_tanh(output) * (1.0f - sigmoid_update);
}
}
H_prev += D;
X += 3 * D;
H += D;
}
}
template <typename T, typename Context>
void GRUUnitGradient(
int N,
int D,
int t,
const T* H_prev,
const T* X,
const int32_t* seqLengths,
const T* H,
const T* H_diff,
bool drop_states,
T* H_prev_diff,
T* X_diff,
Context* /*context*/) {
for (int n = 0; n < N; ++n) {
const bool valid = seqLengths == nullptr || t < seqLengths[n];
for (int d = 0; d < D; ++d) {
T* h_prev_diff = H_prev_diff + d;
T* reset_diff = X_diff + 0 * D + d;
T* update_diff = X_diff + 1 * D + d;
T* output_diff = X_diff + 2 * D + d;
if (!valid) {
if (drop_states) {
*h_prev_diff = 0;
} else {
*h_prev_diff = H_diff[d];
}
*reset_diff = 0;
*update_diff = 0;
*output_diff = 0;
} else {
// Calculate Gate Outputs
const T u = sigmoid(X[1 * D + d]);
const T o = host_tanh(X[2 * D + d]);
*h_prev_diff = H_diff[d] * u;
*reset_diff = 0; // 0 contribution to gradient from this operation
*update_diff = (H_diff[d] * H_prev[d] - H_diff[d] * o) * u * (1.0f - u);
*output_diff = H_diff[d] * (1.0f - u) * (1.0f - o * o);
}
}
H_prev += D;
X += 3 * D;
H += D;
H_diff += D;
X_diff += 3 * D;
H_prev_diff += D;
}
}
} // namespace detail
template <typename T, typename Context>
class GRUUnitOp : public Operator<Context> {
public:
template <class... Args>
explicit GRUUnitOp(Args&&... args)
: Operator<Context>(std::forward<Args>(args)...),
drop_states_(
this->template GetSingleArgument<bool>("drop_states", false)),
sequence_lengths_(
this->template GetSingleArgument<bool>("sequence_lengths", true)) {}
USE_OPERATOR_CONTEXT_FUNCTIONS;
bool RunOnDevice() override {
// handle potentially-missing sequence lengths input
const size_t TIMESTEP = SEQ_LENGTHS + (sequence_lengths_ ? 1 : 0);
// Extract N
const auto N = Input(HIDDEN_T_M_1).size(1);
// Gates: 1xNxG
const auto G = Input(GATES).size(2);
const auto D = Input(HIDDEN_T_M_1).size(2);
CAFFE_ENFORCE_EQ(3 * D, G);
const auto* H_prev = Input(HIDDEN_T_M_1).template data<T>();
const auto* X = Input(GATES).template data<T>();
const int32_t* seqLengths = nullptr;
if (sequence_lengths_) {
CAFFE_ENFORCE_EQ(Input(SEQ_LENGTHS).numel(), N);
seqLengths = Input(SEQ_LENGTHS).template data<int32_t>();
}
const auto t = static_cast<OperatorBase*>(this)
->Input<Tensor>(TIMESTEP, CPU)
.template data<int32_t>()[0];
Output(HIDDEN_T)->ResizeLike(Input(HIDDEN_T_M_1));
auto* H = Output(HIDDEN_T)->template mutable_data<T>();
detail::GRUUnit<T, Context>(
N, D, t, H_prev, X, seqLengths, drop_states_, H, &context_);
return true;
}
protected:
INPUT_TAGS(HIDDEN_T_M_1, GATES, SEQ_LENGTHS);
// additional input tags are determined dynamically based on whether
// sequence_lengths is present.
OUTPUT_TAGS(HIDDEN_T);
private:
bool drop_states_;
bool sequence_lengths_;
};
template <typename T, typename Context>
class GRUUnitGradientOp : public Operator<Context> {
public:
template <class... Args>
explicit GRUUnitGradientOp(Args&&... args)
: Operator<Context>(std::forward<Args>(args)...),
drop_states_(
this->template GetSingleArgument<bool>("drop_states", false)),
sequence_lengths_(
this->template GetSingleArgument<bool>("sequence_lengths", true)) {}
USE_OPERATOR_CONTEXT_FUNCTIONS;
bool RunOnDevice() override {
// handle potentially-missing sequence lengths input
const size_t inputOffset = SEQ_LENGTHS + (sequence_lengths_ ? 1 : 0);
const size_t TIMESTEP = inputOffset;
const size_t HIDDEN_T = inputOffset + 1;
const size_t HIDDEN_T_GRAD = inputOffset + 2;
// Extract N
const auto N = Input(HIDDEN_T_M_1).size(1);
// Gates: 1xNxG
const auto G = Input(GATES).size(2);
const auto D = Input(HIDDEN_T_M_1).size(2);
CAFFE_ENFORCE_EQ(3 * D, G);
const auto* H_prev = Input(HIDDEN_T_M_1).template data<T>();
const auto* X = Input(GATES).template data<T>();
const auto t = static_cast<OperatorBase*>(this)
->Input<Tensor>(TIMESTEP, CPU)
.template data<int32_t>()[0];
const auto* H = Input(HIDDEN_T).template data<T>();
const auto* H_diff = Input(HIDDEN_T_GRAD).template data<T>();
const int32_t* seqLengths = nullptr;
if (sequence_lengths_) {
CAFFE_ENFORCE_EQ(Input(SEQ_LENGTHS).numel(), N);
seqLengths = Input(SEQ_LENGTHS).template data<int32_t>();
}
Output(HIDDEN_T_M_1_GRAD)->ResizeLike(Input(HIDDEN_T_M_1));
auto* H_prev_diff = Output(HIDDEN_T_M_1_GRAD)->template mutable_data<T>();
Output(GATES_GRAD)->ResizeLike(Input(GATES));
auto* X_diff = Output(GATES_GRAD)->template mutable_data<T>();
detail::GRUUnitGradient<T, Context>(
N,
D,
t,
H_prev,
X,
seqLengths,
H,
H_diff,
drop_states_,
H_prev_diff,
X_diff,
&context_);
return true;
}
protected:
INPUT_TAGS(HIDDEN_T_M_1, GATES, SEQ_LENGTHS);
OUTPUT_TAGS(HIDDEN_T_M_1_GRAD, GATES_GRAD);
private:
bool drop_states_;
bool sequence_lengths_;
};
} // namespace caffe2
#endif // CAFFE2_OPERATORS_GRU_UNIT_OP_H_