## Imports

In [1]:
import os
import re

import matplotlib.pyplot as plt
import librosa as lr
import numpy as np
import numpy
import math
import decimal
import soundfile as sf

from scipy import signal
from scipy.fftpack import dct

## Utility functions

In [3]:
#params
NUM_FILTER = 22
SAMPLING_RATE = 16_000 # number of samples per sec
FMIN=20 # hz
FMAX=8000 # hz
WINDOW_LENGTH = 0.020 # ms
HOP_LENGTH = 0.010 # ms

def get_band_filter_coeff(samplerate, f0, Q=1.0):
    """
    Bandpass filter based on BLT: Cookbook formulae for audio EQ biquad filter coefficients
    https://gist.github.com/RyanMarcus/d3386baa6b4cb1ac47f4#file-gistfile1-txt
    """
    w0 = 2 * np.pi * f0 / samplerate
    alpha = np.sin(w0) / (2 * Q)
    a = np.zeros(3)
    b = np.zeros(3)
    b[0] = Q*alpha
    b[1] = 0
    b[2] = -Q*alpha
    a[0] = 1 + alpha
    a[1] = -2*np.cos(w0)
    a[2] = 1-alpha
    return  b, a

def iir_design_first_order(band_frequency, samplerate=SAMPLING_RATE, normalize=True): # the ban frequency is the middel fre
    b = []
    a = []
    for i in range(len(band_frequency)):
        b_, a_ = get_band_filter_coeff(samplerate, band_frequency[i])
        if(normalize):
            b_ = b_/a_[0]           # unified
            a_[1:] = a_[1:]/a_[0]
            a_[0] = 1
        b.append(b_)
        a.append(a_)
    return b, a
    # Ref implementation:
    # b, a = set_gains(b_in, a_in, alpha, gains[0])
    # i = 0
    # g = 0
    # for n in range(2, len(x)):
    #     y[n] = b[0] * x[n] + b[1] * x[n - 1] + b[2] * x[n - 2] - a[1]* y[n - 1] - a[2] * y[n - 2]
    #     if (n % step == 0 and i < len(gains)-1):
    #         i += 1
    #         g = gains[i] * 0.4 + g*0.6
    #         b, a = set_gains(b_in, a_in, alpha, g)
    # return y

def generate_filter_header(b, a, order, filename='equalizer_coeff.h'):
    def array2str(data):
        s = np.array2string(np.array(data).flatten(), separator=',')
        return s.replace("\n", "").replace("\r", "").replace(' ', '').replace(',', ', ').replace('[', '{').replace(']', '}')
    with open(filename, 'w') as file:
        file.write("\n#define NUM_FILTER " + str(len(b)) + '\n')
        file.write("\n#define NUM_ORDER " +  str(order) + '\n')
        file.write("\n#define NUM_COEFF_PAIR " + str(order*2+1) + '\n')
        file.write("\n#define FILTER_COEFF_A " + array2str(a) + "\n")
        file.write("\n#define FILTER_COEFF_B " + array2str(b) + "\n")

def iir_design(band_frequency, samplerate=SAMPLING_RATE, order=1): # the ban frequency is the middel fre
    b = []
    a = []
    fre = band_frequency / (samplerate/2)
    for i in range(1, len(band_frequency)-1):
        b_, a_ = signal.iirfilter(order, [fre[i] - (fre[i]-fre[i-1])/2, fre[i]+ (fre[i+1]-fre[i])/2],
                                  btype='bandpass', output='ba')
        # b_, a_ = signal.iirfilter(order, [fre[i-1], fre[i+1]-0.001],
        #                            btype='bandpass', output='ba')
        # b_, a_ = signal.cheby1(order, 1, [fre[i] - (fre[i]-fre[i-1])/2, fre[i]+ (fre[i+1]-fre[i])/2],
        #                           btype='bandpass', output='ba')
        b.append(b_)
        a.append(a_)
    return b, a

def fir_design(band_frequency, samplerate=SAMPLING_RATE, order=51):
    from scipy import signal
    b = []
    fre = band_frequency / (samplerate/2)
    for i in range(1, len(band_frequency)-1):
        b.append(signal.firwin(order, [fre[i] - (fre[i]-fre[i-1])/2, fre[i]+ (fre[i+1]-fre[i])/2], pass_zero='bandpass'))
    return b

