# Tutorial: Source Separation using NMF

## Introduction
This tutorial introduces the concept of source separation using Non-negative Matrix Factorization (NMF). Source separation involves isolating individual audio sources from a composite signal. NMF is a powerful technique for this purpose due to its ability to decompose a signal into its constituent parts, in an unsupervised fashion.

## Objective
By the end of this tutorial, you will understand how to apply NMF to separate sources from a mixed audio signal.

## Prerequisites
- Basic understanding of Python programming.
- Familiarity with linear algebra concepts.
- Installed Python environment with Jupyter Notebook.

## Conceptual Overview
NMF decomposes a non-negative data matrix into the product of two lower-rank non-negative matrices, often interpreted as the features and coefficients. In the context of audio, this allows us to separate original sources from a mixed signal.

# Let's dive into the code!
First, imports and parameters.

In [None]:
## Code to set up the virtual environment
# !pip install -r ../requirements.txt
# !pip install -e ..

## Download the AnuraSet dataset: visit https://zenodo.org/records/8342596
## Set up the path to the dataset here
anuraset_path = "/home/a23marmo/datasets/anuraset" # TO CHANGE

In [None]:
## Imports
from nmf_audio_benchmark.dataloaders.ecoacoustics.source_count_dataloader import AnuraSet # Loader for the AnuraSet dataset

import nmf_audio_benchmark.tasks.ecoacoustics.bss as bss # The scripts for evaluating the estimated_sources

import nmf_audio_benchmark.algorithms.nn_fac_algos as nn_fac_algos # Code for NMF

import base_audio.audio_helper as audio_helper # High level function to listen to audio files from spectrograms

# Plotting functions
from librosa.display import specshow
import matplotlib.pyplot as plt

In [None]:
## Signal parameters
sr = 16000
hop_length = 2048
n_fft = hop_length * 2
feature = "stft_complex" # The nn_log_mel (Nonnegative Log Mel Spectrogram, corresponding to log(mel + 1)) appeared to be the best performing feature in our experiments.

# Loading the dataset
dataset = AnuraSet(audio_path=f"{anuraset_path}/raw_data/", subfolder="INCT17", annotations_file=f"{anuraset_path}/weak_labels.csv", 
                   feature=feature, sr=sr, hop_length=hop_length, n_fft=n_fft)

# Loading a specific file
file_name = "INCT17_20191113_040000.wav"

Now, let's compute the spectrogram.

In [None]:
## Computing the spectrogram
if feature == "stft_complex":
    (spec, phase), annotations = dataset.get_item_of_id(file_name)
else:
    spec, annotations = dataset.get_item_of_id(file_name)

## Listen and plot the spectrogram
print("Original spectrogram")
audio_helper.listen_to_this_spectrogram(spec, dataset.feature_object, phase_retrieval = "original_phase", original_phase = phase)

fig, ax = plt.subplots()
ax.set_title("Original spectrogram")
img = specshow(spec, sr=sr, hop_length=hop_length, y_axis="log", x_axis = 'time', vmax=10) # specshow(W@H, sr=sr, hop_length=hop_length, y_axis="log")
# Save this figure as a png, with transparent background
# plt.savefig("imgs/original_spectrogram.png", transparent = True)
plt.show()

Now, let's compute NMF.

In [None]:
# NMF parameters
n_nmf = 4 # Number of components, fixed to 10 because it correponded to the maximum number of species in the annotations.
beta = 1 # The best performing beta in our experiments was 1

# NMF object
nmf = nn_fac_algos.nn_fac_NMF(n_nmf, beta=beta, init = "nndsvd", nmf_type="unconstrained", normalize=[False, True])

import time # Time the computation
start = time.time()

# Actually compute NMF
W, H = nmf.run(data=spec)

print(f"NMF computation done in {time.time() - start} seconds.")

In [None]:
## Listen to the separated sources
bss_object = bss.BlindSourceSeparation(feature_object=dataset.feature_object, nb_sources=n_nmf, phase_retrieval="original_phase")
source_list = bss_object.qualitatively_evaluate_source_separation(W, H, time_limit=None, phase=phase)

Plotting zone!

Below you will find the spectrograms of the NMF outputs and of the separated sources.

In [None]:
# Plot the spectrograms of the W and H matrices
fig, ax = plt.subplots(figsize=(0.5, 6))
img = specshow(W, sr=sr, hop_length=hop_length, y_axis="log", vmax=5)
ax.set_title("W matrix (frequency components)")
# plt.savefig("imgs/W.png", transparent = True)
plt.show()

fig, ax = plt.subplots(figsize=(10,0.5))
img = specshow(H[::-1], sr=sr, hop_length=hop_length, x_axis="time", vmax=5)
ax.set_title("H matrix (time activations)")
# plt.savefig("imgs/H.png", transparent = True)
plt.show()

In [None]:
## Plot the spectrogram of each source
for idx, a_source in enumerate(source_list):
    fig, ax = plt.subplots()
    img = specshow(a_source, sr=sr, hop_length=hop_length, y_axis="log", x_axis="time", vmax=10) # specshow(W@H, sr=sr, hop_length=hop_length, y_axis="log")
    ax.set_title(f"Source {idx}")
    # plt.savefig(f"imgs/source_{idx}.png", transparent = True)
    plt.show()