# Sparse NMF for Separation

One approach to using NMF for source separation is to learn sets of $k$ basis vectors $W$ and $H$ for each of the speakers in a dataset. To separate a mixture where the speakers are known, we concatenate the dictionaries $W$ associated with them, learn new loadings $H$, and take the product of the speaker-specific $W$ with the associated components of $H$ to yield the reconstruction.

Training, for speaker $i$:
$$X_i = W_i H_i$$
$$W_i, H_i = \text{NMF}(X_i)$$

Evaluation, on mixture $X_{ij}$ of speech from speaker $i$ and $j$. $\text{NMF}_W$ performs NMF updates without updating the values in $W$:
$$W_{ij} = [ W_i \, W_j ]$$
$$H_{ij}' = \text{NMF}_{W_{ij}}(X_{ij})$$
$$H_{ij}' = [ H_i' \, H_j' ]$$
$$\hat{X}_i = W_i H_i$$
$$\hat{X}_j = W_j H_j$$

In [None]:
import sys
import time
from itertools import islice, permutations, product, chain
from collections import namedtuple
import numpy as np
import tensorflow as tf
import matplotlib.pyplot as plt
import IPython.display as display

from magnolia.features.hdf5_iterator import Hdf5Iterator, SplitsIterator
from magnolia.features.mixer import FeatureMixer
from magnolia.features.wav_iterator import batcher
from magnolia.utils.tf_utils import scope_decorator as scope
from magnolia.utils.bss_eval import bss_eval_sources
from magnolia.factorization.nmf import snmf, nmf_separate
from magnolia.utils.postprocessing import reconstruct

num_srcs = 2
num_steps = 80
num_freq_bins = 257
num_components = 20
sparsity = 0.1
num_train_exs = 50
num_nmf_iters = 15
num_known_spkrs = 30
update_weight = 0.05
num_test_iters = 80

librispeech_dev = "/local_data/teams/magnolia/librispeech/processed_dev-clean.h5"
# librispeech_train = "/local_data/teams/magnolia/librispeech/processed_train-clean-100.h5"
# librispeech_test = "/local_data/teams/magnolia/librispeech/processed_test_clean.h5"

librispeech_dev = "/Users/patrickc/data/LibriSpeech/processed_dev-clean.h5"
librispeech_train = "/Users/patrickc/data/LibriSpeech/processed_train-clean-100.h5"
librispeech_test = "/Users/patrickc/data/LibriSpeech/processed_test_clean.h5"

train_metrics_path = "/Users/patrickc/src/magnolia/nmf-train-metrics.txt"
inset_results_path = "/Users/patrickc/src/magnolia/nmf-inset-results.txt"
outset_results_path = "/Users/patrickc/src/magnolia/nmf-outset-results.txt"

for path in [train_metrics_path,inset_results_path,outset_results_path]:
    with open(path,"w") as f:
        pass

def scale_spectrogram(spectrogram):
    mag_spec = np.abs(spectrogram)
    phases = np.unwrap(np.angle(spectrogram))
    
    mag_spec = np.sqrt(mag_spec)
    M = mag_spec.max()
    m = mag_spec.min()
    
    return (mag_spec - m)/(M - m), phases

def moving_average(a, n=3):
    ret = np.cumsum(a, dtype=float)
    ret[n:] = ret[n:] - ret[:-n]
    return ret[n - 1:] / n

%matplotlib inline

## Data
### Get speaker-specific iterators

In [None]:
%pdb off
with open("../../data/librispeech/authors/dev-clean-F.txt") as f:
    female_dev = f.read().strip().split('\n')
with open("../../data/librispeech/authors/dev-clean-M.txt") as f:
    male_dev = f.read().strip().split('\n')
with open("../../data/librispeech/authors/train-clean-100-F.txt") as f:
    female_train = f.read().strip().split('\n')
with open("../../data/librispeech/authors/train-clean-100-M.txt") as f:
    male_train = f.read().strip().split('\n')
with open("../../data/librispeech/authors/test-clean-F.txt") as f:
    female_test = f.read().strip().split('\n')
with open("../../data/librispeech/authors/test-clean-M.txt") as f:
    male_test = f.read().strip().split('\n')

