forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
gru_unit_op_gpu.cu
140 lines (130 loc) · 3.49 KB
/
gru_unit_op_gpu.cu
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
#include <algorithm>
#include <cmath>
#include <vector>
#include "caffe2/core/context_gpu.h"
#include "caffe2/operators/gru_unit_op.h"
namespace caffe2 {
namespace detail {
template <typename Dtype>
__device__ Dtype cuda_sigmoid(const Dtype x) {
return Dtype(1) / (Dtype(1) + exp(-x));
}
template <typename T>
__global__ void GRUUnitKernel(
const int ND,
const int dim,
const int t,
const T* H_prev,
const T* X,
const int32_t* seqLengths,
bool drop_states,
T* H) {
// index is virtual thread ID in range [0, ND)
CUDA_1D_KERNEL_LOOP(index, ND) {
const int n = index / dim;
const int d = index % dim;
const bool valid = seqLengths == nullptr || t < seqLengths[n];
if (!valid) {
H[index] = H_prev[index] * !drop_states;
} else {
const T* X_offset = X + 3 * dim * n;
const T update = X_offset[1 * dim + d];
const T output = X_offset[2 * dim + d];
T sigmoid_update = cuda_sigmoid(update);
H[index] = H_prev[index] * sigmoid_update +
tanh(output) * (1.0f - sigmoid_update);
}
}
}
template <typename T>
__global__ void GRUUnitGradientKernel(
const int ND,
const int dim,
const int t,
const T* H_prev,
const T* X,
const int32_t* seqLengths,
const T* H,
const T* H_diff,
bool drop_states,
T* H_prev_diff,
T* X_diff) {
CUDA_1D_KERNEL_LOOP(index, ND) {
const int n = index / dim;
const bool valid = seqLengths == nullptr || t < seqLengths[n];
const int d = index % dim;
const T* X_offset = X + 3 * dim * n;
T* h_prev_diff = H_prev_diff + index;
T* X_diff_offset = X_diff + 3 * dim * n;
T* reset_diff = X_diff_offset + 0 * dim + d;
T* update_diff = X_diff_offset + 1 * dim + d;
T* output_diff = X_diff_offset + 2 * dim + d;
if (!valid) {
*h_prev_diff = H_diff[index] * !drop_states;
*reset_diff = 0;
*update_diff = 0;
*output_diff = 0;
} else {
const T u = cuda_sigmoid(X_offset[1 * dim + d]);
const T o = tanh(X_offset[2 * dim + d]);
*h_prev_diff = H_diff[index] * u;
*reset_diff = 0; // 0 contribution to gradient from this operation
*update_diff =
(H_diff[index] * H_prev[index] - H_diff[index] * o) * u * (1.0f - u);
*output_diff = H_diff[index] * (1.0f - u) * (1.0f - o * o);
}
}
}
template <>
void GRUUnit<float, CUDAContext>(
int N,
int D,
int t,
const float* H_prev,
const float* X,
const int32_t* seqLengths,
bool drop_states,
float* H,
CUDAContext* context) {
GRUUnitKernel<float>
<<<CAFFE_GET_BLOCKS(N * D),
CAFFE_CUDA_NUM_THREADS,
0,
context->cuda_stream()>>>(
N * D, D, t, H_prev, X, seqLengths, drop_states, H);
}
template <>
void GRUUnitGradient<float, CUDAContext>(
int N,
int D,
int t,
const float* H_prev,
const float* X,
const int32_t* seqLengths,
const float* H,
const float* H_diff,
bool drop_states,
float* H_prev_diff,
float* X_diff,
CUDAContext* context) {
GRUUnitGradientKernel<float>
<<<CAFFE_GET_BLOCKS(N * D),
CAFFE_CUDA_NUM_THREADS,
0,
context->cuda_stream()>>>(
N * D,
D,
t,
H_prev,
X,
seqLengths,
H,
H_diff,
drop_states,
H_prev_diff,
X_diff);
}
}
REGISTER_CUDA_OPERATOR(GRUUnit, GRUUnitOp<float, CUDAContext>);
REGISTER_CUDA_OPERATOR(GRUUnitGradient, GRUUnitGradientOp<float, CUDAContext>);
}