Skip to content

Commit

Permalink
[pt2] Add reference implementations of torch.{stft,istft} (pytorch#10…
Browse files Browse the repository at this point in the history
…6400)

This allows symbolic shapes to be traced through `torch.stft` and `torch.istft`.

Pull Request resolved: pytorch#106400
Approved by: https://github.com/lezcano
ghstack dependencies: pytorch#106319
  • Loading branch information
peterbell10 authored and Cyril-Anto committed Aug 17, 2023
1 parent 2054d7c commit beef4bb
Show file tree
Hide file tree
Showing 4 changed files with 294 additions and 2 deletions.
2 changes: 2 additions & 0 deletions test/test_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -1858,6 +1858,7 @@ class TestRefsOpsInfo(TestCase):
'_refs.isclose',
'_refs.isfinite',
'_refs.isreal',
'_refs.istft',
'_refs.log_softmax',
'_refs.movedim',
'_refs.narrow',
Expand All @@ -1875,6 +1876,7 @@ class TestRefsOpsInfo(TestCase):
'_refs.special.log_softmax',
'_refs.special.softmax',
'_refs.square',
'_refs.stft',
'_refs.T',
'_refs.tensor_split',
'_refs.to',
Expand Down
3 changes: 1 addition & 2 deletions test/test_proxy_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -1480,7 +1480,6 @@ def f(t):
# data-dependent control flow
skip('item'),
xfail('cov'),
xfail('istft'),
xfail('nn.functional.gaussian_nll_loss'),
xfail('tensor_split'),
xfail('corrcoef'),
Expand Down Expand Up @@ -1560,7 +1559,6 @@ def f(t):
xfail('special.modified_bessel_k1', ''), # aten.special_modified_bessel_k1.default - couldn't find symbolic meta funct...
xfail('special.scaled_modified_bessel_k0', ''), # aten.special_scaled_modified_bessel_k0.default - couldn't find symbo...
xfail('special.scaled_modified_bessel_k1', ''), # aten.special_scaled_modified_bessel_k1.default - couldn't find symbo...
xfail('stft', ''), # argument 'size' must be tuple of ints, but found element of type torch._C.SymIntNode at...
xfail('take_along_dim', ''), # dtype of indices should be Long but got Float
xfail('triangular_solve', ''), # aten.triangular_solve.default - couldn't find symbolic meta function/decomposition
xfail('unique_consecutive', ''), # aten.unique_consecutive.default - couldn't find symbolic meta function/decomposition
Expand All @@ -1582,6 +1580,7 @@ def f(t):
xfail('fft.rfft2', ''),
xfail('fft.rfft', ''),
xfail('fft.rfftn', ''),
xfail('stft', '')
}
symbolic_tensor_segfaults = {
skip('nn.functional.batch_norm') # Segfault??
Expand Down
266 changes: 266 additions & 0 deletions torch/_refs/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import builtins
import collections
import inspect
import itertools
import math
import operator
import warnings
Expand Down Expand Up @@ -330,6 +331,8 @@
# Misc
#
"renorm",
"stft",
"istft",
]