In [None]:
female_spkrs = [SplitsIterator([0.8, 0.1, 0.,1], hdf5_path=librispeech_train, speaker_keys=[train], shape=(None,)) for train in female_train]
female_spkrs_slice = [SplitsIterator([0.8, 0.1, 0.,1], hdf5_path=librispeech_train, speaker_keys=[train], shape=(num_steps,)) for train in female_train]
male_spkrs = [SplitsIterator([0.8, 0.1, 0.,1], hdf5_path=librispeech_train, speaker_keys=[train], shape=(None,)) for train in male_train]
male_spkrs_slice = [SplitsIterator([0.8, 0.1, 0.,1], hdf5_path=librispeech_train, speaker_keys=[train], shape=(num_steps,)) for train in male_train]

female_spkrs_dev = [Hdf5Iterator(hdf5_path=librispeech_dev, speaker_keys=[dev], shape=(None,)) for dev in female_dev]
female_spkrs_dev_slice = [Hdf5Iterator(hdf5_path=librispeech_dev, speaker_keys=[dev], shape=(num_steps,)) for dev in female_dev]
male_spkrs_dev = [Hdf5Iterator(hdf5_path=librispeech_dev, speaker_keys=[dev], shape=(None,)) for dev in male_dev]
male_spkrs_dev_slice = [Hdf5Iterator(hdf5_path=librispeech_dev, speaker_keys=[dev], shape=(num_steps,)) for dev in male_dev]

female_spkrs_test = [Hdf5Iterator(hdf5_path=librispeech_test, speaker_keys=[test], shape=(None,)) for test in female_test]
female_spkrs_test_slice = [Hdf5Iterator(hdf5_path=librispeech_test, speaker_keys=[test], shape=(num_steps,)) for test in female_test]
male_spkrs_test = [Hdf5Iterator(hdf5_path=librispeech_test, speaker_keys=[test], shape=(None,)) for test in male_test]
male_spkrs_test_slice = [Hdf5Iterator(hdf5_path=librispeech_test, speaker_keys=[test], shape=(num_steps,)) for test in male_test]

female_spkrs_test_slice150 = [Hdf5Iterator(hdf5_path=librispeech_test, speaker_keys=[test], shape=(150,)) for test in female_test]
male_spkrs_test_slice150 = [Hdf5Iterator(hdf5_path=librispeech_test, speaker_keys=[test], shape=(150,)) for test in male_test]

## Training

In [None]:
%pdb off
spkr_models = []
errors = []
train_times = []

# Set split to train
for spkr in chain(female_spkrs, male_spkrs, female_spkrs_slice, male_spkrs_slice):
    spkr.set_split(0)

TrainRecord = namedtuple('TrainRecord', ['i', 'loss', 'time_delta', 'timestamp', 'batch_size', 'spkr'])
for i, spkr in enumerate(chain(female_spkrs_slice[:num_known_spkrs//2], male_spkrs_slice[:num_known_spkrs//2])):
    print("Speaker", i)
    w = None
    h = None
    spkr_errors = []
   
    for j, example in enumerate(islice(spkr,num_train_exs)):
        mag, phases = scale_spectrogram(example)
        try:
            train_start = time.time()
            w, h, w_err, h_size, err = snmf(mag.T, num_components, sparsity=sparsity, 
                                            num_iters=num_nmf_iters, W_init=w, H_init=h, return_errors=True,
                                            update_weight=update_weight)
            train_end = time.time()
        except ValueError as e:
            print("ValueError encountered", file=sys.stderr)
            print(e, file=sys.stderr)
            if "operands" not in repr(e):
                continue
            else:
                raise
        spkr_errors.extend(err)
        # only record final error for each speaker
        train_metrics = TrainRecord(
            i*num_train_exs + j,
            err[-1],
            train_end - train_start,
            train_start,
            1,
            i
        )
        with open(train_metrics_path, "a") as f:
            print('\t'.join(map(str,train_metrics)), file=f)
    
    train_times.append(train_end - train_start)
    errors.append(spkr_errors)
    plt.figure(figsize=(6,1))
    plt.plot(moving_average(errors[-1],30))
    plt.show()
    spkr_models.append((w,h))

    
#     plt.figure(figsize=(14,5))
#     plt.subplot(1,3,1)
#     plt.imshow((w @ h)[:,:100], cmap='bone', origin='lower', aspect=1/4)
#     plt.subplot(1,3,2)
#     plt.imshow(mag.T[:,:100], cmap='bone', origin='lower', aspect=1/4)
#     plt.subplot(1,3,3)
#     plt.imshow(w.T, cmap='bone', origin='lower', aspect=6)
#     plt.show()


### Split a file

In [None]:
from scipy.io import wavfile
from magnolia.utils.clustering_utils import preprocess_signal
test_file = "/Users/patrickc/Downloads/mixed_signal.wav"
fs, wav = wavfile.read(test_file)
wav_spectrogram, x_in = preprocess_signal(wav, fs)



## Inference (in-set)

Inference retrains just the loadings matrix $H$ in light of a given $W$ and $X$. The resulting reconstructions are qualitatively quite cruddy unless they are used as masks on the original input, in which case the result is about what we expect (0-2 dB improvement)

## Out-of-sample test

Above technique only works when you know which set of basis vectors to select for each speaker.  For unseen speakers this is hard. One approach is just to pick the combination of dictionary entries that minimizes the reconstruction error. Unfortunately quadratic in the number of dictionary entries.

(Getting the cross-correlation of each basis with the mixture and picking the top two is another idea.)

(PCA on the W's also make lots of lots of sense)

In [None]:
# Inference on out of set speakers
# spkrs_slices = {'mf': ((female_spkrs_test_slice150[:num_known_spkrs//2],
#                         male_spkrs_test_slice150[num_known_spkrs//2:num_known_spkrs]), 
#                         (spkr_models[:num_known_spkrs//2],
#                          spkr_models[num_known_spkrs//2:])),
#                 'mm': ((male_spkrs_test_slice150[:num_known_spkrs//2:],
#                        male_spkrs_test_slice150[:num_known_spkrs//2:]), 
#                        (spkr_models[num_known_spkrs//2:],
#                         spkr_models[num_known_spkrs//2:])),
#                 'ff': ((female_spkrs_test_slice150[num_known_spkrs//2:], 
#                         female_spkrs_test_slice150[num_known_spkrs//2:]),
#                       (spkr_models[:num_known_spkrs//2],
#                        spkr_models[:num_known_spkrs//2]))}

spkrs_slices = {'mf': ((female_spkrs_test_slice150[:num_known_spkrs//2],
                        male_spkrs_test_slice150[:num_known_spkrs//2]), 
                        (spkr_models[:num_known_spkrs//2],
                         spkr_models[num_known_spkrs//2:])),
                'mm': ((male_spkrs_test_slice150[:num_known_spkrs//2],
                       male_spkrs_test_slice150[:num_known_spkrs//2]), 
                       (spkr_models[num_known_spkrs//2:],
                        spkr_models[num_known_spkrs//2:])),
                'ff': ((female_spkrs_test_slice150[num_known_spkrs//2:], 
                        female_spkrs_test_slice150[num_known_spkrs//2:]),
                      (spkr_models[:num_known_spkrs//2],
                       spkr_models[:num_known_spkrs//2])),
                'all': ((female_spkrs_test_slice150[:num_known_spkrs//2] + male_spkrs_test_slice150[:num_known_spkrs//2],
                         female_spkrs_test_slice150[:num_known_spkrs//2] + male_spkrs_test_slice150[:num_known_spkrs//2]), 
                        (spkr_models,spkr_models))}

# for condition, ((spkrs_slice_i, spkrs_slice_j), (models_i, models_j)) in spkrs_slices.items(): 
num_test_iters = 1
for i in range(num_test_iters):
    print("Test {}".format(i))
#     spkr_i = np.random.randint(len(spkrs_slice_i))
#     spkr_j = np.random.randint(len(spkrs_slice_j))
#     if spkr_i == spkr_j:
#         continue

#         example_i = next(spkrs_slice_i[spkr_i])
#         example_j = next(spkrs_slice_j[spkr_j])

#         mix = example_i + example_j
#         mix_scl_mag, mix_scl_phs = scale_spectrogram(mix)

    models_i = spkr_models
    models_j = spkr_models
    mix_scl_mag = x_in
    mix = wav_spectrogram
    


    # loop over choices of speaker model
    NmfSearchResult = namedtuple("NMFSearchResult", ['i', 'j', 'error', 'models', 'reconstructions'])
    optimal_pair = NmfSearchResult(0, 0, np.inf, [], [])  # i, j, error, reconstructions a and b
    search_time = 0
    for model_i in range(len(models_i)):
        for model_j in range(len(models_j)):
            nmf_start = time.time()
            reco_i, reco_j = nmf_separate(mix_scl_mag.T, [models_i[model_i], models_j[model_j]], mask=True)
            nmf_end = time.time()
            search_time += nmf_end - nmf_start
            err = np.mean(np.abs(mix_scl_mag.T - (reco_i + reco_j)))

            if err < optimal_pair.error:
                optimal_pair = NmfSearchResult(model_i, model_j, err, 
                                               [models_i[model_i][0], models_j[model_j][0]], 
                                               [reco_i, reco_j])

    # Display results
    print("Used weight matrices {} and {}".format(optimal_pair.i, optimal_pair.j))

    plt.figure(figsize=(14,5))
    plt.subplot(1,3,1)
    plt.imshow(mix_scl_mag.T, cmap='bone', origin='lower', aspect=1/4)
    plt.subplot(1,3,2)

    plt.imshow(optimal_pair.reconstructions[0], cmap='bone', origin='lower', aspect=1/4)
    plt.subplot(1,3,3)
    plt.imshow(optimal_pair.reconstructions[1], cmap='bone', origin='lower', aspect=1/4)
    plt.show()

    # Evaluate
    opt_spec_i, opt_spec_j = optimal_pair.reconstructions
    mix_audio = reconstruct(mix, mix, 10000, None, 0.0256)
#     ref_i = reconstruct(example_i, example_i, 10000, None, 0.0256)
#     ref_j = reconstruct(example_j, example_j, 10000, None, 0.0256)
    opt_reco_i = reconstruct(optimal_pair.reconstructions[0].T**2, mix, 10000, None, 0.0256)
    opt_reco_j = reconstruct(optimal_pair.reconstructions[1].T**2, mix, 10000, None, 0.0256)

#     base_metrics =  bss_eval_sources(np.stack([ref_i, ref_j]), np.stack([mix_audio, mix_audio]))
#     predicted_metrics = bss_eval_sources(np.stack([ref_i, ref_j]), np.stack([opt_reco_i, opt_reco_j]))
#     base_metrics_mean = np.apply_along_axis(np.mean, 1, base_metrics)
#     predicted_metrics_mean = np.apply_along_axis(np.mean, 1, predicted_metrics)
#     ms_error = (np.sum(np.square(opt_spec_j - np.abs(example_j.T))) + 
#                np.sum(np.square(opt_spec_i - np.abs(example_i.T))))
#     avg_diff_metrics = [y-x for x, y in zip(base_metrics_mean[:3], predicted_metrics_mean[:3])]
#     test_record = TestRecord(
#         condition,
#         ms_error,
#         *avg_diff_metrics,
#         search_time,
#         nmf_end,
#         1,
#         "oos_{}_{}_{}".format(condition, spkr_i, spkr_j)
#     )

#     with open(outset_results_path, "a") as f:
#         print('\t'.join(map(str, test_record)), file=f)
            

In [None]:
from magnolia.features.data_preprocessing import undo_preemphasis
display.display(display.Audio(undo_preemphasis(opt_reco_i.astype(np.float16)/np.abs(opt_reco_i.astype(np.float16).max())),rate=10000))
display.display(display.Audio(undo_preemphasis(opt_reco_j.astype(np.float16)/np.abs(opt_reco_j.astype(np.float16).max())),rate=10000))

In [None]:
import pandas as pd
ins_df = pd.read_csv(inset_results_path, sep='\t', header=None, names=TestRecord._fields)
print(ins_df.groupby('condition').sdr.mean())

out_df = pd.read_csv(outset_results_path, sep='\t', header=None, names=TestRecord._fields)
out_df.groupby('condition').sdr.mean()