-
Notifications
You must be signed in to change notification settings - Fork 612
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add support for torch STFT & spectrogram #1824
Changes from all commits
df3f16d
f4c3972
745d105
c7003a7
9f6701a
302576a
4fd2cbe
4f4d812
ad57378
5726c94
ade0182
a54ebbc
5ee10d8
577ebb4
433ef3a
42acf20
6c3f605
e28dd08
2dfbcbc
4199483
59efc44
02a739c
42070fb
70b1559
0b386ab
d69e2dc
b7ebb5a
58e4e72
66cb1ff
81a8c2e
841efb8
f621659
e39341d
c230767
8f93669
6d81961
2fa88bb
a74b1ee
b6d7ca4
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -11,6 +11,7 @@ | |
import numpy as np | ||
import pytest | ||
import torch.nn as nn | ||
import torchaudio | ||
import torchvision | ||
|
||
import coremltools as ct | ||
|
@@ -7877,6 +7878,26 @@ def forward(self, x): | |
(2, 3, 4), ComplexModel(), backend=backend, compute_unit=compute_unit | ||
) | ||
|
||
@pytest.mark.parametrize( | ||
"compute_unit, backend", | ||
itertools.product( | ||
compute_units, | ||
backends, | ||
) | ||
) | ||
def test_abs(self, compute_unit, backend): | ||
class AbsModel(torch.nn.Module): | ||
def forward(self, x): | ||
x = torch.complex(x, x) | ||
return torch.abs(x) | ||
|
||
TorchBaseTest.run_compare_torch( | ||
(1, 16), | ||
AbsModel(), | ||
backend=backend, | ||
compute_unit=compute_unit, | ||
) | ||
|
||
|
||
class TestReal(TorchBaseTest): | ||
@pytest.mark.parametrize( | ||
|
@@ -8099,6 +8120,94 @@ def forward(self, x): | |
(2, 3, 4), FftnModel(), backend=backend, compute_unit=compute_unit | ||
) | ||
|
||
class TestSTFT(TorchBaseTest): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Great! Very nice and comprehensive tests! |
||
@pytest.mark.slow | ||
@pytest.mark.parametrize( | ||
"compute_unit, backend, input_shape, complex, n_fft, hop_length, win_length, window, center, pad_mode, normalized, onesided", | ||
itertools.product( | ||
compute_units, | ||
backends, | ||
[(1, 32), (32,), (3, 32)], # input shape | ||
[False, True], # complex | ||
[16], # n_fft | ||
[None, 4, 5], # hop_length | ||
[None, 16, 9], # win_length | ||
[None, torch.hann_window], # window | ||
[None, False, True], # center | ||
["constant", "reflect", "replicate"], # pad mode | ||
[False, True], # normalized | ||
[None, False, True], # onesided | ||
) | ||
) | ||
def test_stft(self, compute_unit, backend, input_shape, complex, n_fft, hop_length, win_length, window, center, pad_mode, normalized, onesided): | ||
if complex and onesided: | ||
pytest.skip("Onesided stft not possible for complex inputs") | ||
|
||
class STFTModel(torch.nn.Module): | ||
def forward(self, x): | ||
applied_window = window(win_length) if window and win_length else None | ||
x = torch.complex(x, x) if complex else x | ||
x = torch.stft( | ||
x, | ||
n_fft=n_fft, | ||
hop_length=hop_length, | ||
win_length=win_length, | ||
window=applied_window, | ||
center=center, | ||
pad_mode=pad_mode, | ||
normalized=normalized, | ||
onesided=onesided, | ||
return_complex=True) | ||
x = torch.stack([torch.real(x), torch.imag(x)], dim=0) | ||
return x | ||
|
||
TorchBaseTest.run_compare_torch( | ||
input_shape, | ||
STFTModel(), | ||
backend=backend, | ||
compute_unit=compute_unit | ||
) | ||
|
||
class TestSpectrogram(TorchBaseTest): | ||
@pytest.mark.parametrize( | ||
"compute_unit, backend, input_shape, spec, power", | ||
itertools.product( | ||
compute_units, | ||
backends, | ||
[(1, 1000), (1000,), (3, 1000)], # input shape | ||
[torchaudio.transforms.Spectrogram, torchaudio.transforms.MelSpectrogram], | ||
[None, 1, 2] # magnitude or power | ||
) | ||
) | ||
def test_spectrogram(self, compute_unit, backend, input_shape, spec, power): | ||
if platform.machine() != "arm64": | ||
pytest.xfail("rdar://108001659 ([PyTorch] Torchaudio Spectrogram Failed on Intel Machine)") | ||
|
||
if spec is torchaudio.transforms.MelSpectrogram and power is None: | ||
pytest.skip("power or magnitude required for melspec") | ||
|
||
class SpectrogramModel(torch.nn.Module): | ||
def __init__(self) -> None: | ||
super().__init__() | ||
# the other spectrogram options are passed through to stft | ||
# and are tested in TestSTFT | ||
self.spec = spec(power=power, n_fft=128) | ||
|
||
def forward(self, x): | ||
x = self.spec(x) | ||
if power is None: | ||
# complex: stack them | ||
x = torch.stack([torch.real(x), torch.imag(x)], dim=0) | ||
return x | ||
|
||
TorchBaseTest.run_compare_torch( | ||
input_shape, | ||
SpectrogramModel(), | ||
backend=backend, | ||
compute_unit=compute_unit, | ||
rtol=1e-4, | ||
atol=1e-4, | ||
) | ||
|
||
class TestNms(TorchBaseTest): | ||
@pytest.mark.parametrize( | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -285,7 +285,7 @@ def type_domain(self): | |
|
||
@type_domain.setter | ||
def type_domain(self, val): | ||
msg = "type_domain must be a tuple of builtin types" | ||
msg = f"type_domain {val} must be a tuple of builtin types" | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is an unrelated change but made errors during the complex lowering passes a little easier to debug. |
||
if not isinstance(val, tuple) or any(map(lambda t: t not in _SUPPORT_TYPES, val)): | ||
raise ValueError(msg) | ||
self._type_domain = val | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -729,16 +729,135 @@ class complex_shape(Operation): | |
"T": (types.complex64,), | ||
} | ||
|
||
# If type_inference or value_inference is invoked when the graph is being constructed, | ||
# x.real and x.imag may not be set since the complex lowering pass hasn't yet been invoked. | ||
# self.x should already have the shape set, so use that instead. | ||
|
||
def type_inference(self): | ||
if not isinstance(self.x, ComplexVar): | ||
raise ValueError("x must be a ComplexVar.") | ||
input_rank = self.x.real.rank | ||
input_rank = self.x.rank | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If type_inference or value_inference is invoked when the graph is being constructed, x.real and x.imag may not be set since the complex lowering pass hasn't yet been invoked. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Good point! Could you add this as a comment here? |
||
return types.tensor(types.int32, tuple([input_rank])) | ||
|
||
def value_inference(self): | ||
if any_symbolic(self.x.real.shape): | ||
if any_symbolic(self.x.shape): | ||
# convert elements in shape to int32 | ||
res = [x if is_symbolic(x) else np.int32(x) for x in self.x.real.shape] | ||
res = [x if is_symbolic(x) else np.int32(x) for x in self.x.shape] | ||
return np.array(res) | ||
else: | ||
return np.array(self.x.real.shape).astype(np.int32) | ||
return np.array(self.x.shape).astype(np.int32) | ||
|
||
@register_op(namespace="complex") | ||
class complex_abs(Operation): | ||
""" | ||
Returns the absolute value of a complex tensor. | ||
|
||
Parameters | ||
---------- | ||
x: tensor<[*d], T> (Required) | ||
|
||
Returns | ||
------- | ||
tensor<[*d], fp32> | ||
* A float tensor with the same shape as ``x`` | ||
|
||
Attributes | ||
---------- | ||
T: complex64 | ||
""" | ||
|
||
input_spec = InputSpec(x=TensorInputType(type_domain="T")) | ||
|
||
type_domains = { | ||
"T": (types.complex64,), | ||
} | ||
|
||
def type_inference(self): | ||
if not isinstance(self.x, ComplexVar): | ||
raise ValueError("x must be a ComplexVar.") | ||
return types.tensor(infer_fp_dtype_from_complex(self.x.dtype), self.x.shape) | ||
|
||
@register_op(namespace="complex") | ||
class complex_stft(Operation): | ||
""" | ||
Dialect op for 1-D STFT. | ||
|
||
Parameters | ||
---------- | ||
input: tensor<\*D, T> (Required) | ||
* The input tensor. | ||
n_fft: const i32 (Required) | ||
* Size of the fourier transform. | ||
hop_length: const i32 (Optional) | ||
* Stride between window frames of the input tensor. | ||
win_length: const i32 (optional) | ||
* The size of the window frame. | ||
window: tensor<1, win_length> (optional) | ||
* The window to apply to the input signal before performing the fourier transform. | ||
normalized: const bool (optional, Default=``false``) | ||
* Whether to normalize the results of the STFT | ||
onesided: const bool (optional, Default=``true``) | ||
* For real-valued inputs, whether to return the first half of the results. | ||
|
||
Returns | ||
------- | ||
tensor<\*V, complex64> | ||
* A complex tensor where real and imag parts have the same shape. | ||
|
||
Attributes | ||
---------- | ||
T: fp32, complex64 | ||
|
||
References | ||
---------- | ||
See `torch.stft <https://pytorch.org/docs/stable/generated/torch.stft.html>`_. | ||
""" | ||
|
||
input_spec = InputSpec( | ||
input=TensorInputType(type_domain="T"), | ||
n_fft=TensorInputType(const=True, type_domain=types.int32), | ||
hop_length=TensorInputType(const=True, optional=True, type_domain=types.int32), | ||
win_length=TensorInputType(const=True, optional=True, type_domain=types.int32), | ||
window=TensorInputType(const=True, optional=True, type_domain=types.fp32), | ||
normalized=TensorInputType(const=True, optional=True, type_domain=types.bool), | ||
onesided=TensorInputType(const=True, optional=True, type_domain=types.bool), | ||
) | ||
|
||
type_domains = { | ||
"T": (types.fp32, types.complex64), | ||
} | ||
|
||
def default_inputs(self): | ||
return DefaultInputs( | ||
hop_length = None, | ||
win_length = None, | ||
window = None, | ||
normalized = False, | ||
onesided = True, | ||
) | ||
|
||
def type_inference(self): | ||
output_type = (types.complex64) | ||
|
||
# STFT shape is [B x N x T], where N is the number of frequency bins | ||
# and T is the number of windows | ||
# B is 1 for a time series or 2 for a batch of time series | ||
|
||
window_length = self.n_fft.val | ||
hop = self.hop_length.val if self.hop_length else self.n_fft.val // 4 | ||
|
||
# if onesided is true, the input is real valued | ||
# because of Hermitian symmetry, we only need to calculate the FFT | ||
# for the first half of the frequences | ||
if self.onesided and self.onesided.val: | ||
window_length = window_length // 2 + 1 | ||
|
||
frames = (self.input.shape[-1] - self.n_fft.val) // hop + 1 | ||
output_shape = [window_length, frames] | ||
|
||
# add back rank if needed | ||
if self.input.rank == 2: | ||
output_shape = [self.input.shape[0]] + output_shape | ||
|
||
return types.tensor(output_type, tuple(output_shape)) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I opted not to create complex dialect ops for reshape and pad (below) because their behavior doesn't change as a result of the inputs being complex.
I'm more than happy to create a complex dialect op for these if that's the preferred approach, but figured that this might be a better route to avoid duplicating each built-in op as a complex dialect op.
If in the future, there's support for something like a lowering pass where all non-complex dialect ops with complex support in their type domain can be duplicated across the real and imaginary components of the input, this would probably be easier to get rid of and restore to just the code in the else block.