Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for torch STFT & spectrogram #1824

Merged
merged 39 commits into from
Apr 13, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
39 commits
Select commit Hold shift + click to select a range
df3f16d
refactor out dft matrix creation
nikalra Apr 5, 2023
f4c3972
add real stft impl (not tested)
nikalra Apr 6, 2023
745d105
shim dialect op + basic conversion
nikalra Apr 6, 2023
c7003a7
better error msg
nikalra Apr 6, 2023
9f6701a
fix shape inference
nikalra Apr 6, 2023
302576a
other small changes
nikalra Apr 6, 2023
4fd2cbe
torch op def and basic test
nikalra Apr 6, 2023
4f4d812
added test for dft matrix calculation
nikalra Apr 6, 2023
ad57378
update _calulate_dft_matrix for onesided flag
nikalra Apr 7, 2023
5726c94
basic tests pass
nikalra Apr 7, 2023
ade0182
support hop_length
nikalra Apr 7, 2023
a54ebbc
add support for window length
nikalra Apr 7, 2023
5ee10d8
add support for custom windows
nikalra Apr 7, 2023
577ebb4
fix 1D support
nikalra Apr 7, 2023
433ef3a
tests for centering and pad mode
nikalra Apr 7, 2023
42acf20
add normalization
nikalra Apr 7, 2023
6c3f605
refactor out support for real and imag
nikalra Apr 7, 2023
e28dd08
add the onesided flag
nikalra Apr 7, 2023
2dfbcbc
basic complex support
nikalra Apr 7, 2023
4199483
add reshape and pad complex support for stft windowing
nikalra Apr 7, 2023
59efc44
full test suite -- everything works!
nikalra Apr 7, 2023
02a739c
add batched test
nikalra Apr 7, 2023
42070fb
cleanup + input checking
nikalra Apr 7, 2023
70b1559
fix shape for dependent ops
nikalra Apr 8, 2023
0b386ab
fix reshape and pad for dependent ops
nikalra Apr 8, 2023
d69e2dc
add complex abs
nikalra Apr 8, 2023
b7ebb5a
add spectrogram test (basic)
nikalra Apr 8, 2023
58e4e72
add magnitude and power spec
nikalra Apr 8, 2023
66cb1ff
add melspec tests
nikalra Apr 8, 2023
81a8c2e
fix complex + onesided
nikalra Apr 8, 2023
841efb8
add test for abs
nikalra Apr 8, 2023
f621659
add docstring for complex_abs
nikalra Apr 8, 2023
e39341d
stft documentation
nikalra Apr 8, 2023
c230767
remove commented out code
nikalra Apr 8, 2023
8f93669
mark stft test as slow
nikalra Apr 8, 2023
6d81961
fixed some comments
nikalra Apr 10, 2023
2fa88bb
move complex tests
nikalra Apr 12, 2023
a74b1ee
added comments to complex dialect ops
nikalra Apr 12, 2023
b6d7ca4
Mark test_spectrogram as xfail on Intel Machine
junpeiz Apr 13, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 34 additions & 4 deletions coremltools/converters/mil/frontend/torch/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -1538,7 +1538,13 @@ def view(context, node):
shape = mb.concat(values=shape, axis=0)

shape = mb.cast(x=shape, dtype="int32")
view = mb.reshape(x=x, shape=shape, name=node.name)

if types.is_complex(x.dtype):
real, imag = (mb.reshape(x=x, shape=shape, name=node.name) for x in (mb.complex_real(data=x), mb.complex_imag(data=x)))
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I opted not to create complex dialect ops for reshape and pad (below) because their behavior doesn't change as a result of the inputs being complex.

I'm more than happy to create a complex dialect op for these if that's the preferred approach, but figured that this might be a better route to avoid duplicating each built-in op as a complex dialect op.

If in the future, there's support for something like a lowering pass where all non-complex dialect ops with complex support in their type domain can be duplicated across the real and imaginary components of the input, this would probably be easier to get rid of and restore to just the code in the else block.

view = mb.complex(real_data=real, imag_data=imag, name=node.name)
else:
view = mb.reshape(x=x, shape=shape, name=node.name)

context.add(view)


Expand All @@ -1565,7 +1571,11 @@ def pad(context, node):
if inputs[val_index] and inputs[val_index].op.op_type == "const":
scalar_val = float(scalar_val.val)

res = mb.pad(x=x, pad=pad, mode=mode, constant_val=scalar_val, name=node.name)
if types.is_complex(x.dtype):
real, imag = (mb.pad(x=x, pad=pad, mode=mode, constant_val=scalar_val, name=node.name) for x in (mb.complex_real(data=x), mb.complex_imag(data=x)))
res = mb.complex(real_data=real, imag_data=imag, name=node.name)
else:
res = mb.pad(x=x, pad=pad, mode=mode, constant_val=scalar_val, name=node.name)
context.add(res)


Expand Down Expand Up @@ -4427,8 +4437,11 @@ def index_select(context, node):