def get_mel_scale(nfilt=NUM_FILTER, samplerate=SAMPLING_RATE, lowfreq=FMIN, highfreq=FMAX):
    highfreq = highfreq or samplerate / 2
    assert highfreq <= samplerate / 2, "highfreq is greater than samplerate/2"
    # compute points evenly spaced in mels
    lowmel = lr.hz_to_mel(lowfreq)
    highmel = lr.hz_to_mel(highfreq)
    melpoints = np.linspace(lowmel, highmel, nfilt + 2)
    return melpoints

def bandpass_filter_fir(sig, b_in, a_in, step, gains):
    from scipy import signal
    x = sig
    y = np.zeros(len(x))
    state = np.zeros(len(b_in)-1)
    g=0
    for n in range(0, len(gains)):
        g = max(0.8*g, gains[n])    # pre RNNoise paper https://arxiv.org/pdf/1709.08243.pdf
        b = b_in * g
        filtered, state = signal.lfilter(b, 1, x[n*step: min((n+1)*step, len(x))], zi=state)
        y[n*step: min((n+1)*step, len(x))] = filtered
    return y

def bandpass_filter_iir(sig, b_in, a_in, step, gains):
    from scipy import signal
    x = sig
    y = np.zeros(len(x))
    state = np.zeros(len(b_in)-1)
    g=0
    for n in range(0, len(gains)):
        g = max(0.6*g, gains[n])    # r=0.6 pre RNNoise paper https://arxiv.org/pdf/1709.08243.pdf
        b = b_in*g
        a = a_in
        filtered, state = signal.lfilter(b, a, x[n*step: min((n+1)*step, len(x))], zi=state)
        y[n*step: min((n+1)*step, len(x))] = filtered
    return y


def plot_frequency_respond(b, a=None, fs=SAMPLING_RATE):
    a = a if len(a) == len(b)  else np.ones(len(b))
    for i in range(len(b)):
        w, h = signal.freqz(b[i], a[i])
        plt.plot(w*0.15915494327*fs, 20 * np.log10(np.maximum(abs(h), 1e-5)), 'b')
    plt.title('Digital filter frequency response')
    plt.ylabel('Amplitude [dB]', color='b')
    plt.xlabel('Frequency [Hz]')
    plt.show()

def noise_suppressed_example(example_number, snr, plot=False):
    """
    In this example, we demonstrate how we suppress noise using dynamic gains in an audio equalizer [EQ].
    The basic idea is we use the clean to noisy energy ratio of each frequency band as the gain of suppression.
    It is done in a very small windows (500 point = 31.25ms) so that it can respone very quickly.
    Then we apply these gains to an equalizer (a set of parallel bandpass filter). The gains are changing very fast
    so the noise will be suppressed when it is detected.

    This is also the principle that how do we generate the truth gains for the training data (y_train).
    """
    # change here to select the file and its noise mixing level.
    test_num = example_number          # which file
    test_noise_level = snr  # noise level in db, selected from 0, 10, 20, depeneded on dataset

    # change here to select the file and its noise mixing level.
    clean_file = "MS-SNSD/CleanSpeech_training/clnsp" + str(test_num) + ".wav"
    noisy_file = "MS-SNSD/NoisySpeech_training/noisy"+str(test_num)+"_SNRdb_"+str(test_noise_level)+".0_clnsp"+str(test_num) +".wav"

    (clean_sig, rate) = lr.load(clean_file, sr=SAMPLING_RATE)
    (noisy_sig, rate) = lr.load(noisy_file, sr=SAMPLING_RATE)
    # clean_sig = clean_sig/32768
    # noisy_sig = noisy_sig/32768

    # Calculate the energy of each frequency bands
    clean_band_eng, _ = fbank(clean_sig, rate, winlen=WINDOW_LENGTH, winstep=HOP_LENGTH, nfilt=NUM_FILTER, nfft=512, lowfreq=FMIN, highfreq=FMAX, preemph=0)
    noisy_band_eng, _ = fbank(noisy_sig, rate, winlen=WINDOW_LENGTH, winstep=HOP_LENGTH, nfilt=NUM_FILTER, nfft=512, lowfreq=FMIN, highfreq=FMAX, preemph=0)
    # gains
    gains = np.sqrt(clean_band_eng / noisy_band_eng)
    if(plot):
        plt.title("Gains")
        plt.plot(gains[:, :10])
        plt.show()

    # convert mel scale back to frequency band
    mel_freqs = lr.mel_frequencies(n_mels=NUM_FILTER, fmin=FMIN, fmax=FMAX)
    band_frequency = mel_freqs[1:-1] # the middle point of each band
    print('band frequency', band_frequency)

    # the noisy audio now pass to a set of parallel band pass filter.
    # which performed like an audio equalizer [EQ]
    # the different is we will change the gains of each band very quickly so that we suppress the noise while keeping the speech.
    # design our band pass filter for each band in the equalizer.
    # becasue the frequency band is overlapping, we need to reduce the signal to avoid overflow when converting back to int16.

    print("denoising using IIR filter")
    b, a = iir_design(mel_freqs, SAMPLING_RATE)
    if plot:
        plot_frequency_respond(b, a)
    print("b", b)
    print("a", a)
    step = int(HOP_LENGTH * SAMPLING_RATE)
    print("audio process step:", step)
    filtered_signal = np.zeros(len(noisy_sig))
    for i in range(len(b)):
        filtered_signal += bandpass_filter_iir(noisy_sig, b[i].copy(), a[i].copy(), step, gains[:, i])
        print("filtering with frequency: ", band_frequency[i])
    filtered_signal = filtered_signal * 0.6

    filtered_signal = np.clip(filtered_signal, -1, 1)
    sf.write("_filtered_sample.wav", np.asarray(filtered_signal * 32767, dtype=np.int16), SAMPLING_RATE)
    sf.write("_noisy_sample.wav", np.asarray(noisy_sig * 32767, dtype=np.int16), SAMPLING_RATE)
    sf.write("_clean_sample.wav", np.asarray(clean_sig * 32767, dtype=np.int16), SAMPLING_RATE)
    print("filtered signal is saved to:", "_filtered_sample.wav")
    print("noisy signal is saved to:", "_noisy_sample.wav")
    print("clean signal is saved to:", "_clean_sample.wav")


