## Simulation d'une situation réel avec Pyroomacoustics

NOTE: 

Il y a plusieurs possibilités dans ce notebook: 

- Utilisation de différents algorithmes (Celui de Louis, ou celui de pyroomacoustics)
Pour utiliser l'un ou l'autre, il faut faire attention aux variables globales (définies avec des MAJUSCULES). 
- Utilisation de fichiers MuseDB ou de fichiers WAV.
- Possibilité de plot les performances en fonction de l'avancée de l'algorithme. (PLOT_PERFORMANCE) mais ATTENTION, cela est implémenté que pour l'algo pyroom pour l'instant. 
- Possibilité de sauvegarder les figures et les audios (SAVE_FIG et SAVE_AUDIO), cela va créé les dossiers correspondants sous 'test'

- Puis bien sûr, possibilité de modifier toutes les données concernant l'emplacement des microphones, dimensions de la salle etc... au fur et à mesure du notebook.

In [None]:
import IPython.display as ipd
from IPython.display import display
import pickle
import librosa
import os
import math
import matplotlib.pyplot as plt
import stempeg
import itertools
import operator
import pyroomacoustics as pra
import wave
import random as rd
import numpy as np
from scipy.io import wavfile
from pyroomacoustics.directivities import (
    DirectivityPattern,
    DirectionVector,
    CardioidFamily,
)
from numpy import typing
from mir_eval.separation import bss_eval_sources

path_in = "./data/musedb/"
save_path = "./test/surdeterm/"

In [None]:
from src import performance, data_processing, fast_nmf

### Chargement d'un fichier MuseDB

All files in MUSDB18 dataset is a multitrack format composed of 5 stereo streams, each one encoded in AAC @256kbps. These signals correspond to:

- `0` - The mixture,
- `1` - The drums,
- `2` - The bass,
- `3` - The rest of the accompaniment,
- `4` - The vocals.

```S.shape = (5, time_step , 2)```

In [None]:
# load the data
files_in, files_tilte = data_processing.get_files(path_in, ".mp4")

In [None]:
for path in files_in[:3]:
    S, rate = stempeg.read_stems(path)
    display(ipd.Audio(S[0][:, 0], rate=rate))

### Test avec un fichier wav classique

In [None]:
wav_files = [
    ["data/samples/BACH Cello Suite 1, Prelude, Violin - Kateryna Timokhina.wav"],
    ["data/samples/Bach_ Prélude, Cello suite Nr.1  Ophélie Gaillard.wav"],
    ["data/samples/Bach_ Prélude, Cello suite Nr.1  Ophélie Gaillard.wav"],
    ["data/samples/BACH Cello Suite 1, Prelude, Violin - Kateryna Timokhina.wav"],
]

signals = [
    np.concatenate([librosa.load(f, sr=None, mono=False)[0] for f in source_files])
    for source_files in wav_files
]

### Définition des variables globales

In [None]:
PLOT_PERFORMANCE = False  # plot les performances en fonction des itérations de l'algorithme => Implémenté que pour pyroomacoustics
MUSEDB = True  # Si True, on utilise la base de données MUSEDB, sinon on utilise les fichiers wav_files
LOUIS = True  # Si True, on utilise l'algorithme de LOUIS, sinon on utilise l'algorithme de pyroomacoustics
SAVE_FIG = (
    False  # Si True, on sauvegarde toutes les figures dans les sous-dossiers de 'test'
)
SAVE_AUDIO = False  # Si True, on sauvegarde tous les fichiers audio dans le sous-dossiers 'audio' de 'test'
SAVE_PERF = (
    False  # Si True, on sauvegarde toutes les figures dans les sous-dossiers de 'test'
)
TYPE = "stft"

# Paramètres des microphones
mic_pattern = DirectivityPattern.CARDIOID
MIC_DIR = CardioidFamily(
    orientation=DirectionVector(azimuth=180, colatitude=60, degrees=True),
    pattern_enum=mic_pattern,
)

In [None]:
# Permet de créé les dossiers de sauvegarde si les variables permettant la sauvegarde
# sont à True

try:
    os.mkdir("./test")

except OSError as error:
    pass


if SAVE_FIG:
    try:
        os.mkdir(save_path)

        try:
            os.mkdir(save_path + "activation")
            os.mkdir(save_path + "base")
            os.mkdir(save_path + "spectro")
            os.mkdir(save_path + "mix")
        except OSError as error:
            pass

    except OSError as error:
        pass