@register_torch_op(torch_alias=["abs"])
def _abs(context, node):
inputs = _get_inputs(context, node, expected=1)
context.add(mb.abs(x=inputs[0], name=node.name))
x = _get_inputs(context, node, expected=1)[0]
if types.is_complex(x.dtype):
context.add(mb.complex_abs(x=x, name=node.name))
else:
context.add(mb.abs(x=x, name=node.name))


@register_torch_op
Expand Down Expand Up @@ -5676,6 +5689,23 @@ def fft_irfftn(context, node):
irfftn_res = mb.complex_irfftn(data=input_data, shapes=shapes, dims=dims, norm=norm)
context.add(irfftn_res, node.name)

@register_torch_op
def stft(context, node):
"""
Lowers torch.stft with the dialect op `complex_stft` from complex_dialect_ops.py
"""
input_data, n_fft, hop_length, win_length, window, normalized, onesided, _ = _get_inputs(context, node, min_expected=2)
if types.is_complex(input_data.dtype):
onesided = False # pytorch defaults onesided to False for complex inputs
stft_res = mb.complex_stft(
input=input_data,
n_fft=n_fft,
hop_length=hop_length,
win_length=win_length,
window=window,
normalized=normalized,
onesided=onesided)
context.add(stft_res, node.name)

@register_torch_op(torch_alias=["torchvision::nms"])
def torchvision_nms(context, node):
Expand Down
109 changes: 109 additions & 0 deletions coremltools/converters/mil/frontend/torch/test/test_torch_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import numpy as np
import pytest
import torch.nn as nn
import torchaudio
import torchvision

import coremltools as ct
Expand Down Expand Up @@ -7877,6 +7878,26 @@ def forward(self, x):
(2, 3, 4), ComplexModel(), backend=backend, compute_unit=compute_unit
)

@pytest.mark.parametrize(
"compute_unit, backend",
itertools.product(
compute_units,
backends,
)
)
def test_abs(self, compute_unit, backend):
class AbsModel(torch.nn.Module):
def forward(self, x):
x = torch.complex(x, x)
return torch.abs(x)

TorchBaseTest.run_compare_torch(
(1, 16),
AbsModel(),
backend=backend,
compute_unit=compute_unit,
)


class TestReal(TorchBaseTest):
@pytest.mark.parametrize(
Expand Down Expand Up @@ -8099,6 +8120,94 @@ def forward(self, x):
(2, 3, 4), FftnModel(), backend=backend, compute_unit=compute_unit
)

class TestSTFT(TorchBaseTest):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great! Very nice and comprehensive tests!

@pytest.mark.slow
@pytest.mark.parametrize(
"compute_unit, backend, input_shape, complex, n_fft, hop_length, win_length, window, center, pad_mode, normalized, onesided",
itertools.product(
compute_units,
backends,
[(1, 32), (32,), (3, 32)], # input shape
[False, True], # complex
[16], # n_fft
[None, 4, 5], # hop_length
[None, 16, 9], # win_length
[None, torch.hann_window], # window
[None, False, True], # center
["constant", "reflect", "replicate"], # pad mode
[False, True], # normalized
[None, False, True], # onesided
)
)
def test_stft(self, compute_unit, backend, input_shape, complex, n_fft, hop_length, win_length, window, center, pad_mode, normalized, onesided):
if complex and onesided:
pytest.skip("Onesided stft not possible for complex inputs")

class STFTModel(torch.nn.Module):
def forward(self, x):
applied_window = window(win_length) if window and win_length else None
x = torch.complex(x, x) if complex else x
x = torch.stft(
x,
n_fft=n_fft,
hop_length=hop_length,
win_length=win_length,
window=applied_window,
center=center,
pad_mode=pad_mode,
normalized=normalized,
onesided=onesided,
return_complex=True)
x = torch.stack([torch.real(x), torch.imag(x)], dim=0)
return x

TorchBaseTest.run_compare_torch(
input_shape,
STFTModel(),
backend=backend,
compute_unit=compute_unit
)

class TestSpectrogram(TorchBaseTest):
@pytest.mark.parametrize(
"compute_unit, backend, input_shape, spec, power",
itertools.product(
compute_units,
backends,
[(1, 1000), (1000,), (3, 1000)], # input shape
[torchaudio.transforms.Spectrogram, torchaudio.transforms.MelSpectrogram],
[None, 1, 2] # magnitude or power
)
)
def test_spectrogram(self, compute_unit, backend, input_shape, spec, power):
if platform.machine() != "arm64":
pytest.xfail("rdar://108001659 ([PyTorch] Torchaudio Spectrogram Failed on Intel Machine)")

if spec is torchaudio.transforms.MelSpectrogram and power is None:
pytest.skip("power or magnitude required for melspec")

