In [None]:
%reload_ext autoreload
%autoreload 2
%matplotlib inline

import IPython
import matplotlib.pyplot as plt
import numpy as np
import soundfile as sf
import time
from tqdm import tqdm

from nara_wpe.wpe import wpe
from nara_wpe.utils import stft, istft, get_stft_center_frequencies
from nara_wpe import project_root

In [None]:
stft_options = dict(
    size=512,
    shift=128,
    window_length=None,
    fading=True,
    pad=True,
    symmetric_window=False
)

# Minimal example with random data

In [None]:
def aquire_audio_data():
    D, T = 4, 10000
    y = np.random.normal(size=(D, T))
    return y

In [None]:
y = aquire_audio_data()
Y = stft(y, **stft_options)
Y = Y.transpose(2, 0, 1)

start = time.perf_counter()
Z = wpe(Y)
end = time.perf_counter()

z_np = istft(Z.transpose(1, 2, 0), size=stft_options['size'], shift=stft_options['shift'])
print(f"Time: {end-start}")

# Example with real audio recordings

WPE estimates a filter to predict the current reverberation tail frame from K time frames which lie 3 (delay) time frames in the past. This frame (reverberation tail) is then subtracted from the observed signal.

### Setup

In [None]:
channels = 8
sampling_rate = 16000
delay = 3
iterations = 5
taps = 10

### Audio data
Shape: (frames, channels)

In [None]:
file_template = 'AMI_WSJ20-Array1-{}_T10c0201.wav'
signal_list = [
    sf.read(f"{project_root}/data/{file_template.format(d + 1)}")[0]
    for d in range(channels)
]
y = np.stack(signal_list, axis=0)
IPython.display.Audio(y[0], rate=sampling_rate)

### STFT
A STFT is performed to obtain a Numpy array with shape (frequency bins, channels, frames).

In [None]:
Y = stft(y, **stft_options).transpose(2, 0, 1)

### iterative WPE
The wpe function is fed with Y. Finally, an inverse STFT is performed to obtain a dereverberated result in time domain. 

In [None]:
Z = wpe(Y, iterations=iterations, statistics_mode='full').transpose(1, 2, 0)
z = istft(Z, size=stft_options['size'], shift=stft_options['shift'])
IPython.display.Audio(z[0], rate=sampling_rate)

## Power spectrum 
Before and after applying WPE

In [None]:
fig, [ax1, ax2] = plt.subplots(1, 2, figsize=(20, 10))
im1 = ax1.imshow(20 * np.log10(np.abs(Y[ :, 0, 200:400])), origin='lower')
ax1.set_xlabel('frames')
_ = ax1.set_title('reverberated')
im2 = ax2.imshow(20 * np.log10(np.abs(Z[0, 200:400, :])).T, origin='lower')
ax2.set_xlabel('frames')
_ = ax2.set_title('dereverberated')
cb = fig.colorbar(im1)