In [1]:
%reload_ext autoreload
%autoreload 2
%matplotlib inline

In [2]:
from pathlib import Path
from IPython.display import Audio, display
import librosa
import matplotlib.pyplot as plt
import numpy as np
from scipy.signal import stft
import torch
import torch.nn.functional as F

from fastai.torch_core import to_np
from fastai_audio.audio_clip import open_audio
from fastai_audio.transform import Spectrogram, FrequencyToMel

In [3]:
DATA_PATH = Path('../tests/data')  # GCP 
#DATA_PATH = Path('data/examples_16KHz') # MBP

def get_data():
    # load data from example files
    clips = [open_audio(fn) for fn in DATA_PATH.iterdir()]
    sample_rate = clips[0].sample_rate
    tensors = [clip.data for clip in clips]
    # make them all the same length so they can be combined into a batch
    min_len = min(t.size(0) for t in tensors)
    tensors = [t[:min_len] for t in tensors]
    batch_tensor = torch.stack(tensors)
    return batch_tensor, sample_rate

In [4]:
def info(arr, title=None):
    if title:
        print(title)
    print(" shape:", arr.shape)
    print(" dtype:", arr.dtype)
    print(" range: [{:.5f}, {:.5f}]".format(arr.min(), arr.max()))
    print()

In [5]:
xs, sr = get_data()
x = xs[0]
info(x)
Audio(x, rate=sr)

 shape: torch.Size([90508])
 dtype: torch.float32
 range: [-0.63845, 0.50826]



In [6]:
TWO_PI = 2.0*np.pi

def torch_phase_vocoder(mags, phases, rate, hop_length=None):
    n_steps, n_fft_half = mags.size()
    n_fft = 2 * (n_fft_half - 1)
    
    if hop_length is None:
        hop_length = int(n_fft // 4)
            
    time_steps = torch.arange(0, n_steps, rate)
    time_steps_idx = time_steps.long()
    alpha = torch.remainder(time_steps, 1.0).unsqueeze_(1)
    
    mags = F.pad(mags, [0,0,0,1])
    new_mags = ((1.0 - alpha) * mags[time_steps_idx, :]
                     + alpha  * mags[time_steps_idx+1, :])
    
    initial_phase = phases[:1,:]
    phase_delta = phases[1:,:] - phases[:-1,:]
    phi_advance = torch.linspace(0, np.pi * hop_length, n_fft_half).unsqueeze_(0)
    # remove expected phase change due to hop length
    phase_delta.sub_(phi_advance)
    # wrap to [-pi, pi] range
    phase_delta.sub_(TWO_PI * torch.round(phase_delta / TWO_PI))
    # reindex the phases
    new_phase_delta = phase_delta[time_steps_idx[:-1]]
    # add back expected phase change over hops 
    new_phase_delta.add_(phi_advance)
    new_phases = torch.cumsum(torch.cat([initial_phase, 
                                         new_phase_delta], dim=0), dim=0)
    return new_mags, new_phases

In [7]:
def torch_stft(signal, n_fft=1024, hop_length=512, window=None, center=True, normalized=False):
    # pad signal 
    if center:
        p = window.size(0) // 2
        signal = F.pad(signal, (p, p), 'constant')
    # overlap frames
    frames = signal.unfold(0, n_fft, hop_length) * window.unsqueeze(0)
    # compute ffts
    spectrum = torch.rfft(frames, 1)
    if normalized:
        spectrum.mul_(np.power(n_fft, -0.5))
    # tranpose to match torch.stft 
    return spectrum

In [8]:
def torch_istft(spectrum, hop_length=512, window=None, center=True, length=None):
    n_fft, n_frames = spectrum.size(1), spectrum.size(0)
    n_fft = (n_fft - 1) * 2
    n_samples = (n_frames - 1) * hop_length + n_fft
    
    w = window.view(1, -1, 1)
    
    segments = torch.irfft(spectrum, 1, signal_sizes=(n_fft,))
    segments = segments.transpose(0,1)
    segments.unsqueeze_(0)
    segments.mul_(w)

    signal = F.fold(segments, 
                    output_size=(1, n_samples), 
                    kernel_size=(1, n_fft),
                    stride=(1, hop_length))
    
    norm = torch.ones_like(segments).mul_(w**2.0)
    norm = F.fold(norm, 
                  output_size=(1, n_samples), 
                  kernel_size=(1, n_fft),
                  stride=(1, hop_length))
    
    signal = signal.div_(norm).reshape(n_samples)
    
    # remove padding and trim to length
    start = 0
    end = n_samples
    if center:
        p = n_fft // 2
        start, end = p, -p
    if length is not None:
        end = start + length
    return signal[start:end]

In [9]:
def torch_time_stretch(x, rate, n_fft=2048, n_hop=512):
    w = torch.hann_window(n_fft)

    X = torch_stft(x, n_fft=n_fft, hop_length=n_hop, window=w)
    
    # rect -> polar
    X_squared = X.pow(2.0)
    Xm = (X_squared[...,0] + X_squared[...,1]).sqrt_()
    Xp = X[...,1].atan2_(X[...,0])
    
    # actual stretching
    Xmh, Xph = torch_phase_vocoder(Xm, Xp, rate, hop_length=n_hop)
    
    # polar -> rect 
    Xh = torch.empty(*Xmh.size(), 2)
    torch.mul(Xmh, Xph.cos(), out=Xh[:,:,0])
    torch.mul(Xmh, Xph.sin(), out=Xh[:,:,1])

    xh = torch_istft(Xh, hop_length=n_hop, window=w)
    return xh

In [10]:
n_fft = 2048
n_hop =  512
rate = 1.3
xn = to_np(x)
xnh = librosa.effects.time_stretch(xn, rate=rate)
xh = torch_time_stretch(x, rate=rate, n_fft=n_fft, n_hop=n_hop)
display(Audio(xnh, rate=sr))
display(Audio(xh, rate=sr))

In [11]:
def test():
    n_fft = 2048
    n_hop = 512
    rate = 1.2

    xn = to_np(x)
    xnh = librosa.effects.time_stretch(xn, rate=rate)
    xh = torch_time_stretch(x, rate=rate, n_fft=n_fft, n_hop=n_hop)
    
    print(np.isclose(xnh, xh, atol=1e-3).mean())

In [12]:
test()

0.9958944515306123


In [14]:
n_fft = 2048
n_hop = 512
rate = 1.0
xn = to_np(x)
xnh = librosa.effects.time_stretch(xn, rate=rate)

In [15]:
%%timeit 
torch_time_stretch(x, rate=rate, n_fft=n_fft, n_hop=n_hop)

14.4 ms ± 794 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [16]:
%%timeit
librosa.effects.time_stretch(xn, rate=rate)

38.8 ms ± 89.8 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
