forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
RReLU.cu
120 lines (112 loc) · 3.87 KB
/
RReLU.cu
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
#ifndef THC_GENERIC_FILE
#define THC_GENERIC_FILE "THCUNN/generic/RReLU.cu"
#else
#include <THCUNN/common.h>
#include <ATen/CUDAGenerator.h>
void THNN_(RReLU_updateOutput)(
THCState *state,
THCTensor *input,
THCTensor *output,
THCTensor *noise,
double lower,
double upper,
bool train,
bool inplace,
void *generator)
{
THCUNN_assertSameGPU(state, 3, input, output, noise);
auto gen = at::cuda::detail::getDefaultCUDAGenerator();
if (train)
{
input = THCTensor_(newContiguous)(state, input);
THCTensor_(resizeAs)(state, noise, input);
scalar_t *input_data = THCTensor_(data)(state, input);
scalar_t *noise_data = THCTensor_(data)(state, noise);
ptrdiff_t n = THCTensor_(nElement)(state, input);
// philox offset calculation for grid-stride loop utilizing curand4
const uint32_t curand4_engine_calls = 4;
dim3 grid = NUM_BLOCKS(n);
uint64_t counter_offset = ((n - 1) / (BLOCK_SIZE * grid.x) + 1) * curand4_engine_calls;
std::pair<uint64_t, uint64_t> rng_engine_inputs;
{
// See Note [Acquire lock when using random generators]
std::lock_guard<std::mutex> lock(gen->mutex_);
rng_engine_inputs = gen->philox_engine_inputs(counter_offset);
}
if (inplace)
{
rreluUpdateOutputTrain<<<grid, BLOCK_SIZE, 0, c10::cuda::getCurrentCUDAStream()>>>(
n, rng_engine_inputs, input_data, noise_data, input_data, lower, upper);
THCTensor_(set)(state, output, input);
}
else
{
THCTensor_(resizeAs)(state, output, input);
scalar_t *output_data = THCTensor_(data)(state, output);
rreluUpdateOutputTrain<<<grid, BLOCK_SIZE, 0, c10::cuda::getCurrentCUDAStream()>>>(
n, rng_engine_inputs, input_data, noise_data, output_data, lower, upper);
}
THCudaCheck(cudaGetLastError());
THCTensor_(free)(state, input);
}
else
{
const scalar_t negSlope = ScalarConvert<double, scalar_t>::to((lower + upper) / 2);
if (inplace)
{
THC_pointwiseApply1<scalar_t>(state, input, RReLUUpdateOutputEvalIP_functor<scalar_t>(negSlope));
THCTensor_(set)(state, output, input);
}
else
{
THCTensor_(resizeAs)(state, output, input);
THC_pointwiseApply2<scalar_t, scalar_t>(state, output, input, RReLUUpdateOutputEval_functor<scalar_t>(negSlope));
}
}
}
void THNN_(RReLU_updateGradInput)(
THCState *state,
THCTensor *input,
THCTensor *gradOutput,
THCTensor *gradInput,
THCTensor *noise,
double lower,
double upper,
bool train,
bool inplace)
{
THCUNN_check_nElement(state, input, gradOutput);
THCUNN_assertSameGPU(state, 4, input, gradOutput, gradInput, noise);
gradOutput = THCTensor_(newContiguous)(state, gradOutput);
if (train && upper - lower > 1E-6) // e.g. if upper == lower, RReLU behaves like LeakyReLU
{
// multiply the gradient by the noise tensor
if (inplace)
{
THCTensor_(cmul)(state, gradOutput, gradOutput, noise);
THCTensor_(set)(state, gradInput, gradOutput);
}
else
{
THCTensor_(resizeAs)(state, gradInput, input);
THCTensor_(cmul)(state, gradInput, gradOutput, noise);
}
}
else
{
// use constant factor for negative input values
const scalar_t negSlope = ScalarConvert<double, scalar_t>::to((lower + upper) / 2);
if (inplace)
{
THC_pointwiseApply2<scalar_t, scalar_t>(state, gradOutput, input, RReLUupdateGradInputEvalIP_functor<scalar_t>(negSlope));
THCTensor_(set)(state, gradInput, gradOutput);
}
else
{
THCTensor_(resizeAs)(state, gradInput, input);
THC_pointwiseApply3<scalar_t, scalar_t, scalar_t>(state, gradInput, gradOutput, input, RReLUupdateGradInputEval_functor<scalar_t>(negSlope));
}
}
THCTensor_(free)(state, gradOutput);
}
#endif