Tensor = torch.Tensor
Expand Down Expand Up @@ -3148,6 +3151,269 @@ def renorm(
return (input * norm_factor).contiguous()


# CompositeImplicitAutograd - don't register decomp
@aten.stft.center.py_impl(DispatchKey.CompositeImplicitAutograd)
def stft(
input: Tensor,
n_fft: int,
hop_length: Optional[int] = None,
win_length: Optional[int] = None,
window: Optional[Tensor] = None,
center: bool = True,
pad_mode: str = "reflect",
normalized: bool = False,
onesided: Optional[bool] = None,
return_complex: Optional[bool] = None,
) -> Tensor:
torch._check(
window is None or window.device == input.device,
lambda: (
f"stft input and window must be on the same device but got self on {input.device}"
+ f" and window on {window.device}" # type: ignore[union-attr]
),
)

hop_length_ = hop_length if hop_length is not None else n_fft // 4
win_length_ = win_length if win_length is not None else n_fft

if return_complex is None:
return_complex_ = input.is_complex() or (
window is not None and utils.is_complex_dtype(window.dtype)
)
torch._check(
return_complex_,
(
"stft requires the return_complex parameter be given for real inputs, "
+ "and will further require that return_complex=True in a future PyTorch release."
),
)
else:
return_complex_ = return_complex

torch._check(
utils.is_float_dtype(input.dtype) or utils.is_complex_dtype(input.dtype),
lambda: "stft expected a tensor of floating point or complex values",
)
torch._check(1 <= input.ndim <= 2, lambda: "stft expected a 1D or 2D tensor")

original_ndim = input.ndim
if original_ndim == 1:
input = input.unsqueeze(0)

if center:
extra_dims = 3 - input.ndim
pad_amount = n_fft // 2
extended_shape = [*itertools.repeat(1, extra_dims), *input.shape]
input = aten.pad(input.view(extended_shape), [pad_amount, pad_amount], pad_mode)
input = input.view(input.size()[extra_dims:])

batch = input.size(0)
length = input.size(1)
torch._check(
0 < n_fft <= length,
lambda: f"stft expected 0 < n_fft <= {length}, but got n_fft={n_fft}",
)
torch._check(
hop_length_ > 0,
lambda: f"stft expected hop_length > 0 but got hop_length={hop_length_}",
)
torch._check(
0 < win_length_ <= n_fft,
lambda: f"stft expected 0 < win_length <= n_fft but got win_length={win_length_}",
)
torch._check(
window is None or window.shape == (win_length_,),
lambda: (
f"expected a 1D window tensor of size equal to win_length={win_length_}, "
+ f"but got window with size {window.shape}" # type: ignore[union-attr]
),
)

if win_length_ < n_fft:
if window is None:
window = torch.ones(win_length_, dtype=input.dtype, device=input.device)
left = (n_fft - win_length_) // 2
window = aten.constant_pad_nd(window, [left, n_fft - win_length_ - left])

input = input.unfold(dimension=-1, size=n_fft, step=hop_length_)
if window is not None:
input = input * window

complex_fft = utils.is_complex_dtype(input.dtype)
onesided = onesided if onesided is not None else not complex_fft
norm = "ortho" if normalized else None
if onesided:
torch._check(
not complex_fft,
lambda: "Cannot have onesided output if window or input is complex",
)
out = torch.fft.rfft(input, dim=-1, norm=norm)
else:
out = torch.fft.fft(input, dim=-1, norm=norm)

out.transpose_(1, 2)

if original_ndim == 1:
out = out.squeeze_(0)

return out if return_complex_ else torch.view_as_real(out)


# CompositeImplicitAutograd - don't register decomp
@aten.istft.default.py_impl(DispatchKey.CompositeImplicitAutograd)
def istft(
input: Tensor,
n_fft: int,
hop_length: Optional[int] = None,
win_length: Optional[int] = None,
window: Optional[Tensor] = None,
center: bool = True,
normalized: bool = False,
onesided: Optional[bool] = None,
length: Optional[int] = None,
return_complex=False,
) -> Tensor:
torch._check(
window is None or window.device == input.device,
lambda: (
f"istft input and window must be on the same device but got self on {input.device}"
+ f" and window on {window.device}" # type: ignore[union-attr]
),
)

hop_length_ = hop_length if hop_length is not None else n_fft // 4
win_length_ = win_length if win_length is not None else n_fft

torch._check(
utils.is_complex_dtype(input.dtype),
lambda: (
"istft input and window must be on the same device but got self on "
+ f"{input.device} and window on {window.device}" # type: ignore[union-attr]
),
)
n_frames = input.size(-1)
fft_size = input.size(-2)

expected_output_signal_len = n_fft + hop_length_ * (n_frames - 1)
torch._check(input.numel() > 0, lambda: "istft input tensor cannot be empty")
torch._check(
2 <= input.ndim <= 3,
lambda: f"istft expected a tensor with 2 or 3 dimensions, but got {input.ndim}",
)
onesided_ = onesided if onesided is not None else fft_size != n_fft

if onesided_:
torch._check(
n_fft // 2 + 1 == fft_size,
lambda: (
"istft expected the frequency dimension (3rd to the last) of the input tensor "
+ "to match n_fft / 2 + 1 when onesided=True, but got {fft_size}"
),
)
else:
torch._check(
n_fft == fft_size,
lambda: (
"istft expected the frequency dimension (3rd to the last) of the input tensor "
+ "to match n_fft when onesided=False, but got {fft_size}",
),
)

torch._check(
0 < hop_length_ <= win_length_,
lambda: "istft expected 0 < hop_length <= win_length",
)
torch._check(
0 < win_length_ <= n_fft, lambda: "istft expected 0 < win_length <= n_fft"
)
torch._check(
window is None or window.shape == (win_length_,),
lambda: "Invalid window shape. window has to be 1D and length of `win_length`",
)

if window is None:
real_dtype = utils.corresponding_real_dtype(input.dtype)
window_ = torch.ones(win_length_, dtype=real_dtype, device=input.device)
else:
window_ = window

if win_length_ != n_fft:
left = (n_fft - win_length_) // 2
window_ = aten.constant_pad_nd(window_, (left, n_fft - win_length_ - left), 0)

original_ndim = input.ndim
if input.ndim == 2:
input = input.unsqueeze(0)

input = input.transpose(1, 2)
norm = "ortho" if normalized else None
if return_complex:
torch._check(
not onesided_,
lambda: "cannot have onesided output if window or input is complex",
)
input = torch.fft.ifft(input, dim=-1, norm=norm)
else:
torch._check(
window is None or not utils.is_complex_dtype(window.dtype),
lambda: "Complex windows are incompatible with return_complex=False",
)
if not onesided_:
input = input.narrow(dim=-1, start=0, length=n_fft // 2 + 1)
input = torch.fft.irfft(input, dim=-1, norm=norm)

assert input.size(2) == n_fft

y_tmp = input * window_.view([1, 1, n_fft])
y = aten.unfold_backward(
y_tmp,
input_sizes=(y_tmp.size(0), expected_output_signal_len),
dim=1,
size=n_fft,
step=hop_length_,
)
window_envelop = aten.unfold_backward(
window_.pow(2).expand((1, n_frames, n_fft)),
input_sizes=(y_tmp.size(0), expected_output_signal_len),
dim=1,
size=n_fft,
step=hop_length_,
)

assert expected_output_signal_len == y.size(1)
assert expected_output_signal_len == window_envelop.size(1)

start = n_fft // 2 if center else 0
if length is not None:
end = start + length
elif center:
end = expected_output_signal_len - n_fft // 2
else:
end = expected_output_signal_len

length = max(0, end - start)
y = y.narrow(dim=1, start=start, length=length)
window_envelop = window_envelop.narrow(dim=1, start=start, length=length)

window_envelop_lowest = window_envelop.abs().min().lt(1e-11)
torch._check(
not window_envelop_lowest.item(),
lambda: "window overlap add min less than 1e-11",
)

y = y / window_envelop
if original_ndim == 2:
y = y.squeeze(0)

if end > expected_output_signal_len:
warnings.warn(
"The length of signal is shorter than the length parameter. Result is being "
+ "padded with zeros in the tail. Please check your center and hop_length settings"
)
y = aten.constant_pad_nd(y, (0, end - expected_output_signal_len), 0)
return y


# Get the new shape and stride after applying unfold to an input tensor
def _get_unfold_shape_stride(
a_shape: ShapeType, a_stride: StrideType, dimension: int, size: int, step: int
Expand Down
25 changes: 25 additions & 0 deletions torch/testing/_internal/common_methods_invocations.py
Original file line number Diff line number Diff line change
Expand Up @@ -20779,6 +20779,31 @@ def reference_flatten(input, start_dim=0, end_dim=-1):
torch_opinfo_name="allclose",
supports_nvfuser=False,
),
#
# Misc functions
#
PythonRefInfo(
"_refs.stft",
torch_opinfo_name="stft",
supports_nvfuser=False,
skips=[
# RuntimeError: no _refs support for aten.pad
DecorateInfo(
unittest.expectedFailure, 'TestCommon', 'test_python_ref'
),
],
),
PythonRefInfo(
"_refs.istft",
torch_opinfo_name="istft",
supports_nvfuser=False,
skips=[
# RuntimeError: no _refs support for aten.unfold_backward
DecorateInfo(
unittest.expectedFailure, 'TestCommon', 'test_python_ref'
),
],
),
]
python_ref_db += opinfo.definitions.python_ref_db

Expand Down

0 comments on commit beef4bb

Please sign in to comment.