## Simulation d'une situation réel avec Pyroomacoustics

NOTE: 

Il y a plusieurs possibilités dans ce notebook: 

- Utilisation de différents algorithmes (Celui de Louis, ou celui de pyroomacoustics)
Pour utiliser l'un ou l'autre, il faut faire attention aux variables globales (définies avec des MAJUSCULES). 
- Utilisation de fichiers MuseDB ou de fichiers WAV.
- Possibilité de plot les performances en fonction de l'avancée de l'algorithme. (PLOT_PERFORMANCE) mais ATTENTION, cela est implémenté que pour l'algo pyroom pour l'instant. 
- Possibilité de sauvegarder les figures et les audios (SAVE_FIG et SAVE_AUDIO), cela va créé les dossiers correspondants sous 'test'

- Puis bien sûr, possibilité de modifier toutes les données concernant l'emplacement des microphones, dimensions de la salle etc... au fur et à mesure du notebook.

In [None]:
import IPython.display as ipd
from  IPython.display import display
import pickle
import librosa
import os
import math
import matplotlib.pyplot as plt
import stempeg
import itertools
import operator
import soundfile
import pyroomacoustics as pra
import wave
import random as rd
import numpy as np
from scipy.io import wavfile
from pyroomacoustics.directivities import (
    DirectivityPattern,
    DirectionVector,
    CardioidFamily,
)
from numpy import typing
from mir_eval.separation import bss_eval_sources

path_in = './data/musedb/'
save_path = './test/surdeterm/'

In [None]:
from src import performance
from src.fast_nmf import fast_MNMF2

### Chargmenent d'un fichier MuseDB

All files in MUSDB18 dataset is a multitrack format composed of 5 stereo streams, each one encoded in AAC @256kbps. These signals correspond to:

- `0` - The mixture,
- `1` - The drums,
- `2` - The bass,
- `3` - The rest of the accompaniment,
- `4` - The vocals.

```S.shape = (5, time_step , 2)```

In [None]:
files_in = []
files_title = []

# extract files 
for r, d, f in os.walk(path_in):
    for file in f:
        if '.mp4' in file:
            # file address
            files_in.append(os.path.join(r, file))
            # file author + song
            files_title.append(file[:-9])
            
files_in.sort()

In [None]:
for path in files_in[:3]:
    S, rate = stempeg.read_stems(path)
    display(ipd.Audio(S[0][:,0], rate=rate))

### Test avec un fichier wav classique

In [None]:
wav_files = [['data/samples/BACH Cello Suite 1, Prelude, Violin - Kateryna Timokhina.wav'],
            ['data/samples/Bach_ Prélude, Cello suite Nr.1  Ophélie Gaillard.wav'],
            ['data/samples/Bach_ Prélude, Cello suite Nr.1  Ophélie Gaillard.wav'],
            ['data/samples/BACH Cello Suite 1, Prelude, Violin - Kateryna Timokhina.wav']]

signals = [np.concatenate([wavfile.read(f)[1].astype(np.float32) for f in source_files]) for source_files in wav_files]

### Définition des variables globales

In [None]:
PLOT_PERFORMANCE = False # plot les performances en fonction des itérations de l'algorithme => Implémenté que pour pyroomacoustics
MUSEDB = True # Si True, on utilise la base de données MUSEDB, sinon on utilise les fichiers wav_files
LOUIS = True # Si True, on utilise l'algorithme de LOUIS, sinon on utilise l'algorithme de pyroomacoustics
SAVE_FIG = True # Si True, on sauvegarde toutes les figures dans les sous-dossiers de 'test'
SAVE_AUDIO = True # Si True, on sauvegarde tous les fichiers audio dans le sous-dossiers 'audio' de 'test'
SAVE_PERF = True # Si True, on sauvegarde toutes les figures dans les sous-dossiers de 'test'
    
# Paramètres des microphones
mic_pattern = DirectivityPattern.CARDIOID
MIC_DIR = CardioidFamily(
    orientation=DirectionVector(azimuth=180, colatitude=60, degrees=True),
    pattern_enum=mic_pattern,
)

In [None]:
# Permet de créé les dossiers de sauvegarde si les variables permettant la sauvegarde
# sont à True

try:
    os.mkdir('./test')
    
except  OSError as error:
    pass


if SAVE_FIG:
    try:
        os.mkdir(save_path)
        
        try:
            os.mkdir(save_path+'activation')
            os.mkdir(save_path+'base')
            os.mkdir(save_path+'spectro')
            os.mkdir(save_path+'mix')
        except OSError as error:
            pass

    except OSError as error:
        pass
    
if SAVE_AUDIO:
    try:
        os.mkdir(save_path+'audios')
        os.mkdir(save_path+'audios/separation')
        os.mkdir(save_path+'audios/no_separation')
        
        os.mkdir(save_path+'audios/separation/micro_drums')
        os.mkdir(save_path+'audios/separation/micro_bass')
        os.mkdir(save_path+'audios/separation/micro_vocals')
        os.mkdir(save_path+'audios/separation/micro_other')
        os.mkdir(save_path+'audios/separation/micro_AB1')
        os.mkdir(save_path+'audios/separation/micro_AB2')
        
        os.mkdir(save_path+'audios/no_separation/micro_AB1')
        os.mkdir(save_path+'audios/no_separation/micro_AB2')
        os.mkdir(save_path+'audios/no_separation/micro_drums')
        os.mkdir(save_path+'audios/no_separation/micro_bass')
        os.mkdir(save_path+'audios/no_separation/micro_vocals')
        os.mkdir(save_path+'audios/no_separation/micro_other')
        
    except OSError as error:
        pass

