**DSP DEMIXING**


In [None]:
# Install the required packages
%pip install -r requirements.txt

Let's start with loading the dataset. The folder "musdbhq_trimmed" contains 30 seconda of all the tracks. Since we noticed that not all the stems of the tracks were non-silent in the first 30 seconds, we trimmed the dataset in order to retrieve 30 seconds of each track where every stem is non-silent, in order to have a more accurate measure.

In [None]:
import tqdm, os, torchaudio

def load_dataset():
    """
    Load the dataset from the musdb18hq_trimmed folder.
    Each subfolder in the dataset corresponds to a song.
    Each song contains multiple stems (e.g., mixture, drums, bass, etc.).
    Returns:
        dataset_dict (dict): A dictionary where keys are track folders and values are dictionaries of stems.
    """
    dataset_dict = {}

    for track_folder in tqdm.tqdm(os.listdir("/Users/alessandromanattini/Desktop/MAE/SELECTED TOPIC/PROJECT STMAE/musdb18hq_trimmed")):
        track_path = os.path.join("/Users/alessandromanattini/Desktop/MAE/SELECTED TOPIC/PROJECT STMAE/musdb18hq_trimmed", track_folder)
        if not os.path.isdir(track_path):
            continue

        # Prepare a sub-dictionary for this song
        stems_dict = {}
        
        for stem_name in ["mixture", "drums", "bass", "vocals", "other", "new_mixture"]:
            file_path = os.path.abspath(os.path.join(track_path, f"{stem_name}.wav"))
            
            if not os.path.isfile(file_path):
                print(f"Warning: file not found {file_path}")
                continue

            # Load full audio
            waveform, sr = torchaudio.load(file_path)

            stems_dict[stem_name] = waveform
            
        dataset_dict[track_folder] = stems_dict
        
    return dataset_dict

In [None]:
# Load the dataset
dataset_dict = load_dataset()  

print("Number of keys in dataset_dict:", len(dataset_dict))

# Check the first track folder and its contents
first_track_folder = list(dataset_dict.keys())[0]
print("First track folder:", first_track_folder)
print("Contents of the first track folder:")
for stem_name in dataset_dict[first_track_folder].keys():
    print(f" - {stem_name}: {dataset_dict[first_track_folder][stem_name].shape}")

Let's load all the mixtures in a list ***mixture_files[]***.


In [None]:
# Load all new_mixture.wav files 
mixture_files = []
for track_folder in dataset_dict.keys():
    new_mixture_path = os.path.join("/Users/alessandromanattini/Desktop/MAE/SELECTED TOPIC/PROJECT STMAE/musdb18hq_trimmed", track_folder, "new_mixture.wav")
    if os.path.isfile(new_mixture_path):
        mixture_files.append(new_mixture_path)
    else:
        print(f"Warning: file not found {new_mixture_path}")


Define the parameters of the STFT:

In [None]:
import librosa
# STFT parameters
n_fft = 2048
hop_length = 512
win = 'hann'

# Initialize lists to store STFT results
S_full_list = []
phase_list = []

# Loop through each mixture file and compute STFT
for mixture_path in mixture_files:
    # Carica l'audio dal file
    audio, sr = librosa.load(mixture_path, sr=None)
    # Calcola STFT
    D = librosa.stft(audio, n_fft=n_fft, hop_length=hop_length, window=win)
    # Estrai modulo e fase
    mag, phase = librosa.magphase(D)
    
    S_full_list.append(mag)
    phase_list.append(phase)

In [None]:
# Import required modules for SDR computation
try:
    from mir_eval.separation import bss_eval_sources
except ImportError:
    print("Installing mir_eval...")
    %pip install mir_eval
    from mir_eval.separation import bss_eval_sources

def compute_sdr(original, separated, sr):
    """
    Compute Signal-to-Distortion Ratio (SDR) between original and separated signals.
    """
    sdr, _, _, _ = bss_eval_sources(
        np.array([original]), 
        np.array([separated]), 
        compute_permutation=True
    )
    return sdr[0]  # Return SDR for the first source

**DRUMS EXTRACTION**
(using HPSS):

- STFT Magnitude Input: The function receives the magnitude (mix_mag) and phase (mix_phase) of the mixture’s STFT.

- HPSS Decomposition: It utilizes the **Harmonic-Percussive Source Separation (HPSS)** algorithm to split the mixture’s magnitude into two components:
    1) A ***harmonic component*** that captures the tonal content.
    2) A ***percussive component*** that emphasizes transient, drum-like features.

- Drums Reconstruction: The function then reconstructs the time-domain drums signal by combining the percussive component with the original phase information using the iSTFT.

- Output: The result is a time-domain signal (drums) that represents the extracted percussive (drum) elements from the mixture.

**Vocal Extraction** (using REPET-SIM):

- STFT Magnitude Input: Similar to the drums extraction, this function takes the mixture’s magnitude (mix_mag) and phase (mix_phase) as input.

- NN Filter Processing: The function applies a ***nearest-neighbor (nn) filter*** to the magnitude. This filter:
    - Uses a median aggregation to estimate a smooth background signal.
    - Operates with a cosine similarity metric and a time window (converted from 2.0 seconds into frames) to capture repeating patterns.

- Filter Application: The filtered version (S_filter) is then constrained by taking the element-wise minimum with the original magnitude, ensuring that only components present in both are retained.

- Soft Mask Creation: A soft mask is computed using librosa.util.softmask that emphasizes differences between the original magnitude and the filtered background. This mask is tuned (with a factor of 100 and power 2) to highlight vocal components.

- Vocals Reconstruction: The function applies this mask to the original magnitude to produce a modified magnitude focused on the vocal content. It then reconstructs the time-domain vocal signal by combining this modified magnitude with the original phase via the iSTFT.

- Output: The final output is a time-domain signal representing the extracted vocals from the mixture.

**BASS EXTRACTION**

For bass components, we simply applied a low-pass filter to retrieve the low frequency components.

**OTHER EXTRACTION**
Finally we use a subtractive approach to retrieve the other component. 
We define :   

***y_other = y - y_drums - y_vocals - y_bass***

Now we are going to do the separation for just a random track of the dataset to see the results.

Single Track

In [None]:
import IPython.display as ipd
# Perform extraction for the first mixture and listen to the results
mixture_path = mixture_files[49]
mixture, sr = librosa.load(mixture_path, sr=None)

# Play the mixture
ipd.display(ipd.Audio(mixture, rate=sr))

# DSP-Based Source Separation

## Overview
A classical digital signal processing approach that combines multiple techniques for audio source separation without machine learning.

## Function Overview: `separate_sources_v2`

### 1. **Initial HPSS Decomposition**
- Uses **Harmonic-Percussive Source Separation** to split the mixture into:
  - **Harmonic component**: Captures tonal/pitched content (vocals, bass, other instruments)
  - **Percussive component**: Emphasizes transient, drum-like features

### 2. **Drums Extraction (Enhanced HPSS)**
- Starts with percussive component from HPSS
- **Vocal frequency cleaning**: Attenuates 300-3000 Hz range to remove vocal bleed
- Applies attenuation factor (0.4) to reduce vocal artifacts in drum track
- Reconstructs cleaned drum signal using original phase

### 3. **Bass Extraction (Low-pass Filtering)**
- Applies Butterworth low-pass filter to harmonic component
- **Cutoff frequency**: 200 Hz (captures fundamental bass frequencies)
- **Filter order**: 4th order for sharp frequency rolloff

### 4. **Vocals Extraction (REPET-SIM)**
- Uses **nearest-neighbor filtering** on harmonic component
- **Repeating pattern removal**: Identifies and removes instrumental patterns
- **Cosine similarity**: Compares spectral frames over 2-second windows
- **Vocal isolation**: Extracts non-repetitive content (vocals) from repetitive background

