## Package Imports

In [None]:
import time
import IPython
import librosa
import scipy.signal
import numpy as np
import matplotlib.pyplot as plt
from scipy.io import wavfile
from datetime import timedelta as td
%matplotlib inline

## Librosa Wrapper Function Definitions

In [None]:
def _stft(y, n_fft, hop_length, win_length):
    return librosa.stft(y = y, n_fft = n_fft, hop_length = hop_length, win_length = win_length)


def _istft(y, hop_length, win_length):
    return librosa.istft(y, hop_length, win_length)


def _amp_to_db(x):
    return librosa.core.logamplitude(x, ref_power = 1.0, amin = 1e-20, top_db = 80.0)  # Librosa 0.4.2 functionality
#     return librosa.core.amplitude_to_db(x, ref = 1.0, amin = 1e-20, top_db = 80.0)  # Librosa 0.6.3 functionality


def _db_to_amp(x):
    return librosa.core.perceptual_weighting(x, frequencies = 1.0)  # Librosa 0.4.2 functionality
#     return librosa.core.db_to_amplitude(x, ref = 1.0)  # Librosa 0.6.3 functionality

## Custom Graph Plotting Function Definitions

In [None]:
def plot_spectrogram(signal, title):
    fig, ax = plt.subplots(figsize = (20, 4))
    cax = ax.matshow(
        signal,
        origin = "lower",
        aspect = "auto",
        cmap = plt.cm.seismic,
        vmin = -1 * np.max(np.abs(signal)),
        vmax = np.max(np.abs(signal)),
    )
    fig.colorbar(cax)
    ax.set_title(title)
    plt.tight_layout()
    plt.show()


def plot_statistics_and_filter(mean_freq_noise,
                               std_freq_noise,
                               noise_thresh,
                               smoothing_filter):
    
    fig, ax = plt.subplots(ncols = 2, figsize = (20, 4))
    plt_mean, = ax[0].plot(mean_freq_noise, label = "Mean power of noise")
    plt_std, = ax[0].plot(std_freq_noise, label = "Std. power of noise")
    plt_std, = ax[0].plot(noise_thresh, label = "Noise threshold (by frequency)")
    ax[0].set_title("Threshold for mask")
    ax[0].legend()
    cax = ax[1].matshow(smoothing_filter, origin = "lower")
    fig.colorbar(cax)
    ax[1].set_title("Filter for smoothing Mask")
    plt.show()

## Custom Noise Reduction Function Definition