In [None]:
def spectrogram_from_musdb(
    room_dimension, 
    abs_coef, 
    source_locations,
    source_names,
    microphone_locations,
    microphone_names,
    mic_dir,
    song_path,
    audio_length, 
    L=2048,
    hop=512, 
    display_audio = False,
    display_room = False,
    ):

    """ 
    this function process a song from MUSDB18 dataset
    as a recording in a shoebox room (defined with its 
    geometry, room absorption, signal locations and 
    microphones locations) into a multichannel STFT.

    Inputs:
    ------------------------------------------------------
    room_dimension: 
        room dimension (np.array)
    abs_coef: 
        absorption coefficient (int)
    source_locations: 
        localization of the sources in the room (list of list)
    source_names: 
        name of the sources (list of string)
    microphone_locations: 
        localization of the micros in the room (warning, locs in `np.c_` class)
    microphone_names: 
        name of the microphones (list of string)
    mic_dir:
        microphone directivity (pyroomacoustics directivity object)
    song_path: 
        path to the song (string)
    audio_length: 
        length of the audio signal (int)
    L: 
        frame size (2048 default)
    hop: 
        hop length (512 default)
    display_audio: 
        bool (display audio signal)
    display_room: 
        bool (display room geometry)

    
    Output:
    ------------------------------------------------------
    X: (n_frames, n_frequencies, n_channels)
    room: pyroomacoustics room object
    separate_recordings: (n_sources, n_channels, n_frames)
    mics_signals: (n_channels, n_frames)

    """
    path = song_path
    data, rate = stempeg.read_stems(path)
    channel_nb, time_step, _ = data.shape
    X = []

    # Create an shoebox room
    room = pra.ShoeBox(room_dimension, fs=rate, max_order=15, absorption=abs_coef, sigma2_awgn=1e-8)

    # Add sources
    for channel_source, source_loc in zip(range(1,len(data)), source_locations):
        signal_channel = librosa.core.to_mono(data[channel_source, :rate*audio_length, :].T)
        signal_channel /= np.max(signal_channel)
        room.add_source(source_loc, signal=signal_channel)

    # # Add microphone array
    mic_array = pra.MicrophoneArray(microphone_locations, rate)
    # # Appoint => orienté dans la direction de l'instrument => cardioide pour l'instant
    # # Couple => ORTF 17cm, 110°, 90° angle utile
    room.add_microphone_array(mic_array, directivity=mic_dir)
    
    
    if display_room:
        fig, ax = room.plot()
        lim = np.max(room_dimension)
        ax.set_xlim([0, lim])
        ax.set_ylim([0, lim])
        ax.set_zlim([0, lim])
        ax.set_xlabel('X')
        ax.set_ylabel('Y')
        ax.set_zlabel('Z')
    
    # Recordings
    separate_recordings = room.simulate(return_premix=True)
    mics_signals = np.sum(separate_recordings, axis=0)
    
    # STFT parameters
    win_a = pra.hamming(L)
    win_s = pra.transform.stft.compute_synthesis_window(win_a, hop)
          
    # Observation vector in the STFT domain
    X = pra.transform.stft.analysis(mics_signals.T, L, hop, win=win_a)

    if display_audio:
        for microphone_n in range(microphone_locations.shape[1]) :
            print(f"Microphone {microphone_names[microphone_n]}")
            display(ipd.Audio(mics_signals[microphone_n], rate=room.fs))

    return X, room, separate_recordings, mics_signals

