You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
I have a use case for convolving two tensors with dtypes torch.float16 or torch.bfloat16 containing interleaved complex data.
Motivation, pitch
I have a couple DSP functions being applied in a training loop using float16 or bfloat16.
Alternatives
I have the following code which works
import torch
import torch.nn.functional as F
from einops import rearrange
def convolve_cplx (
x: torch.Tensor, # [B, T, 2]
y: torch.Tensor # [B, H, 2]
) :
""" Same as torch.view_as_real(torchaudio.convolve(torch.view_as_complex()...) but works with torch.float16 and torch.bfloat16"""
B, m = x.shape[0], int(y.shape[1]//2)
y = y.flip(-2) # [B H 2]
yr, yi = y[...,0], y[...,1]
x = rearrange(x, 'b t re -> 1 (b re) t')
w = rearrange([yr,-yi,yi,yr], '(o re) b h -> (b o) re h', re=2)
y = F.conv1d(x, w, padding=m, groups=B)
y = rearrange(y, '1 (b re) t -> b t re', re=2)
return y
I have a unit test:
import unittest
import time
import torchaudio.functional as aF
class TestDSP(unittest.TestCase):
def test_conv_cplx(self):
perf0 = 0
perf1 = 0
first = True
for T in [2000, 1000, 500, 200]:
for H in [21, 5, 9, 161]:
x = torch.randn(32, T, 2)
h = torch.randn(32, H, 2)
t0 = time.perf_counter()
y1 = torch.view_as_real(aF.convolve(torch.view_as_complex(x), torch.view_as_complex(h), mode='same'))
t1 = time.perf_counter()
y2 = convolve_cplx(x, h)
t2 = time.perf_counter()
self.assertTrue(torch.allclose(y1,y2,rtol=1e-4,atol=1e-4))
if first:
first = False
else:
perf0 += (t1-t0)
perf1 += (t2-t1)
print("torchaudio time {:.5f} custom {:.5f} improvement {:.2f}%".format(perf0, perf1, 100*(perf0-perf1)/perf0))
if __name__ == '__main__':
unittest.main()
On my machine, this test shows that my "dummy" method is about 47% faster than torchaudio...
Additional context
No response
The text was updated successfully, but these errors were encountered:
馃殌 The feature
I have a use case for convolving two tensors with dtypes torch.float16 or torch.bfloat16 containing interleaved complex data.
Motivation, pitch
I have a couple DSP functions being applied in a training loop using float16 or bfloat16.
Alternatives
I have the following code which works
I have a unit test:
On my machine, this test shows that my "dummy" method is about 47% faster than torchaudio...
Additional context
No response
The text was updated successfully, but these errors were encountered: