In [1]:
%reload_ext autoreload
%autoreload 2
%matplotlib inline

In [2]:
from pathlib import Path
from IPython.display import Audio, display
import librosa
import matplotlib.pyplot as plt
import numpy as np
from scipy.signal import stft
import torch
import torch.nn.functional as F

from fastai.torch_core import to_np
from fastai_audio.audio_clip import open_audio
from fastai_audio.transform import Spectrogram, FrequencyToMel

In [3]:
DATA_PATH = Path('data/examples_16KHz')

def get_data():
    # load data from example files
    clips = [open_audio(fn) for fn in DATA_PATH.iterdir()]
    sample_rate = clips[0].sample_rate
    tensors = [clip.data for clip in clips]
    # make them all the same length so they can be combined into a batch
    min_len = min(t.size(0) for t in tensors)
    tensors = [t[:min_len] for t in tensors]
    batch_tensor = torch.stack(tensors)
    return batch_tensor, sample_rate

In [4]:
def info(arr, title=None):
    if title:
        print(title)
    print(" shape:", arr.shape)
    print(" dtype:", arr.dtype)
    print(" range: [{:.5f}, {:.5f}]".format(arr.min(), arr.max()))
    print()

In [291]:
xs, sr = get_data()
x = to_np(xs[0])
info(x)
Audio(x, rate=sr)

 shape: (90508,)
 dtype: float32
 range: [-0.56581, 0.48552]



In [390]:
xs.shape

torch.Size([3, 90508])

In [292]:
n_fft = 1024
hop_length = 256

### Compare torch.stft and librosa.stft

In [293]:
D = librosa.stft(x, 
                 n_fft=n_fft, 
                 hop_length=hop_length,
                 window='hann')
info(D.real), info(D.imag);

 shape: (513, 354)
 dtype: float32
 range: [-51.81450, 54.89397]

 shape: (513, 354)
 dtype: float32
 range: [-59.15759, 53.54287]



In [294]:
xt = torch.from_numpy(x)
Dt = torch.stft(xt, 
                n_fft=n_fft, 
                hop_length=hop_length,
                window=torch.hann_window(n_fft))
Dt_real = Dt[...,0]
Dt_imag = Dt[...,1]
info(Dt_real), info(Dt_imag);

 shape: torch.Size([513, 354])
 dtype: torch.float32
 range: [-51.81450, 54.89398]

 shape: torch.Size([513, 354])
 dtype: torch.float32
 range: [-59.15759, 53.54286]



In [9]:
(np.isclose(D.real, Dt_real, atol=1e-5).all(),
 np.isclose(D.imag, Dt_imag, atol=1e-5).all())

(True, True)

### Confirm librosa.stft invertibility

In [10]:
D = librosa.stft(x, n_fft=n_fft, hop_length=hop_length, window='hann')
xh = librosa.istft(D, hop_length=hop_length, window='hann', length=x.size)
info(xh)

np.isclose(x, xh, atol=1e-7).all()

 shape: (90508,)
 dtype: float32
 range: [-0.56581, 0.48552]



True

### Numpy stft to match librosa.stft

In [17]:
from scipy.signal import get_window
from scipy.fftpack import fft

def np_stft(signal, n_fft=1024, hop_length=512, 
            win_length=1024, window='hann', 
            center=True, pad_mode='constant'):
    # todo: implement the normalize param
    w = get_window(window, win_length)
    if center:
        p = win_length//2
        signal = np.pad(signal, p, pad_mode)
    
    n_samples = signal.size
    n_frames = int(np.ceil(float(n_samples - n_fft) / hop_length))
    frames = np.zeros((n_fft, n_frames))

    for i, offset in enumerate(n*hop_length for n in range(n_frames)):
        n_frame_samples = min(n_fft, n_samples - offset)
        frames[:n_frame_samples, i] = signal[offset:offset + n_fft]

    # window each frame
    frames *= w[:,None]

    spectrum = np.fft.rfft(frames, n=n_fft, axis=0)
    return spectrum

D = librosa.stft(x, n_fft=n_fft, hop_length=hop_length, window='hann', pad_mode='constant')
Dn = np_stft(x, n_fft=n_fft, hop_length=hop_length)

np.isclose(D.real, Dn.real).all(), np.isclose(D.imag, Dn.imag).all()

(True, True)

In [None]:
D = librosa.stft(x, n_fft=n_fft, hop_length=hop_length, window='hann', pad_mode='constant')
Dn = np_stft(x, n_fft=n_fft, hop_length=hop_length)

np.isclose(D.real, Dn.real).all(), np.isclose(D.imag, Dn.imag).all()

### Custom PyTorch STFT

In [462]:
def torch_stft(signal, n_fft=1024, hop_length=512, window=None, center=True, normalized=False):
    # pad signal 
    if center:
        p = window.size(0) // 2
        signal = F.pad(signal, (p, p), 'constant')
    # overlap frames
    frames = signal.unfold(0, n_fft, hop_length) * window.unsqueeze(0)
    # compute ffts
    spectrum = torch.rfft(frames, 1)
    if normalized:
        spectrum.mul_(np.power(n_fft, -0.5))
    # tranpose to match torch.stft 
    return spectrum.transpose(0,1)

xt = torch.from_numpy(x)
w = torch.hann_window(n_fft)

Dt1 = torch.stft(xt, n_fft=n_fft, hop_length=hop_length, window=w, pad_mode='constant')
Dt2 = torch_stft(xt, n_fft=n_fft, hop_length=hop_length, window=w)
bool((Dt1 == Dt2).all())

True

In [465]:
Dt1 = torch.stft(xt, n_fft=n_fft, hop_length=hop_length, window=w, pad_mode='constant', normalized=True)
Dt2 = torch_stft(xt, n_fft=n_fft, hop_length=hop_length, window=w, normalized=True)

info(Dt1)
info(Dt2)

bool((Dt1 == Dt2).all())

 shape: torch.Size([513, 354, 2])
 dtype: torch.float32
 range: [-1.84867, 1.71544]

 shape: torch.Size([513, 354, 2])
 dtype: torch.float32
 range: [-1.84867, 1.71544]



True

In [455]:
-59.15759 / -1.84867

32.00008113941374

In [53]:
%timeit torch_stft(xt, n_fft=n_fft, hop_length=hop_length, window=w)

592 µs ± 55 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)


In [54]:
%timeit torch.stft(xt, n_fft=n_fft, hop_length=hop_length, window=w, pad_mode='constant')

626 µs ± 23.5 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)


### Numpy istft to match librosa.istft

In [482]:
def np_istft(spectrum, n_fft=1024, hop_length=512, win_length=1024, window='hann', output_length=None):
    w = get_window(window, win_length)
    
    n_fft, n_frames = spectrum.shape
    n_fft = (n_fft - 1) * 2
    n_samples = n_frames * hop_length + n_fft

    signal = np.zeros(n_samples)
    norm = np.zeros(n_samples)
    signal_segments = np.fft.irfft(spectrum, n=n_fft, axis=0)
    print(signal_segments.shape)
    for i, offset in enumerate(f*hop_length for f in range(n_frames)):
        signal[offset:offset + n_fft] += signal_segments[:, i] * w
        norm[offset:offset + n_fft] += w**2.0

    # divide by norm
    signal /= np.where(norm > 1e-10, norm, 1.0)

    # remove boundary padding
    p = win_length//2
    signal = signal[p:-p]

    # shorten to output_length
    if output_length is not None:
        signal = signal[:output_length]

    return signal

D = librosa.stft(x, n_fft=n_fft, hop_length=hop_length, window='hann', pad_mode='constant')
xh = librosa.istft(D, hop_length=hop_length, window='hann', length=x.size)
xhn = np_istft(D, hop_length=hop_length, window='hann', output_length=x.size)

assert np.isclose(x, xh, atol=1e-7).all()
assert np.isclose(xh, xhn, atol=1e-7).all()

info(xh)
info(xhn)

(1024, 354)
 shape: (90508,)
 dtype: float32
 range: [-0.56581, 0.48552]

 shape: (90508,)
 dtype: float64
 range: [-0.56581, 0.48552]



### Custom PyTorch istft

In [490]:
def torch_istft(spectrum, hop_length=512, window=None, center=True, eps=1e-7, length=None, normalized=False):
    n_fft, n_frames = spectrum.size(0), spectrum.size(1)
    n_fft = (n_fft - 1) * 2
    n_samples = n_frames * hop_length + n_fft
    
    segments = torch.irfft(spectrum.transpose(0, 1), 1, signal_sizes=(n_fft,))
    
    signal = torch.zeros(n_samples)
    norm =   torch.zeros(n_samples)
    for i, offset in enumerate(f * hop_length for f in range(n_frames)):
        signal[offset:offset + n_fft] += segments[i,:] * window
        norm[offset:offset + n_fft] += window**2.0
        
    nonzero_norm = norm > eps
    signal[nonzero_norm] = signal[nonzero_norm] / norm[nonzero_norm]
    
    if normalized:
        signal.mul_(np.power(n_fft, 0.5))
    
    start = 0
    end = n_samples
    if center:
        p = n_fft // 2
        start, end = p, -p
    if length is not None:
        end = start + length
    return signal[start:end]

    
Dt = torch.stft(xt, n_fft=n_fft, hop_length=hop_length, window=w, pad_mode='constant')
xth = torch_istft(Dt, hop_length=hop_length, window=w, length=xt.size(0))

info(Dt)
info(xth)

np.isclose(xt, xth,atol=1e-7).all()

 shape: torch.Size([513, 354, 2])
 dtype: torch.float32
 range: [-59.15759, 54.89398]

 shape: torch.Size([90508])
 dtype: torch.float32
 range: [-0.56581, 0.48552]



True

In [491]:
Dt = torch.stft(xt, n_fft=n_fft, hop_length=hop_length, window=w, pad_mode='constant', normalized=True)
xth = torch_istft(Dt, hop_length=hop_length, window=w, length=xt.size(0), normalized=True)

info(Dt)
info(xth)

np.isclose(xt, xth,atol=1e-7).all()

 shape: torch.Size([513, 354, 2])
 dtype: torch.float32
 range: [-1.84867, 1.71544]

 shape: torch.Size([90508])
 dtype: torch.float32
 range: [-0.56581, 0.48552]



True

In [492]:
%timeit torch_istft(Dt, hop_length=hop_length, window=w, length=xt.size(0))

49.3 ms ± 3.51 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)


### Custom PyTorch istft *FAST*

In [475]:
def torch_istft(spectrum, hop_length=512, window=None, center=True, eps=1e-7, length=None):
    n_fft, n_frames = spectrum.size(1), spectrum.size(0)
    n_fft = (n_fft - 1) * 2
    n_samples = (n_frames - 1) * hop_length + n_fft
    
    w = window.view(1, -1, 1)
    
    segments = torch.irfft(spectrum, 1, signal_sizes=(n_fft,))
    segments = segments.transpose(0,1)
    segments.unsqueeze_(0)
    segments.mul_(w)

    signal = F.fold(segments, 
                    output_size=(1, n_samples), 
                    kernel_size=(1, n_fft),
                    stride=(1, hop_length))
    
    norm = torch.ones_like(segments).mul_(w**2.0)
    norm = F.fold(norm, 
                  output_size=(1, n_samples), 
                  kernel_size=(1, n_fft),
                  stride=(1, hop_length))
    
    signal = signal.div_(norm).reshape(n_samples)
    
    # remove padding and trim to length
    start = 0
    end = n_samples
    if center:
        p = n_fft // 2
        start, end = p, -p
    if length is not None:
        end = start + length
    return signal[start:end]
    
Dt = torch.stft(xt, n_fft=n_fft, hop_length=hop_length, window=w, pad_mode='constant').transpose(0,1)
xth = torch_istft(Dt, hop_length=hop_length, window=w, length=xt.size(0))

(np.isclose(xt, xth, atol=1e-7).all(),
 np.isclose(xth, xhn, atol=1e-7).all())

(True, True)

In [476]:
%timeit torch_istft(Dt, hop_length=hop_length, window=w, length=xt.size(0))

4.71 ms ± 374 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


### Batch PyTorch STFT

In [469]:
def torch_stft_b(signal, n_fft=1024, hop_length=512, window=None, center=True, normalized=False):
    # pad signal 
    if center:
        p = window.size(0) // 2
        signal = F.pad(signal, (p, p), 'constant')
    # overlap frames
    frames = signal.unfold(1, n_fft, hop_length) * window.unsqueeze(0)
    # compute ffts
    spectrum = torch.rfft(frames, 1)
    if normalized:
        spectrum.mul_(np.power(n_fft, -0.5))
    # tranpose to match torch.stft 
    return spectrum.transpose(1,2)

w = torch.hann_window(n_fft)
info(xs)

Dt1 = torch.stft(xs, n_fft=n_fft, hop_length=hop_length, window=w, pad_mode='constant')
Dt2 = torch_stft_b(xs, n_fft=n_fft, hop_length=hop_length, window=w)
info(Dt1)
bool((Dt1 == Dt2).all())

 shape: torch.Size([3, 90508])
 dtype: torch.float32
 range: [-0.63845, 0.50826]

 shape: torch.Size([3, 513, 354, 2])
 dtype: torch.float32
 range: [-99.35357, 91.80214]



True

In [470]:
Dt1 = torch.stft(xs, n_fft=n_fft, hop_length=hop_length, window=w, pad_mode='constant', normalized=True)
Dt2 = torch_stft_b(xs, n_fft=n_fft, hop_length=hop_length, window=w, normalized=True)
info(Dt1)
bool((Dt1 == Dt2).all())

 shape: torch.Size([3, 513, 354, 2])
 dtype: torch.float32
 range: [-3.10480, 2.86882]



True

### Batch PyTorch ISTFT

In [477]:
def torch_istft_b(spectrum, hop_length=512, window=None, center=True, eps=1e-7, length=None):
    batch_size, n_frames, n_fft = spectrum.size()[:3]
    n_fft = (n_fft - 1) * 2
    n_samples = (n_frames - 1) * hop_length + n_fft
    
    w = window.view(1, -1, 1)
    
    segments = torch.irfft(spectrum, 1, signal_sizes=(n_fft,))
    segments = segments.transpose(1,2)
    segments *= w
    
    signal = F.fold(segments, 
                    output_size=(1, n_samples), 
                    kernel_size=(1, n_fft),
                    stride=(1, hop_length))
    
    norm = torch.ones_like(segments).mul_(w**2.0)
    norm = F.fold(norm, 
                  output_size=(1, n_samples), 
                  kernel_size=(1, n_fft),
                  stride=(1, hop_length))
    
    signal = signal.div_(norm).reshape(batch_size, n_samples)
    
    start = 0
    end = n_samples
    if center:
        p = n_fft // 2
        start, end = p, -p
    if length is not None:
        end = start + length
    return signal[:, start:end]

    
Dt = torch.stft(xs, n_fft=n_fft, hop_length=hop_length, window=w, pad_mode='constant').transpose(1,2)
xsh = torch_istft_b(Dt, hop_length=hop_length, window=w, length=xt.size(0))

np.isclose(xs, xsh, atol=1e-7).all()

True

In [478]:
%timeit torch_istft_b(Dt, hop_length=hop_length, window=w, length=xt.size(0))

13.8 ms ± 714 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [479]:
np.power(1024, -0.5)

0.03125

In [480]:
1/0.03125

32.0

In [481]:
np.power(1024, 0.5)

32.0

### Time stfts

In [None]:
%timeit np_stft(x, n_fft=n_fft, hop_length=hop_length)

In [None]:
%timeit librosa.stft(x, n_fft=n_fft, hop_length=hop_length, window='hann')

In [None]:
%timeit torch.stft(xt, n_fft=n_fft, hop_length=hop_length, window=torch.hann_window(n_fft))