In [None]:
def spectrogram_from_wav(
    room_dimension, 
    abs_coef, 
    source_locations,
    source_names,
    microphone_locations,
    microphone_names,
    mic_dir,
    song_path,
    rate,
    audio_length,
    L=2048,
    hop=512, 
    display_audio = False,
    display_room = False,
    ):

    """ 
    
    this function process a custom recording of a song (NOT MUSDB18)
    as a recording in a shoebox room (defined with its 
    geometry, room absorption, signal locations and 
    microphones locations) into a multichannel STFT.

    Inputs:
    ------------------------------------------------------
    room_dimension: 
        room dimension (np.array)
    abs_coef: 
        absorption coefficient (int)
    source_locations: 
        localization of the sources in the room (list of list)
    source_names: 
        name of the sources (list of string)
    microphone_locations: 
        localization of the micros in the room (warning, locs in `np.c_` class)
    microphone_names: 
        name of the microphones (list of string)
    mic_dir:
        microphone directivity (pyroomacoustics directivity object)
    song_path: 
        path to the song (string)
    rate:
        sampling rate (int)
    audio_length: 
        length of the audio signal (int)
    L: 
        frame size (2048 default)
    hop: 
        hop length (512 default)
    display_audio: 
        bool (display audio signal)
    display_room: 
        bool (display room geometry)

    Output:
    ------------------------------------------------------
    X: (n_frames, n_frequencies, n_channels)
    room: pyroomacoustics room object
    separate_recordings: (n_sources, n_channels, n_frames)
    mics_signals: (n_channels, n_frames)

    """
    track_list = song_path

    # Create an shoebox room
    room = pra.ShoeBox(room_dimension, fs=rate, max_order=15, absorption=abs_coef, sigma2_awgn=1e-8)

    # Add sources
    for source, source_loc in zip(track_list, source_locations):
        signal_channel = librosa.core.to_mono(source[:rate*audio_length, :].T)
        signal_channel /= np.max(signal_channel)
        room.add_source(source_loc, signal=signal_channel)

    # Add microphone array
    mic_array = pra.MicrophoneArray(microphone_locations, rate)
    # Appoint => orienté dans la direction de l'instrument => cardioide pour l'instant
    # Couple => ORTF 17cm, 110°, 90° angle utile
    room.add_microphone_array(mic_array, directivity=mic_dir)
    
    if display_room:
        fig, ax = room.plot()
        lim = np.max(room_dimension)
        ax.set_xlim([0, lim])
        ax.set_ylim([0, lim])
        ax.set_zlim([0, lim])
        ax.set_xlabel('X')
        ax.set_ylabel('Y')
        ax.set_zlabel('Z')
    
    # Recordings
    separate_recordings = room.simulate(return_premix=True)
    mics_signals = np.sum(separate_recordings, axis=0)
    
    # STFT parameters
    win_a = pra.hamming(L)
    win_s = pra.transform.stft.compute_synthesis_window(win_a, hop)
          
    # Observation vector in the STFT domain
    X = pra.transform.stft.analysis(mics_signals.T, L, hop, win=win_a)

    if display_audio:
        for microphone_n in range(microphone_locations.shape[1]) :
            print(f"Microphone {microphone_names[microphone_n]}")
            display(ipd.Audio(mics_signals[microphone_n], rate=room.fs))

    return X, room, separate_recordings, mics_signals

In [None]:
if not MUSEDB:
    print("working with custom signals")
    
    # Parametres de la room
    room_dimension = [12, 20, 5]
    abs_coef = 0.35

    # Parametres des sources: noms et localisations dans la salle
    source_names = ["source1", "source2", "source3", "source4"]
    source_locations = [[2, 9, 1], [2, 11, 1], [3, 12, 1], [3, 7, 1]]
    
    # Parametres des micros: noms et localisations dans la salle
    microphone_names = ["mic1", "mic2", "mic3", "mic4", "mic5", "mic6"]
    microphone_locations = np.c_[[2.7, 9, 1], [2.7, 11, 1], [3.5, 12, 1], [3.5, 7, 1], [6, 9.6, 1], [6, 10.4, 1]]

    # Parametres de la STFT
    L = 2048
    hop = L // 4
    
    # Parametres de la musique
    audio_length = 10
    rate = 44100
    
    X, room, separate_recordings, mics_signals = spectrogram_from_wav(room_dimension, abs_coef, source_locations, source_names, microphone_locations, microphone_names, MIC_DIR, signals, rate=rate, audio_length=audio_length, L=L, hop=hop, display_audio=True, display_room=True)

In [None]:
if MUSEDB:
    print("working with MUSDB18")
    
    # Parametres de la room
    room_dimension = [12, 20, 5]
    abs_coef = 0.35

    # Parametres des sources: noms et localisations dans la salle
    # 1: Drums | 2: Bass | 3: Accompaniemenet | 4: Vocals
    source_names = ['drums', 'bass', 'other', 'vocals']
    source_locations = [[2, 9, 1], [2, 11, 1], [3, 12, 1], [3, 7, 1]]

    # Parametres des micros: noms et localisations dans la salle
    # 1: Drums | 2: Bass | 3: Accompaniemenet | 4: Vocals | 5 et 6 couples AB (espacés de 80cm)
    microphone_names = ['drums', 'bass', 'other', 'vocals', 'AB1', 'AB2']
    microphone_locations = np.c_[[2.7, 9, 1], [2.7, 11, 1], [3.5, 12, 1], [3.5, 7, 1], [6, 9.6, 1], [6, 10.4, 1]]
    
    # Prend une musique au hasard
    ind = rd.randint(0, len(files_in))
    song_path = files_in[2]

    # Parametres de la STFT
    L = 4096
    hop = L // 4
    
    # Parametres de la musique
    audio_length = 30

    # Transformations des signaux audio en STFT_multichannel en prennant en compte la room et l'emplacement des micros et sources
    X, room, separate_recordings, mics_signals = spectrogram_from_musdb(room_dimension, abs_coef, source_locations, source_names, microphone_locations, microphone_names, MIC_DIR, song_path, audio_length=audio_length, L=L, hop=hop, display_audio=True, display_room=True)

### Ajout de la Fast MNMF pour la séparation

In [None]:
"""
FastMNMF2 from pyroomacoustics
=========

Blind Source Separation using Fast Multichannel Nonnegative Matrix Factorization 2 (FastMNMF2)
"""
import numpy as np