def generate_data(path, vad_active_delay=0.07, vad_threshold=1e-1, random_volume=True, winlen=WINDOW_LENGTH, winstep=HOP_LENGTH,
                  numcep=13, nfilt=NUM_FILTER, nfft=512, lowfreq=FMIN, highfreq=FMAX, winfunc=np.hanning, ceplifter=0,
                  preemph=0.97, appendEnergy=True):
    """
    vad_filter_size: number of winstep for filter. if one of the point is active, the first size/2 and last size/2 will be actived
    Larger size will have better cover to the speech, but will bring none-speech moments
    please refer to python_speech_features.mfcc for other parameters
    """
    mfcc_data = []
    filename_label = []
    total_energy = []
    band_energy = []
    vad = []
    files = os.listdir(path)
    for f in files:
        filename = f
        if ('wav' not in filename):
            continue
        (sig, rate) = lr.load(path+'/'+f, sr=SAMPLING_RATE)
        # convert file to [-1, 1)
        # sig = sig/32768

        # calculate the energy per band, this was one of the step in mfcc but taked out
        band_eng, total_eng = fbank(sig, rate, winlen=winlen, winstep=winstep, nfilt=nfilt, nfft=nfft, lowfreq=lowfreq,
                                   highfreq=highfreq, preemph=preemph, winfunc=winfunc)

        # for the mfcc, because we are not normalizing them,
        # so we randomize the volume to simulate the real life voice record.
        if(random_volume):
            sig = sig * np.random.uniform(0.8, 1)

        # calculate mfcc features
        mfcc_feat = mfcc(sig, rate, winlen=winlen, winstep=winstep, numcep=numcep, nfilt=nfilt, nfft=nfft,
                         lowfreq=lowfreq, highfreq=highfreq, winfunc=winfunc, ceplifter=ceplifter, preemph=preemph,
                         appendEnergy=appendEnergy)

        # voice active detections, only valid with clean speech. Detected by total energy vs threshold.
        v = (total_eng > vad_threshold).astype(int)
        vad_delay = int(vad_active_delay*(rate*winstep))
        conv_win = np.concatenate([np.zeros(vad_delay), np.ones(vad_delay)]) # delay the VAD for a vad_active_delay second
        v = np.convolve(v, conv_win, mode='same')
        v = (v > 0).astype(int)

        total_energy.append(total_eng)
        band_energy.append(band_eng)
        vad.append(v)
        mfcc_data.append(mfcc_feat.astype('float32'))
        filename_label.append(filename)
    return mfcc_data, filename_label, total_energy, vad, band_energy