### 5. **Other Component (Subtractive)**
- Calculates residual: `y_other = y_mixture - y_drums - y_bass - y_vocals`
- Captures remaining instruments and ambient sounds

## Key Features
- **No training required**: Uses classical DSP techniques
- **Real-time capable**: Computationally efficient
- **Frequency-aware**: Each method targets appropriate frequency ranges
- **Artifact reduction**: Includes vocal cleaning step for drums
- **Visualization**: Generates spectrograms for each separated component

In [None]:
import librosa
import numpy as np
import scipy.signal

def separate_sources_v2( # Rinomino la funzione per chiarezza
    mixture_path,
    sr=None,
    hpss_margin=1.0, # Leggermente aumentato come punto di partenza
    hpss_power=2.0,
    # Parametri per la pulizia delle voci dalla batteria
    drum_clean_vocal_freq_min=300.0,  # Hz, inizio range vocale da attenuare
    drum_clean_vocal_freq_max=3000.0,  # Hz, fine range vocale da attenuare
    drum_clean_vocal_atten_factor=0.4, # Fattore di attenuazione (0.0 = muto, 1.0 = nessun cambiamento)
    # Parametri per il basso
    bass_cutoff=200.0,
    bass_order=4,
    # Parametri per la separazione vocale principale
    nn_width_vocals_sec=2.0 # Aumentato per una migliore separazione vocale
):
    """
    Carica un file audio e lo separa in batteria, basso, voci e altro,
    con un passaggio dedicato per pulire i residui vocali dalla batteria.
    Restituisce un dizionario di array numpy (tutti mono).
    """
    # 1) Caricamento e conversione in mono
    y, sr_loaded = librosa.load(mixture_path, sr=sr)
    if sr is None: # Se sr non era specificato, usa quello del file
        sr = sr_loaded
        
    if y.ndim > 1:
        y = librosa.to_mono(y)

    # 2) HPSS per l'estrazione iniziale di batteria e componente armonica
    # Calcola lo STFT del mix originale
    D_mixture = librosa.stft(y)
    n_fft = (D_mixture.shape[0] - 1) * 2 # Infer n_fft dallo spettrogramma
    
    D_harmonic_mixture, D_percussive_mixture = librosa.decompose.hpss(
        D_mixture, 
        margin=hpss_margin, 
        power=hpss_power
    )
    
    # Componente armonica generale (verrà usata per basso e voci)
    y_harmonic_overall = librosa.istft(D_harmonic_mixture, length=len(y))

    # --- 3) Pulizia della Batteria dai Residui Vocali ---
    # Partiamo da D_percussive_mixture (lo spettrogramma della batteria da HPSS)
    D_perc_mag, D_perc_phase = librosa.magphase(D_percussive_mixture)
    
    # Crea una copia della magnitudine per la modifica
    D_perc_mag_cleaned = np.copy(D_perc_mag)
    
    # Ottieni le frequenze corrispondenti ai bin dello STFT
    frequencies = librosa.fft_frequencies(sr=sr, n_fft=n_fft)
    
    # Applica l'attenuazione nel range di frequenza vocale definito
    for i, freq_bin in enumerate(frequencies):
        if drum_clean_vocal_freq_min <= freq_bin <= drum_clean_vocal_freq_max:
            D_perc_mag_cleaned[i, :] *= drum_clean_vocal_atten_factor
            
    # Ricostruisci lo spettrogramma della batteria pulita
    D_drums_cleaned_stft = D_perc_mag_cleaned * D_perc_phase
    y_drums = librosa.istft(D_drums_cleaned_stft, length=len(y))

    # --- 4) Estrazione del Basso dalla componente armonica generale ---
    # Applica un filtro passa-basso a y_harmonic_overall
    nyquist = 0.5 * sr
    # Assicurati che bass_cutoff sia sotto la frequenza di Nyquist
    actual_bass_cutoff = min(bass_cutoff, nyquist - 1) # -1 per sicurezza
    if actual_bass_cutoff <= 0:
        print(f"Attenzione: bass_cutoff ({bass_cutoff} Hz) non valido con sr={sr} Hz. Il basso non verrà filtrato.")
        y_bass = np.zeros_like(y_harmonic_overall) # o gestisci diversamente
    else:
        b, a = scipy.signal.butter(bass_order, actual_bass_cutoff / nyquist, btype='low', analog=False)
        y_bass = scipy.signal.lfilter(b, a, y_harmonic_overall)

    # --- 5) Estrazione delle Voci dalla componente armonica generale ---
    # Nota: per una migliore separazione, si potrebbe sottrarre il basso stimato
    # da y_harmonic_overall prima di cercare le voci, ma per semplicità usiamo y_harmonic_overall.
    # y_harmonic_minus_bass = y_harmonic_overall - y_bass # Opzionale, potrebbe aiutare
    
    S_harmonic_overall, phase_harmonic_overall = librosa.magphase(librosa.stft(y_harmonic_overall)) # o di y_harmonic_minus_bass
    
    width_vocals_frames = int(librosa.time_to_frames(nn_width_vocals_sec, sr=sr, n_fft=n_fft))
    
    S_instrumental_repeating = librosa.decompose.nn_filter(
        S_harmonic_overall, 
        aggregate=np.median, 
        metric='cosine', 
        width=width_vocals_frames
    )
    
    # La maschera per le voci è ciò che NON è ripetitivo nella componente armonica
    S_vocals_mag = np.clip(S_harmonic_overall - S_instrumental_repeating, 0, None)
    
    y_vocals = librosa.istft(S_vocals_mag * phase_harmonic_overall, length=len(y))

    # --- 6) Calcolo del Residuo "Other" ---
    # Sottrai le componenti stimate dal mix originale
    y_other = y - y_drums - y_bass - y_vocals

    # Plot the spectrograms for visualization
    import matplotlib.pyplot as plt
    def plot_spectrogram(y, sr, title):
        D = librosa.amplitude_to_db(np.abs(librosa.stft(y)), ref=np.max)
        plt.figure(figsize=(10, 4))
        librosa.display.specshow(D, sr=sr, x_axis='time', y_axis='log')
        plt.colorbar(format='%+2.0f dB')
        plt.title(title)
        plt.tight_layout()
        plt.show()
    # Plot spectrograms for each component
    # plot_spectrogram(y_drums, sr, "Drums Spectrogram")
    # plot_spectrogram(y_bass, sr, "Bass Spectrogram")
    # plot_spectrogram(y_vocals, sr, "Vocals Spectrogram")
    # plot_spectrogram(y_other, sr, "Other Spectrogram")
    # # Plot the original mixture for comparison
    # plot_spectrogram(y, sr, "Original Mixture Spectrogram")

    return {
        'drums': y_drums,
        'bass': y_bass,
        'vocals': y_vocals,
        'other': y_other,
        'sr': sr
    }

In [None]:
sources = separate_sources_v2(mixture_path)
drums  = sources['drums']
bass   = sources['bass']
vocals = sources['vocals']
other  = sources['other']

# play mixture
print("Mixture:")
display(ipd.Audio(mixture, rate=sr))

# Play the separated components
print("Drums Component:")
display(ipd.Audio(drums, rate=sources['sr']))
print("Bass Component:")
display(ipd.Audio(bass, rate=sources['sr']))
print("Vocals Component:")
display(ipd.Audio(vocals, rate=sources['sr']))
print("Other Component:")
display(ipd.Audio(other, rate=sources['sr']))

# Plot the spectrogram of the original components and the separated components
import matplotlib.pyplot as plt

# Get the track name from the mixture path
track_name = mixture_path.split('/')[-2]  # Extract folder name from path