if SAVE_AUDIO:
    try:
        os.mkdir(save_path + "audios")
        os.mkdir(save_path + "audios/separation")
        os.mkdir(save_path + "audios/no_separation")

        os.mkdir(save_path + "audios/separation/micro_drums")
        os.mkdir(save_path + "audios/separation/micro_bass")
        os.mkdir(save_path + "audios/separation/micro_vocals")
        os.mkdir(save_path + "audios/separation/micro_other")
        os.mkdir(save_path + "audios/separation/micro_AB1")
        os.mkdir(save_path + "audios/separation/micro_AB2")

        os.mkdir(save_path + "audios/no_separation/micro_AB1")
        os.mkdir(save_path + "audios/no_separation/micro_AB2")
        os.mkdir(save_path + "audios/no_separation/micro_drums")
        os.mkdir(save_path + "audios/no_separation/micro_bass")
        os.mkdir(save_path + "audios/no_separation/micro_vocals")
        os.mkdir(save_path + "audios/no_separation/micro_other")

    except OSError as error:
        pass

In [None]:
if not MUSEDB:
    print("working with custom signals")

    # Parametres de la room
    room_dimension = [12, 20, 5]
    abs_coef = 0.35

    # Parametres des sources: noms et localisations dans la salle
    source_names = ["source1", "source2", "source3", "source4"]
    source_locations = [[2, 9, 1], [2, 11, 1], [3, 12, 1], [3, 7, 1]]

    # Parametres des micros: noms et localisations dans la salle
    microphone_names = ["mic1", "mic2", "mic3", "mic4", "mic5", "mic6"]
    microphone_locations = np.c_[
        [2.7, 9, 1], [2.7, 11, 1], [3.5, 12, 1], [3.5, 7, 1], [6, 9.6, 1], [6, 10.4, 1]
    ]

    # Parametres de la STFT
    L = 2048
    hop = L // 4

    # Parametres de la musique
    audio_length = 10
    rate = 44100

    # Création de la room
    room = data_processing.shoebox_room(
        room_dimension,
        abs_coef,
    )

    # Ajout des sources et des microphones
    room, separate_recordings, mics_signals = data_processing.room_sources_micro(
        signals,
        rate=rate,
        audio_length=audio_length,
        room=room,
        source_locations=source_locations,
        microphone_locations=microphone_locations,
        microphone_names=microphone_names,
        source_dir=None,
        mic_dir=MIC_DIR,
        display_room=True,
    )

    # Transformations des signaux audio en STFT_multichannel en prennant en compte la room et l'emplacement des micros et sources
    X = data_processing.spectrogram_from_mics_signal(
        mics_signals,
        microphone_names,
        rate=rate,
        L=L,
        hop=hop,
        type=TYPE,
        display_audio=True,
        display_spectrogram=False,
    )

In [None]:
if MUSEDB:
    print("working with MUSDB18")

    # Parametres de la room
    room_dimension = [12, 20, 5]
    abs_coef = 0.35

    # Parametres des sources: noms et localisations dans la salle
    # 1: Drums | 2: Bass | 3: Accompaniemenet | 4: Vocals
    source_names = ["drums", "bass", "other", "vocals"]
    source_locations = [[2, 9, 1], [2, 11, 1], [3, 12, 1], [3, 7, 1]]

    # Parametres des micros: noms et localisations dans la salle
    # 1: Drums | 2: Bass | 3: Accompaniemenet | 4: Vocals | 5 et 6 couples AB (espacés de 80cm)
    microphone_names = ["drums", "bass", "other", "vocals", "AB1", "AB2"]
    microphone_locations = np.c_[
        [2.7, 9, 1], [2.7, 11, 1], [3.5, 12, 1], [3.5, 7, 1], [6, 9.6, 1], [6, 10.4, 1]
    ]

    # Prend une musique au hasard
    ind = rd.randint(0, len(files_in))
    song_path = files_in[2]

    data, rate = stempeg.read_stems(song_path)
    audio_list = data[1:, :].transpose(0, 2, 1)

    # Parametres de la STFT
    L = 4096
    hop = L // 4

    # Parametres de la musique
    audio_length = 10

    # Création de la room
    room = data_processing.shoebox_room(
        room_dimension,
        abs_coef,
    )

    # Ajout des sources et des microphones
    room, separate_recordings, mics_signals = data_processing.room_sources_micro(
        audio_list,
        rate=rate,
        audio_length=audio_length,
        room=room,
        source_locations=source_locations,
        microphone_locations=microphone_locations,
        microphone_names=microphone_names,
        source_dir=None,
        mic_dir=MIC_DIR,
        display_room=True,
    )

    # Transformations des signaux audio en STFT_multichannel en prennant en compte la room et l'emplacement des micros et sources
    X = data_processing.spectrogram_from_mics_signal(
        mics_signals,
        microphone_names,
        rate=rate,
        L=L,
        hop=hop,
        type=TYPE,
        display_audio=True,
        display_spectrogram=False,
    )

### Ajout de la Fast MNMF pour la séparation

In [None]:
# shape_separate_recordings=(n_sources, n_mics, n_samples)
# ref0 = separate_recordings[:, 0, :]
SDR, SIR, SAR, PERM = [], [], [], []

win_a = pra.hamming(L)
win_s = pra.transform.stft.compute_synthesis_window(win_a, hop)

# Callback function to monitor the convergence of the algorithm
def convergence_callback_micro(Y):
    global SDR, SIR, SAR, PERM
    sdr = np.zeros(len(Y), dtype=object)
    sir = np.zeros(len(Y), dtype=object)
    sar = np.zeros(len(Y), dtype=object)
    perm = np.zeros(len(Y), dtype=object)

    for i in range(len(Y)):
        signal_ = pra.transform.stft.synthesis(Y[i], L, hop, win=win_s)
        signal_ = signal_[L - hop :, :].T

        # shape_separate_recordings=(n_sources, n_mics, n_samples)
        ref_ = separate_recordings[:, i, :]

        m_ = np.minimum(signal_.shape[1], ref_.shape[1])

        sdr[i], sir[i], sar[i], perm[i] = bss_eval_sources(
            ref_[:, :m_], signal_[:, :m_]
        )

    SDR.append(sdr)
    SIR.append(sir)
    SAR.append(sar)
    PERM.append(perm)

In [None]:
# Parametres de l'algorithme
n_basis = 16
n_iter = 100

if PLOT_PERFORMANCE:
    print("Running fastmnmf2 with callback function")
    Y, W_NFK, H_NKT, Y_FTM, g_NM, Q_FMM = fast_nmf.fastmnmf2_pyroom(
        X,
        n_src=len(source_names),
        n_iter=n_iter,
        n_components=n_basis,
        mic_index="all",
        W0=None,
        accelerate=True,
        callback=convergence_callback_micro,
    )

elif LOUIS:
    print("Running fastmnmf2 with Louis custom algorithm")
    (
        Y,
        W_NFK,
        H_NKT,
        g_NM,
        Q_FMM,
        Qx_FTM,
        X_tilde_FTM,
        Y_tilde_FTM,
    ) = fast_nmf.fast_MNMF2(
        X.transpose(1, 0, 2),
        n_iter=n_iter,
        n_microphones=len(microphone_names),
        n_sources=len(source_names),
        n_time_frames=X.shape[0],
        n_freq_bins=X.shape[1],
        n_basis=n_basis,
        algo="IP",
        mic_index=None,
        split=True,
        n_activations=X.shape[0] - 1,
        n_notes=X.shape[1] - 1,
    )
    Y = Y.transpose(0, 3, 2, 1)

elif not PLOT_PERFORMANCE and not LOUIS:
    print("Running fastmnmf2 without callback function")
    Y, W_NFK, H_NKT, Y_FTM, g_NM, Q_FMM = fast_nmf.fastmnmf2_pyroom(
        X,
        n_src=len(source_names),
        n_iter=n_iter,
        n_components=n_basis,
        mic_index="all",
        W0=None,
        accelerate=True,
    )

#### Compute performance score

In [None]:
if TYPE == "stft":
    # STFT parameters
    win_a = pra.hamming(L)
    win_s = pra.transform.stft.compute_synthesis_window(win_a, hop)

    y = []
    for i in range(len(Y)):
        signal_ = pra.transform.stft.synthesis(Y[i], L, hop, win=win_s)
        signal_ = signal_[L - hop :, :].T
        y.append(signal_)

elif TYPE == "cqt":
    y = []
    for i in range(len(Y)):
        signal_ = librosa.icqt(
            Y[i].transpose(2, 1, 0),
            sr=rate,
            hop_length=hop,
            fmin=None,
            bins_per_octave=12,
            tuning=0.0,
            filter_scale=1,
            norm=1,
            sparsity=0.01,
            window="hann",
            scale=True,
            length=None,
            res_type="fft",
            dtype=None,
        )
        y.append(signal_)

# shape of y = (n_mics, n_sources, n_samples)
y = np.array(y)
# shape of ref = (n_mics, n_sources, n_samples)
ref = separate_recordings.transpose(1, 0, 2)

In [None]:
sdr, si_sdr, sir, sar, perm = performance.compute_perf(y, ref)

In [None]:
# Print des sdr et sir finaux
print("SDR final : ", np.mean(sdr))
print("SI_SDR final : ", np.mean(si_sdr))
print("SIR final : ", np.mean(sir))
print("SAR final : ", np.mean(sar))

perf_mean = np.array(
    [
        {"sdr_mean": np.mean(sdr)},
        {"si_sdr_mean": np.mean(si_sdr)},
        {"sir_mean": np.mean(sir)},
        {"sar_mean": np.mean(sar)},
    ],
    dtype=object,
)
perf_final = np.array(
    [
        {"sdr_final": sdr},
        {"si_sdr_final": si_sdr},
        {"sir_final": sir},
        {"sar_final": sar},
    ],
    dtype=object,
)

perf = np.concatenate((perf_mean, perf_final))

In [None]:
if SAVE_PERF:
    # save sdr, sir, sar in a csv file
    np.savetxt(
        save_path
        + "perf_"
        + str(song_path.split("/")[2][:-9])
        + "_audio_length_"
        + str(audio_length)
        + "_n_basis_"
        + str(n_basis)
        + "_n_fft_"
        + str(L)
        + ".csv",
        perf,
        delimiter=",",
        fmt="%s",
    )

### Un peu de visualisation

In [None]:
# Correspond à la séparation de source, micro_plot_sep
mic_plot_sep = 5

fig = plt.figure()
fig.set_size_inches(10, 6)

plt.subplot(2, 2, 1)
plt.specgram(ref[mic_plot_sep][0, :], NFFT=1024, Fs=room.fs)
plt.title("Source 0 (target)")

plt.subplot(2, 2, 2)
plt.specgram(ref[mic_plot_sep][1, :], NFFT=1024, Fs=room.fs)
plt.title("Source 1 (target)")

plt.subplot(2, 2, 3)
plt.specgram(y[mic_plot_sep][0, :], NFFT=1024, Fs=room.fs)
plt.title("Source 0 (séparé)")

plt.subplot(2, 2, 4)
plt.specgram(y[mic_plot_sep][1, :], NFFT=1024, Fs=room.fs)
plt.title("Source 1 (séparé)")

plt.tight_layout()

fig.tight_layout(pad=2.5)
fig.suptitle(
    "spectro_source01_micro_"
    + microphone_names[mic_plot_sep]
    + str(song_path.split("/")[2][:-9])
    + "_n_basis_"
    + str(n_basis)
    + "_n_fft_"
    + str(L)
)

if SAVE_FIG:
    fig.savefig(
        save_path
        + "spectro/"
        + "spectro_source01_micro_"
        + microphone_names[mic_plot_sep]
        + str(song_path.split("/")[2][:-9])
        + "_audio_length_"
        + str(audio_length)
        + "_n_basis_"
        + str(n_basis)
        + "_n_fft_"
        + str(L)
        + ".pdf"
    )

In [None]:
# Correspond à la séparation de source, micro_plot_sep

fig = plt.figure()
fig.set_size_inches(10, 6)

plt.subplot(2, 2, 1)
plt.specgram(ref[mic_plot_sep][2, :], NFFT=1024, Fs=room.fs)
plt.title("Source 2 (clean)")

plt.subplot(2, 2, 2)
plt.specgram(ref[mic_plot_sep][3, :], NFFT=1024, Fs=room.fs)
plt.title("Source 3 (clean)")

plt.subplot(2, 2, 3)
plt.specgram(y[mic_plot_sep][2, :], NFFT=1024, Fs=room.fs)
plt.title("Source 2 (separated)")

