In [8]:
import os
if 'examples' in os.getcwd(): os.chdir('../../')
print(os.getcwd())

import warnings
warnings.filterwarnings('ignore')

/home/bernie40916/Nas/home/Project/DJtransGAN


In [9]:
import time
import torch

from djtransgan.config   import settings
from djtransgan.utils    import purify_device
from djtransgan.frontend import get_amp2db_func, get_stft_func
from djtransgan.dataset  import get_dataset, batchlize

### Helper Function to Test Differentiability

In [10]:
def add_gard_requires(params):
    if isinstance(params, tuple):
        return tuple([param.requires_grad_(True)])
    elif isinstance(params, list):
        return [param.requires_grad_(True)]
    else:
        return params.requires_grad_(True)
    
    
def fake_backward(data):
    return data.backward(torch.ones_like(data))


def stft_diff_test(data, stft):
    d    = add_gard_requires(torch.clone(data))
    m, p = stft(d)
    fake_backward(m)
    return d.grad


def istft_diff_test(data, stft):
    m, p = stft(torch.clone(data))
    mg   = add_gard_requires(m)
    s = stft.inverse(mg, p)
    fake_backward(s)
    return mg.grad

### Data Preparation

In [17]:
n_time     = 30
batch_size = 4

dataset    = get_dataset('noise', n_time=n_time)
dataloader = iter(batchlize(dataset, batch_size))
waves      = next(dataloader)

print('size: ', waves.size())

size:  torch.Size([4, 1, 1323000])


### NNAudio

In [12]:
nnaudio_stft = get_stft_func(stft_type='nnaudio', length=waves.size(-1))
print('=' * 50)

start_time   = time.time()
stft_grad    = stft_diff_test(waves, nnaudio_stft)
end_time     = time.time()

print('STFT Grad: ', stft_grad is not None)
print('STFT Time: ', end_time - start_time)

start_time   = time.time()
istft_grad   = istft_diff_test(waves, nnaudio_stft)
end_time     = time.time()

print('ISTFT Grad: ', istft_grad is not None)
print('ISTFT Time: ', end_time - start_time)

STFT kernels created, time used = 0.1857 seconds
STFT Grad:  True
STFT Time:  0.2763242721557617
ISTFT Grad:  True
ISTFT Time:  1.5830333232879639


### AsteroidSTFT

In [13]:
asteroid_stft = get_stft_func(stft_type='asteroid', length=waves.size(-1))
print('=' * 50)

start_time   = time.time()
stft_grad    = stft_diff_test(waves, asteroid_stft)
end_time     = time.time()

print('STFT Grad: ', stft_grad is not None)
print('STFT Time: ', end_time - start_time)

start_time   = time.time()
istft_grad   = istft_diff_test(waves, asteroid_stft)
end_time     = time.time()

print('ISTFT Grad: ', istft_grad is not None)
print('ISTFT Time: ', end_time - start_time)

STFT kernels created, time used = 0.1263 seconds
STFT Grad:  True
STFT Time:  0.2017836570739746
ISTFT Grad:  True
ISTFT Time:  0.5179605484008789


### TorchLibrosa

In [15]:
torchlibrosa_stft = get_stft_func(stft_type='torchlibrosa', length=waves.size(-1))

print('=' * 50)

start_time   = time.time()
stft_grad    = stft_diff_test(waves, torchlibrosa_stft)
end_time     = time.time()

print('STFT Grad: ', stft_grad is not None)
print('STFT Time: ', end_time - start_time)

start_time   = time.time()
istft_grad   = istft_diff_test(waves, torchlibrosa_stft)
end_time     = time.time()

print('ISTFT Grad: ', istft_grad is not None)
print('ISTFT Time: ', end_time - start_time)

STFT kernels created, time used = 0.1272 seconds
STFT Grad:  True
STFT Time:  0.22011661529541016
ISTFT Grad:  True
ISTFT Time:  0.37281274795532227
