In [None]:
import os
import torch
from demucs import pretrained
from demucs.apply import apply_model
import soundfile as sf
import torchaudio
import librosa
import librosa.display
import matplotlib.pyplot as plt
import numpy as np

# Récupérer le répertoire courant
CURRENT_DIRECTORY = os.getcwd()

class AudioDenoiser:
    def __init__(self, model_name='mdx_extra'):
        """Initialise le modèle Demucs pour le débruitage."""
        self.model = pretrained.get_model(model_name)
        self.model.eval()
        if torch.cuda.is_available():
            self.model = self.model.cuda()
        self.waveform = None
        self.sr = None
        self.sources = None
        self.vocal_source = None

    def load_audio(self, input_audio):
        """Charge l'audio à débruiter."""
        if not os.path.exists(input_audio):
            raise FileNotFoundError(f"Le fichier {input_audio} n'existe pas.")
        waveform, sr = torchaudio.load(input_audio)
        self.waveform = waveform.unsqueeze(0)
        if torch.cuda.is_available():
            self.waveform = self.waveform.cuda()
        self.sr = sr

    def separate_sources(self):
        """Sépare les différentes sources audio avec Demucs."""
        with torch.no_grad():
            self.sources = apply_model(self.model, self.waveform, split=True)

    def extract_vocal_source(self, source_index=3):
        """Extrait la source vocale des sources séparées."""
        vocal_source = self.sources[0, source_index, :, :]
        self.vocal_source = torch.mean(vocal_source, dim=0)  # Moyenne pour obtenir un signal mono

    def save_isolated_voice(self, output_path=None):
        """Sauvegarde la voix isolée dans un fichier."""
        if output_path is None:
            output_path = os.path.join(CURRENT_DIRECTORY, "data", "isolated_voice_demucs.wav")
        os.makedirs(os.path.dirname(output_path), exist_ok=True)
        vocal_numpy = self.vocal_source.cpu().numpy()
        sf.write(output_path, vocal_numpy, self.sr)
        return output_path

    def plot_waveforms(self, original_audio_path, processed_audio_path):
        """Affiche les formes d'onde de l'audio original et débruité."""
        original_audio, sr = librosa.load(original_audio_path, sr=None)
        processed_audio, _ = librosa.load(processed_audio_path, sr=None)

        plt.figure(figsize=(14, 6))
        plt.subplot(2, 1, 1)
        librosa.display.waveshow(original_audio, sr=sr)
        plt.title('Forme d\'onde originale')
        plt.subplot(2, 1, 2)
        librosa.display.waveshow(processed_audio, sr=sr)
        plt.title('Forme d\'onde après débruitage')
        plt.tight_layout()
        plt.show()

    def plot_spectrograms(self, original_audio_path, processed_audio_path):
        """Affiche les spectrogrammes de l'audio original et débruité."""
        original_audio, sr = librosa.load(original_audio_path, sr=None)
        processed_audio, _ = librosa.load(processed_audio_path, sr=None)

        plt.figure(figsize=(14, 6))
        plt.subplot(2, 1, 1)
        D_original = librosa.amplitude_to_db(np.abs(librosa.stft(original_audio)), ref=np.max)
        librosa.display.specshow(D_original, sr=sr, x_axis='time', y_axis='log')
        plt.title('Spectrogramme original')
        plt.colorbar(format='%+2.0f dB')
        plt.subplot(2, 1, 2)
        D_processed = librosa.amplitude_to_db(np.abs(librosa.stft(processed_audio)), ref=np.max)
        librosa.display.specshow(D_processed, sr=sr, x_axis='time', y_axis='log')
        plt.title('Spectrogramme après débruitage')
        plt.colorbar(format='%+2.0f dB')
        plt.tight_layout()
        plt.show()

    def compute_kpi(self, original_audio_path, processed_audio_path):
        """Calcule les KPI, comme l'amélioration du SNR."""
        original_audio, _ = librosa.load(original_audio_path, sr=None)
        processed_audio, _ = librosa.load(processed_audio_path, sr=None)

        # Calcul de l'énergie du signal
        original_energy = np.sum(original_audio**2)
        processed_energy = np.sum(processed_audio**2)

        # Calcul du rapport signal-bruit (SNR)
        noise = original_audio - processed_audio
        signal_power = np.sum(processed_audio**2)
        noise_power = np.sum(noise**2)
        snr_improvement = 10 * np.log10(signal_power / noise_power)

        print(f"Amélioration du SNR : {snr_improvement:.2f} dB")
        print(f"Énergie de l'audio original : {original_energy:.2f}")
        print(f"Énergie de l'audio débruité : {processed_energy:.2f}")

    def process_audio(self, input_audio):
        """Exécute tout le processus de débruitage."""
        self.load_audio(input_audio)
        self.separate_sources()
        self.extract_vocal_source()
        output_path = self.save_isolated_voice()
        self.plot_waveforms(input_audio, output_path)
        self.plot_spectrograms(input_audio, output_path)
        self.compute_kpi(input_audio, output_path)


if __name__ == "__main__":
    input_audio_path = os.path.join(CURRENT_DIRECTORY, "data", "radiobruite.wav")
    denoiser = AudioDenoiser()
    denoiser.process_audio(input_audio_path)