# Load original stems from dataset_dict
original_stems = dataset_dict[track_name]

def plot_spectrogram_comparison(original_audio, separated_audio, stem_name, sr, 
                              original_title="Original", separated_title="Separated"):
    """
    Plot side-by-side spectrograms for comparison
    """
    fig, axes = plt.subplots(1, 2, figsize=(15, 5))
    
    # Convert tensors to numpy if needed and handle dimensions
    if hasattr(original_audio, 'numpy'):
        original_np = original_audio.numpy()
    else:
        original_np = original_audio
    
    if original_np.ndim == 2:
        original_np = np.mean(original_np, axis=0)  # Convert to mono
    
    if separated_audio.ndim == 2:
        separated_np = np.mean(separated_audio, axis=0)  # Convert to mono
    else:
        separated_np = separated_audio
    
    # Ensure same length for comparison
    min_len = min(len(original_np), len(separated_np))
    original_np = original_np[:min_len]
    separated_np = separated_np[:min_len]
    
    # Compute spectrograms
    D_original = librosa.stft(original_np)
    D_separated = librosa.stft(separated_np)
    
    # Convert to dB scale
    S_original_db = librosa.amplitude_to_db(np.abs(D_original), ref=np.max)
    S_separated_db = librosa.amplitude_to_db(np.abs(D_separated), ref=np.max)
    
    # Plot original spectrogram
    img1 = librosa.display.specshow(S_original_db, sr=sr, x_axis='time', y_axis='hz', 
                                    ax=axes[0], cmap='viridis')
    axes[0].set_title(f'{original_title} {stem_name.capitalize()}')
    axes[0].set_ylabel('Frequency (Hz)')
    plt.colorbar(img1, ax=axes[0], format='%+2.0f dB')
    
    # Plot separated spectrogram
    img2 = librosa.display.specshow(S_separated_db, sr=sr, x_axis='time', y_axis='hz', 
                                    ax=axes[1], cmap='viridis')
    axes[1].set_title(f'{separated_title} {stem_name.capitalize()}')
    axes[1].set_ylabel('Frequency (Hz)')
    plt.colorbar(img2, ax=axes[1], format='%+2.0f dB')
    
    plt.tight_layout()
    plt.show()
    
    # Compute and display similarity metrics
    mse = np.mean((S_original_db - S_separated_db)**2)
    correlation = np.corrcoef(S_original_db.flatten(), S_separated_db.flatten())[0, 1]
    
    print(f"{stem_name.capitalize()} Spectrogram Analysis:")
    print(f"  MSE: {mse:.4f}")
    print(f"  Correlation: {correlation:.4f}")
    print()

# Create comparison plots for each stem
stems_info = [
    ('drums', drums, 'Drums'),
    ('bass', bass, 'Bass'), 
    ('vocals', vocals, 'Vocals'),
    ('other', other, 'Other')
]

print("=== SPECTROGRAM COMPARISONS ===\n")

for stem_key, separated_audio, stem_display_name in stems_info:
    if stem_key in original_stems:
        print(f"Comparing {stem_display_name}...")
        plot_spectrogram_comparison(
            original_stems[stem_key], 
            separated_audio, 
            stem_display_name.lower(),
            sources['sr'],
            "Ground Truth",
            "DSP Separated"
        )
    else:
        print(f"Warning: {stem_key} not found in original stems")

# Plot mixture spectrogram for reference
print("=== MIXTURE SPECTROGRAM ===")
plt.figure(figsize=(12, 6))
D_mixture = librosa.stft(mixture)
S_mixture_db = librosa.amplitude_to_db(np.abs(D_mixture), ref=np.max)
librosa.display.specshow(S_mixture_db, sr=sr, x_axis='time', y_axis='hz', cmap='viridis')
plt.colorbar(format='%+2.0f dB')
plt.title('Original Mixture Spectrogram')
plt.ylabel('Frequency (Hz)')
plt.xlabel('Time (s)')
plt.tight_layout()
plt.show()

# Calculate SDR values
sdr_drums = compute_sdr(mixture, sources['drums'], sources['sr'])
sdr_bass = compute_sdr(mixture, sources['bass'], sources['sr'])
sdr_vocals = compute_sdr(mixture, sources['vocals'], sources['sr'])
sdr_other = compute_sdr(mixture, sources['other'], sources['sr'])

print("=== SDR RESULTS ===")
print(f"SDR for Drums: {sdr_drums:.2f} dB")
print(f"SDR for Bass: {sdr_bass:.2f} dB")
print(f"SDR for Vocals: {sdr_vocals:.2f} dB")
print(f"SDR for Other: {sdr_other:.2f} dB")

# Create a summary comparison plot
fig, axes = plt.subplots(2, 4, figsize=(20, 10))

# Row 1: Original stems
# Row 2: Separated stems
for i, (stem_key, separated_audio, stem_display_name) in enumerate(stems_info):
    if stem_key in original_stems:
        # Original stem (top row)
        original_audio = original_stems[stem_key]
        if hasattr(original_audio, 'numpy'):
            original_np = original_audio.numpy()
        else:
            original_np = original_audio
        
        if original_np.ndim == 2:
            original_np = np.mean(original_np, axis=0)
            
        D_orig = librosa.stft(original_np)
        S_orig_db = librosa.amplitude_to_db(np.abs(D_orig), ref=np.max)
        
        img1 = librosa.display.specshow(S_orig_db, sr=sources['sr'], x_axis='time', y_axis='hz', 
                                        ax=axes[0, i], cmap='viridis')
        axes[0, i].set_title(f'Original {stem_display_name}')
        
        # Separated stem (bottom row)
        if separated_audio.ndim == 2:
            separated_np = np.mean(separated_audio, axis=0)
        else:
            separated_np = separated_audio
            
        min_len = min(len(original_np), len(separated_np))
        separated_np = separated_np[:min_len]
        
        D_sep = librosa.stft(separated_np)
        S_sep_db = librosa.amplitude_to_db(np.abs(D_sep), ref=np.max)
        
        img2 = librosa.display.specshow(S_sep_db, sr=sources['sr'], x_axis='time', y_axis='hz', 
                                        ax=axes[1, i], cmap='viridis')
        axes[1, i].set_title(f'Separated {stem_display_name}')
        
        # Add colorbars
        plt.colorbar(img1, ax=axes[0, i], format='%+2.0f dB')
        plt.colorbar(img2, ax=axes[1, i], format='%+2.0f dB')

plt.suptitle('Spectrogram Comparison: Original vs Separated Stems', fontsize=16)
plt.tight_layout()
plt.show()

# Non-Negative Matrix Factorization (NMF) for Source Separation

## What is Non-Negative Matrix Factorization?

**Non-Negative Matrix Factorization (NMF)** is a dimensionality reduction technique that decomposes a non-negative matrix **V** into two non-negative matrices **W** and **H**:

**V ≈ W × H**

Where:
- **V**: Original magnitude spectrogram (frequency × time)
- **W**: Basis matrix containing spectral templates (frequency × components)  
- **H**: Activation matrix showing when each template is active (components × time)

In audio source separation, each column of **W** represents a spectral pattern (e.g., drum timbre, vocal formant), and **H** shows their temporal activations.

## Function Overview: `separate_sources_v3`

### 1. **Audio Loading & STFT**
- Loads audio file and converts to mono
- Computes Short-Time Fourier Transform (STFT) to get magnitude and phase
- Extracts frequency bins for masking

### 2. **Frequency-Informed Separation**
- Creates frequency masks for each source type:
  - **Drums**: 20-8000 Hz (wide range for transients)
  - **Vocals**: 80-8000 Hz (human voice range)
  - **Bass**: 20-250 Hz (low frequencies)
  - **Other**: 200-16000 Hz (mid-high frequencies)

### 3. **Hierarchical NMF Processing**
- **Initial HPSS**: Separates harmonic and percussive components
- **Drums**: Applies NMF to percussive component in drum frequency range
- **Bass**: Applies NMF to harmonic component in low frequencies
- **Vocals**: Applies NMF to remaining harmonic content (after bass removal)
- **Other**: Applies NMF to residual harmonic content

### 4. **NMF Algorithm** (`nmf_separate` function)
- Extracts frequency band using mask
- Runs NMF with specified number of components
- Reconstructs separated source: **S_reconstructed = W @ H**
- Places result back in full frequency spectrum

### 5. **Post-Processing**
- **Soft Masking**: Reduces artifacts using probabilistic masks
- **Audio Reconstruction**: Combines magnitude with original phase via inverse STFT
- **Normalization**: Prevents clipping by scaling to safe levels

### Key Advantages
- **Learned Patterns**: NMF discovers instrument-specific spectral templates
- **Frequency Awareness**: Focuses each model on appropriate frequency ranges
- **Flexibility**: Adjustable number of components per source
- **Artifact Reduction**: Soft masking maintains audio quality

In [None]:
# %pip install mir-eval
import librosa
import numpy as np
from sklearn.decomposition import NMF
import warnings
try:
    from mir_eval.separation import bss_eval_sources
except ImportError:
    print("Installing mir_eval...")
    %pip install mir_eval
    from mir_eval.separation import bss_eval_sources

def separate_sources_v3(
    mixture_path,
    sr=None,
    n_components_drums=4,      # NMF components for drums
    n_components_vocals=6,     # NMF components for vocals  
    n_components_bass=3,       # NMF components for bass
    n_components_other=5,      # NMF components for other
    n_fft=2048,
    hop_length=512,
    # Frequency ranges for each instrument (Hz)
    drums_freq_range=(20, 8000),    # Drums: wide range with emphasis on transients
    vocals_freq_range=(80, 8000),   # Vocals: human voice range
    bass_freq_range=(20, 250),      # Bass: low frequencies
    other_freq_range=(200, 16000),  # Other: mid-high frequencies
    max_iter=200,
    random_state=42
):
    """
    Separate audio sources using Non-negative Matrix Factorization (NMF).
    Uses frequency-informed NMF to separate drums, vocals, bass, and other components.
    
    Returns a dictionary with separated sources.
    """
    
    # 1) Load audio
    y, sr_loaded = librosa.load(mixture_path, sr=sr)
    if sr is None:
        sr = sr_loaded
        
    if y.ndim > 1:
        y = librosa.to_mono(y)

    # 2) Compute STFT
    D = librosa.stft(y, n_fft=n_fft, hop_length=hop_length)
    S_full = np.abs(D)  # Magnitude spectrogram
    phase = np.angle(D)  # Phase information
    
    # Get frequency bins
    freqs = librosa.fft_frequencies(sr=sr, n_fft=n_fft)
    
    # 3) Create frequency masks for each source type
    def create_freq_mask(freq_range):
        """Create a frequency mask for the given range"""
        mask = (freqs >= freq_range[0]) & (freqs <= freq_range[1])
        return mask
    
    drums_mask = create_freq_mask(drums_freq_range)
    vocals_mask = create_freq_mask(vocals_freq_range)
    bass_mask = create_freq_mask(bass_freq_range)
    other_mask = create_freq_mask(other_freq_range)
    
    # 4) Apply frequency masking and run NMF on each frequency band
    def nmf_separate(S_masked, n_components, mask):
        """Run NMF on a frequency-masked spectrogram"""
        if not np.any(mask):
            return np.zeros_like(S_masked)
            
        # Extract the relevant frequency band
        S_band = S_masked[mask, :]
        
        if S_band.size == 0 or np.all(S_band == 0):
            return np.zeros_like(S_masked)
        
        # Suppress warnings for convergence
        with warnings.catch_warnings():
            warnings.simplefilter("ignore")
            
            # Run NMF
            model = NMF(
                n_components=n_components, 
                init='random', 
                random_state=random_state,
                max_iter=max_iter,
                alpha_W=0.1,  # Sparsity for basis vectors
                alpha_H=0.1   # Sparsity for activations
            )
            
            try:
                W = model.fit_transform(S_band)  # Basis spectra
                H = model.components_            # Activations
                
                # Reconstruct the separated component
                S_reconstructed = W @ H
                
                # Create full spectrogram with zeros outside the frequency band
                S_separated = np.zeros_like(S_masked)
                S_separated[mask, :] = S_reconstructed
                
                return S_separated
                
            except Exception as e:
                print(f"NMF failed: {e}")
                return np.zeros_like(S_masked)
    
    # 5) Separate each source using NMF
    print("Separating sources with NMF...")
    
    # Initial HPSS to help guide separation
    S_harmonic, S_percussive = librosa.decompose.hpss(S_full, margin=1.0)
    
    # Drums: focus on percussive component
    S_drums = nmf_separate(S_percussive, n_components_drums, drums_mask)
    
    # Bass: focus on low frequencies in harmonic component  
    S_bass = nmf_separate(S_harmonic, n_components_bass, bass_mask)
    
    # Vocals: focus on mid frequencies in harmonic component after removing bass
    S_harmonic_no_bass = np.maximum(S_harmonic - S_bass, 0.1 * S_harmonic)
    S_vocals = nmf_separate(S_harmonic_no_bass, n_components_vocals, vocals_mask)
    
    # Other: remaining harmonic content
    S_remaining = np.maximum(S_harmonic - S_bass - S_vocals, 0.1 * S_harmonic)
    S_other = nmf_separate(S_remaining, n_components_other, other_mask)
    
    # 6) Post-processing: ensure non-negativity and apply soft masks
    def apply_soft_mask(S_target, S_mixture, power=2, eps=1e-10):
        """Apply soft masking to reduce artifacts"""
        mask = (S_target ** power) / (S_mixture ** power + eps)
        mask = np.clip(mask, 0, 1)
        return mask * S_mixture
    
    # Apply soft masking to reduce artifacts
    S_drums = apply_soft_mask(S_drums, S_full)
    S_vocals = apply_soft_mask(S_vocals, S_full) 
    S_bass = apply_soft_mask(S_bass, S_full)
    S_other = apply_soft_mask(S_other, S_full)
    
    # 7) Reconstruct time-domain signals
    def reconstruct_audio(S_mag, phase_orig):
        """Reconstruct audio from magnitude and phase"""
        S_complex = S_mag * np.exp(1j * phase_orig)
        return librosa.istft(S_complex, hop_length=hop_length, length=len(y))
    
    y_drums = reconstruct_audio(S_drums, phase)
    y_vocals = reconstruct_audio(S_vocals, phase)
    y_bass = reconstruct_audio(S_bass, phase)
    y_other = reconstruct_audio(S_other, phase)
    
    # 8) Normalize to prevent clipping
    def safe_normalize(signal, max_val=0.95):
        """Normalize signal to prevent clipping"""
        if np.max(np.abs(signal)) > 0:
            return max_val * signal / np.max(np.abs(signal))
        return signal
    
    y_drums = safe_normalize(y_drums)
    y_vocals = safe_normalize(y_vocals)  
    y_bass = safe_normalize(y_bass)
    y_other = safe_normalize(y_other)
    
    return {
        'drums': y_drums,
        'bass': y_bass,
        'vocals': y_vocals,
        'other': y_other,
        'sr': sr
    }


In [None]:
# Test the new function
print("Testing NMF-based separation...")
sources_nmf = separate_sources_v3(mixture_path)
# Compute sdr between original and separated components