plt.subplot(2, 2, 4)
plt.specgram(y[mic_plot_sep][3, :], NFFT=1024, Fs=room.fs)
plt.title("Source 3 (separated)")

plt.tight_layout()

fig.tight_layout(pad=2.5)
fig.suptitle(
    "spectro_source23_micro_"
    + microphone_names[mic_plot_sep]
    + str(song_path.split("/")[2][:-9])
    + "_n_basis_"
    + str(n_basis)
    + "_n_fft_"
    + str(L)
)

if SAVE_FIG:
    fig.savefig(
        save_path
        + "spectro/"
        + "spectro_source23_micro_"
        + microphone_names[mic_plot_sep]
        + str(song_path.split("/")[2][:-9])
        + "_audio_length_"
        + str(audio_length)
        + "_n_basis_"
        + str(n_basis)
        + "_n_fft_"
        + str(L)
        + ".pdf"
    )

In [None]:
# Correspond à la séparation de source, micro_plot_sep

if PLOT_PERFORMANCE:

    print("Plotting performance for fastmnmf2")
    fig = plt.figure()
    fig.set_size_inches(10, 6)
    a = np.array(SDR)
    a = np.stack(a[:, mic_plot_perf])

    b = np.array(SIR)
    b = np.stack(b[:, mic_plot_perf])

    plt.plot(np.arange(a.shape[0]) * 10, a[:, 0], label="SDR Source 0", c="r", ls="-")
    plt.plot(np.arange(a.shape[0]) * 10, a[:, 1], label="SDR Source 1", c="r", ls="--")
    plt.plot(np.arange(a.shape[0]) * 10, a[:, 2], label="SDR Source 2", c="r", ls=":")
    plt.plot(np.arange(a.shape[0]) * 10, a[:, 3], label="SDR Source 3", c="r", ls="-.")

    plt.plot(np.arange(b.shape[0]) * 10, b[:, 0], label="SIR Source 0", c="b", ls="-")
    plt.plot(np.arange(b.shape[0]) * 10, b[:, 1], label="SIR Source 1", c="b", ls="--")
    plt.plot(np.arange(b.shape[0]) * 10, b[:, 2], label="SIR Source 2", c="b", ls=":")
    plt.plot(np.arange(b.shape[0]) * 10, b[:, 3], label="SIR Source 3", c="b", ls="-.")

    plt.legend(ncol=1)
    plt.xlabel("Iteration")
    plt.ylabel("dB")
    plt.grid()
    plt.title("performance microphone 0")

    plt.show()

#### Ecoute maintenant 

##### Pas de séparation

In [None]:
print("Listening of Audio at each microphones without separation")

SAVE_AUDIO = True

for micro_n in range(len(mics_signals)):
    print("microphone ", microphone_names[micro_n])
    display(ipd.Audio(mics_signals[micro_n], rate=44100))
    if SAVE_AUDIO:
        wavfile.write(
            save_path
            + "audios/no_separation/micro_"
            + microphone_names[micro_n]
            + "/test_"
            + str(song_path.split("/")[2][:-9])
            + "_globals_micro_"
            + microphone_names[micro_n]
            + ".wav",
            44100,
            mics_signals[micro_n].astype(np.float32),
        )

##### Séparation des sources

In [None]:
def ecoute_separation_micro(mic, y, save):
    """Fonction permettant d'écouter les audios séparés pour un micro donné et de les sauvegarder si save=True

    Args:
        mic (int): index du microphone dont on veut écouter la séparation
        y (array): array contenant les signaux audios séparés pour chaque microphones
        save (boolean): True si on veut sauvegarder les audios séparés, False sinon
    """

    for source_n in range(len(y[0])):
        if source_n == 0:
            print("Drums séparé microphone ", microphone_names[mic])
            display(ipd.Audio(y[mic][source_n], rate=44100))
            if save:
                wavfile.write(
                    save_path
                    + "audios/separation/micro_"
                    + microphone_names[mic]
                    + "/"
                    + str(song_path.split("/")[2][:-9])
                    + "_drums_micro_"
                    + microphone_names[mic]
                    + "_audio_length_"
                    + str(audio_length)
                    + "_n_basis_"
                    + str(n_basis)
                    + "_n_fft_"
                    + str(L)
                    + ".wav",
                    44100,
                    y[mic][source_n].astype(np.float32),
                )

        elif source_n == 1:
            print("Bass séparé microphone ", microphone_names[mic])
            display(ipd.Audio(y[mic][source_n], rate=44100))
            if save:
                wavfile.write(
                    save_path
                    + "audios/separation/micro_"
                    + microphone_names[mic]
                    + "/"
                    + str(song_path.split("/")[2][:-9])
                    + "_bass_micro_"
                    + microphone_names[mic]
                    + "_audio_length_"
                    + str(audio_length)
                    + "_n_basis_"
                    + str(n_basis)
                    + "_n_fft_"
                    + str(L)
                    + ".wav",
                    44100,
                    y[mic][source_n].astype(np.float32),
                )

        elif source_n == 2:
            print("Accompaniement séparé microphone ", microphone_names[mic])
            display(ipd.Audio(y[mic][source_n], rate=44100))
            if save:
                wavfile.write(
                    save_path
                    + "audios/separation/micro_"
                    + microphone_names[mic]
                    + "/"
                    + str(song_path.split("/")[2][:-9])
                    + "_other_micro_"
                    + microphone_names[mic]
                    + "_audio_length_"
                    + str(audio_length)
                    + "_n_basis_"
                    + str(n_basis)
                    + "_n_fft_"
                    + str(L)
                    + ".wav",
                    44100,
                    y[mic][source_n].astype(np.float32),
                )

        elif source_n == 3:
            print("Vocals séparé microphone ", microphone_names[mic])
            display(ipd.Audio(y[mic][source_n], rate=44100))
            if save:
                wavfile.write(
                    save_path
                    + "audios/separation/micro_"
                    + microphone_names[mic]
                    + "/"
                    + str(song_path.split("/")[2][:-9])
                    + "_vocals_micro_"
                    + microphone_names[mic]
                    + "_audio_length_"
                    + str(audio_length)
                    + "_n_basis_"
                    + str(n_basis)
                    + "_n_fft_"
                    + str(L)
                    + ".wav",
                    44100,
                    y[mic][source_n].astype(np.float32),
                )

In [None]:
print("Listening of Audios at microphones 0 (couple) with separation")
ecoute_separation_micro(0, y, SAVE_AUDIO)

In [None]:
print("Listening of Audios at microphones 1 (couple) with separation")
ecoute_separation_micro(1, y, SAVE_AUDIO)

In [None]:
print("Listening of Audios at microphones 2 (couple) with separation")
ecoute_separation_micro(2, y, SAVE_AUDIO)

In [None]:
print("Listening of Audios at microphones 3 (couple) with separation")
ecoute_separation_micro(3, y, SAVE_AUDIO)

In [None]:
print("Listening of Audios at microphones 4 (couple) with separation")
ecoute_separation_micro(4, y, SAVE_AUDIO)

In [None]:
print("Listening of Audios at microphones 5 (couple) with separation")
ecoute_separation_micro(5, y, SAVE_AUDIO)

### Visualisation des matrices intermédiaires

#### Représentation of g, and G (covariance et matrice spatiale)

In [None]:
if not LOUIS:
    plt.imshow(g_NM[:, :], cmap="inferno", aspect="auto")
    plt.title("g_NM")

    plt.tight_layout()
    if SAVE_FIG:
        plt.savefig(
            save_path
            + "mix/"
            + "g_mn_"
            + str(song_path.split("/")[2])
            + "_n_basis_"
            + str(n_basis)
            + "_n_fft_"
            + str(L)
            + ".pdf"
        )


elif LOUIS:
    plt.imshow(g_NM[:, :], cmap="inferno", aspect="auto")
    plt.title(
        "g_mn_"
        + str(song_path.split("/")[2][:-9])
        + "_n_basis_"
        + str(n_basis)
        + "_n_fft_"
        + str(L)
    )

    plt.tight_layout()

    if SAVE_FIG:
        plt.savefig(
            save_path
            + "mix/"
            + "g_mn_"
            + str(song_path.split("/")[2][:-9])
            + "_audio_length_"
            + str(audio_length)
            + "_n_basis_"
            + str(n_basis)
            + "_n_fft_"
            + str(L)
            + ".pdf"
        )

In [None]:
def G_mix(Q_FMM, g_NM):
    g_NMM = []
    for n_source in range(g_NM.shape[0]):
        g_NMM.append(np.diag(g_NM[n_source]))

    g_NMM = np.array(g_NMM)

    G_NF = np.einsum(
        "fij, nij, ijf -> ni",
        np.linalg.inv(Q_FMM),
        g_NMM,
        np.linalg.inv(Q_FMM).conj().T,
    )
    return G_NF

In [None]:
# G_NF
G_NF = G_mix(Q_FMM, g_NM)

In [None]:
plt.imshow(np.log(np.abs(G_NF)), cmap="inferno", aspect="auto")
plt.colorbar()
plt.title(
    "G_NF_"
    + str(song_path.split("/")[2][:-9])
    + "_n_basis_"
    + str(n_basis)
    + "_n_fft_"
    + str(L)
)
plt.tight_layout()

if SAVE_FIG:
    plt.savefig(
        save_path
        + "mix/"
        + "G_NF_"
        + str(song_path.split("/")[2][:-9])
        + "_audio_length_"
        + str(audio_length)
        + "_n_basis_"
        + str(n_basis)
        + "_n_fft_"
        + str(L)
        + ".pdf"
    )

#### Représentation de W

In [None]:
fig = plt.figure()
fig.set_size_inches(10, 6)

plt.subplot(2, 2, 1)
plt.imshow(np.log(W_NFK[0, :, :]), cmap="inferno", aspect="auto")
plt.gca().invert_yaxis()
plt.title("Matrice de base W0")

plt.subplot(2, 2, 2)
plt.imshow(np.log(W_NFK[1, :, :]), cmap="inferno", aspect="auto")
plt.gca().invert_yaxis()
plt.title("Matrice de base W1")

plt.subplot(2, 2, 3)
plt.imshow(np.log(W_NFK[2, :, :]), cmap="inferno", aspect="auto")
plt.gca().invert_yaxis()
plt.title("Matrice de base W2")

plt.subplot(2, 2, 4)
plt.imshow(np.log(W_NFK[3, :, :]), cmap="inferno", aspect="auto")
plt.gca().invert_yaxis()
plt.title("Matrice de base W3")

plt.tight_layout()

fig.tight_layout(pad=2.5)
fig.suptitle(
    "W_NFK_"
    + str(song_path.split("/")[2][:-9])
    + "_n_basis_"
    + str(n_basis)
    + "_n_fft_"
    + str(L)
)

if SAVE_FIG:
    fig.savefig(
        save_path
        + "base/"
        + "W_NFK_"
        + str(song_path.split("/")[2][:-9])
        + "_audio_length_"
        + str(audio_length)
        + "_n_basis_"
        + str(n_basis)
        + "_n_fft_"
        + str(L)
        + ".pdf"
    )

#### Représentation des activations

In [None]:
fig = plt.figure()
fig.set_size_inches(10, 6)

plt.subplot(2, 2, 1)
plt.imshow(H_NKT[0, :, :], cmap="inferno", aspect="auto")
plt.gca().invert_yaxis()
plt.title("Matrice de base H0")

plt.subplot(2, 2, 2)
plt.imshow(H_NKT[1, :, :], cmap="inferno", aspect="auto")
plt.gca().invert_yaxis()
plt.title("Matrice de base H1")

plt.subplot(2, 2, 3)
plt.imshow(H_NKT[2, :, :], cmap="inferno", aspect="auto")
plt.gca().invert_yaxis()
plt.title("Matrice de base H2")

plt.subplot(2, 2, 4)
plt.imshow(H_NKT[3, :, :], cmap="inferno", aspect="auto")
plt.gca().invert_yaxis()
plt.title("Matrice de base H3")

plt.tight_layout()

fig.tight_layout(pad=2.5)
fig.suptitle(
    "H_NKT_"
    + str(song_path.split("/")[2][:-9])
    + "_n_basis_"
    + str(n_basis)
    + "_n_fft_"
    + str(L)
)

if SAVE_FIG:
    fig.savefig(
        save_path
        + "activation/"
        + "H_NKT_"
        + str(song_path.split("/")[2][:-9])
        + "_audio_length_"
        + str(audio_length)
        + "_n_basis_"
        + str(n_basis)
        + "_n_fft_"
        + str(L)
        + ".pdf"
    )