In [1]:

%load_ext autoreload
%autoreload 2

%matplotlib inline
import numpy as np
from matplotlib import pyplot as plt
from scipy.ndimage import \
    binary_opening, binary_closing, generate_binary_structure

from madarrays import Waveform
from ltfatpy import arg_firwin, gabwin, dgtreal

In [2]:
def db(x):
    return 20 * np.log10(np.abs(x))

In [4]:
def get_signal_params(sig_len, fs):
  
    return dict(sig_len=sig_len, fs=fs)



In [None]:
def get_dataset():
    
    dataset = dict()
    dataset['wideband'] = {
        x.stem: x
        for x in (pathDir / 'wide_band_sources').glob('*.wav')
    }
    dataset['localized'] = {
        x.stem: x
        for x in (pathDir / 'localized_sources').glob('*.wav')
    }
    return dataset


In [None]:
def get_mix(loc_source, wideband_src, crop=None,
            wb_to_loc_ratio_db=8, win_dur=256/ 8000, win_type='gauss',
            hop_ratio=1/4, n_bins_ratio=4, n_iter_closing=3,
            n_iter_opening=3, delta_mix_db=0, delta_loc_db=40,
            closing_first=True, or_mask=True,
            fig_dir=None, prefix=''):
    
    dataset = get_dataset()

    x_loc = Waveform.from_wavfile(dataset['localized'][loc_source])
    x_wb = Waveform.from_wavfile(dataset['wideband'][wideband_src])
    np.testing.assert_array_equal(x_loc.shape, x_wb.shape)
    if crop is not None:
        x_len = crop
        i_start = (x_loc.shape[0] - x_len) // 2
        x_loc = x_loc[i_start:i_start+x_len]
        x_wb = x_wb[i_start:i_start+x_len]
    signal_params = get_signal_params(sig_len=x_loc.shape[0], fs=x_loc.fs)

    # Unit energy
    x_loc /= np.linalg.norm(x_loc)
    x_wb /= np.linalg.norm(x_wb)
    gain_wb = 1 / (1 + 10 ** (-wb_to_loc_ratio_db / 20))
    x_loc *= (1 - gain_wb)
    x_wb *= gain_wb

    # Build mix
    x_mix = x_loc + x_wb

    # Build dgt
    fs = x_loc.fs
    approx_win_len = int(2 ** np.round(np.log2(win_dur * fs)))
    hop = int(approx_win_len * hop_ratio)
    n_bins = int(approx_win_len * n_bins_ratio)
    sig_len = x_loc.shape[0]
    dgt_params = get_dgt_params(win_type=win_type,
                                approx_win_len=approx_win_len,
                                hop=hop, n_bins=n_bins, sig_len=sig_len)

    tf_mat_loc_db = db(np.abs(dgt(x_loc, dgt_params=dgt_params)))
    tf_mat_wb_db = db(np.abs(dgt(x_wb, dgt_params=dgt_params)))

    # Build mask_raw
    mask_mix = tf_mat_loc_db > tf_mat_wb_db + delta_mix_db
    mask_loc = tf_mat_loc_db > tf_mat_loc_db.max() - delta_loc_db

    if or_mask:
        mask_raw = np.logical_or(mask_mix, mask_loc)
    else:
        mask_raw = np.logical_and(mask_mix, mask_loc)

    struct = generate_binary_structure(2, 1)
    if n_iter_closing > 0:
        if closing_first:
            mask = binary_opening(
                binary_closing(input=mask_raw, structure=struct,
                               iterations=n_iter_closing, border_value=1),
                iterations=n_iter_opening, structure=struct, border_value=0)
        else:
            mask = binary_closing(
                binary_opening(input=mask_raw,structure=struct,
                               iterations=n_iter_opening, border_value=0),
                iterations=n_iter_closing, structure=struct, border_value=1)
    else:
        mask = mask_raw


    if fig_dir is not None:
        fig_dir = Path(fig_dir)
        fig_dir.mkdir(exist_ok=True, parents=True)
        if len(prefix) > 0:
            prefix = prefix + '_'

        plt.figure()
        plot_mask(mask=mask_mix, hop=dgt_params['hop'],
                  n_bins=dgt_params['n_bins'], fs=signal_params['fs'])
        plt.title('Mask Mix - Area: {} ({:.1%})'.format(mask_mix.sum(),
                                                        np.average(mask_mix)))
        plt.tight_layout()
        plt.savefig(fig_dir / 'mask_mix.pdf')

        plt.figure()
        plot_mask(mask=mask_loc, hop=dgt_params['hop'],
                  n_bins=dgt_params['n_bins'], fs=signal_params['fs'])
        plt.title('Mask Loc - Area: {} ({:.1%})'.format(mask_loc.sum(),
                                                        np.average(mask_loc)))
        plt.tight_layout()
        plt.savefig(fig_dir / 'mask_loc.pdf')

        plt.figure()
        plot_spectrogram(x=x_mix, dgt_params=dgt_params, fs=fs)
        plt.title('Mix')
        plt.tight_layout()
        plt.savefig(fig_dir / 'mix_spectrogram.pdf')

        plt.figure()
        plot_mask(mask=mask_raw, hop=dgt_params['hop'],
                  n_bins=dgt_params['n_bins'], fs=fs)
        plt.title('Raw mask')
        plt.tight_layout()
        plt.savefig(fig_dir / 'raw_mask.pdf')

        plt.figure()
        plot_mask(mask=mask, hop=dgt_params['hop'],
                  n_bins=dgt_params['n_bins'], fs=fs)
        plt.tight_layout()
        plt.title('Smoothed mask')
        plt.savefig(fig_dir / 'smoothed_mask.pdf')

        plt.figure()
        plot_spectrogram(x=x_loc, dgt_params=dgt_params, fs=fs)
        plt.title('Loc')
        plt.tight_layout()
        plt.savefig(fig_dir / 'loc_source.pdf')

        plt.figure()
        tf_mat = dgt(x_loc, dgt_params=dgt_params) * mask
        plotdgtreal(coef=tf_mat, a=dgt_params['hop'], M=dgt_params['n_bins'],
                    fs=fs, dynrange=100)
        plt.title('Masked loc')
        plt.tight_layout()
        plt.savefig(fig_dir / 'masked_loc.pdf')

        plt.figure()
        gabmul = GaborMultiplier(mask=~mask,
                                 dgt_params=dgt_params,
                                 signal_params=signal_params)
        x_est = gabmul @ x_wb
        plot_spectrogram(x=x_est, dgt_params=dgt_params, fs=fs)
        plt.title('Filtered wb')
        plt.tight_layout()
        plt.savefig(fig_dir / 'zerofill_spectrogram.pdf'.format(prefix))

    return x_mix, dgt_params, signal_params, mask, x_loc, x_wb