def fbank(signal,samplerate=SAMPLING_RATE,winlen=WINDOW_LENGTH,winstep=HOP_LENGTH,
          nfilt=NUM_FILTER,nfft=512,lowfreq=FMIN,highfreq=FMAX,preemph=0.97,
          winfunc=lambda x:numpy.ones((x,))):
    """Compute Mel-filterbank energy features from an audio signal.

    :param signal: the audio signal from which to compute features. Should be an N*1 array
    :param samplerate: the sample rate of the signal we are working with, in Hz.
    :param winlen: the length of the analysis window in seconds. Default is 0.025s (25 milliseconds)
    :param winstep: the step between successive windows in seconds. Default is 0.01s (10 milliseconds)
    :param nfilt: the number of filters in the filterbank, default 26.
    :param nfft: the FFT size. Default is 512.
    :param lowfreq: lowest band edge of mel filters. In Hz, default is 0.
    :param highfreq: highest band edge of mel filters. In Hz, default is samplerate/2
    :param preemph: apply preemphasis filter with preemph as coefficient. 0 is no filter. Default is 0.97.
    :param winfunc: the analysis window to apply to each frame. By default no window is applied. You can use numpy window functions here e.g. winfunc=numpy.hamming
    :returns: 2 values. The first is a numpy array of size (NUMFRAMES by nfilt) containing features. Each row holds 1 feature vector. The
        second return value is the energy in each frame (total energy, unwindowed)
    """
    highfreq= highfreq or samplerate/2
    signal = preemphasis(signal, preemph)
    frames = framesig(signal, winlen*samplerate, winstep*samplerate, winfunc)
    pspec = powspec(frames, nfft)
    energy = numpy.sum(pspec,1) # this stores the total energy in each frame
    energy = numpy.where(energy == 0,numpy.finfo(float).eps,energy) # if energy is zero, we get problems with log

    fb = get_filterbanks(nfilt,nfft,samplerate,lowfreq,highfreq)
    feat = numpy.dot(pspec,fb.T) # compute the filterbank energies
    feat = numpy.where(feat == 0,numpy.finfo(float).eps,feat) # if feat is zero, we get problems with log
    return feat,energy