In [None]:
def removeNoise(audio_clip,
                noise_clip,
                n_grad_freq = 2,
                n_grad_time = 4,
                n_fft = 2048,
                win_length = 2048,
                hop_length = 512,
                n_std_thresh = 1.5,
                prop_decrease = 1.0,
                verbose = False,
                visual = False):
    
    """ Removes noise from audio based upon a clip containing only noise

    Args:
        audio_clip (array): The first parameter.
        noise_clip (array): The second parameter.
        n_grad_freq (int): how many frequency channels to smooth over with the mask.
        n_grad_time (int): how many time channels to smooth over with the mask.
        n_fft (int): number audio of frames between STFT columns.
        win_length (int): Each frame of audio is windowed by `window()`. The window will be of length `win_length` and then padded with zeros to match `n_fft`..
        hop_length (int):number audio of frames between STFT columns.
        n_std_thresh (int): how many standard deviations louder than the mean dB of the noise (at each frequency level) to be considered signal
        prop_decrease (float): To what extent should you decrease noise (1 = all, 0 = none)
        verbose: Whether to display time statistics for the noise reduction process
        visual (bool): Whether to plot the steps of the algorithm

    Returns:
        array: The recovered signal with noise subtracted

    """
    
    # Debugging
    if verbose:
        start = time.time()
        
    # Takes a STFT over the noise sample
    noise_stft = _stft(noise_clip, n_fft, hop_length, win_length)
    noise_stft_db = _amp_to_db(np.abs(noise_stft))  # Converts the sample units to dB
    
    # Calculates statistics over the noise sample
    mean_freq_noise = np.mean(noise_stft_db, axis = 1)
    std_freq_noise = np.std(noise_stft_db, axis = 1)
    noise_thresh = mean_freq_noise + std_freq_noise * n_std_thresh
    
    # Debugging
    if verbose:
        print("STFT on noise:", td(seconds = time.time() - start))
        start = time.time()
        
    # Takes a STFT over the signal sample
    sig_stft = _stft(audio_clip, n_fft, hop_length, win_length)
    sig_stft_db = _amp_to_db(np.abs(sig_stft))
    
    # Debugging
    if verbose:
        print("STFT on signal:", td(seconds = time.time() - start))
        start = time.time()
        
    # Calculates value to which to mask dB
    mask_gain_dB = np.min(_amp_to_db(np.abs(sig_stft)))
    print("Noise Threshold & Mask Gain in dB: ", noise_thresh, mask_gain_dB)
    
    # Creates a smoothing filter for the mask in time and frequency
    smoothing_filter = np.outer(
        np.concatenate(
            [
                np.linspace(0, 1, n_grad_freq + 1, endpoint = False),
                np.linspace(1, 0, n_grad_freq + 2),
            ]
        )[1:-1],
        np.concatenate(
            [
                np.linspace(0, 1, n_grad_time + 1, endpoint = False),
                np.linspace(1, 0, n_grad_time + 2),
            ]
        )[1:-1]
    )
    
    smoothing_filter = smoothing_filter / np.sum(smoothing_filter)
    
    # Calculates the threshold for each frequency/time bin
    db_thresh = np.repeat(np.reshape(noise_thresh, [1, len(mean_freq_noise)]),
                          np.shape(sig_stft_db)[1],
                          axis = 0).T
    
    # Masks segment if the signal is above the threshold
    sig_mask = sig_stft_db < db_thresh
    
    # Debugging
    if verbose:
        print("Masking:", td(seconds = time.time() - start))
        start = time.time()
        
    # Convolves the mask with a smoothing filter
    sig_mask = scipy.signal.fftconvolve(sig_mask, smoothing_filter, mode="same")
    sig_mask = sig_mask * prop_decrease
    
    # Debugging
    if verbose:
        print("Mask convolution:", td(seconds = time.time() - start))
        start = time.time()
        
    # Masks the signal
    sig_stft_db_masked = (sig_stft_db * (1 - sig_mask)
                          + np.ones(np.shape(mask_gain_dB))
                          * mask_gain_dB * sig_mask)  # Masks real
    
    sig_imag_masked = np.imag(sig_stft) * (1 - sig_mask)
    sig_stft_amp = (_db_to_amp(sig_stft_db_masked) * np.sign(sig_stft)) + (1j * sig_imag_masked)
    
    # Debugging
    if verbose:
        print("Mask application:", td(seconds = time.time() - start))
        start = time.time()
        
    # Recovers the signal
    recovered_signal = _istft(sig_stft_amp, hop_length, win_length)
    recovered_spec = _amp_to_db(
        np.abs(_stft(recovered_signal, n_fft, hop_length, win_length))
    )
    
    # Debugging
    if verbose:
        print("Signal recovery:", td(seconds = time.time() - start))
        
    # Visual Plotting
    if visual:
        plot_spectrogram(noise_stft_db, title = "Noise")
        plot_statistics_and_filter(mean_freq_noise, std_freq_noise, noise_thresh, smoothing_filter)
        plot_spectrogram(sig_stft_db, title = "Signal")
        plot_spectrogram(sig_mask, title = "Mask applied")
        plot_spectrogram(sig_stft_db_masked, title = "Masked signal")
        plot_spectrogram(recovered_spec, title = "Recovered spectrogram")
        
    # Returns noise-reduced audio sample
    return recovered_signal

## Reading in a WAV file

In [None]:
test_wav = "True Positive Gunshot Sample.wav"
rate, data = wavfile.read(test_wav)

## Inspecting Audio Sample Before Noise-Reduction

In [None]:
IPython.display.Audio(data = data, rate = rate)

## Finding a Noise Sample

In [None]:
noise_clip = data[14000:18000]  # Finding a clip with just noise

## Validating Noise Sample

In [None]:
IPython.display.Audio(data = noise_clip, rate = rate)

## Performing Noise Reduction

In [None]:
output = removeNoise(audio_clip = np.float32(data), noise_clip = np.float32(noise_clip), verbose = True, visual = True)

## Testing Audio Sample After Noise-Reduction

In [None]:
IPython.display.Audio(data = output, rate = rate)