forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
FeatureLPPooling.cu
267 lines (233 loc) · 9.19 KB
/
FeatureLPPooling.cu
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
#ifndef THC_GENERIC_FILE
#define THC_GENERIC_FILE "THCUNN/generic/FeatureLPPooling.cu"
#else
#include <THCUNN/common.h>
// non-batch mode:
// [feature dim]
// [feature dim][opt dim 1]
// [feature dim][opt dim 1][opt dim 2]
//
// batch mode:
// [batch dim][feature dim]
// [batch dim][feature dim][opt dim 1]
// [batch dim][feature dim][opt dim 1][opt dim 2]
THCDeviceTensor<scalar_t, 4>
THNN_(FeatureLPPooling_upcast)(THCState* state, THCTensor* t, bool batchMode) {
int inputDim = THCTensor_(nDimensionLegacyAll)(state, t);
if (inputDim == 1) {
// [feature dim]
return toDeviceTensor<scalar_t, 1>(state, t).
upcastOuter<2>().upcastInner<4>();
} else if (inputDim == 2) {
if (batchMode) {
// [batch dim][feature dim]
return toDeviceTensor<scalar_t, 2>(state, t).
upcastInner<4>();
} else {
// [feature dim][opt dim 1]
return toDeviceTensor<scalar_t, 2>(state, t).
upcastOuter<3>().upcastInner<4>();
}
} else if (inputDim == 3) {
if (batchMode) {
// [batch dim][feature dim][opt dim 1]
return toDeviceTensor<scalar_t, 3>(state, t).
upcastInner<4>();
} else {
// [feature dim][opt dim 1][opt dim 2]
return toDeviceTensor<scalar_t, 3>(state, t).
upcastOuter<4>();
}
} else {
// inputDim == 4
// [batch dim][feature dim][opt dim 1][opt dim 2]
THAssert(batchMode);
return toDeviceTensor<scalar_t, 4>(state, t);
}
}
// Resizes `toResize` based on the output size for `src` as an input
// tensor
void
THNN_(FeatureLPPooling_resizeForOutput)(THCState* state,
THCTensor* toResize,
THCTensor* input,
bool batchMode,
int width,
int stride) {
int inputDim = THCTensor_(nDimensionLegacyAll)(state, input);
THAssert(inputDim >= 1 && inputDim <= 4);
int64_t outSize =
lpPoolingOutputSize(THCTensor_(size)(state, input, 0), width, stride);
if (batchMode) {
THAssert(inputDim > 1);
outSize =
lpPoolingOutputSize(THCTensor_(size)(state, input, 1), width, stride);
} else {
THAssert(inputDim < 4);
}
if (inputDim == 1) {
THCTensor_(resize1d)(state, toResize, outSize);
} else if (inputDim == 2) {
if (batchMode) {
THCTensor_(resize2d)(
state, toResize, THCTensor_(size)(state, input, 0), outSize);
} else {
THCTensor_(resize2d)(
state, toResize, outSize, THCTensor_(size)(state, input, 1));
}
} else if (inputDim == 3) {
if (batchMode) {
THCTensor_(resize3d)(
state,
toResize,
THCTensor_(size)(state, input, 0), outSize,
THCTensor_(size)(state, input, 2));
} else {
THCTensor_(resize3d)(
state,
toResize,
outSize, THCTensor_(size)(state, input, 1),
THCTensor_(size)(state, input, 2));
}
} else if (inputDim == 4) {
THCTensor_(resize4d)(
state,
toResize,
THCTensor_(size)(state, input, 0), outSize,
THCTensor_(size)(state, input, 2), THCTensor_(size)(state, input, 3));
}
}
// Makes `toResize` the same size/dimensionality as `src`
void
THNN_(FeatureLPPooling_resize)(THCState* state,
THCTensor* toResize,
THCTensor* src) {
int inputDim = THCTensor_(nDimensionLegacyAll)(state, src);
THAssert(inputDim >= 1 && inputDim <= 4);
if (inputDim == 1) {
THCTensor_(resize1d)(state,
toResize,
THCTensor_(size)(state, src, 0));
} else if (inputDim == 2) {
THCTensor_(resize2d)(
state,
toResize,
THCTensor_(size)(state, src, 0),
THCTensor_(size)(state, src, 1));
} else if (inputDim == 3) {
THCTensor_(resize3d)(
state,
toResize,
THCTensor_(size)(state, src, 0),
THCTensor_(size)(state, src, 1),
THCTensor_(size)(state, src, 2));
} else if (inputDim == 4) {
THCTensor_(resize4d)(
state,
toResize,
THCTensor_(size)(state, src, 0),
THCTensor_(size)(state, src, 1),
THCTensor_(size)(state, src, 2),
THCTensor_(size)(state, src, 3));
}
}
void THNN_(FeatureLPPooling_updateOutput)(THCState* state,
THCTensor* inputTH,
THCTensor* outputTH,
accreal power,
int width,
int stride,
bool batchMode) {
THCUNN_assertSameGPU(state, 2, inputTH, outputTH);
int inputDim = THCTensor_(nDimensionLegacyAll)(state, inputTH);
if (batchMode) {
THArgCheck(inputDim >= 2 && inputDim <= 4, 2,
"input must be 2-4 dimensions for batch mode");
} else {
THArgCheck(inputDim >= 1 && inputDim <= 3, 2,
"input must be 1-3 dimensions for non-batch mode");
}
THArgCheck(THCTensor_canUse32BitIndexMath(state, inputTH), 2,
"input tensor must fit into 32-bit index math");
THCDeviceTensor<scalar_t, 4> input;
THCDeviceTensor<scalar_t, 4> output;
input = THNN_(FeatureLPPooling_upcast)(state, inputTH, batchMode);
// Make sure the feature dimension is properly sized
THArgCheck(input.getSize(1) >= width, 2,
"input: feature dimension must be >= width");
// Make sure that width and stride are within range
THArgCheck(width >= 2 && width <= 16, 5,
"width must be between 2 - 16");
THArgCheck(stride >= 1 && stride <= 4, 6,
"stride must be between 1 - 4");
THNN_(FeatureLPPooling_resizeForOutput)(
state, outputTH, inputTH, batchMode, width, stride);
output = THNN_(FeatureLPPooling_upcast)(state, outputTH, batchMode);
bool found = runFeatureLPPoolingUpdateOutput(state,
input,
output,
power,
width,
stride);
THAssert(found);
}
void THNN_(FeatureLPPooling_updateGradInput)(THCState* state,
THCTensor* gradOutputTH,
THCTensor* inputTH,
THCTensor* outputTH,
THCTensor* gradInputTH,
accreal power,
int width,
int stride,
bool batchMode) {
THArgCheck(THCTensor_canUse32BitIndexMath(state, gradOutputTH), 2,
"output gradient tensor must fit into 32-bit index math");
THArgCheck(THCTensor_canUse32BitIndexMath(state, inputTH), 3,
"input tensor must fit into 32-bit index math");
THCUNN_assertSameGPU(state, 4, gradOutputTH, inputTH, outputTH, gradInputTH);
int inputDim = THCTensor_(nDimensionLegacyAll)(state, inputTH);
if (batchMode) {
THArgCheck(inputDim >= 2 && inputDim <= 4, 2,
"input must be 2-4 dimensions for batch mode");
} else {
THArgCheck(inputDim >= 1 && inputDim <= 3, 2,
"input must be 1-3 dimensions for non-batch mode");
}
THCDeviceTensor<scalar_t, 4> gradOutput;
THCDeviceTensor<scalar_t, 4> input;
THCDeviceTensor<scalar_t, 4> output;
THCDeviceTensor<scalar_t, 4> gradInput;
input = THNN_(FeatureLPPooling_upcast)(state, inputTH, batchMode);
// Make sure the feature dimension is properly sized
THArgCheck(input.getSize(1) >= width, 3,
"input: feature dimension must be >= width");
// Make sure that width and stride are within range
THArgCheck(width >= 2 && width <= 16, 7,
"width must be between 2 - 16");
THArgCheck(stride >= 1 && stride <= 4, 8,
"stride must be between 1 - 4");
gradOutput = THNN_(FeatureLPPooling_upcast)(state, gradOutputTH, batchMode);
output = THNN_(FeatureLPPooling_upcast)(state, outputTH, batchMode);
for (int i = 0; i < 4; ++i) {
THAssertMsg(output.getSize(i) == gradOutput.getSize(i),
"output and gradOutput sizes do not match");
}
// Make sure that the input sizes produce the output sizes
THArgCheck(lpPoolingOutputSize(input.getSize(1), width, stride) ==
output.getSize(1), 3,
"input and output sizes do not match with respect to "
"width and stride");
// Resize `gradInput` based on `input`
THNN_(FeatureLPPooling_resize)(state, gradInputTH, inputTH);
gradInput = THNN_(FeatureLPPooling_upcast)(state, gradInputTH, batchMode);
bool found = runFeatureLPPoolingUpdateGradInput(state,
gradOutput,
input,
output,
gradInput,
power,
width,
stride);
THAssert(found);
}
#endif