-
Notifications
You must be signed in to change notification settings - Fork 52
/
idea.py
290 lines (261 loc) · 12.2 KB
/
idea.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
from torch import nn
from torch.autograd import Function
import torch.nn.functional as F
import torch
from torch.nn.modules.utils import _triple, _pair, _single
import softpool_cuda
class CUDA_SOFTPOOL1d(Function):
@staticmethod
@torch.cuda.amp.custom_fwd(cast_inputs=torch.float32)
def forward(ctx, input, kernel=2, stride=None):
# Create contiguous tensor (if tensor is not contiguous)
no_batch = False
if len(input.size()) == 2:
no_batch = True
input.unsqueeze_(0)
B, C, D = input.size()
kernel = _single(kernel)
if stride is None:
stride = kernel
else:
stride = _single(stride)
oD = (D-kernel[0]) // stride[0] + 1
output = input.new_zeros((B, C, oD))
softpool_cuda.forward_1d(input.contiguous(), kernel, stride, output)
ctx.save_for_backward(input)
ctx.kernel = kernel
ctx.stride = stride
if no_batch:
return output.squeeze_(0)
return output
@staticmethod
@torch.cuda.amp.custom_bwd
def backward(ctx, grad_output):
# Create contiguous tensor (if tensor is not contiguous)
grad_input = torch.zeros_like(ctx.saved_tensors[0])
saved = [grad_output.contiguous()] + list(ctx.saved_tensors) + [ctx.kernel, ctx.stride] + [grad_input]
softpool_cuda.backward_1d(*saved)
# Gradient underflow
saved[-1][torch.isnan(saved[-1])] = 0
return saved[-1], None, None
class CUDA_SOFTPOOL2d(Function):
@staticmethod
@torch.cuda.amp.custom_fwd(cast_inputs=torch.float32)
def forward(ctx, input, kernel=2, stride=None):
# Create contiguous tensor (if tensor is not contiguous)
no_batch = False
if len(input.size()) == 3:
no_batch = True
input.unsqueeze_(0)
B, C, H, W = input.size()
kernel = _pair(kernel)
if stride is None:
stride = kernel
else:
stride = _pair(stride)
oH = (H - kernel[0]) // stride[0] + 1
oW = (W - kernel[1]) // stride[1] + 1
output = input.new_zeros((B, C, oH, oW))
softpool_cuda.forward_2d(input.contiguous(), kernel, stride, output)
ctx.save_for_backward(input)
ctx.kernel = kernel
ctx.stride = stride
if no_batch:
return output.squeeze_(0)
return output
@staticmethod
@torch.cuda.amp.custom_bwd
def backward(ctx, grad_output):
# Create contiguous tensor (if tensor is not contiguous)
grad_input = torch.zeros_like(ctx.saved_tensors[0])
saved = [grad_output.contiguous()] + list(ctx.saved_tensors) + [ctx.kernel,ctx.stride] + [grad_input]
softpool_cuda.backward_2d(*saved)
# Gradient underflow
saved[-1][torch.isnan(saved[-1])] = 0
return saved[-1], None, None
class CUDA_SOFTPOOL3d(Function):
@staticmethod
@torch.cuda.amp.custom_fwd(cast_inputs=torch.float32)
def forward(ctx, input, kernel=2, stride=None):
# Create contiguous tensor (if tensor is not contiguous)
no_batch = False
if len(input.size()) == 3:
no_batch = True
input.unsqueeze_(0)
B, C, D, H, W = input.size()
kernel = _triple(kernel)
if stride is None:
stride = kernel
else:
stride = _triple(stride)
oD = (D - kernel[0]) // stride[0] + 1
oH = (H - kernel[1]) // stride[1] + 1
oW = (W - kernel[2]) // stride[2] + 1
output = input.new_zeros((B, C, oD, oH, oW))
softpool_cuda.forward_3d(input.contiguous(), kernel, stride, output)
ctx.save_for_backward(input)
ctx.kernel = kernel
ctx.stride = stride
if no_batch:
return output.squeeze_(0)
return output
@staticmethod
@torch.cuda.amp.custom_bwd
def backward(ctx, grad_output):
# Create contiguous tensor (if tensor is not contiguous)
grad_input = torch.zeros_like(ctx.saved_tensors[0])
saved = [grad_output.contiguous()] + list(ctx.saved_tensors) + [ctx.kernel,ctx.stride] + [grad_input]
softpool_cuda.backward_3d(*saved)
# Gradient underflow
saved[-1][torch.isnan(saved[-1])] = 0
return saved[-1], None, None
'''
--- S T A R T O F F U N C T I O N S O F T _ P O O L 1 D ---
[About]
Function for dowsampling based on the exponenial proportion rate of pixels (soft pooling).
If the tensor is in CUDA the custom operation is used. Alternatively, the function uses
standard (mostly) in-place PyTorch operations for speed and reduced memory consumption.
It is also possible to use non-inplace operations in order to improve stability.
[Args]
- x: PyTorch Tensor, could be in either cpu of CUDA. If in CUDA the homonym extension is used.
- kernel_size: Integer or Tuple, for the kernel size to be used for downsampling. If an `Integer`
is used, a `Tuple` is created for the rest of the dimensions. Defaults to 2.
- stride: Integer or Tuple, for the steps taken between kernels (i.e. strides). If `None` the
strides become equal to the `kernel_size` tuple. Defaults to `None`.
- force_inplace: Bool, determines if in-place operations are to be used regardless of the CUDA
custom op. Mostly useful for time monitoring. Defaults to `False`.
[Returns]
- PyTorch Tensor, subsampled based on the specified `kernel_size` and `stride`
'''
def soft_pool1d(x, kernel_size=2, stride=None, force_inplace=False):
if x.is_cuda and not force_inplace:
x = CUDA_SOFTPOOL1d.apply(x, kernel_size, stride)
# Replace `NaN's if found
if torch.isnan(x).any():
return torch.nan_to_num(x)
return x
kernel_size = _single(kernel_size)
if stride is None:
stride = kernel_size
else:
stride = _single(stride)
# Get input sizes
_, c, d = x.size()
# Create exponential mask (should be similar to max-like pooling)
e_x = torch.sum(torch.exp(x),dim=1,keepdim=True)
e_x = torch.clamp(e_x , float(0), float('inf'))
# Apply mask to input and pool and calculate the exponential sum
# Tensor: [b x c x d] -> [b x c x d']
x = F.avg_pool1d(x.mul(e_x), kernel_size, stride=stride).mul_(sum(kernel_size)).div_(F.avg_pool1d(e_x, kernel_size, stride=stride).mul_(sum(kernel_size)))
return torch.clamp(x , float(0), float('inf'))
'''
--- E N D O F F U N C T I O N S O F T _ P O O L 1 D ---
'''
'''
--- S T A R T O F F U N C T I O N S O F T _ P O O L 2 D ---
[About]
Function for dowsampling based on the exponenial proportion rate of pixels (soft pooling).
If the tensor is in CUDA the custom operation is used. Alternatively, the function uses
standard (mostly) in-place PyTorch operations for speed and reduced memory consumption.
It is also possible to use non-inplace operations in order to improve stability.
[Args]
- x: PyTorch Tensor, could be in either cpu of CUDA. If in CUDA the homonym extension is used.
- kernel_size: Integer or Tuple, for the kernel size to be used for downsampling. If an `Integer`
is used, a `Tuple` is created for the rest of the dimensions. Defaults to 2.
- stride: Integer or Tuple, for the steps taken between kernels (i.e. strides). If `None` the
strides become equal to the `kernel_size` tuple. Defaults to `None`.
- force_inplace: Bool, determines if in-place operations are to be used regardless of the CUDA
custom op. Mostly useful for time monitoring. Defaults to `False`.
[Returns]
- PyTorch Tensor, subsampled based on the specified `kernel_size` and `stride`
'''
def soft_pool2d(x, kernel_size=2, stride=None, force_inplace=False):
if x.is_cuda and not force_inplace:
x = CUDA_SOFTPOOL2d.apply(x, kernel_size, stride)
# Replace `NaN's if found
if torch.isnan(x).any():
return torch.nan_to_num(x)
return x
kernel_size = _pair(kernel_size)
if stride is None:
stride = kernel_size
else:
stride = _pair(stride)
# Get input sizes
_, c, h, w = x.size()
# Create exponential mask (should be similar to max-like pooling)
e_x = torch.sum(torch.exp(x),dim=1,keepdim=True)
e_x = torch.clamp(e_x , float(0), float('inf'))
# Apply mask to input and pool and calculate the exponential sum
# Tensor: [b x c x d] -> [b x c x d']
x = F.avg_pool2d(x.mul(e_x), kernel_size, stride=stride).mul_(sum(kernel_size)).div_(F.avg_pool2d(e_x, kernel_size, stride=stride).mul_(sum(kernel_size)))
return torch.clamp(x , float(0), float('inf'))
'''
--- E N D O F F U N C T I O N S O F T _ P O O L 2 D ---
'''
'''
--- S T A R T O F F U N C T I O N S O F T _ P O O L 3 D ---
[About]
Function for dowsampling based on the exponenial proportion rate of pixels (soft pooling).
If the tensor is in CUDA the custom operation is used. Alternatively, the function uses
standard (mostly) in-place PyTorch operations for speed and reduced memory consumption.
It is also possible to use non-inplace operations in order to improve stability.
[Args]
- x: PyTorch Tensor, could be in either cpu of CUDA. If in CUDA the homonym extension is used.
- kernel_size: Integer or Tuple, for the kernel size to be used for downsampling. If an `Integer`
is used, a `Tuple` is created for the rest of the dimensions. Defaults to 2.
- stride: Integer or Tuple, for the steps taken between kernels (i.e. strides). If `None` the
strides become equal to the `kernel_size` tuple. Defaults to `None`.
- force_inplace: Bool, determines if in-place operations are to be used regardless of the CUDA
custom op. Mostly useful for time monitoring. Defaults to `False`.
[Returns]
- PyTorch Tensor, subsampled based on the specified `kernel_size` and `stride`
'''
def soft_pool3d(x, kernel_size=2, stride=None, force_inplace=False):
if x.is_cuda and not force_inplace:
x = CUDA_SOFTPOOL3d.apply(x, kernel_size, stride)
# Replace `NaN's if found
if torch.isnan(x).any():
return torch.nan_to_num(x)
return x
kernel_size = _triple(kernel_size)
if stride is None:
stride = kernel_size
else:
stride = _triple(stride)
# Get input sizes
_, c, d, h, w = x.size()
# Create exponential mask (should be similar to max-like pooling)
e_x = torch.sum(torch.exp(x),dim=1,keepdim=True)
e_x = torch.clamp(e_x , float(0), float('inf'))
# Apply mask to input and pool and calculate the exponential sum
# Tensor: [b x c x d x h x w] -> [b x c x d' x h' x w']
x = F.avg_pool3d(x.mul(e_x), kernel_size, stride=stride).mul_(sum(kernel_size)).div_(F.avg_pool3d(e_x, kernel_size, stride=stride).mul_(sum(kernel_size)))
return torch.clamp(x , float(0), float('inf'))
'''
--- E N D O F F U N C T I O N S O F T _ P O O L 3 D ---
'''
class SoftPool1d(torch.nn.Module):
def __init__(self, kernel_size=2, stride=None, force_inplace=False):
super(SoftPool1d, self).__init__()
self.kernel_size = kernel_size
self.stride = stride
self.force_inplace = force_inplace
def forward(self, x):
return soft_pool1d(x, kernel_size=self.kernel_size, stride=self.stride, force_inplace=self.force_inplace)
class SoftPool2d(torch.nn.Module):
def __init__(self, kernel_size=2, stride=None, force_inplace=False):
super(SoftPool2d, self).__init__()
self.kernel_size = kernel_size
self.stride = stride
self.force_inplace = force_inplace
def forward(self, x):
return soft_pool2d(x, kernel_size=self.kernel_size, stride=self.stride, force_inplace=self.force_inplace)
class SoftPool3d(torch.nn.Module):
def __init__(self, kernel_size=2, stride=None, force_inplace=False):
super(SoftPool3d, self).__init__()
self.kernel_size = kernel_size
self.stride = stride
self.force_inplace = force_inplace
def forward(self, x):
return soft_pool3d(x, kernel_size=self.kernel_size, stride=self.stride, force_inplace=self.force_inplace)