def fastmnmf2(
    X,
    n_src=None,
    n_iter=30,
    n_components=32,
    mic_index=0,
    W0=None,
    accelerate=True,
    callback=None,
):
    """
    Implementation of FastMNMF2 algorithm presented in

    K. Sekiguchi, Y. Bando, A. A. Nugraha, K. Yoshii, T. Kawahara, *Fast Multichannel Nonnegative
    Matrix Factorization With Directivity-Aware Jointly-Diagonalizable Spatial
    Covariance Matrices for Blind Source Separation*, IEEE/ACM TASLP, 2020.
    [`IEEE <https://ieeexplore.ieee.org/abstract/document/9177266>`_]

    The code of FastMNMF2 with GPU support and more sophisticated initialization
    is available on  https://github.com/sekiguchi92/SoundSourceSeparation

    Parameters
    ----------
    X: ndarray (nframes, nfrequencies, nchannels)
        STFT representation of the observed signal
    n_src: int, optional
        The number of sound sources (default None).
        If None, n_src is set to the number of microphones
    n_iter: int, optional
        The number of iterations (default 30)
    n_components: int, optional
        Number of components in the non-negative spectrum (default 8)
    mic_index: int or 'all', optional
        The index of microphone of which you want to get the source image (default 0).
        If 'all', return the source images of all microphones
    W0: ndarray (nfrequencies, nchannels, nchannels), optional
        Initial value for diagonalizer Q (default None).
        If None, identity matrices are used for all frequency bins.
    accelerate: bool, optional
        If true, the basis and activation of NMF are updated simultaneously (default True)
    callback: func, optional
        A callback function called every 10 iterations, allows to monitor convergence

    Returns
    -------
    If mic_index is int, returns an (nframes, nfrequencies, nsources) array.
    If mic_index is 'all', returns an (nchannels, nframes, nfrequencies, nsources) array.
    """
    eps = 1e-10
    g_eps = 5e-2
    interval_update_Q = 1  # 2 may work as well and is faster
    interval_normalize = 10
    TYPE_FLOAT = X.real.dtype
    TYPE_COMPLEX = X.dtype

    # initialize parameter
    X_FTM = X.transpose(1, 0, 2)
    n_freq, n_frames, n_chan = X_FTM.shape
    XX_FTMM = np.matmul(X_FTM[:, :, :, None], X_FTM[:, :, None, :].conj())
    if n_src is None:
        n_src = X_FTM.shape[2]

    if W0 is not None:
        Q_FMM = W0
    else:
        Q_FMM = np.tile(np.eye(n_chan).astype(TYPE_COMPLEX), [n_freq, 1, 1])

    g_NM = np.ones([n_src, n_chan], dtype=TYPE_FLOAT) * g_eps
    for m in range(n_chan):
        g_NM[m % n_src, m] = 1

    for m in range(n_chan):
        mu_F = (Q_FMM[:, m] * Q_FMM[:, m].conj()).sum(axis=1).real
        Q_FMM[:, m] /= np.sqrt(mu_F[:, None])

    H_NKT = np.random.rand(n_src, n_components, n_frames).astype(TYPE_FLOAT)
    W_NFK = np.random.rand(n_src, n_freq, n_components).astype(TYPE_FLOAT)
    lambda_NFT = W_NFK @ H_NKT
    Qx_power_FTM = np.abs(np.einsum("fij, ftj -> fti", Q_FMM, X_FTM)) ** 2
    Y_FTM = np.einsum("nft, nm -> ftm", lambda_NFT, g_NM)

    def separate():
        Qx_FTM = np.einsum("fij, ftj -> fti", Q_FMM, X_FTM)
        Qinv_FMM = np.linalg.inv(Q_FMM)
        Y_NFTM = np.einsum("nft, nm -> nftm", lambda_NFT, g_NM)

        if mic_index == "all":
            return np.einsum(
                "fij, ftj, nftj -> itfn", Qinv_FMM, Qx_FTM / Y_NFTM.sum(axis=0), Y_NFTM
            )
        elif type(mic_index) is int:
            return np.einsum(
                "fj, ftj, nftj -> tfn",
                Qinv_FMM[:, mic_index],
                Qx_FTM / Y_NFTM.sum(axis=0),
                Y_NFTM,
            )
        else:
            raise ValueError("mic_index should be int or 'all'")

    # update parameters
    for epoch in range(n_iter):
        if callback is not None and epoch % 10 == 0:
            callback(separate())

        # update W and H (basis and activation of NMF)
        tmp1_NFT = np.einsum("nm, ftm -> nft", g_NM, Qx_power_FTM / (Y_FTM**2))
        tmp2_NFT = np.einsum("nm, ftm -> nft", g_NM, 1 / Y_FTM)

        numerator = np.einsum("nkt, nft -> nfk", H_NKT, tmp1_NFT)
        denominator = np.einsum("nkt, nft -> nfk", H_NKT, tmp2_NFT)
        W_NFK *= np.sqrt(numerator / denominator)

        if not accelerate:
            tmp1_NFT = np.einsum("nm, ftm -> nft", g_NM, Qx_power_FTM / (Y_FTM**2))
            tmp2_NFT = np.einsum("nm, ftm -> nft", g_NM, 1 / Y_FTM)
            lambda_NFT = W_NFK @ H_NKT + eps
            Y_FTM = np.einsum("nft, nm -> ftm", lambda_NFT, g_NM) + eps

        numerator = np.einsum("nfk, nft -> nkt", W_NFK, tmp1_NFT)
        denominator = np.einsum("nfk, nft -> nkt", W_NFK, tmp2_NFT)
        H_NKT *= np.sqrt(numerator / denominator)

        lambda_NFT = W_NFK @ H_NKT + eps
        Y_FTM = np.einsum("nft, nm -> ftm", lambda_NFT, g_NM) + eps

        # update g_NM (diagonal element of spatial covariance matrices)
        numerator = np.einsum("nft, ftm -> nm", lambda_NFT, Qx_power_FTM / (Y_FTM**2))
        denominator = np.einsum("nft, ftm -> nm", lambda_NFT, 1 / Y_FTM)
        g_NM *= np.sqrt(numerator / denominator)
        Y_FTM = np.einsum("nft, nm -> ftm", lambda_NFT, g_NM) + eps

        # udpate Q (joint diagonalizer)
        if (interval_update_Q <= 0) or (epoch % interval_update_Q == 0):
            for m in range(n_chan):
                V_FMM = (
                    np.einsum("ftij, ft -> fij", XX_FTMM, 1 / Y_FTM[..., m]) / n_frames
                )
                tmp_FM = np.linalg.solve(
                    np.matmul(Q_FMM, V_FMM), np.eye(n_chan)[None, m]
                )
                Q_FMM[:, m] = (
                    tmp_FM
                    / np.sqrt(
                        np.einsum("fi, fij, fj -> f", tmp_FM.conj(), V_FMM, tmp_FM)
                    )[:, None]  
                ).conj()
                Qx_power_FTM = np.abs(np.einsum("fij, ftj -> fti", Q_FMM, X_FTM)) ** 2

        # normalize
        if (interval_normalize <= 0) or (epoch % interval_normalize == 0):
            phi_F = np.einsum("fij, fij -> f", Q_FMM, Q_FMM.conj()).real / n_chan
            Q_FMM /= np.sqrt(phi_F)[:, None, None]
            W_NFK /= phi_F[None, :, None]

            mu_N = g_NM.sum(axis=1)
            g_NM /= mu_N[:, None]
            W_NFK *= mu_N[:, None, None]

            nu_NK = W_NFK.sum(axis=1)
            W_NFK /= nu_NK[:, None]
            H_NKT *= nu_NK[:, :, None]

            lambda_NFT = W_NFK @ H_NKT + eps
            Qx_power_FTM = np.abs(np.einsum("fij, ftj -> fti", Q_FMM, X_FTM)) ** 2
            Y_FTM = np.einsum("nft, nm -> ftm", lambda_NFT, g_NM) + eps

    return separate(), W_NFK, H_NKT, Y_FTM, g_NM, Q_FMM

In [None]:
# shape_separate_recordings=(n_sources, n_mics, n_samples)
# ref0 = separate_recordings[:, 0, :]
SDR, SIR, SAR, PERM = [], [], [], []

win_a = pra.hamming(L)
win_s = pra.transform.stft.compute_synthesis_window(win_a, hop)

# Callback function to monitor the convergence of the algorithm
def convergence_callback_micro(Y):
    global SDR, SIR, SAR, PERM
    sdr = np.zeros(len(Y), dtype=object)
    sir = np.zeros(len(Y), dtype=object)
    sar = np.zeros(len(Y), dtype=object)
    perm = np.zeros(len(Y), dtype=object)
    
    for i in range(len(Y)):
        signal_ = pra.transform.stft.synthesis(Y[i], L, hop, win=win_s)
        signal_ = signal_[L - hop:, :].T
        
        # shape_separate_recordings=(n_sources, n_mics, n_samples)
        ref_ = separate_recordings[:, i, :]
        
        m_ = np.minimum(signal_.shape[1], ref_.shape[1])
        
        sdr[i], sir[i], sar[i], perm[i] = bss_eval_sources(ref_[:,:m_], signal_[:,:m_])
        
    SDR.append(sdr)
    SIR.append(sir)
    SAR.append(sar)
    PERM.append(perm)

In [None]:
# Parametres de l'algorithme
n_basis = 32
n_iter = 100
    
if PLOT_PERFORMANCE:
    print("Running fastmnmf2 with callback function")
    Y, W_NFK, H_NKT, Y_FTM, g_NM, Q_FMM = fastmnmf2(X, n_src=len(source_names), n_iter=n_iter, n_components=n_basis, mic_index='all', W0=None, accelerate=True, callback=convergence_callback_micro)

elif LOUIS:
    print("Running fastmnmf2 with Louis custom algotithm")
    Y, W_NFK, H_NKT, g_NM, Q_FMM, Qx_FTM, X_tilde_FTM, Y_tilde_FTM = fast_MNMF2(X.transpose(1,0,2), n_iter=n_iter, n_microphones=len(microphone_names), n_sources=len(source_names), n_time_frames=X.shape[0], n_freq_bins=X.shape[1], n_basis=n_basis, algo='IP', mic_index=None)
    Y = Y.transpose(0,3,2,1)
    
elif not PLOT_PERFORMANCE and not LOUIS:
    print("Running fastmnmf2 without callback function")
    Y, W_NFK, H_NKT, Y_FTM, g_NM, Q_FMM = fastmnmf2(X, n_src=len(source_names), n_iter=n_iter, n_components=n_basis, mic_index='all', W0=None, accelerate=True)

