forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
LookupTable.cu
210 lines (181 loc) · 6.88 KB
/
LookupTable.cu
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
#ifndef THC_GENERIC_FILE
#define THC_GENERIC_FILE "THCUNN/generic/LookupTable.cu"
#else
void THNN_(LookupTable_accGradParameters)(
THCState *state,
THCIndexTensor *input,
THCTensor *gradOutput,
THCTensor *gradWeight,
THCIndexTensor *count,
THCIndexTensor *sortedIndices,
THCIndexTensor *origIndices,
bool scaleGradByFreq,
int paddingValue,
accreal scale_)
{
scalar_t scale = ScalarConvert<accreal, scalar_t>::to(scale_);
THCUNN_assertSameGPU(state, 5, input, gradOutput, gradWeight, sortedIndices, origIndices);
gradOutput = THCTensor_(newContiguous)(state, gradOutput);
if (!(THCIndexTensor_(isContiguous)(state, input) &&
THCTensor_(isContiguous)(state, gradWeight))) {
THError("Tensors must be contiguous");
}
int nDim = THCIndexTensor_(nDimensionLegacyAll)(state, input);
if (THCIndexTensor_(nDimensionLegacyAll)(state, input) != 1 && THCIndexTensor_(nDimensionLegacyAll)(state, input) != 2) {
THCDescBuff s1 = THCIndexTensor_(sizeDesc)(state, input);
THError("input must be a vector or matrix, but is of shape: %s", s1.str);
}
ptrdiff_t numel = THCIndexTensor_(nElement)(state, input);
int64_t stride = THCTensor_(stride)(state, gradWeight, 0);
cudaStream_t stream = THCState_getCurrentStream(state);
if (numel <= 768 && !scaleGradByFreq) {
const int WARP_SIZE = 32;
const int BLOCKDIMY = 32;
dim3 grid(THCCeilDiv(stride, (int64_t)WARP_SIZE));
dim3 block(WARP_SIZE, BLOCKDIMY);
cunn_LookupTable_accGradParametersKernelByFeature<scalar_t, accreal>
<<<grid,
block,
sizeof(accreal)*WARP_SIZE*BLOCKDIMY + sizeof(int)*WARP_SIZE*BLOCKDIMY,
stream>>>
(THCIndexTensor_(data)(state, input),
THCTensor_(data)(state, gradOutput),
THCTensor_(data)(state, gradWeight),
scale,
numel,
stride,
paddingValue);
THCTensor_(free)(state, gradOutput);
THCudaCheck(cudaGetLastError());
return;
}
THCIndexTensor_(resize)(state, sortedIndices, input->sizes(), {});
THCIndexTensor_(resize)(state, origIndices, input->sizes(), {});
// Sort the inputs into sorted with the corresponding indices; we
// don't need a stable or multidimensional sort, so just use Thrust
// directly
{
THCIndexTensor_(copy)(state, sortedIndices, input);
THCThrustAllocator thrustAlloc(state);
thrust::device_ptr<THCIndex_t>
sortedIndicesIter(THCIndexTensor_(data)(state, sortedIndices));
thrust::device_ptr<THCIndex_t>
origIndicesIter(THCIndexTensor_(data)(state, origIndices));
// Fill sortedOrigIndices with sequential indices
thrust::counting_iterator<THCIndex_t> countIter(0);
thrust::copy(
#if CUDA_VERSION >= 7000 || defined __HIP_PLATFORM_HCC__
thrust::cuda::par(thrustAlloc).on(THCState_getCurrentStream(state)),
#endif
countIter, countIter + numel, origIndicesIter);
// Sort; a stable sort is not required
thrust::sort_by_key(
#if CUDA_VERSION >= 7000 || defined __HIP_PLATFORM_HCC__
thrust::cuda::par(thrustAlloc).on(THCState_getCurrentStream(state)),
#endif
sortedIndicesIter, sortedIndicesIter + numel,
origIndicesIter, ThrustLTOp<int64_t>());
}
THCIndex_t *sortedIndices_data = THCIndexTensor_(data)(state, sortedIndices);
THCIndex_t *origIndices_data = THCIndexTensor_(data)(state, origIndices);
THCIndex_t *count_data = NULL;
if (scaleGradByFreq) {
THCIndexTensor_(resizeAs)(state, count, input);
count_data = THCIndexTensor_(data)(state, count);
THCThrustAllocator thrustAlloc(state);
thrust::device_ptr<THCIndex_t> sortedIndices_ptr(sortedIndices_data);
thrust::device_ptr<THCIndex_t> count_ptr(count_data);
// Compute an increasing sequence per unique item in sortedIndices:
// sorted: 2 5 5 5 7 7 8 9 9
// count: 1 1 2 3 1 2 1 1 2
thrust::inclusive_scan_by_key(
#if CUDA_VERSION >= 7000 || defined __HIP_PLATFORM_HCC__
thrust::cuda::par(thrustAlloc).on(THCState_getCurrentStream(state)),
#endif
sortedIndices_ptr,
sortedIndices_ptr + numel,
thrust::make_constant_iterator(1),
count_ptr
);
// Take the maximum of each count per unique key in reverse:
// sorted: 2 5 5 5 7 7 8 9 9
// count: 1 3 3 3 2 2 1 2 2
thrust::inclusive_scan_by_key(
#if CUDA_VERSION >= 7000 || defined __HIP_PLATFORM_HCC__
thrust::cuda::par(thrustAlloc).on(THCState_getCurrentStream(state)),
#endif
thrust::make_reverse_iterator(sortedIndices_ptr + numel),
thrust::make_reverse_iterator(sortedIndices_ptr),
thrust::make_reverse_iterator(count_ptr + numel),
thrust::make_reverse_iterator(count_ptr + numel),
thrust::equal_to<int64_t>(),
thrust::maximum<int64_t>()
);
}
dim3 grid(THCCeilDiv(numel, (ptrdiff_t) 4), THCCeilDiv(stride, (int64_t) 128));
dim3 block(32, 4);
cunn_LookupTable_accGradParametersKernel<scalar_t, accreal><<<grid, block, 0, stream>>>(
sortedIndices_data,
origIndices_data,
THCTensor_(data)(state, gradOutput),
THCTensor_(data)(state, gradWeight),
count_data,
scale,
numel,
stride,
paddingValue
);
THCTensor_(free)(state, gradOutput);
THCudaCheck(cudaGetLastError());
}
#define THREADS 256
#define RUN(NORM, IDXTYPE) \
calculate_norms_and_renorm<scalar_t, accreal, IDXTYPE, NORM> \
<<<numel, THREADS/2, THREADS * sizeof(accreal), THCState_getCurrentStream(state)>>> \
(weightsRaw, idxRaw, normType, maxNorm, THCTensor_(stride)(state, weight, 0))
void THNN_(LookupTable_renorm)(
THCState *state,
THCIndexTensor *idx,
THCTensor *weight,
accreal maxNorm,
accreal normType)
{
THCUNN_assertSameGPU(state, 2, idx, weight);
if (!(THCIndexTensor_(isContiguous)(state, idx) &&
THCTensor_(isContiguous)(state, weight))) {
THError("Tensors must be contiguous");
}
if (THCIndexTensor_(nDimensionLegacyAll)(state, idx) != 1) {
THError("idx must be a vector");
}
if (normType <= 0) {
THError("non-positive-norm not supported");
}
THCIndex_t numel = THCIndexTensor_(nElement)(state, idx);
scalar_t * weightsRaw = THCTensor_(data)(state, weight);
THCIndex_t * idxRaw = THCIndexTensor_(data)(state, idx);
// get the unique indices
thrust::device_ptr<THCIndex_t> idxThrust(idxRaw);
thrust::device_ptr<THCIndex_t> endIdxThrust(thrust::unique(idxThrust, idxThrust+numel));
numel = endIdxThrust - idxThrust;
// At launch time figure out what the index type is and norm type
int Norm = ScalarConvert<accreal, int>::to(normType);
if (THCTensor_canUse32BitIndexMath(state, idx)) {
if (Norm == 1) {
RUN(1, unsigned int);
} else if (Norm == 2) {
RUN(2, unsigned int);
} else {
RUN(-1, unsigned int);
}
} else {
if (Norm == 1) {
RUN(1, unsigned long);
} else if (Norm == 2) {
RUN(2, unsigned long);
} else {
RUN(-1, unsigned long);
}
}
}
#endif