Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for torch STFT & spectrogram #1824

Merged
merged 39 commits into from
Apr 13, 2023
Merged

Conversation

nikalra
Copy link
Contributor

@nikalra nikalra commented Apr 10, 2023

  • Adds support for torch.stft and its various options for both complex and real inputs
  • Adds support for torchaudio.functional.Spectrogram (and MelSpectrogram) via STFT support and complex support for pad/reshape/abs.

view = mb.reshape(x=x, shape=shape, name=node.name)

if types.is_complex(x.dtype):
real, imag = (mb.reshape(x=x, shape=shape, name=node.name) for x in (mb.complex_real(data=x), mb.complex_imag(data=x)))
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I opted not to create complex dialect ops for reshape and pad (below) because their behavior doesn't change as a result of the inputs being complex.

I'm more than happy to create a complex dialect op for these if that's the preferred approach, but figured that this might be a better route to avoid duplicating each built-in op as a complex dialect op.

If in the future, there's support for something like a lowering pass where all non-complex dialect ops with complex support in their type domain can be duplicated across the real and imaginary components of the input, this would probably be easier to get rid of and restore to just the code in the else block.

@@ -285,7 +285,7 @@ def type_domain(self):

@type_domain.setter
def type_domain(self, val):
msg = "type_domain must be a tuple of builtin types"
msg = f"type_domain {val} must be a tuple of builtin types"
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is an unrelated change but made errors during the complex lowering passes a little easier to debug.

@@ -732,13 +732,128 @@ class complex_shape(Operation):
def type_inference(self):
if not isinstance(self.x, ComplexVar):
raise ValueError("x must be a ComplexVar.")
input_rank = self.x.real.rank
input_rank = self.x.rank
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If type_inference or value_inference is invoked when the graph is being constructed, x.real and x.imag may not be set since the complex lowering pass hasn't yet been invoked.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point! Could you add this as a comment here?

@acosmicflamingo
Copy link

:O It's happening!!! :D

@junpeiz junpeiz self-requested a review April 10, 2023 16:49
@junpeiz junpeiz self-assigned this Apr 10, 2023
Copy link
Collaborator

@junpeiz junpeiz left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Really nice PR! Added several comments. After those comments are addressed, I will kick off a CI run and merge it. Thanks a lot!

@@ -8099,6 +8100,113 @@ def forward(self, x):
(2, 3, 4), FftnModel(), backend=backend, compute_unit=compute_unit
)

class TestSTFT(TorchBaseTest):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great! Very nice and comprehensive tests!

atol=1e-4,
)

class TestComplex(TorchBaseTest):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We already have a TestComplex test class in this file. Could you merge this test_abs into that test class?

@@ -732,13 +732,128 @@ class complex_shape(Operation):
def type_inference(self):
if not isinstance(self.x, ComplexVar):
raise ValueError("x must be a ComplexVar.")
input_rank = self.x.real.rank
input_rank = self.x.rank
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point! Could you add this as a comment here?

@nikalra
Copy link
Contributor Author

nikalra commented Apr 12, 2023

Really nice PR! Added several comments. After those comments are addressed, I will kick off a CI run and merge it. Thanks a lot!

Done! Let me know if I missed anything!

@junpeiz
Copy link
Collaborator

junpeiz commented Apr 12, 2023

CI Run: https://gitlab.com/coremltools1/coremltools/-/pipelines/835422539

The test for test_spectrogram failed with a strange error Fatal Python error. However, the test passed on my local machine, so it looks like a machine-dependent issue. I will try to upgrade the CI machine's system to see if it works.

As this test behavior is machine-dependent and failed on Intel Mac, we xfail it for now and will debug later. Already filed an internal tracking radar for it.
@junpeiz
Copy link
Collaborator

junpeiz commented Apr 13, 2023

I also tested it on an Intel Mac, and it also failed (but not the same error message). So the conclusion is that the test_spectrogram doesn't support Intel Mac. I filed an internal radar to track it and marked it as xfail for now.

Starting a new CI run after marking xfail: https://gitlab.com/coremltools1/coremltools/-/pipelines/836614946

@fwcd
Copy link

fwcd commented Jun 7, 2023

Very excited to see torch.stft supported, is there any chance to get support for torch.istft too?

@junpeiz junpeiz mentioned this pull request Oct 16, 2023
@alealv alealv mentioned this pull request Oct 24, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

4 participants