In [None]:
# STFT parameters
win_a = pra.hamming(L)
win_s = pra.transform.stft.compute_synthesis_window(win_a, hop)

y = []
ref = []
m = []

sdr = []
si_sdr = []
sir = []
sar = []
perm = []

for i in range(len(Y)):
        
    signal_ = pra.transform.stft.synthesis(Y[i], L, hop, win=win_s)
    signal_ = signal_[L - hop:, :].T
    y.append(signal_)
    
    ref_ = separate_recordings[:, i, :]
    ref.append(ref_)
    
    m_ = np.minimum(signal_.shape[1], ref_.shape[1])
    m.append(m_)
    
    sdr_, sir_, sar_, perm_ = bss_eval_sources(ref_[:,:m_], signal_[:,:m_])
    si_sdr_ = performance.si_sdr(ref_[:,:m_], signal_[:,:m_])
    
    sdr.append(sdr_)
    si_sdr.append(si_sdr_)
    sir.append(sir_)
    sar.append(sar_)
    perm.append(perm_)

In [None]:
# Print des sdr et sir finaux
print("SDR final : ", np.mean(sdr))
print("SI_SDR final : ", np.mean(si_sdr))
print("SIR final : ", np.mean(sir))
print("SAR final : ", np.mean(sar))

perf_mean = np.array([{"sdr_mean": np.mean(sdr)}, {"si_sdr_mean": np.mean(si_sdr)}, {"sir_mean": np.mean(sir)}, {"sar_mean": np.mean(sar)}], dtype=object)
perf_final = np.array([{"sdr_final": sdr}, {"si_sdr_final": si_sdr},{"sir_final": sir}, {"sar_final": sar}], dtype=object)

perf = np.concatenate((perf_mean, perf_final))

In [None]:
if SAVE_PERF:
    # save sdr, sir, sar in a csv file 
    np.savetxt(save_path+"perf_"+str(song_path.split("/")[2][:-9])+"_audio_length_"+str(audio_length)+"_n_basis_"+str(n_basis)+"_n_fft_"+str(L)+".csv", perf, delimiter=",", fmt="%s")

### Un peu de visualisation

In [None]:
# Correspond à la séparation de source, micro_plot_sep
mic_plot_sep = 5

fig = plt.figure()
fig.set_size_inches(10, 6)

plt.subplot(2,2,1)
plt.specgram(ref[mic_plot_sep][0,:], NFFT=1024, Fs=room.fs)
plt.title('Source 0 (target)')

plt.subplot(2,2,2)
plt.specgram(ref[mic_plot_sep][1,:], NFFT=1024, Fs=room.fs)
plt.title('Source 1 (target)')

plt.subplot(2,2,3)
plt.specgram(y[mic_plot_sep][0,:], NFFT=1024, Fs=room.fs)
plt.title('Source 0 (séparé)')

plt.subplot(2,2,4)
plt.specgram(y[mic_plot_sep][1,:], NFFT=1024, Fs=room.fs)
plt.title('Source 1 (séparé)')

plt.tight_layout()

fig.tight_layout(pad=2.5)
fig.suptitle("spectro_source01_micro_"+microphone_names[mic_plot_sep]+str(song_path.split("/")[2][:-9])+"_n_basis_"+str(n_basis)+"_n_fft_"+str(L))

if SAVE_FIG:
    fig.savefig(save_path+"spectro/"+"spectro_source01_micro_"+microphone_names[mic_plot_sep]+str(song_path.split("/")[2][:-9])+"_audio_length_"+str(audio_length)+"_n_basis_"+str(n_basis)+"_n_fft_"+str(L)+".pdf")

In [None]:
# Correspond à la séparation de source, micro_plot_sep 

fig = plt.figure()
fig.set_size_inches(10, 6)

plt.subplot(2,2,1)
plt.specgram(ref[mic_plot_sep][2,:], NFFT=1024, Fs=room.fs)
plt.title('Source 2 (clean)')

plt.subplot(2,2,2)
plt.specgram(ref[mic_plot_sep][3,:], NFFT=1024, Fs=room.fs)
plt.title('Source 3 (clean)')

plt.subplot(2,2,3)
plt.specgram(y[mic_plot_sep][2,:], NFFT=1024, Fs=room.fs)
plt.title('Source 2 (separated)')

plt.subplot(2,2,4)
plt.specgram(y[mic_plot_sep][3,:], NFFT=1024, Fs=room.fs)
plt.title('Source 3 (separated)')

plt.tight_layout()

fig.tight_layout(pad=2.5)
fig.suptitle("spectro_source23_micro_"+microphone_names[mic_plot_sep]+str(song_path.split("/")[2][:-9])+"_n_basis_"+str(n_basis)+"_n_fft_"+str(L))

if SAVE_FIG:
    fig.savefig(save_path+"spectro/"+"spectro_source23_micro_"+microphone_names[mic_plot_sep]+str(song_path.split("/")[2][:-9])+"_audio_length_"+str(audio_length)+"_n_basis_"+str(n_basis)+"_n_fft_"+str(L)+".pdf")

In [None]:
# Correspond à la séparation de source, micro_plot_sep 

if PLOT_PERFORMANCE: 
    
    print("Plotting performance for fastmnmf2")
    fig = plt.figure()
    fig.set_size_inches(10, 6)
    a = np.array(SDR)
    a = np.stack(a[:, mic_plot_perf])

    b = np.array(SIR)
    b = np.stack(b[:, mic_plot_perf])

    plt.plot(np.arange(a.shape[0]) * 10, a[:,0], label='SDR Source 0', c='r', ls='-')
    plt.plot(np.arange(a.shape[0]) * 10, a[:,1], label='SDR Source 1', c='r', ls='--')
    plt.plot(np.arange(a.shape[0]) * 10, a[:,2], label='SDR Source 2', c='r', ls=':')
    plt.plot(np.arange(a.shape[0]) * 10, a[:,3], label='SDR Source 3', c='r', ls='-.')

    plt.plot(np.arange(b.shape[0]) * 10, b[:,0], label='SIR Source 0', c='b', ls='-')
    plt.plot(np.arange(b.shape[0]) * 10, b[:,1], label='SIR Source 1', c='b', ls='--')
    plt.plot(np.arange(b.shape[0]) * 10, b[:,2], label='SIR Source 2', c='b', ls=':')
    plt.plot(np.arange(b.shape[0]) * 10, b[:,3], label='SIR Source 3', c='b', ls='-.')

    plt.legend(ncol=1)
    plt.xlabel('Iteration')
    plt.ylabel('dB')
    plt.grid()
    plt.title('performance microphone 0')

    plt.show()

#### Ecoute maintenant 

##### Pas de séparation

In [None]:
print("Listening of Audio at each microphones without separation")

for micro_n in range(len(mics_signals)):
    print('microphone ', microphone_names[micro_n])
    display(ipd.Audio(mics_signals[micro_n], rate=44100))
    if SAVE_AUDIO:
        wavfile.write(save_path+"audios/no_separation/micro_"+microphone_names[micro_n]+"/"+str(song_path.split("/")[2][:-9])+"_globals_micro_"+microphone_names[micro_n]+".wav", 44100, mics_signals[micro_n]) 

##### Séparation des sources

In [None]:
def ecoute_separation_micro(mic, y, save):
    """ Fonction permettant d'écouter les audios séparés pour un micro donné et de les sauvegarder si save=True

    Args:
        mic (int): index du microphone dont on veut écouter la séparation
        y (array): array contenant les signaux audios séparés pour chaque microphones
        save (boolean): True si on veut sauvegarder les audios séparés, False sinon
    """

    for source_n in range(len(y[0])):
        if source_n == 0:
            print("Drums séparé microphone ", microphone_names[mic])
            display(ipd.Audio(y[mic][source_n], rate=44100))
            if save:
                wavfile.write(save_path+"audios/separation/micro_"+microphone_names[mic]+"/"+str(song_path.split("/")[2][:-9])+"_drums_micro_"+microphone_names[mic]+"_audio_length_"+str(audio_length)+"_n_basis_"+str(n_basis)+"_n_fft_"+str(L)+".wav", 44100, y[mic][source_n]) 
            
        elif source_n == 1:
            print("Bass séparé microphone ", microphone_names[mic])
            display(ipd.Audio(y[mic][source_n], rate=44100))
            if save:
                wavfile.write(save_path+"audios/separation/micro_"+microphone_names[mic]+"/"+str(song_path.split("/")[2][:-9])+"_bass_micro_"+microphone_names[mic]+"_audio_length_"+str(audio_length)+"_n_basis_"+str(n_basis)+"_n_fft_"+str(L)+".wav", 44100, y[mic][source_n]) 
        
        elif source_n == 2:
            print("Accompaniement séparé microphone ", microphone_names[mic])
            display(ipd.Audio(y[mic][source_n], rate=44100))
            if save:
                wavfile.write(save_path+"audios/separation/micro_"+microphone_names[mic]+"/"+str(song_path.split("/")[2][:-9])+"_other_micro_"+microphone_names[mic]+"_audio_length_"+str(audio_length)+"_n_basis_"+str(n_basis)+"_n_fft_"+str(L)+".wav", 44100, y[mic][source_n]) 
        
        elif source_n == 3:
            print("Vocals séparé microphone ", microphone_names[mic])
            display(ipd.Audio(y[mic][source_n], rate=44100))
            if save:
                wavfile.write(save_path+"audios/separation/micro_"+microphone_names[mic]+"/"+str(song_path.split("/")[2][:-9])+"_vocals_micro_"+microphone_names[mic]+"_audio_length_"+str(audio_length)+"_n_basis_"+str(n_basis)+"_n_fft_"+str(L)+".wav", 44100, y[mic][source_n]) 

In [None]:
print("Listening of Audios at microphones 0 (couple) with separation")
ecoute_separation_micro(0, y, SAVE_AUDIO)

In [None]:
print("Listening of Audios at microphones 1 (couple) with separation")
ecoute_separation_micro(1, y, SAVE_AUDIO)

In [None]:
print("Listening of Audios at microphones 2 (couple) with separation")
ecoute_separation_micro(2, y, SAVE_AUDIO)

In [None]:
print("Listening of Audios at microphones 3 (couple) with separation")
ecoute_separation_micro(3, y, SAVE_AUDIO)

In [None]:
print("Listening of Audios at microphones 4 (couple) with separation")
ecoute_separation_micro(4, y, SAVE_AUDIO)