drums_nmf = sources_nmf['drums']
bass_nmf = sources_nmf['bass'] 
vocals_nmf = sources_nmf['vocals']
other_nmf = sources_nmf['other']

# Play the separated components
print("NMF Drums Component:")
display(ipd.Audio(drums_nmf, rate=sources_nmf['sr']))
print("NMF Bass Component:")
display(ipd.Audio(bass_nmf, rate=sources_nmf['sr']))
print("NMF Vocals Component:")
display(ipd.Audio(vocals_nmf, rate=sources_nmf['sr']))
print("NMF Other Component:")
display(ipd.Audio(other_nmf, rate=sources_nmf['sr']))

# Create comparison plots for each stem - NMF version
stems_info_nmf = [
    ('drums', drums_nmf, 'Drums'),
    ('bass', bass_nmf, 'Bass'), 
    ('vocals', vocals_nmf, 'Vocals'),
    ('other', other_nmf, 'Other')
]

print("\n=== NMF SPECTROGRAM COMPARISONS ===\n")

for stem_key, separated_audio, stem_display_name in stems_info_nmf:
    if stem_key in original_stems:
        print(f"Comparing {stem_display_name} (NMF)...")
        plot_spectrogram_comparison(
            original_stems[stem_key], 
            separated_audio, 
            stem_display_name.lower(),
            sources_nmf['sr'],
            "Ground Truth",
            "NMF Separated"
        )
    else:
        print(f"Warning: {stem_key} not found in original stems")

# Create a summary comparison plot for NMF
print("\n=== NMF SUMMARY COMPARISON ===")
fig, axes = plt.subplots(2, 4, figsize=(20, 10))

# Row 1: Original stems
# Row 2: NMF separated stems
for i, (stem_key, separated_audio, stem_display_name) in enumerate(stems_info_nmf):
    if stem_key in original_stems:
        # Original stem (top row)
        original_audio = original_stems[stem_key]
        if hasattr(original_audio, 'numpy'):
            original_np = original_audio.numpy()
        else:
            original_np = original_audio
        
        if original_np.ndim == 2:
            original_np = np.mean(original_np, axis=0)
            
        D_orig = librosa.stft(original_np)
        S_orig_db = librosa.amplitude_to_db(np.abs(D_orig), ref=np.max)
        
        img1 = librosa.display.specshow(S_orig_db, sr=sources_nmf['sr'], x_axis='time', y_axis='hz', 
                                        ax=axes[0, i], cmap='viridis')
        axes[0, i].set_title(f'Original {stem_display_name}')
        
        # NMF separated stem (bottom row)
        if separated_audio.ndim == 2:
            separated_np = np.mean(separated_audio, axis=0)
        else:
            separated_np = separated_audio
            
        min_len = min(len(original_np), len(separated_np))
        separated_np = separated_np[:min_len]
        
        D_sep = librosa.stft(separated_np)
        S_sep_db = librosa.amplitude_to_db(np.abs(D_sep), ref=np.max)
        
        img2 = librosa.display.specshow(S_sep_db, sr=sources_nmf['sr'], x_axis='time', y_axis='hz', 
                                        ax=axes[1, i], cmap='viridis')
        axes[1, i].set_title(f'NMF Separated {stem_display_name}')
        
        # Add colorbars
        plt.colorbar(img1, ax=axes[0, i], format='%+2.0f dB')
        plt.colorbar(img2, ax=axes[1, i], format='%+2.0f dB')

plt.suptitle('NMF Spectrogram Comparison: Original vs Separated Stems', fontsize=16)
plt.tight_layout()
plt.show()

# Compute SDR for each separated source
sdr_drums_nmf = compute_sdr(mixture, sources_nmf['drums'], sources_nmf['sr'])
sdr_bass_nmf = compute_sdr(mixture, sources_nmf['bass'], sources_nmf['sr'])
sdr_vocals_nmf = compute_sdr(mixture, sources_nmf['vocals'], sources_nmf['sr'])
sdr_other_nmf = compute_sdr(mixture, sources_nmf['other'], sources_nmf['sr'])

print("\n=== NMF SDR RESULTS ===")
print(f"SDR for Drums: {sdr_drums_nmf:.2f} dB")
print(f"SDR for Bass: {sdr_bass_nmf:.2f} dB")
print(f"SDR for Vocals: {sdr_vocals_nmf:.2f} dB")
print(f"SDR for Other: {sdr_other_nmf:.2f} dB")

# Optional: Create a side-by-side comparison between DSP and NMF methods
print("\n=== DSP vs NMF COMPARISON ===")
fig, axes = plt.subplots(3, 4, figsize=(24, 15))

# Get DSP results for comparison (assuming you ran the DSP separation earlier)
if 'sources' in locals():
    stems_comparison = [
        ('drums', drums, drums_nmf, 'Drums'),
        ('bass', bass, bass_nmf, 'Bass'),
        ('vocals', vocals, vocals_nmf, 'Vocals'),
        ('other', other, other_nmf, 'Other')
    ]
    
    for i, (stem_key, dsp_audio, nmf_audio, stem_display_name) in enumerate(stems_comparison):
        if stem_key in original_stems:
            # Original stem (top row)
            original_audio = original_stems[stem_key]
            if hasattr(original_audio, 'numpy'):
                original_np = original_audio.numpy()
            else:
                original_np = original_audio
            
            if original_np.ndim == 2:
                original_np = np.mean(original_np, axis=0)
                
            D_orig = librosa.stft(original_np)
            S_orig_db = librosa.amplitude_to_db(np.abs(D_orig), ref=np.max)
            
            img1 = librosa.display.specshow(S_orig_db, sr=sources_nmf['sr'], x_axis='time', y_axis='hz', 
                                            ax=axes[0, i], cmap='viridis')
            axes[0, i].set_title(f'Original {stem_display_name}')
            
            # DSP separated stem (middle row)
            if dsp_audio.ndim == 2:
                dsp_np = np.mean(dsp_audio, axis=0)
            else:
                dsp_np = dsp_audio
            
            min_len_all = min(len(original_np), len(dsp_np), len(nmf_audio))
            dsp_np = dsp_np[:min_len_all]
            
            D_dsp = librosa.stft(dsp_np)
            S_dsp_db = librosa.amplitude_to_db(np.abs(D_dsp), ref=np.max)
            
            img2 = librosa.display.specshow(S_dsp_db, sr=sources_nmf['sr'], x_axis='time', y_axis='hz', 
                                            ax=axes[1, i], cmap='viridis')
            axes[1, i].set_title(f'DSP Separated {stem_display_name}')
            
            # NMF separated stem (bottom row)
            if nmf_audio.ndim == 2:
                nmf_np = np.mean(nmf_audio, axis=0)
            else:
                nmf_np = nmf_audio
                
            nmf_np = nmf_np[:min_len_all]
            
            D_nmf = librosa.stft(nmf_np)
            S_nmf_db = librosa.amplitude_to_db(np.abs(D_nmf), ref=np.max)
            
            img3 = librosa.display.specshow(S_nmf_db, sr=sources_nmf['sr'], x_axis='time', y_axis='hz', 
                                            ax=axes[2, i], cmap='viridis')
            axes[2, i].set_title(f'NMF Separated {stem_display_name}')
            
            # Add colorbars
            plt.colorbar(img1, ax=axes[0, i], format='%+2.0f dB')
            plt.colorbar(img2, ax=axes[1, i], format='%+2.0f dB')
            plt.colorbar(img3, ax=axes[2, i], format='%+2.0f dB')

    plt.suptitle('Method Comparison: Original vs DSP vs NMF Separation', fontsize=16)
    plt.tight_layout()
    plt.show()
    
    # Compare SDR values
    print("\n=== SDR COMPARISON: DSP vs NMF ===")
    print(f"{'Stem':<8} {'DSP SDR (dB)':<12} {'NMF SDR (dB)':<12} {'Difference':<12}")
    print("-" * 50)
    
    dsp_sdrs = [sdr_drums, sdr_bass, sdr_vocals, sdr_other]
    nmf_sdrs = [sdr_drums_nmf, sdr_bass_nmf, sdr_vocals_nmf, sdr_other_nmf]
    stem_names = ['Drums', 'Bass', 'Vocals', 'Other']
    
    for stem, dsp_sdr, nmf_sdr in zip(stem_names, dsp_sdrs, nmf_sdrs):
        diff = nmf_sdr - dsp_sdr
        print(f"{stem:<8} {dsp_sdr:>8.2f}    {nmf_sdr:>8.2f}    {diff:>+8.2f}")
