-
Notifications
You must be signed in to change notification settings - Fork 170
/
fft.py
559 lines (448 loc) · 17.7 KB
/
fft.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
# SPDX-FileCopyrightText: Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES.
# SPDX-FileCopyrightText: All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
from typing import List, Optional, Tuple
import torch
import torch.fft
import torch.onnx
from torch import Tensor
from torch.autograd import Function
# Note 1: for DFT operators, the less verbose way of registering an operator is via
# `register_custom_op_symbolic`. However, it does not currently work due to
# torch.fft.rfft* functions returning Complex type which is not yet supported in ONNX.
# Note 2:
# - current ONNX Contrib implementation does not support configurable normalization, so
# "normalized" must be 0, the normalization is done outside of Contrib ops.
# See also comments in `_scale_output_backward` function for more details.
# - "onesided" is not configurable either - must be set to 1.
# - Contrib implementation requires DFT dimensions to be the last ones,
# otherwise axes permutation is required.
# See:
# https://github.com/microsoft/onnxruntime/blob/main/onnxruntime/contrib_ops/cuda/math/fft_ops.h#L19
def rfft(
input: Tensor,
n: Optional[int] = None,
dim: int = -1,
norm: Optional[str] = None,
) -> Tensor:
"""ONNX compatable method to compute the 1d Fourier transform of real-valued input.
Parameters
----------
input : Tensor
Real input tensor
n : Optional[int], optional
Signal strength, by default None
dim : int, optional
Dimension along which to take the real FFT, by default -1
norm : Optional[str], optional
Normalization mode with options "forward", "backward and "ortho". When set to None,
normalization will default to backward (no normalization), by default None
Note
----
The function is equivalent to `torch.fft.rfft` when not running in ONNX export mode
"""
if not torch.onnx.is_in_onnx_export():
return torch.fft.rfft(input, n=n, dim=dim, norm=norm)
if not isinstance(dim, int):
raise TypeError()
return _rfft_onnx(input, (n,), (dim,), norm)
def rfft2(
input: Tensor,
s: Optional[Tuple[int]] = None,
dim: Tuple[int] = (-2, -1),
norm: Optional[str] = None,
) -> Tensor:
"""ONNX compatable method to compute the 2d Fourier transform of real-valued input.
Parameters
----------
input : Tensor
Real input tensor
s : Optional[Tuple[int]], optional
Signal size in the transformed dimensions, by default None
dim : Tuple[int], optional
Dimensions along which to take the real 2D FFT, by default (-2, -1)
norm : Optional[str], optional
Normalization mode with options "forward", "backward" and "ortho". When set to None,
normalization will default to backward (normalize by 1/n), by default None
Note
----
The function is equivalent to `torch.fft.rfft2` when not running in ONNX export mode
"""
if not torch.onnx.is_in_onnx_export():
return torch.fft.rfft2(input, s=s, dim=dim, norm=norm)
if not (isinstance(dim, tuple) and len(dim) == 2):
raise ValueError()
return _rfft_onnx(input, s, dim, norm)
def irfft(
input: Tensor,
n: Optional[int] = None,
dim: int = -1,
norm: Optional[str] = None,
) -> Tensor:
"""ONNX compatable method to compute the inverse of `rfft`.
Parameters
----------
input : Tensor
Real input tensor
n : Optional[int], optional
Signal strength, by default None
dim : int, optional
Dimension along which to take the real IFFT, by default -1
norm : Optional[str], optional
Normalization mode with options "forward", "backward" and "ortho". When set to None,
normalization will default to backward (no normalization), by default None
Note
----
The function is equivalent to `torch.fft.irfft` when not running in ONNX export mode
"""
if not torch.onnx.is_in_onnx_export():
return torch.fft.irfft(input, n=n, dim=dim, norm=norm)
if not isinstance(dim, int):
raise TypeError()
return _irfft_onnx(input, (n,), (dim,), norm)
def irfft2(
input: Tensor,
s: Optional[Tuple[int]] = None,
dim: Tuple[int] = (-2, -1),
norm: Optional[str] = None,
) -> Tensor:
"""ONNX compatable method to compute the inverse of `rfft2`.
Parameters
----------
input : Tensor
Real input tensor
s : Optional[Tuple[int]], optional
Signal size in the transformed dimensions, by default None
dim : Tuple[int], optional
Dimensions along which to take the real 2D IFFT, by default (-2, -1)
norm : Optional[str], optional
Normalization mode with options "forward", "backward" and "ortho". When set to None,
normalization will default to backward (normalize by 1/n), by default None
Note
----
The function is equivalent to `torch.fft.irfft2` when not running in ONNX export mode
"""
if not torch.onnx.is_in_onnx_export():
return torch.fft.irfft2(input, s=s, dim=dim, norm=norm)
if not (isinstance(dim, tuple) and len(dim) == 2):
raise ValueError()
return _irfft_onnx(input, s, dim, norm)
def view_as_complex(input: Tensor) -> Tensor:
"""ONNX compatable method to view input as complex tensor
Parameters
----------
input : Tensor
The input Tensor
Note
----
The function is equivalent to `torch.view_as_complex` when not running in ONNX export mode
Raises
------
AssertionError
If input tensor shape is not [...,2] during ONNX runtime where the last dimension
denotes the real / imaginary tensors
"""
if not torch.onnx.is_in_onnx_export():
return torch.view_as_complex(input)
# Just return the input unchanged - during ONNX export
# there will be no complex type.
if input.size(-1) != 2:
raise ValueError
return input
def real(input: Tensor) -> Tensor:
"""ONNX compatable method to view input as real tensor
Parameters
----------
input : Tensor
The input Tensor
Note
----
The function is equivalent to `input.real` when not running in ONNX export mode
Raises
------
AssertionError
If input tensor shape is not [...,2] during ONNX runtime where the last dimension
denotes the real / imaginary tensors
"""
if not torch.onnx.is_in_onnx_export():
return input.real
# There is no complex type during ONNX export, so assuming
# complex numbers are represented as if after `view_as_real`.
if input.size(-1) != 2:
raise ValueError()
return input[..., 0]
def imag(input: Tensor) -> Tensor:
"""ONNX compatable method to view input as imaginary tensor
Parameters
----------
input : Tensor
The input Tensor
Note
----
The function is equivalent to `input.imag` when not running in ONNX export mode
Raises
------
AssertionError
If input tensor shape is not [...,2] during ONNX runtime where the last dimension
denotes the real / imaginary tensors
"""
if not torch.onnx.is_in_onnx_export():
return input.imag
# There is no complex type during ONNX export, so assuming
# complex numbers are represented as if after `view_as_real`.
if input.size(-1) != 2:
raise ValueError(input.size(-1))
return input[..., 1]
def _rfft_onnx(
input: Tensor, s: Optional[Tuple[Optional[int]]], dim: Tuple[int], norm: str
) -> Tensor:
if s is not None:
_check_padding_rfft(s, dim, input.size())
ndim = len(dim)
if ndim not in [1, 2]:
raise ValueError(ndim)
perm = not _is_last_dims(dim, input.ndim)
if perm:
perm_in, perm_out = _create_axes_perm(input.ndim, dim)
# Add a dimension to account for complex output.
perm_out.append(len(perm_out))
# Transpose -> RFFT -> Transpose (inverse).
input = input.permute(perm_in)
rfft_func = OnnxRfft if ndim == 1 else OnnxRfft2
output = rfft_func.apply(input)
output = _scale_output_forward(output, norm, input.size(), ndim)
if perm:
output = output.permute(perm_out)
return output
def _irfft_onnx(
input: Tensor, s: Optional[Tuple[Optional[int]]], dim: Tuple[int], norm: str
) -> Tensor:
if s is not None:
_check_padding_irfft(s, dim, input.size())
ndim = len(dim)
if ndim not in [1, 2]:
raise ValueError(ndim)
# Whether to permute axes when DFT axis is not the last.
perm = not _is_last_dims(dim, input.ndim)
if perm:
# Do not include last dimension (input is complex).
perm_in, perm_out = _create_axes_perm(input.ndim - 1, dim)
# Add a dimension to account for complex input.
perm_in.append(len(perm_in))
# Transpose -> IRFFT -> Transpose (inverse).
input = input.permute(perm_in)
irfft_func = OnnxIrfft if ndim == 1 else OnnxIrfft2
output = irfft_func.apply(input)
output = _scale_output_backward(output, norm, input.size(), ndim)
if perm:
output = output.permute(perm_out)
return output
def _contrib_rfft(g: torch.Graph, input: torch.Value, ndim: int) -> torch.Value:
if ndim not in [1, 2]:
raise ValueError(ndim)
# See https://github.com/microsoft/onnxruntime/blob/main/docs/ContribOperators.md#com.microsoft.Rfft
output = g.op(
"com.microsoft::Rfft",
input,
normalized_i=0,
onesided_i=1,
signal_ndim_i=ndim,
)
return output
def _contrib_irfft(g: torch.Graph, input: torch.Value, ndim: int) -> torch.Value:
if ndim not in [1, 2]:
raise ValueError(ndim)
# See https://github.com/microsoft/onnxruntime/blob/main/docs/ContribOperators.md#com.microsoft.Irfft
output = g.op(
"com.microsoft::Irfft",
input,
normalized_i=0,
onesided_i=1,
signal_ndim_i=ndim,
)
return output
def _is_last_dims(dim: Tuple[int], inp_ndim: int) -> bool:
ndim = len(dim)
for i, idim in enumerate(dim):
# This takes care of both positive and negative axis indices.
if idim % inp_ndim != inp_ndim - ndim + i:
return False
return True
def _check_padding_rfft(
sizes: Tuple[Optional[int]], dim: Tuple[int], inp_sizes: Tuple[int]
) -> None:
if len(sizes) != len(dim):
raise ValueError(f"{sizes}, {dim}")
for i, s in enumerate(sizes):
if s is None or s < 0:
continue
# Current Contrib RFFT does not support pad/trim yet.
if s != inp_sizes[dim[i]]:
raise RuntimeError(
f"Padding/trimming is not yet supported, "
f"got sizes {sizes}, DFT dims {dim}, "
f"input dims {inp_sizes}."
)
def _check_padding_irfft(
sizes: Tuple[Optional[int]], dim: Tuple[int], inp_sizes: Tuple[int]
) -> None:
if len(sizes) != len(dim):
raise ValueError(f"{sizes}, {dim}")
# All but last dims must be equal to input dims.
for i, s in enumerate(sizes[:-1]):
if s is None or s < 0:
continue
# Current Contrib RFFT does not support pad/trim yet.
if s != inp_sizes[dim[i]]:
raise RuntimeError(
f"Padding/trimming is not yet supported, "
f"got sizes {sizes}, DFT dims {dim}, "
f"input dims {inp_sizes}."
)
# Check last dim.
s = sizes[-1]
if s is not None and s > 0:
expected_size = 2 * (inp_sizes[dim[-1]] - 1)
if s != expected_size:
raise RuntimeError(
f"Padding/trimming is not yet supported, got sizes {sizes}"
f", DFT dims {dim}, input dims {inp_sizes}"
f", expected last size {expected_size}."
)
def _create_axes_perm(ndim: int, dims: Tuple[int]) -> Tuple[List[int], List[int]]:
"""Creates permuted axes indices for RFFT/IRFFT operators."""
perm_in = list(range(ndim))
perm_out = list(perm_in)
# Move indices to the right to make 'dims' as innermost dimensions.
for i in range(-1, -(len(dims) + 1), -1):
perm_in[dims[i]], perm_in[i] = perm_in[i], perm_in[dims[i]]
# Move indices to the left to restore original shape.
for i in range(-len(dims), 0):
perm_out[dims[i]], perm_out[i] = perm_out[i], perm_out[dims[i]]
return perm_in, perm_out
def _scale_output_forward(
output: Tensor, norm: str, sizes: torch.Size, ndim: int
) -> Tensor:
"""Scales the RFFT output according to norm parameter."""
norm = "backward" if norm is None else norm
if norm not in ["forward", "backward", "ortho"]:
raise ValueError(norm)
# No normalization for "backward" in RFFT ops.
if norm in ["forward", "ortho"]:
# Assuming DFT dimensions are the last. This is required by the current Contrib ops,
# so the axes permutation of the input is done accordingly.
dft_size = math.prod(sizes[-ndim:]).float()
denom = torch.sqrt(dft_size) if norm == "ortho" else dft_size
output = output / denom
return output
def _scale_output_backward(
output: Tensor, norm: str, sizes: torch.Size, ndim: int
) -> Tensor:
"""Scales the IRFFT output according to norm parameter."""
norm = "backward" if norm is None else norm
if norm not in ["forward", "backward", "ortho"]:
raise ValueError(norm)
# Things get interesting here: Contrib IRFFT op uses cuFFT cufftXtExec
# followed by a custom CUDA kernel (`_Normalize`) which always performs
# normalization (division by N) which means "norm" is essentially
# always "backward" here. So we need to cancel this normalization
# when norm is "forward" or "ortho".
if norm in ["forward", "ortho"]:
# Last dimension is complex numbers representation.
# Second-to-last dim corresponds to last dim in RFFT transform.
# This is required by the current Contrib ops,
# so the axes permutation of the input is done previously.
if not len(sizes) >= ndim + 1:
raise ValueError
dft_size = math.prod(sizes[-(ndim + 1) : -2])
dft_size *= 2 * (sizes[-2] - 1)
dft_size = dft_size.float()
# Since cuFFT scales by 1/dft_size, replace this scale with appropriate one.
scale = dft_size if norm == "forward" else torch.sqrt(dft_size)
output = scale * output
return output
class OnnxRfft(Function):
"""Auto-grad function to mimic rfft for ONNX exporting
Note
----
Should only be called during an ONNX export
"""
@staticmethod
def forward(ctx, input: Tensor) -> Tensor:
if not torch.onnx.is_in_onnx_export():
raise ValueError("Must be called only during ONNX export.")
# We need to mimic the behavior of Contrib RFFT which assumes
# DFT of last dim and no normalization.
y = torch.fft.rfft(input, dim=-1, norm="backward")
return torch.view_as_real(y)
@staticmethod
def symbolic(g: torch.Graph, input: torch.Value) -> torch.Value:
"""Symbolic representation for onnx graph"""
return _contrib_rfft(g, input, ndim=1)
class OnnxRfft2(Function):
"""Auto-grad function to mimic rfft2 for ONNX exporting
Note
----
Should only be called during an ONNX export
"""
@staticmethod
def forward(ctx, input: Tensor) -> Tensor:
if not torch.onnx.is_in_onnx_export():
raise AssertionError("Must be called only during ONNX export.")
# We need to mimic the behavior of Contrib RFFT which assumes
# DFT of last dims and no normalization.
y = torch.fft.rfft2(input, dim=(-2, -1), norm="backward")
return torch.view_as_real(y)
@staticmethod
def symbolic(g: torch.Graph, input: torch.Value) -> torch.Value:
"""Symbolic representation for onnx graph"""
return _contrib_rfft(g, input, ndim=2)
class OnnxIrfft(Function):
"""Auto-grad function to mimic irfft for ONNX exporting
Note
----
Should only be called during an ONNX export
"""
@staticmethod
def forward(ctx, input: Tensor) -> Tensor:
if not torch.onnx.is_in_onnx_export():
raise ValueError("Must be called only during ONNX export.")
# We need to mimic the behavior of Contrib IRFFT which assumes
# DFT of last dim and 1/n normalization.
return torch.fft.irfft(torch.view_as_complex(input), dim=-1, norm="backward")
@staticmethod
def symbolic(g: torch.Graph, input: torch.Value) -> torch.Value:
"""Symbolic representation for onnx graph"""
return _contrib_irfft(g, input, ndim=1)
class OnnxIrfft2(Function):
"""Auto-grad function to mimic irfft2 for ONNX exporting.
Note
----
Should only be called during an ONNX export
"""
@staticmethod
def forward(ctx, input: Tensor) -> Tensor:
if not torch.onnx.is_in_onnx_export():
raise AssertionError("Must be called only during ONNX export.")
# We need to mimic the behavior of Contrib IRFFT which assumes
# DFT of last dims and 1/n normalization.
return torch.fft.irfft2(
torch.view_as_complex(input), dim=(-2, -1), norm="backward"
)
@staticmethod
def symbolic(g: torch.Graph, input: torch.Value) -> torch.Value:
"""Symbolic representation for onnx graph"""
return _contrib_irfft(g, input, ndim=2)