In [None]:
print("Listening of Audios at microphones 5 (couple) with separation")
ecoute_separation_micro(5, y, SAVE_AUDIO)

### Visualisation des matrices intermédiaires

#### Représentation of g, and G (covariance et matrice spatiale)

In [None]:
if not LOUIS:
    plt.imshow(g_NM[:, :], cmap='inferno', aspect='auto')
    plt.title('g_NM')
    
    plt.tight_layout()
    if SAVE_FIG:
        plt.savefig(save_path+"mix/"+"g_mn_"+str(song_path.split("/")[2])+"_n_basis_"+str(n_basis)+"_n_fft_"+str(L)+".pdf")   
    
    
elif LOUIS:
    plt.imshow(g_NM[:, :], cmap='inferno', aspect='auto')
    plt.title("g_mn_"+str(song_path.split("/")[2][:-9])+"_n_basis_"+str(n_basis)+"_n_fft_"+str(L))
    
    plt.tight_layout()
    
    if SAVE_FIG:
        plt.savefig(save_path+"mix/"+"g_mn_"+str(song_path.split("/")[2][:-9])+"_audio_length_"+str(audio_length)+"_n_basis_"+str(n_basis)+"_n_fft_"+str(L)+".pdf")

In [None]:
def G_mix(Q_FMM, g_NM):
    g_NMM = []
    for n_source in range(g_NM.shape[0]):
        g_NMM.append(np.diag(g_NM[n_source]))

    g_NMM = np.array(g_NMM)
    
    G_NF = np.einsum("fij, nij, ijf -> ni", np.linalg.inv(Q_FMM), g_NMM, np.linalg.inv(Q_FMM).conj().T)
    return G_NF

In [None]:
# G_NF
G_NF = G_mix(Q_FMM, g_NM)

In [None]:
plt.imshow(np.log(np.abs(G_NF)), cmap='inferno', aspect='auto')
plt.colorbar()
plt.title("G_NF_"+str(song_path.split("/")[2][:-9])+"_n_basis_"+str(n_basis)+"_n_fft_"+str(L))
plt.tight_layout()

if SAVE_FIG:
    plt.savefig(save_path+"mix/"+"G_NF_"+str(song_path.split("/")[2][:-9])+"_audio_length_"+str(audio_length)+"_n_basis_"+str(n_basis)+"_n_fft_"+str(L)+".pdf") 

#### Représentation de W

In [None]:
fig = plt.figure()
fig.set_size_inches(10, 6)

plt.subplot(2,2,1)
plt.imshow(np.log(W_NFK[0, :, :]), cmap='inferno', aspect='auto')
plt.gca().invert_yaxis()
plt.title('Matrice de base W0')

plt.subplot(2,2,2)
plt.imshow(np.log(W_NFK[1, :, :]), cmap='inferno', aspect='auto')
plt.gca().invert_yaxis()
plt.title('Matrice de base W1')

plt.subplot(2,2,3)
plt.imshow(np.log(W_NFK[2, :, :]), cmap='inferno', aspect='auto')
plt.gca().invert_yaxis()
plt.title('Matrice de base W2')

plt.subplot(2,2,4)
plt.imshow(np.log(W_NFK[3, :, :]), cmap='inferno', aspect='auto')
plt.gca().invert_yaxis()
plt.title('Matrice de base W3')

plt.tight_layout()

fig.tight_layout(pad=2.5)
fig.suptitle("W_NFK_"+str(song_path.split("/")[2][:-9])+"_n_basis_"+str(n_basis)+"_n_fft_"+str(L))

if SAVE_FIG:
    fig.savefig(save_path+"base/"+"W_NFK_"+str(song_path.split("/")[2][:-9])+"_audio_length_"+str(audio_length)+"_n_basis_"+str(n_basis)+"_n_fft_"+str(L)+".pdf") 

#### Représentation des activations

In [None]:
fig = plt.figure()
fig.set_size_inches(10, 6)

plt.subplot(2,2,1)
plt.imshow(H_NKT[0, :, :], cmap='inferno', aspect='auto')
plt.gca().invert_yaxis()
plt.title('Matrice de base H0')

plt.subplot(2,2,2)
plt.imshow(H_NKT[1, :, :], cmap='inferno', aspect='auto')
plt.gca().invert_yaxis()
plt.title('Matrice de base H1')

plt.subplot(2,2,3)
plt.imshow(H_NKT[2, :, :], cmap='inferno', aspect='auto')
plt.gca().invert_yaxis()
plt.title('Matrice de base H2')

plt.subplot(2,2,4)
plt.imshow(H_NKT[3, :, :], cmap='inferno', aspect='auto')
plt.gca().invert_yaxis()
plt.title('Matrice de base H3')

plt.tight_layout()

fig.tight_layout(pad=2.5)
fig.suptitle("H_NKT_"+str(song_path.split("/")[2][:-9])+"_n_basis_"+str(n_basis)+"_n_fft_"+str(L))

if SAVE_FIG:
    fig.savefig(save_path+"activation/"+"H_NKT_"+str(song_path.split("/")[2][:-9])+"_audio_length_"+str(audio_length)+"_n_basis_"+str(n_basis)+"_n_fft_"+str(L)+".pdf") 