In [3]:
import os
if 'note' in os.getcwd(): os.chdir('../')
print(os.getcwd())

import warnings
warnings.filterwarnings('ignore')

/home/bernie40916/Nas/home/Project/DJtransGAN


In [15]:
import time
import torch

from djtransgan.config         import settings
from djtransgan.utils.utils    import purify_device
from djtransgan.frontend.utils import get_amp2db_func, get_stft_func
from djtransgan.utils.noise    import generate_noise

### Helper Function to Test Differentiability

In [16]:
def add_gard_requires(params):
    if isinstance(params, tuple):
        return tuple([param.requires_grad_(True)])
    elif isinstance(params, list):
        return [param.requires_grad_(True)]
    else:
        return params.requires_grad_(True)
    
    
def fake_backward(data):
    return data.backward(torch.ones_like(data))


def stft_diff_test(data, stft):
    d    = add_gard_requires(torch.clone(data))
    m, p = stft(d)
    fake_backward(m)
    return d.grad


def istft_diff_test(data, stft):
    m, p = stft(torch.clone(data))
    mg   = add_gard_requires(m)
    s = stft.inverse(mg, p)
    fake_backward(s)
    return mg.grad

### Data Preparation

In [17]:
batch_size = 4
times      = 30
waves      = torch.stack([generate_noise(times).unsqueeze(0) for batch in range(batch_size)])

print('size: ', waves.size())

size:  torch.Size([4, 1, 1323000])


### NNAudio

In [21]:
nnaudio_stft = get_stft_func(stype='nnaudio', length=waves.size(-1))
print('=' * 50)

start_time   = time.time()
stft_grad    = stft_diff_test(waves, nnaudio_stft)
end_time     = time.time()

print('STFT Grad: ', stft_grad is not None)
print('STFT Time: ', end_time - start_time)

start_time   = time.time()
istft_grad   = istft_diff_test(waves, nnaudio_stft)
end_time     = time.time()

print('ISTFT Grad: ', istft_grad is not None)
print('ISTFT Time: ', end_time - start_time)

STFT kernels created, time used = 0.1894 seconds
STFT Grad:  True
STFT Time:  0.8163197040557861
ISTFT Grad:  True
ISTFT Time:  5.596508502960205


### AsteroidSTFT

In [22]:
asteroid_stft = get_stft_func(stype='asteroid', length=waves.size(-1))
print('=' * 50)

start_time   = time.time()
stft_grad    = stft_diff_test(waves, asteroid_stft)
end_time     = time.time()

print('STFT Grad: ', stft_grad is not None)
print('STFT Time: ', end_time - start_time)

start_time   = time.time()
istft_grad   = istft_diff_test(waves, asteroid_stft)
end_time     = time.time()

print('ISTFT Grad: ', istft_grad is not None)
print('ISTFT Time: ', end_time - start_time)

STFT kernels created, time used = 0.1729 seconds
STFT Grad:  True
STFT Time:  0.9932258129119873
ISTFT Grad:  True
ISTFT Time:  1.6749699115753174


### TorchLibrosa

In [23]:
torchlibrosa_stft = get_stft_func(stype='torchlibrosa', length=waves.size(-1))

print('=' * 50)

start_time   = time.time()
stft_grad    = stft_diff_test(waves, torchlibrosa_stft)
end_time     = time.time()

print('STFT Grad: ', stft_grad is not None)
print('STFT Time: ', end_time - start_time)

start_time   = time.time()
istft_grad   = istft_diff_test(waves, torchlibrosa_stft)
end_time     = time.time()

print('ISTFT Grad: ', istft_grad is not None)
print('ISTFT Time: ', end_time - start_time)

STFT kernels created, time used = 0.2103 seconds
STFT Grad:  True
STFT Time:  0.8302221298217773
ISTFT Grad:  True
ISTFT Time:  1.4559872150421143