def get_filterbanks(nfilt=NUM_FILTER,nfft=512,samplerate=SAMPLING_RATE,lowfreq=FMIN,highfreq=FMAX):
    """Compute a Mel-filterbank. The filters are stored in the rows, the columns correspond
    to fft bins. The filters are returned as an array of size nfilt * (nfft/2 + 1)

    :param nfilt: the number of filters in the filterbank, default 20.
    :param nfft: the FFT size. Default is 512.
    :param samplerate: the sample rate of the signal we are working with, in Hz. Affects mel spacing.
    :param lowfreq: lowest band edge of mel filters, default 0 Hz
    :param highfreq: highest band edge of mel filters, default samplerate/2
    :returns: A numpy array of size nfilt * (nfft/2 + 1) containing filterbank. Each row holds 1 filter.
    """
    highfreq= highfreq or samplerate/2
    assert highfreq <= samplerate/2, "highfreq is greater than samplerate/2"

    # compute points evenly spaced in mels
    lowmel = lr.hz_to_mel(lowfreq)
    highmel = lr.hz_to_mel(highfreq)
    melpoints = numpy.linspace(lowmel,highmel,nfilt+2)
    # our points are in Hz, but we use fft bins, so we have to convert
    #  from Hz to fft bin number
    bin = numpy.floor((nfft+1)*lr.mel_to_hz(melpoints)/samplerate)

    fbank = numpy.zeros([nfilt,nfft//2+1])
    for j in range(0,nfilt):
        for i in range(int(bin[j]), int(bin[j+1])):
            fbank[j,i] = (i - bin[j]) / (bin[j+1]-bin[j])
        for i in range(int(bin[j+1]), int(bin[j+2])):
            fbank[j,i] = (bin[j+2]-i) / (bin[j+2]-bin[j+1])
    return fbank


def lifter(cepstra, L=22):
    """Apply a cepstral lifter the the matrix of cepstra. This has the effect of increasing the
    magnitude of the high frequency DCT coeffs.

    :param cepstra: the matrix of mel-cepstra, will be numframes * numcep in size.
    :param L: the liftering coefficient to use. Default is 22. L <= 0 disables lifter.
    """
    if L > 0:
        nframes,ncoeff = numpy.shape(cepstra)
        n = numpy.arange(ncoeff)
        lift = 1 + (L/2.)*numpy.sin(numpy.pi*n/L)
        return lift*cepstra
    else:
        # values of L <= 0, do nothing
        return cepstra


#signal processing
def powspec(frames, NFFT):
    """Compute the power spectrum of each frame in frames. If frames is an NxD matrix, output will be Nx(NFFT/2+1).

    :param frames: the array of frames. Each row is a frame.
    :param NFFT: the FFT length to use. If NFFT > frame_len, the frames are zero-padded.
    :returns: If frames is an NxD matrix, output will be Nx(NFFT/2+1). Each row will be the power spectrum of the corresponding frame.
    """
    return 1.0 / NFFT * numpy.square(magspec(frames, NFFT))
    #return 1.0 / NFFT * magspec(frames, NFFT)


def preemphasis(signal, coeff=0.95):
    """perform preemphasis on the input signal.

    :param signal: The signal to filter.
    :param coeff: The preemphasis coefficient. 0 is no filter, default is 0.95.
    :returns: the filtered signal.
    """
    return numpy.append(signal[0], signal[1:] - coeff * signal[:-1])


def framesig(sig, frame_len, frame_step, winfunc=lambda x: numpy.ones((x,)), stride_trick=True):
    """Frame a signal into overlapping frames.

    :param sig: the audio signal to frame.
    :param frame_len: length of each frame measured in samples.
    :param frame_step: number of samples after the start of the previous frame that the next frame should begin.
    :param winfunc: the analysis window to apply to each frame. By default no window is applied.
    :param stride_trick: use stride trick to compute the rolling window and window multiplication faster
    :returns: an array of frames. Size is NUMFRAMES by frame_len.
    """
    slen = len(sig)
    frame_len = int(round_half_up(frame_len))
    frame_step = int(round_half_up(frame_step))
    if slen <= frame_len:
        numframes = 1
    else:
        numframes = 1 + int(math.ceil((1.0 * slen - frame_len) / frame_step))

    padlen = int((numframes - 1) * frame_step + frame_len)

    zeros = numpy.zeros((padlen - slen,))
    padsignal = numpy.concatenate((sig, zeros))
    if stride_trick:
        win = winfunc(frame_len)
        frames = rolling_window(padsignal, window=frame_len, step=frame_step)
    else:
        indices = numpy.tile(numpy.arange(0, frame_len), (numframes, 1)) + numpy.tile(
            numpy.arange(0, numframes * frame_step, frame_step), (frame_len, 1)).T
        indices = numpy.array(indices, dtype=numpy.int32)
        frames = padsignal[indices]
        win = numpy.tile(winfunc(frame_len), (numframes, 1))

    return frames * win


def round_half_up(number):
    return int(decimal.Decimal(number).quantize(decimal.Decimal('1'), rounding=decimal.ROUND_HALF_UP))


def rolling_window(a, window, step=1):
    # http://ellisvalentiner.com/post/2017-03-21-np-strides-trick
    shape = a.shape[:-1] + (a.shape[-1] - window + 1, window)
    strides = a.strides + (a.strides[-1],)
    return numpy.lib.stride_tricks.as_strided(a, shape=shape, strides=strides)[::step]


def magspec(frames, NFFT):
    """Compute the magnitude spectrum of each frame in frames. If frames is an NxD matrix, output will be Nx(NFFT/2+1).

    :param frames: the array of frames. Each row is a frame.
    :param NFFT: the FFT length to use. If NFFT > frame_len, the frames are zero-padded.
    :returns: If frames is an NxD matrix, output will be Nx(NFFT/2+1). Each row will be the magnitude spectrum of the corresponding frame.
    """
    if numpy.shape(frames)[1] > NFFT:
        logging.warn(
            'frame length (%d) is greater than FFT size (%d), frame will be truncated. Increase NFFT to avoid.',
            numpy.shape(frames)[1], NFFT)
    complex_spec = numpy.fft.rfft(frames, NFFT)
    return numpy.absolute(complex_spec)


def calculate_nfft(samplerate=SAMPLING_RATE, winlen=WINDOW_LENGTH):
    """Calculates the FFT size as a power of two greater than or equal to
    the number of samples in a single window length.
    
    Having an FFT less than the window length loses precision by dropping
    many of the samples; a longer FFT than the window allows zero-padding
    of the FFT buffer which is neutral in terms of frequency domain conversion.

    :param samplerate: The sample rate of the signal we are working with, in Hz.
    :param winlen: The length of the analysis window in seconds.
    """
    window_length_samples = winlen * samplerate
    nfft = 1
    while nfft < window_length_samples:
        nfft *= 2
    return nfft

def mfcc(signal,samplerate=SAMPLING_RATE,winlen=WINDOW_LENGTH,winstep=HOP_LENGTH,numcep=13,
         nfilt=NUM_FILTER,nfft=None,lowfreq=FMIN,highfreq=FMAX,preemph=0.97,ceplifter=22,appendEnergy=True,
         winfunc=lambda x:numpy.ones((x,))):
    """Compute MFCC features from an audio signal.

    :param signal: the audio signal from which to compute features. Should be an N*1 array
    :param samplerate: the sample rate of the signal we are working with, in Hz.
    :param winlen: the length of the analysis window in seconds. Default is 0.025s (25 milliseconds)
    :param winstep: the step between successive windows in seconds. Default is 0.01s (10 milliseconds)
    :param numcep: the number of cepstrum to return, default 13
    :param nfilt: the number of filters in the filterbank, default 26.
    :param nfft: the FFT size. Default is None, which uses the calculate_nfft function to choose the smallest size that does not drop sample data.
    :param lowfreq: lowest band edge of mel filters. In Hz, default is 0.
    :param highfreq: highest band edge of mel filters. In Hz, default is samplerate/2
    :param preemph: apply preemphasis filter with preemph as coefficient. 0 is no filter. Default is 0.97.
    :param ceplifter: apply a lifter to final cepstral coefficients. 0 is no lifter. Default is 22.
    :param appendEnergy: if this is true, the zeroth cepstral coefficient is replaced with the log of the total frame energy.
    :param winfunc: the analysis window to apply to each frame. By default no window is applied. You can use numpy window functions here e.g. winfunc=numpy.hamming
    :returns: A numpy array of size (NUMFRAMES by numcep) containing features. Each row holds 1 feature vector.
    """
    nfft = nfft or calculate_nfft(samplerate, winlen)
    feat,energy = fbank(signal,samplerate,winlen,winstep,nfilt,nfft,lowfreq,highfreq,preemph,winfunc)
    feat = numpy.log(feat)
    feat = dct(feat, type=2, axis=1, norm='ortho')[:,:numcep]
    feat = lifter(feat,ceplifter)
    if appendEnergy: feat[:,0] = numpy.log(energy) # replace first cepstral coefficient with log of frame energy
    return feat

## Dataset generation

In [4]:
# This example will generate 2 files, noisy speech and noise suppressed speech.
# You might open them with your player to get a feeling ot what does it sound like.
# It give you an idea that how does this energy based noise suppression work.
noise_suppressed_example(example_number=6, snr=20)

# change this will change the whole system, including equalizer and RNN
# it set: number of filter in equalizer, number of mfcc feature, and number of RNN output.
# choose from 10 ~ 30.
# num_filter = 20
# sampling_rate = 16_000
# fmin=20
# fmax=8000


# generate filter coefficient
mel_freqs = lr.mel_frequencies(n_mels=NUM_FILTER, fmin=FMIN, fmax=FMAX)
b, a = iir_design(mel_freqs, SAMPLING_RATE, order=1) # >2 order will not stable with only float32 accuracy in C.
# plot frequency respond
#plot_frequency_respond(b, a)

print('Reading noisy and clean speech files...')
# dataset generation start from here:
# energy thresehold for voice activivity detection in clean speech.
vad_energy_threashold = 0.1

noisy_speech_dir = 'MS-SNSD/NoisySpeech_training'
clean_speech_dir = 'MS-SNSD/CleanSpeech_training'
noise_dir = 'MS-SNSD/Noise_training'

# clean sound, mfcc, and vad
print('generating clean speech MFCC...')
clean_speech_mfcc, clean_file_label, total_energy, vad, clnsp_band_energy = \
    generate_data(clean_speech_dir, nfilt=NUM_FILTER, numcep=NUM_FILTER, appendEnergy=True, preemph=0, vad_threshold=vad_energy_threashold)

# add noise to clean speech, then generate the noise MFCC
print('generating noisy speech MFCC...')
noisy_speech_mfcc, noisy_file_label, _, _ , noisy_band_energy= \
    generate_data(noisy_speech_dir, nfilt=NUM_FILTER, numcep=NUM_FILTER, appendEnergy=True, preemph=0, vad_threshold=vad_energy_threashold)

# MFCC for noise only
print('generating noisy MFCC...')
noise_only_mfcc, noise_only_label, _, _ , noise_band_energy= \
    generate_data(noise_dir, random_volume=False, nfilt=NUM_FILTER, numcep=NUM_FILTER, appendEnergy=True, preemph=0)

# plt.plot(vad[5], label='voice active')
# plt.plot(total_energy[5], label='energy')
# plt.legend()
# plt.show()

# combine them together
clnsp_mfcc = []
noisy_mfcc = []
noise_mfcc = []
voice_active = []
gains_array = []

print('Processing training data')
for idx_nosiy, label in enumerate(noisy_file_label):
    # get file encode from file name e.g. "noisy614_SNRdb_30.0_clnsp614.wav"
    nums = re.findall(r'\d+', label)
    file_code = nums[0]
    db_code = nums[1]

    # get clean sound name
    idx_clnsp = clean_file_label.index('clnsp'+str(file_code)+'.wav')

    # truth gains y_train
    gains = np.sqrt(clnsp_band_energy[idx_clnsp]/ noisy_band_energy[idx_nosiy])
    #gains = clnsp_band_energy[idx_clnsp] / noisy_band_energy[idx_nosiy]
    gains = np.clip(gains, 0, 1)

    # experimential, suppress the gains when there is no voice detected
    #gains[vad[idx_clnsp] < 1] = gains[vad[idx_clnsp] < 1] / 10
    # g = np.swapaxes(gains, 0, 1)
    # plt.imshow(g, interpolation='nearest', origin='lower', aspect='auto')
    # plt.show()

    # get all data needed
    voice_active.append(vad[idx_clnsp])
    clnsp_mfcc.append(clean_speech_mfcc[idx_clnsp])
    noisy_mfcc.append(noisy_speech_mfcc[idx_nosiy])
    noise_mfcc.append(noise_only_mfcc[idx_nosiy]) # noise has the same index as noisy speech
    gains_array.append(gains)

    #>>> Uncomment to plot the MFCC image
    # mfcc_feat1 = np.swapaxes(clean_speech_mfcc[idx_clnsp], 0, 1)
    # mfcc_feat2 = np.swapaxes(noisy_speech_mfcc[idx_nosiy], 0, 1)
    # fig, ax = plt.subplots(2)
    # ax[0].set_title('MFCC Audio:' + str(idx_clnsp))
    # ax[0].imshow(mfcc_feat1, origin='lower', aspect='auto', vmin=-8, vmax=8)
    # ax[1].imshow(mfcc_feat2, origin='lower', aspect='auto', vmin=-8, vmax=8)
    # plt.show()

# save the dataset.
np.savez("dataset_neighbor.npz", clnsp_mfcc=clnsp_mfcc, noisy_mfcc=noisy_mfcc, noise_mfcc=noise_mfcc, vad=voice_active, gains=gains_array)
print("Dataset generation has been saved to:", "dataset_neighbor.npz")

band frequency [ 162.68457293  305.36914585  448.05371878  590.73829171  733.42286463
  876.10743756 1019.56876841 1181.19546531 1368.44396424 1585.37595027
 1836.69698531 2127.85857846 2465.1764369  2855.96746258 3308.70846614
 3833.22004097 4440.87958576 5144.86809638 5960.45608037 6905.33479586]
denoising using IIR filter
b [array([ 0.02725948,  0.        , -0.02725948]), array([ 0.02725948,  0.        , -0.02725948]), array([ 0.02725948,  0.        , -0.02725948]), array([ 0.02725948,  0.        , -0.02725948]), array([ 0.02725948,  0.        , -0.02725948]), array([ 0.02733169,  0.        , -0.02733169]), array([ 0.02908936,  0.        , -0.02908936]), array([ 0.033129,  0.      , -0.033129]), array([ 0.03818517,  0.        , -0.03818517]), array([ 0.04397984,  0.        , -0.04397984]), array([ 0.05061055,  0.        , -0.05061055]), array([ 0.05818464,  0.        , -0.05818464]), array([ 0.06681939,  0.        , -0.06681939]), array([ 0.07664186,  0.        , -0.07664186]), arra