else:
    print("DSP separation results not found. Run DSP separation first for comparison.")

Whole Dataset

In [None]:
# import the libraries
import torch
import numpy as np
from torchmetrics import SignalDistortionRatio
import librosa

In [None]:
import warnings
import matplotlib.pyplot as plt
warnings.filterwarnings('ignore')
import tqdm
import time

# Calculate average SDR, SIR, SAR and their standard deviations for each stem across all tracks
stems = ['drums', 'vocals', 'bass', 'other']
average_sdr = {stem: [] for stem in stems}
average_sir = {stem: [] for stem in stems}
average_sar = {stem: [] for stem in stems}

track_folders = list(dataset_dict.keys())
print(f"Processing all stems simultaneously - Total tracks: {len(track_folders)}")

for idx, track_folder in enumerate(tqdm.tqdm(track_folders, desc="Processing tracks")):
    start_time = time.time()
    
    try:
        print(f"\n[{idx+1}/{len(track_folders)}] Processing {track_folder}...")
        
        mixture_path = os.path.join("/Users/alessandromanattini/Desktop/MAE/SELECTED TOPIC/PROJECT STMAE/musdb18hq_trimmed", track_folder, "new_mixture.wav")
        
        # Check if file exists
        if not os.path.exists(mixture_path):
            print(f"Skipping {track_folder}: mixture file not found")
            continue
        
        # Perform source separation once per track
        print(f"Starting source separation for {track_folder}...")
        separation_start = time.time()
        
        try:
            separated_sources = separate_sources_v2(mixture_path)
            separation_time = time.time() - separation_start
            print(f"Source separation completed in {separation_time:.2f}s")
        except Exception as e:
            print(f"ERROR in source separation for {track_folder}: {e}")
            continue
        
        # Prepare all reference and estimated sources
        ref_sources = []
        est_sources = []
        available_stems = []
        
        for stem in stems:
            if stem not in dataset_dict[track_folder]:
                print(f"Warning: {stem} not found in ground truth for {track_folder}")
                continue
            
            if stem not in separated_sources:
                print(f"Warning: {stem} not found in separated sources for {track_folder}")
                continue
                
            ref_stem = dataset_dict[track_folder][stem]
            est_stem = separated_sources[stem]
            
            # Convert to tensors and handle dimensions
            ref_tensor = torch.tensor(ref_stem, dtype=torch.float32)
            est_tensor = torch.tensor(est_stem, dtype=torch.float32)
            
            # Handle dimension mismatches
            if ref_tensor.dim() == 2:
                ref_tensor = torch.mean(ref_tensor, dim=0)
            if est_tensor.dim() == 2:
                est_tensor = torch.mean(est_tensor, dim=0)
            
            # Ensure same length
            min_len = min(len(ref_tensor), len(est_tensor))
            ref_tensor = ref_tensor[:min_len]
            est_tensor = est_tensor[:min_len]
            
            # Convert to numpy
            ref_np = ref_tensor.numpy().astype(np.float64)
            est_np = est_tensor.numpy().astype(np.float64)
            
            # Check energy threshold
            energy_threshold = 1e-6
            ref_energy = np.mean(ref_np**2)
            est_energy = np.mean(est_np**2)
            
            if ref_energy < energy_threshold or est_energy < energy_threshold:
                print(f"Skipping {stem} in {track_folder}: insufficient energy")
                continue
            
            ref_sources.append(ref_np)
            est_sources.append(est_np)
            available_stems.append(stem)
        
        # Only proceed if we have multiple sources (needed for SIR calculation)
        if len(ref_sources) < 2:
            print(f"Skipping {track_folder}: need at least 2 sources for proper BSS evaluation, got {len(ref_sources)}")
            continue
        
        # Convert to numpy arrays for mir_eval
        ref_sources = np.array(ref_sources)
        est_sources = np.array(est_sources)
        
        print(f"Evaluating BSS metrics for {len(available_stems)} sources...")
        print(f"Reference shape: {ref_sources.shape}, Estimated shape: {est_sources.shape}")
        
        try:
            # Compute BSS metrics for all sources simultaneously
            sdr, sir, sar, perm = bss_eval_sources(
                ref_sources, 
                est_sources, 
                compute_permutation=True  # Allow permutation to find best matching
            )
            
            print(f"BSS evaluation successful, permutation: {perm}")
            
            # Store results for each available stem
            for i, stem in enumerate(available_stems):
                # Use permutation to get correct mapping
                perm_idx = perm[i] if len(perm) > i else i
                
                if i < len(sdr) and np.isfinite(sdr[i]):
                    average_sdr[stem].append(float(sdr[i]))
                if i < len(sir) and np.isfinite(sir[i]):
                    average_sir[stem].append(float(sir[i]))
                if i < len(sar) and np.isfinite(sar[i]):
                    average_sar[stem].append(float(sar[i]))
                
                print(f"{stem}: SDR={sdr[i]:.3f}, SIR={sir[i]:.3f}, SAR={sar[i]:.3f}")
                
        except Exception as e:
            print(f"ERROR in BSS evaluation for {track_folder}: {e}")
            import traceback
            traceback.print_exc()
            continue
            
    except KeyboardInterrupt:
        print(f"Interrupted by user at {track_folder}")
        break
    except Exception as e:
        print(f"UNEXPECTED ERROR processing {track_folder}: {e}")
        import traceback
        traceback.print_exc()
        continue
    
    elapsed_time = time.time() - start_time
    print(f"Track {track_folder} completed in {elapsed_time:.2f}s")

# Calculate final statistics
final_sdr = {}
final_sir = {}
final_sar = {}
std_sdr = {}
std_sir = {}
std_sar = {}

for stem in stems:
    final_sdr[stem] = np.mean(average_sdr[stem]) if average_sdr[stem] else np.nan
    std_sdr[stem] = np.std(average_sdr[stem]) if len(average_sdr[stem]) > 1 else 0
    
    final_sir[stem] = np.mean(average_sir[stem]) if average_sir[stem] else np.nan
    std_sir[stem] = np.std(average_sir[stem]) if len(average_sir[stem]) > 1 else 0
    
    final_sar[stem] = np.mean(average_sar[stem]) if average_sar[stem] else np.nan
    std_sar[stem] = np.std(average_sar[stem]) if len(average_sar[stem]) > 1 else 0
    
    print(f"\n{stem.upper()}: {len(average_sdr[stem])} valid measurements")
    print(f"  SDR: {final_sdr[stem]:.3f} ± {std_sdr[stem]:.3f}")
    print(f"  SIR: {final_sir[stem]:.3f} ± {std_sir[stem]:.3f}")
    print(f"  SAR: {final_sar[stem]:.3f} ± {std_sar[stem]:.3f}")

