Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.

Already on GitHub? Sign in to your account

torchaudio.functional.convolve with torch.complex32 and interleaved fake complex torch.bfloat16 #3739

Open
pfeatherstone opened this issue Jan 31, 2024 · 1 comment

Comments

@pfeatherstone
Copy link

pfeatherstone commented Jan 31, 2024

馃殌 The feature

I have a use case for convolving two tensors with dtypes torch.float16 or torch.bfloat16 containing interleaved complex data.

Motivation, pitch

I have a couple DSP functions being applied in a training loop using float16 or bfloat16.

Alternatives

I have the following code which works

import torch
import torch.nn.functional as F
from einops import rearrange

def convolve_cplx (
    x: torch.Tensor, # [B, T, 2]
    y: torch.Tensor  # [B, H, 2]
) :
    """ Same as torch.view_as_real(torchaudio.convolve(torch.view_as_complex()...) but works with torch.float16 and torch.bfloat16"""
    B, m    = x.shape[0], int(y.shape[1]//2)
    y       = y.flip(-2) # [B H 2]
    yr, yi  = y[...,0], y[...,1]
    x       = rearrange(x, 'b t re -> 1 (b re) t')
    w       = rearrange([yr,-yi,yi,yr], '(o re) b h -> (b o) re h', re=2)
    y       = F.conv1d(x, w, padding=m, groups=B)
    y       = rearrange(y, '1 (b re) t -> b t re', re=2)
    return y

I have a unit test:

import unittest
import time
import torchaudio.functional as aF

class TestDSP(unittest.TestCase):
    def test_conv_cplx(self):
        perf0 = 0
        perf1 = 0
        first = True
        for T in [2000, 1000, 500, 200]:
            for H in [21, 5, 9, 161]:
                x = torch.randn(32, T, 2)
                h = torch.randn(32, H, 2)
                t0 = time.perf_counter()
                y1 = torch.view_as_real(aF.convolve(torch.view_as_complex(x), torch.view_as_complex(h), mode='same'))
                t1 = time.perf_counter()
                y2 = convolve_cplx(x, h)
                t2 = time.perf_counter()
                self.assertTrue(torch.allclose(y1,y2,rtol=1e-4,atol=1e-4))
                if first:
                    first = False
                else:
                    perf0 += (t1-t0)
                    perf1 += (t2-t1)
        print("torchaudio time {:.5f} custom {:.5f} improvement {:.2f}%".format(perf0, perf1, 100*(perf0-perf1)/perf0))

if __name__ == '__main__':
    unittest.main()

On my machine, this test shows that my "dummy" method is about 47% faster than torchaudio...

Additional context

No response

@pfeatherstone
Copy link
Author

Now if you try with torch.float16 or torch.bfloat16, the torchaudio method doesn't work.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant