In [None]:
# Import libraries
import numpy as np
import matplotlib.pyplot as plt
import librosa
import IPython.display as ipd
import copy

import nn_fac.multilayer_nmf as mlnmf
import nn_fac.deep_nmf as dnmf
#from nn_fac.utils.current_plot import *

import base_audio.audio_helper as audio_helper
import base_audio.signal_to_spectrogram as signal_to_spectrogram

from tasks.mss import mss as source_separation_utils
from tasks.transcription import transcription as transcription_utils
import tasks.msa.plotting_utils as plotting_utils

In [None]:
# %% Audio params
sr = 44100
hop_length = 1024
feature_object = signal_to_spectrogram.FeatureObject(sr=sr, feature="stft", hop_length=hop_length, n_fft = 2048)

# %% General params
eps = 1e-12
plotting = False # If you want data to be plotted

# %% Deep NMF params
all_ranks = [32,16,10]
n_iter = 800
n_iter_init_deep_nmf = 100
n_iter_deep = n_iter - n_iter_init_deep_nmf # 100 iterations for the initialization using multi-layer NMF


In [None]:
audio_path = 'data/Drum+Bass.wav'
signal, _ = librosa.load(audio_path, sr=sr, mono=True)
#plot_me_this_spectrogram(features.get_spectrogram(signal, sr, feature="stft", hop_length=512), title='Spectrogram of the input signal', x_axis='Time T (s)', y_axis='Frequency F (Hz)')

spectrogram = feature_object.get_spectrogram(signal)

In [None]:
transcription_tool = transcription_utils.Transcription(feature_object)
source_separation_tool = source_separation_utils.MusicSourceSeparation(feature_object, nb_sources = 2, phase_retrieval="griffin_lim") # If the dimension of the last layer is larger than the number of sources, sources are clustered using the MFCC features of the columns of W.

In [None]:
W_multi, H_multi, errors_multi, toc_multi = mlnmf.multilayer_beta_NMF(spectrogram, all_ranks = all_ranks, beta = 1, n_iter_max_each_nmf = n_iter, return_errors = True)

In [None]:
# Transcription
notes_predicted = transcription_tool.predict(W_multi[0], H_multi[0])
len(notes_predicted)

# Source Separation
last_level_H_multi = None
for level in range(0, len(H_multi)):
    if last_level_H_multi is None:
        last_level_H_multi = H_multi[level]
    else:
        last_level_H_multi = H_multi[level] @ last_level_H_multi

source_separated_multi = source_separation_tool.predict(W_multi[-1], last_level_H_multi)
for signal in source_separated_multi:
    audio_helper.listen_to_this_signal(signal, sr=44100)

In [None]:
W_deep, H_deep, errors_deep, toc_deep = dnmf.deep_KL_NMF(spectrogram, all_ranks = all_ranks, n_iter_max_each_nmf = n_iter_init_deep_nmf, n_iter_max_deep_loop = n_iter_deep,return_errors=True)


In [None]:
# Transcription
notes_predicted_deep = transcription_tool.predict(W_deep[0], H_deep[0])
len(notes_predicted_deep)

# Source Separation
last_level_H_deep = None
for level in range(0, len(H_deep)):
    if last_level_H_deep is None:
        last_level_H_deep = H_deep[level]
    else:
        last_level_H_deep = H_deep[level] @ last_level_H_deep

source_separated_deep = source_separation_tool.predict(W_deep[-1], last_level_H_deep)
for signal in source_separated_deep:
    audio_helper.listen_to_this_signal(signal, sr=44100)

In [None]:
# Evolution of the erros at the different levels of 
# deep β-NMF with β = 1 (initialized with multilayer β-NMF after 250 iterations) 
# divided by the error of multilayer β-NMF after 500 iterations.
plt.figure(1)
plt.plot(errors_deep[0,1:], color='blue', label='Layer 1')
plt.plot(errors_deep[1,1:], color='red', label='Layer 2')
plt.plot(errors_deep[2,1:], color='black', label='Layer 3')
plt.xlabel('Iterations')
plt.ylabel('Ratio deep vs. multilayer')
# plt.title(r'Computation $x^\star$')
plt.legend()
plt.show()

In [None]:
# Check the constraints
l = 0
print("Layer 1:", np.sum(H_deep[l], axis=1))
l += 1
print("Layer 2:", np.sum(H_deep[l], axis=1))
l += 1
print("Layer 2:", np.sum(H_deep[l], axis=1))
      

In [None]:
plotting_utils.plot_permuted_factor(W_deep[0], "Q matrix of the top layer.", x_axis='bars', y_axis='Patterns (rows in H)')
plotting_utils.plot_permuted_factor(W_deep[1], "Q matrix of the intemediate layer.", x_axis='bars', y_axis='Patterns (rows in H)')
plotting_utils.plot_permuted_factor(W_deep[2], "Q matrix of the bottom layer.", x_axis='bars', y_axis='Patterns (rows in H)')

In [None]:
plotting_utils.plot_me_this_spectrogram(W_deep[0].T)#, "Q matrix of the top layer.", x_axis='bars', y_axis='Patterns (rows in H)')
plotting_utils.plot_me_this_spectrogram(W_deep[1].T)#, "Q matrix of the intemediate layer.", x_axis='bars', y_axis='Patterns (rows in H)')
plotting_utils.plot_me_this_spectrogram(W_deep[2].T)#, "Q matrix of the bottom layer.", x_axis='bars', y_axis='Patterns (rows in H)')