# Verify metrics are different
print("\nSanity check - Are metrics different?")
for stem in stems:
    if not np.isnan(final_sdr[stem]) and not np.isnan(final_sir[stem]) and not np.isnan(final_sar[stem]):
        sdr_sir_diff = abs(final_sdr[stem] - final_sir[stem]) > 0.001
        sdr_sar_diff = abs(final_sdr[stem] - final_sar[stem]) > 0.001
        sir_sar_diff = abs(final_sir[stem] - final_sar[stem]) > 0.001
        print(f"{stem}: SDR≠SIR: {sdr_sir_diff}, SDR≠SAR: {sdr_sar_diff}, SIR≠SAR: {sir_sar_diff}")

print("Evaluation completed! Generating plots...")

# PLOTTING CODE
# Create subplot for all three metrics with error bars
fig, axes = plt.subplots(1, 3, figsize=(18, 6))

# Colors for each stem
colors = ['blue', 'green', 'red', 'orange']

# Helper function to handle NaN values in plotting
def safe_plot_bars(ax, keys, values, errors, title, ylabel, color_list):
    # Replace NaN with 0 for plotting
    plot_values = [v if not np.isnan(v) else 0 for v in values]
    plot_errors = [e if not np.isnan(e) else 0 for e in errors]
    
    bars = ax.bar(keys, plot_values, yerr=plot_errors, 
                  capsize=5, color=color_list, alpha=0.7)
    ax.set_xlabel('Stem Category')
    ax.set_ylabel(ylabel)
    ax.set_title(title)
    ax.grid(True, alpha=0.3)
    
    # Add value labels
    for bar, val, err in zip(bars, values, errors):
        if not np.isnan(val):
            ax.text(bar.get_x() + bar.get_width()/2, bar.get_height() + err + 0.1, 
                   f'{val:.2f}±{err:.2f}', ha='center', va='bottom', fontsize=9)
        else:
            ax.text(bar.get_x() + bar.get_width()/2, 0.1, 
                   'NaN', ha='center', va='bottom', fontsize=9, color='red')

# Plot all metrics
safe_plot_bars(axes[0], final_sdr.keys(), list(final_sdr.values()), 
               list(std_sdr.values()), 'Average SDR Performance by Stem Category', 
               'Average SDR (dB)', colors)

safe_plot_bars(axes[1], final_sir.keys(), list(final_sir.values()), 
               list(std_sir.values()), 'Average SIR Performance by Stem Category', 
               'Average SIR (dB)', colors)

safe_plot_bars(axes[2], final_sar.keys(), list(final_sar.values()), 
               list(std_sar.values()), 'Average SAR Performance by Stem Category', 
               'Average SAR (dB)', colors)

plt.tight_layout()
plt.show()

# Print detailed results with standard deviations
print("Average metrics with standard deviations by stem:")
print("=" * 60)
for stem in stems:
    print(f"{stem.upper()}:")
    sdr_str = f"{final_sdr[stem]:.4f} ± {std_sdr[stem]:.4f}" if not np.isnan(final_sdr[stem]) else "NaN"
    sir_str = f"{final_sir[stem]:.4f} ± {std_sir[stem]:.4f}" if not np.isnan(final_sir[stem]) else "NaN"
    sar_str = f"{final_sar[stem]:.4f} ± {std_sar[stem]:.4f}" if not np.isnan(final_sar[stem]) else "NaN"
    
    print(f"  SDR: {sdr_str} dB")
    print(f"  SIR: {sir_str} dB") 
    print(f"  SAR: {sar_str} dB")
    print()

# Create a comprehensive summary table
print("Summary Table with Standard Deviations:")
print("=" * 80)
print(f"{'Stem':<8} {'SDR (dB)':<18} {'SIR (dB)':<18} {'SAR (dB)':<18}")
print(f"{'':8} {'Mean ± Std':<18} {'Mean ± Std':<18} {'Mean ± Std':<18}")
print("-" * 80)
for stem in stems:
    sdr_cell = f"{final_sdr[stem]:>6.2f}±{std_sdr[stem]:<6.2f}" if not np.isnan(final_sdr[stem]) else "    NaN±  NaN"
    sir_cell = f"{final_sir[stem]:>6.2f}±{std_sir[stem]:<6.2f}" if not np.isnan(final_sir[stem]) else "    NaN±  NaN"
    sar_cell = f"{final_sar[stem]:>6.2f}±{std_sar[stem]:<6.2f}" if not np.isnan(final_sar[stem]) else "    NaN±  NaN"
    
    print(f"{stem:<8} {sdr_cell:<18} {sir_cell:<18} {sar_cell:<18}")

In [None]:
import warnings
import matplotlib.pyplot as plt
warnings.filterwarnings('ignore')
import tqdm
import time

# Calculate average SDR, SIR, SAR and their standard deviations for each stem across all tracks
stems = ['drums', 'vocals', 'bass', 'other']
average_sdr = {stem: [] for stem in stems}
average_sir = {stem: [] for stem in stems}
average_sar = {stem: [] for stem in stems}

track_folders = list(dataset_dict.keys())
print(f"Processing all stems simultaneously - Total tracks: {len(track_folders)}")

for idx, track_folder in enumerate(tqdm.tqdm(track_folders, desc="Processing tracks")):
    start_time = time.time()
    
    try:
        print(f"\n[{idx+1}/{len(track_folders)}] Processing {track_folder}...")
        
        mixture_path = os.path.join("/Users/alessandromanattini/Desktop/MAE/SELECTED TOPIC/PROJECT STMAE/musdb18hq_trimmed", track_folder, "new_mixture.wav")
        
        # Check if file exists
        if not os.path.exists(mixture_path):
            print(f"Skipping {track_folder}: mixture file not found")
            continue
        
        # Perform source separation once per track
        print(f"Starting source separation for {track_folder}...")
        separation_start = time.time()
        
        try:
            separated_sources = separate_sources_v3(mixture_path)
            separation_time = time.time() - separation_start
            print(f"Source separation completed in {separation_time:.2f}s")
        except Exception as e:
            print(f"ERROR in source separation for {track_folder}: {e}")
            continue
        
        # Prepare all reference and estimated sources
        ref_sources = []
        est_sources = []
        available_stems = []
        
        for stem in stems:
            if stem not in dataset_dict[track_folder]:
                print(f"Warning: {stem} not found in ground truth for {track_folder}")
                continue
            
            if stem not in separated_sources:
                print(f"Warning: {stem} not found in separated sources for {track_folder}")
                continue
                
            ref_stem = dataset_dict[track_folder][stem]
            est_stem = separated_sources[stem]
            
            # Convert to tensors and handle dimensions
            ref_tensor = torch.tensor(ref_stem, dtype=torch.float32)
            est_tensor = torch.tensor(est_stem, dtype=torch.float32)
            
            # Handle dimension mismatches
            if ref_tensor.dim() == 2:
                ref_tensor = torch.mean(ref_tensor, dim=0)
            if est_tensor.dim() == 2:
                est_tensor = torch.mean(est_tensor, dim=0)
            
            # Ensure same length
            min_len = min(len(ref_tensor), len(est_tensor))
            ref_tensor = ref_tensor[:min_len]
            est_tensor = est_tensor[:min_len]
            
            # Convert to numpy
            ref_np = ref_tensor.numpy().astype(np.float64)
            est_np = est_tensor.numpy().astype(np.float64)
            
            # Check energy threshold
            energy_threshold = 1e-6
            ref_energy = np.mean(ref_np**2)
            est_energy = np.mean(est_np**2)
            
            if ref_energy < energy_threshold or est_energy < energy_threshold:
                print(f"Skipping {stem} in {track_folder}: insufficient energy")
                continue
            
            ref_sources.append(ref_np)
            est_sources.append(est_np)
            available_stems.append(stem)
        
        # Only proceed if we have multiple sources (needed for SIR calculation)
        if len(ref_sources) < 2:
            print(f"Skipping {track_folder}: need at least 2 sources for proper BSS evaluation, got {len(ref_sources)}")
            continue
        
        # Convert to numpy arrays for mir_eval
        ref_sources = np.array(ref_sources)
        est_sources = np.array(est_sources)
        
        print(f"Evaluating BSS metrics for {len(available_stems)} sources...")
        print(f"Reference shape: {ref_sources.shape}, Estimated shape: {est_sources.shape}")
        
        try:
            # Compute BSS metrics for all sources simultaneously
            sdr, sir, sar, perm = bss_eval_sources(
                ref_sources, 
                est_sources, 
                compute_permutation=True  # Allow permutation to find best matching
            )
            
            print(f"BSS evaluation successful, permutation: {perm}")
            
            # Store results for each available stem
            for i, stem in enumerate(available_stems):
                # Use permutation to get correct mapping
                perm_idx = perm[i] if len(perm) > i else i
                
                if i < len(sdr) and np.isfinite(sdr[i]):
                    average_sdr[stem].append(float(sdr[i]))
                if i < len(sir) and np.isfinite(sir[i]):
                    average_sir[stem].append(float(sir[i]))
                if i < len(sar) and np.isfinite(sar[i]):
                    average_sar[stem].append(float(sar[i]))
                
                print(f"{stem}: SDR={sdr[i]:.3f}, SIR={sir[i]:.3f}, SAR={sar[i]:.3f}")
                
        except Exception as e:
            print(f"ERROR in BSS evaluation for {track_folder}: {e}")
            import traceback
            traceback.print_exc()
            continue
            
    except KeyboardInterrupt:
        print(f"Interrupted by user at {track_folder}")
        break
    except Exception as e:
        print(f"UNEXPECTED ERROR processing {track_folder}: {e}")
        import traceback
        traceback.print_exc()
        continue
    
    elapsed_time = time.time() - start_time
    print(f"Track {track_folder} completed in {elapsed_time:.2f}s")