class SpectrogramModel(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
# the other spectrogram options are passed through to stft
# and are tested in TestSTFT
self.spec = spec(power=power, n_fft=128)

def forward(self, x):
x = self.spec(x)
if power is None:
# complex: stack them
x = torch.stack([torch.real(x), torch.imag(x)], dim=0)
return x

TorchBaseTest.run_compare_torch(
input_shape,
SpectrogramModel(),
backend=backend,
compute_unit=compute_unit,
rtol=1e-4,
atol=1e-4,
)

class TestNms(TorchBaseTest):
@pytest.mark.parametrize(
Expand Down
2 changes: 1 addition & 1 deletion coremltools/converters/mil/mil/input_type.py
Original file line number Diff line number Diff line change
Expand Up @@ -285,7 +285,7 @@ def type_domain(self):

@type_domain.setter
def type_domain(self, val):
msg = "type_domain must be a tuple of builtin types"
msg = f"type_domain {val} must be a tuple of builtin types"
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is an unrelated change but made errors during the complex lowering passes a little easier to debug.

if not isinstance(val, tuple) or any(map(lambda t: t not in _SUPPORT_TYPES, val)):
raise ValueError(msg)
self._type_domain = val
Expand Down
127 changes: 123 additions & 4 deletions coremltools/converters/mil/mil/ops/defs/complex_dialect_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -729,16 +729,135 @@ class complex_shape(Operation):
"T": (types.complex64,),
}

# If type_inference or value_inference is invoked when the graph is being constructed,
# x.real and x.imag may not be set since the complex lowering pass hasn't yet been invoked.
# self.x should already have the shape set, so use that instead.

def type_inference(self):
if not isinstance(self.x, ComplexVar):
raise ValueError("x must be a ComplexVar.")
input_rank = self.x.real.rank
input_rank = self.x.rank
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If type_inference or value_inference is invoked when the graph is being constructed, x.real and x.imag may not be set since the complex lowering pass hasn't yet been invoked.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point! Could you add this as a comment here?

return types.tensor(types.int32, tuple([input_rank]))

def value_inference(self):
if any_symbolic(self.x.real.shape):
if any_symbolic(self.x.shape):
# convert elements in shape to int32
res = [x if is_symbolic(x) else np.int32(x) for x in self.x.real.shape]
res = [x if is_symbolic(x) else np.int32(x) for x in self.x.shape]
return np.array(res)
else:
return np.array(self.x.real.shape).astype(np.int32)
return np.array(self.x.shape).astype(np.int32)

@register_op(namespace="complex")
class complex_abs(Operation):
"""
Returns the absolute value of a complex tensor.

Parameters
----------
x: tensor<[*d], T> (Required)

Returns
-------
tensor<[*d], fp32>
* A float tensor with the same shape as ``x``

Attributes
----------
T: complex64
"""

input_spec = InputSpec(x=TensorInputType(type_domain="T"))

type_domains = {
"T": (types.complex64,),
}

def type_inference(self):
if not isinstance(self.x, ComplexVar):
raise ValueError("x must be a ComplexVar.")
return types.tensor(infer_fp_dtype_from_complex(self.x.dtype), self.x.shape)

@register_op(namespace="complex")
class complex_stft(Operation):
"""
Dialect op for 1-D STFT.

Parameters
----------
input: tensor<\*D, T> (Required)
* The input tensor.
n_fft: const i32 (Required)
* Size of the fourier transform.
hop_length: const i32 (Optional)
* Stride between window frames of the input tensor.
win_length: const i32 (optional)
* The size of the window frame.
window: tensor<1, win_length> (optional)
* The window to apply to the input signal before performing the fourier transform.
normalized: const bool (optional, Default=``false``)
* Whether to normalize the results of the STFT
onesided: const bool (optional, Default=``true``)
* For real-valued inputs, whether to return the first half of the results.

Returns
-------
tensor<\*V, complex64>
* A complex tensor where real and imag parts have the same shape.

Attributes
----------
T: fp32, complex64

References
----------
See `torch.stft <https://pytorch.org/docs/stable/generated/torch.stft.html>`_.
"""

input_spec = InputSpec(
input=TensorInputType(type_domain="T"),
n_fft=TensorInputType(const=True, type_domain=types.int32),
hop_length=TensorInputType(const=True, optional=True, type_domain=types.int32),
win_length=TensorInputType(const=True, optional=True, type_domain=types.int32),
window=TensorInputType(const=True, optional=True, type_domain=types.fp32),
normalized=TensorInputType(const=True, optional=True, type_domain=types.bool),
onesided=TensorInputType(const=True, optional=True, type_domain=types.bool),
)

type_domains = {
"T": (types.fp32, types.complex64),
}

def default_inputs(self):
return DefaultInputs(
hop_length = None,
win_length = None,
window = None,
normalized = False,
onesided = True,
)

def type_inference(self):
output_type = (types.complex64)

# STFT shape is [B x N x T], where N is the number of frequency bins
# and T is the number of windows
# B is 1 for a time series or 2 for a batch of time series

window_length = self.n_fft.val
hop = self.hop_length.val if self.hop_length else self.n_fft.val // 4

# if onesided is true, the input is real valued
# because of Hermitian symmetry, we only need to calculate the FFT
# for the first half of the frequences
if self.onesided and self.onesided.val:
window_length = window_length // 2 + 1

frames = (self.input.shape[-1] - self.n_fft.val) // hop + 1
output_shape = [window_length, frames]

# add back rank if needed
if self.input.rank == 2:
output_shape = [self.input.shape[0]] + output_shape

return types.tensor(output_type, tuple(output_shape))

Loading