# Calculate final statistics
final_sdr = {}
final_sir = {}
final_sar = {}
std_sdr = {}
std_sir = {}
std_sar = {}

for stem in stems:
    final_sdr[stem] = np.mean(average_sdr[stem]) if average_sdr[stem] else np.nan
    std_sdr[stem] = np.std(average_sdr[stem]) if len(average_sdr[stem]) > 1 else 0
    
    final_sir[stem] = np.mean(average_sir[stem]) if average_sir[stem] else np.nan
    std_sir[stem] = np.std(average_sir[stem]) if len(average_sir[stem]) > 1 else 0
    
    final_sar[stem] = np.mean(average_sar[stem]) if average_sar[stem] else np.nan
    std_sar[stem] = np.std(average_sar[stem]) if len(average_sar[stem]) > 1 else 0
    
    print(f"\n{stem.upper()}: {len(average_sdr[stem])} valid measurements")
    print(f"  SDR: {final_sdr[stem]:.3f} ± {std_sdr[stem]:.3f}")
    print(f"  SIR: {final_sir[stem]:.3f} ± {std_sir[stem]:.3f}")
    print(f"  SAR: {final_sar[stem]:.3f} ± {std_sar[stem]:.3f}")

# Verify metrics are different
print("\nSanity check - Are metrics different?")
for stem in stems:
    if not np.isnan(final_sdr[stem]) and not np.isnan(final_sir[stem]) and not np.isnan(final_sar[stem]):
        sdr_sir_diff = abs(final_sdr[stem] - final_sir[stem]) > 0.001
        sdr_sar_diff = abs(final_sdr[stem] - final_sar[stem]) > 0.001
        sir_sar_diff = abs(final_sir[stem] - final_sar[stem]) > 0.001
        print(f"{stem}: SDR≠SIR: {sdr_sir_diff}, SDR≠SAR: {sdr_sar_diff}, SIR≠SAR: {sir_sar_diff}")

print("Evaluation completed! Generating plots...")

# PLOTTING CODE
# Create subplot for all three metrics with error bars
fig, axes = plt.subplots(1, 3, figsize=(18, 6))

# Colors for each stem
colors = ['blue', 'green', 'red', 'orange']

# Helper function to handle NaN values in plotting
def safe_plot_bars(ax, keys, values, errors, title, ylabel, color_list):
    # Replace NaN with 0 for plotting
    plot_values = [v if not np.isnan(v) else 0 for v in values]
    plot_errors = [e if not np.isnan(e) else 0 for e in errors]
    
    bars = ax.bar(keys, plot_values, yerr=plot_errors, 
                  capsize=5, color=color_list, alpha=0.7)
    ax.set_xlabel('Stem Category')
    ax.set_ylabel(ylabel)
    ax.set_title(title)
    ax.grid(True, alpha=0.3)
    
    # Add value labels
    for bar, val, err in zip(bars, values, errors):
        if not np.isnan(val):
            ax.text(bar.get_x() + bar.get_width()/2, bar.get_height() + err + 0.1, 
                   f'{val:.2f}±{err:.2f}', ha='center', va='bottom', fontsize=9)
        else:
            ax.text(bar.get_x() + bar.get_width()/2, 0.1, 
                   'NaN', ha='center', va='bottom', fontsize=9, color='red')

# Plot all metrics
safe_plot_bars(axes[0], final_sdr.keys(), list(final_sdr.values()), 
               list(std_sdr.values()), 'Average SDR Performance by Stem Category', 
               'Average SDR (dB)', colors)

safe_plot_bars(axes[1], final_sir.keys(), list(final_sir.values()), 
               list(std_sir.values()), 'Average SIR Performance by Stem Category', 
               'Average SIR (dB)', colors)

safe_plot_bars(axes[2], final_sar.keys(), list(final_sar.values()), 
               list(std_sar.values()), 'Average SAR Performance by Stem Category', 
               'Average SAR (dB)', colors)

plt.tight_layout()
plt.show()

# Print detailed results with standard deviations
print("Average metrics with standard deviations by stem:")
print("=" * 60)
for stem in stems:
    print(f"{stem.upper()}:")
    sdr_str = f"{final_sdr[stem]:.4f} ± {std_sdr[stem]:.4f}" if not np.isnan(final_sdr[stem]) else "NaN"
    sir_str = f"{final_sir[stem]:.4f} ± {std_sir[stem]:.4f}" if not np.isnan(final_sir[stem]) else "NaN"
    sar_str = f"{final_sar[stem]:.4f} ± {std_sar[stem]:.4f}" if not np.isnan(final_sar[stem]) else "NaN"
    
    print(f"  SDR: {sdr_str} dB")
    print(f"  SIR: {sir_str} dB") 
    print(f"  SAR: {sar_str} dB")
    print()

# Create a comprehensive summary table
print("Summary Table with Standard Deviations:")
print("=" * 80)
print(f"{'Stem':<8} {'SDR (dB)':<18} {'SIR (dB)':<18} {'SAR (dB)':<18}")
print(f"{'':8} {'Mean ± Std':<18} {'Mean ± Std':<18} {'Mean ± Std':<18}")
print("-" * 80)
for stem in stems:
    sdr_cell = f"{final_sdr[stem]:>6.2f}±{std_sdr[stem]:<6.2f}" if not np.isnan(final_sdr[stem]) else "    NaN±  NaN"
    sir_cell = f"{final_sir[stem]:>6.2f}±{std_sir[stem]:<6.2f}" if not np.isnan(final_sir[stem]) else "    NaN±  NaN"
    sar_cell = f"{final_sar[stem]:>6.2f}±{std_sar[stem]:<6.2f}" if not np.isnan(final_sar[stem]) else "    NaN±  NaN"
    
    print(f"{stem:<8} {sdr_cell:<18} {sir_cell:<18} {sar_cell:<18}")