In [1]:
import numpy as np
import nussl
import os
import json
import pandas as pd
from IPython.display import Audio, display
from typing import Tuple
import librosa

  from pkg_resources import resource_filename


In [2]:
#musdb = nussl.datasets.MUSDB18(download=True)  # Download if you don't have it

In [3]:
class HPSS_evaluation_on_MUSDB18():
    def __init__(self, HPSS_function, estimates_dir="./hpss_estimates", output_dir="./hpss_scores"):
        self.HPSS_function = HPSS_function
        self.estimates_dir = estimates_dir
        self.output_dir = output_dir
        os.makedirs(estimates_dir, exist_ok=True)
        os.makedirs(output_dir, exist_ok=True)

    def _ensure_1d(self, x):
        """Return a 1D numpy array (flatten)."""
        x = np.asarray(x)
        if x.ndim == 2:
            # if shaped (channels, samples) prefer mono mix (sum/mean) — match earlier code which uses channel 0
            # We'll take the first row to preserve the old behavior.
            return x[0, :].flatten()
        return x.flatten()

    def _project_contributions(self, estimate, references):
        """
        Project estimate onto the linear span of references using least squares.

        estimate: (T,) vector
        references: list of (T,) vectors (length K)

        Returns:
            contributions: list of (T,) vectors where contributions[j] = ref_j * coeff_j
            residual: estimate - sum(contributions)
        """
        T = estimate.shape[0]
        K = len(references)
        # Build A with shape (T, K)
        A = np.stack([self._ensure_1d(r) for r in references], axis=1)  # (T, K)
        # Solve least squares A * coeffs = estimate
        # Using numpy.linalg.lstsq (fast for small K)
        coeffs, *_ = np.linalg.lstsq(A, estimate, rcond=None)
        # contributions per reference
        contributions = [A[:, j] * coeffs[j] for j in range(K)]
        reconstruction = sum(contributions)
        residual = estimate - reconstruction
        return contributions, residual

    def extract_components(self, estimate, target_ref, interfering_refs, noise_ref=None):
        """
        Decompose `estimate` into s_target, e_interf, e_noise using least squares.

        - target_ref: (T,) array for the target ground truth (e.g., harmonic)
        - interfering_refs: list of (T,) arrays (e.g., other source(s): percussive, bass, etc.)
        - noise_ref: optional (T,) array representing noise/reference for 'noise' modeling.
                     If None, noise = estimate - projection_onto_all_refs

        Returns:
            s_target, e_interf, e_noise  (all shape (T,))
        """
        # Build combined reference list for "all" sources (target + interfering)
        all_refs = [target_ref] + list(interfering_refs)
        contributions, residual = self._project_contributions(estimate, all_refs)
        # contributions[0] is contribution from target_ref
        s_target = contributions[0]
        # interference contribution = sum of contributions of all interfering refs
        e_interf = sum(contributions[1:]) if len(contributions) > 1 else np.zeros_like(estimate)
        # e_noise: if user provided a noise_ref, we can project that as well; otherwise residual is the noise
        if noise_ref is None:
            e_noise = residual
        else:
            # If provided, project residual onto the noise_ref to split deterministic 'noise' component
            noise_contribs, rest = self._project_contributions(residual, [noise_ref])
            e_noise = noise_contribs[0] if len(noise_contribs) > 0 else rest
            # any leftover is remaining residual (rare)
            # but we treat 'e_noise' as the projected noise; small leftover remains ignored or can be added to e_noise
            e_art =  rest  # capture everything in 'noise' bucket
        return s_target, e_interf, e_noise, e_art

    def evaluate_metrics(self, estimate, target_ref, interfering_refs, noise_ref=None, eps=1e-10):
        """
        Compute SDR, SIR, SNR (in dB) using component energies.

        Formulas used:
            SDR = 10 log10( ||s_target||^2 / ||e_interf + e_noise||^2 )
            SIR = 10 log10( ||s_target||^2 / ||e_interf||^2 )
            SNR = 10 log10( ||s_target + e_interf||^2 / ||e_noise||^2 )

        All inputs are 1D numpy arrays (T,).
        """
        # Ensure 1D numpy arrays
        est = self._ensure_1d(estimate)
        target_ref = self._ensure_1d(target_ref)
        interfering_refs = [self._ensure_1d(r) for r in interfering_refs]
        if noise_ref is not None:
            noise_ref = self._ensure_1d(noise_ref)

        s_target, e_interf, e_noise, e_art = self.extract_components(est, target_ref, interfering_refs, noise_ref=noise_ref)

        # energies
        s_target_e = np.linalg.norm(s_target) ** 2
        e_interf_e = np.linalg.norm(e_interf) ** 2
        e_noise_e = np.linalg.norm(e_noise) ** 2
        e_art_e = np.linalg.norm(e_art) ** 2
        
        print("Energies - s_target: {}, e_interf: {}, e_noise: {}, e_art: {}".format(s_target_e, e_interf_e, e_noise_e, e_art_e))

        denom_sdr = e_interf_e + e_noise_e
        SDR = 10.0 * np.log10((s_target_e + eps) / (denom_sdr + eps))
        SIR = 10.0 * np.log10((s_target_e + eps) / (e_interf_e + eps))
        SNR = 10.0 * np.log10(((s_target_e + e_interf_e) + eps) / (e_noise_e + eps))
        SAR = 10.0 * np.log10((s_target_e + e_interf_e + e_noise_e + eps) / (e_art_e + eps))
        return SDR, SIR, SNR, SAR

    def evaluate_algorithm(self, max_items=1):
        """
        Run over MUSDB18 (train+test), call HPSS_function to obtain harmonic/percussive estimates,
        and compute metrics per track. Prints results.

        max_items: optional int to limit number of tracks (useful for quick tests).
        """
        musdb = nussl.datasets.MUSDB18(subsets=['train', 'test'])
        # Iterate (nussl dataset is list-like)
        count = 0
        for idx in range(len(musdb)):
            if max_items is not None and count >= max_items:
                break
            item = musdb[idx]
            item['mix'].embed_audio()
            # Convert sources to mono arrays consistent with your previous code
            Percussive_ground_truth = self._ensure_1d(item['sources']['drums'].to_mono().audio_data) + self._ensure_1d(item['sources']['bass'].to_mono().audio_data)
            # The original combined harmonic group was vocals + other + bass
            Harmonic_ground_truth = (self._ensure_1d(item['sources']['vocals'].to_mono().audio_data)
                                     + self._ensure_1d(item['sources']['other'].to_mono().audio_data))
            mixture_signal = self._ensure_1d(item['mix'].to_mono().audio_data)
            
            print("Harmonic and Percussive Ground Truth Audio:")
            display(Audio(data=Harmonic_ground_truth, rate=item['mix'].sample_rate))
            display(Audio(data=Percussive_ground_truth, rate=item['mix'].sample_rate))

            # Call user-provided HPSS function. Keep API same as before (returns nussl AudioSignal objects)
            Harmonic_estimate, Percussive_estimate = self.HPSS_function(mixture_signal)
                                                                        # window_length=1024,
                                                                        # hop_length=512,
                                                                        # window_type='hamming')
            print("Harmonic and Percussive Estimated Audio:")
            display(Audio(data=Harmonic_estimate, rate=item['mix'].sample_rate))
            display(Audio(data=Percussive_estimate, rate=item['mix'].sample_rate))
            # Noise reference: mixture - (harmonic_gt + percussive_gt) — same as original code
            noise_ref = mixture_signal - (Harmonic_ground_truth + Percussive_ground_truth)

            # Evaluate harmonic estimate
            H_metrics = self.evaluate_metrics(Harmonic_estimate,
                                              Harmonic_ground_truth,
                                              interfering_refs=[Percussive_ground_truth],
                                              noise_ref=noise_ref)
            # Evaluate percussive estimate
            P_metrics = self.evaluate_metrics(Percussive_estimate,
                                              Percussive_ground_truth,
                                              interfering_refs=[Harmonic_ground_truth],
                                              noise_ref=noise_ref)

            print(f"Track Harmonic Estimation Metrics (SDR, SIR, SNR, SAR): {H_metrics}")
            print(f"Track Percussive Estimation Metrics (SDR, SIR, SNR, SAR): {P_metrics}")
            count += 1

In [48]:
#from fast_bss_eval import bss_eval_sources
from mir_eval.separation import bss_eval_sources

def evaluate_with_fast_bss(eval_obj, max_items=5):
    """
    Run over MUSDB and compute SDR/SIR/SAR using bss_eval_sources.
    eval_obj: instance of HPSS_evaluation_on_MUSDB18 (uses its HPSS_function & helpers)
    """
    musdb = nussl.datasets.MUSDB18(subsets=['train', 'test'])
    sdr_list_h, sir_list_h, sar_list_h = [], [], []
    sdr_list_p, sir_list_p, sar_list_p = [], [], []

    count = 0
    for idx in range(len(musdb)):
        if max_items is not None and count >= max_items:
            break
        item = musdb[idx]
        item['mix'].embed_audio()

        # Ground truth (mono)
        # Percussive_ground_truth = eval_obj._ensure_1d(item['sources']['drums'].to_mono().audio_data) + eval_obj._ensure_1d(item['sources']['bass'].to_mono().audio_data)
        # mixture_signal = eval_obj._ensure_1d(item['mix'].to_mono().audio_data)
        # Harmonic_ground_truth = eval_obj._ensure_1d(item['sources']['vocals'].to_mono().audio_data) + eval_obj._ensure_1d(item['sources']['other'].to_mono().audio_data) + eval_obj._ensure_1d(item['sources']['bass'].to_mono().audio_data)

        # Ground truth using librosa HPSS for consistency with estimates
        mixture_signal = eval_obj._ensure_1d(item['mix'].to_mono().audio_data)
        Harmonic_ground_truth, Percussive_ground_truth = librosa.effects.hpss(mixture_signal)

        # Call HPSS function (expects y -> returns harmonic, percussive)
        H_est, P_est = eval_obj.HPSS_function(mixture_signal)

        # Ensure 1D numpy arrays
        H_est = eval_obj._ensure_1d(H_est)
        P_est = eval_obj._ensure_1d(P_est)

        # Align lengths (trim to minimum length across all four signals)
        min_len = min(len(Harmonic_ground_truth), len(Percussive_ground_truth), len(H_est), len(P_est))
        refs = np.vstack([
            Harmonic_ground_truth[:min_len],
            Percussive_ground_truth[:min_len]
        ])
        ests = np.vstack([
            H_est[:min_len],
            P_est[:min_len]
        ])

        # bss_eval_sources expects shape (nsrc, T)
        sdr, sir, sar, perm = bss_eval_sources(refs, ests, compute_permutation=True)

        # After permutation, metrics are ordered to match references (sdr[0] -> best match to ref0)
        # We assume permuted ordering already applied in returned metrics; map metrics to reference sources:
        # ref 0 = Harmonic, ref 1 = Percussive
        sdr_h, sdr_p = sdr[0], sdr[1]
        sir_h, sir_p = sir[0], sir[1]
        sar_h, sar_p = sar[0], sar[1]

        sdr_list_h.append(sdr_h); sir_list_h.append(sir_h); sar_list_h.append(sar_h)
        sdr_list_p.append(sdr_p); sir_list_p.append(sir_p); sar_list_p.append(sar_p)

        print(f"Track {count}: Harmonic SDR/SIR/SAR = ({sdr_h:.3f}, {sir_h:.3f}, {sar_h:.3f}) | Percussive SDR/SIR/SAR = ({sdr_p:.3f}, {sir_p:.003f}, {sar_p:.3f})")
        count += 1

    # Print summary means
    def mean_or_nan(lst):
        return float(np.mean(lst)) if len(lst) > 0 else float('nan')

    print("\nSummary (mean over tracks):")
    print(f"Harmonic SDR: {mean_or_nan(sdr_list_h):.3f}, SIR: {mean_or_nan(sir_list_h):.3f}, SAR: {mean_or_nan(sar_list_h):.3f}")
    print(f"Percussive SDR: {mean_or_nan(sdr_list_p):.3f}, SIR: {mean_or_nan(sir_list_p):.3f}, SAR: {mean_or_nan(sar_list_p):.3f}")

In [49]:
# Helper funcions
def median_filter(arr: np.ndarray, kernel_size: int, axis: int) -> np.ndarray:
    """Apply 1D median filter along a single axis of a 2D array.
    Parameters:
        arr: 2D numpy array (freq x time)
        kernel_size: odd positive int specifying window length
        axis: 0 for frequency axis, 1 for time axis
    Returns:
        filtered array of same shape
    """
    if kernel_size < 1:
        raise ValueError("kernel_size must be >=1")
    if kernel_size % 2 == 0:
        kernel_size += 1  # force odd for symmetric window
    pad = kernel_size // 2
    # Pad reflect to preserve border behavior
    arr_padded = np.pad(arr, ((pad, pad), (pad, pad)), mode='reflect')
    # Prepare output
    out = np.empty_like(arr)
    if axis == 0:
        # iterate frequency bins
        for f in range(arr.shape[0]):
            window = arr_padded[f:f+kernel_size, pad:pad+arr.shape[1]]
            out[f, :] = np.median(window, axis=0)
    elif axis == 1:
        # iterate time frames
        for t in range(arr.shape[1]):
            window = arr_padded[pad:pad+arr.shape[0], t:t+kernel_size]
            out[:, t] = np.median(window, axis=1)
    else:
        raise ValueError("axis must be 0 or 1")
    return out

In [50]:
def hpss_median(
    y: np.ndarray,
    sr: int = 44100,
    n_fft: int = 2048,
    hop_length: int = 1024,
    med_filter_time: int = 17,
    med_filter_freq: int = 17,
    mask_power: float = 2.0,
    use_power: bool = True,
    soft_mask: bool = True
) -> Tuple[np.ndarray, np.ndarray, dict]:
    """Perform HPSS with median filtering.
    Returns harmonic and percussive time-domain signals plus diagnostics.
    """
    # STFT
    X = librosa.stft(y, n_fft=n_fft, hop_length=hop_length)
    mag = np.abs(X)
    rep = mag**2 if use_power else mag

    # Median filtering to get structure estimates
    H_hat = median_filter(rep, med_filter_time, axis=1)  # along time frames
    P_hat = median_filter(rep, med_filter_freq, axis=0)  # along frequency bins

    eps = 1e-10
    if soft_mask:
        H_num = H_hat**mask_power
        P_num = P_hat**mask_power
        denom = H_num + P_num + eps
        M_H = H_num / denom
        M_P = P_num / denom
    else:
        M_H = (H_hat > P_hat).astype(float)
        M_P = 1.0 - M_H

    # Apply masks to original complex STFT
    X_H = M_H * X
    X_P = M_P * X

    # Inverse STFT
    y_h = librosa.istft(X_H, hop_length=hop_length)
    y_p = librosa.istft(X_P, hop_length=hop_length)

    info = {
        'X': X,
        'mag': mag,
        'H_hat': H_hat,
        'P_hat': P_hat,
        'M_H': M_H,
        'M_P': M_P,
        'params': {
            'n_fft': n_fft,
            'hop_length': hop_length,
            'med_filter_time': med_filter_time,
            'med_filter_freq': med_filter_freq,
            'mask_power': mask_power,
            'use_power': use_power,
            'soft_mask': soft_mask
        }
    }
    return y_h, y_p


In [51]:
def Complementary_Diffusion(
    audio_signal,
    window_length,
    hop_length,
    window_type,
    gamma=0.3,
    alpha=0.3,
    num_iters=50,
):
    """
    Complementary Diffusion HPSS using librosa.
    --------------------------------------------------
    audio_signal : np.ndarray, shape (T,) mono audio
    window_length : STFT window size
    hop_length : STFT hop size
    window_type : type of window, e.g. 'hann', 'hamming'
    gamma : power for magnitude deformation
    alpha : diffusion tradeoff
    num_iters : number of diffusion iterations

    Returns:
        harmonic_audio : np.ndarray, time-domain harmonic component
        percussive_audio : np.ndarray, time-domain percussive component
    """

    # ---------- STFT ----------
    stft = librosa.stft(
        audio_signal,
        n_fft=window_length,
        hop_length=hop_length,
        window=window_type,
        center=True,
    )
    magnitude = np.abs(stft)
    phase = np.exp(1j * np.angle(stft))

    # W = |X|^(2*gamma)
    W = magnitude ** (2 * gamma)

    # Initialize H and P
    H = 0.5 * W.copy()
    P = 0.5 * W.copy()

    num_rows, num_cols = W.shape  # (freq, time)

    for _ in range(num_iters):

        # -------- Horizontal diffusion (for H) --------
        # Shift left (time axis)
        H_L = np.concatenate([H[:, 1:], np.zeros((num_rows, 1))], axis=1)
        # Shift right
        H_R = np.concatenate([np.zeros((num_rows, 1)), H[:, :-1]], axis=1)

        # -------- Vertical diffusion (for P) --------
        # Shift up (freq axis)
        P_U = np.concatenate([P[1:, :], np.zeros((1, num_cols))], axis=0)
        # Shift down
        P_D = np.concatenate([np.zeros((1, num_cols)), P[:-1, :]], axis=0)

        # The diffusion update
        delta = alpha * (H_L + H_R - 2 * H) / 4 - (1 - alpha) * (P_U + P_D - 2 * P) / 4

        # Constrain H between 0 and W
        H = np.clip(H + delta, 0, W)

        # Complement P
        P = W - H

    # Final masks
    H_mask = (H >= P).astype(float)
    P_mask = (H < P).astype(float)

    H_mag = W * H_mask
    P_mag = W * P_mask

    # Reconstruct STFTs using original phase
    H_stft = H_mag * phase
    P_stft = P_mag * phase

    # ---------- ISTFT ----------
    harmonic_audio = librosa.istft(H_stft, hop_length=hop_length, window=window_type)
    percussive_audio = librosa.istft(P_stft, hop_length=hop_length, window=window_type)

    return harmonic_audio, percussive_audio


In [52]:
evaluation_class = HPSS_evaluation_on_MUSDB18(HPSS_function=hpss_median)

In [53]:
# Run evaluation (uses the existing evaluation_class)
evaluate_with_fast_bss(evaluation_class, max_items=4)

	Deprecated as of mir_eval version 0.8.
	It will be removed in mir_eval version 0.9.
  sdr, sir, sar, perm = bss_eval_sources(refs, ests, compute_permutation=True)


Track 0: Harmonic SDR/SIR/SAR = (19.488, 29.548, 19.944) | Percussive SDR/SIR/SAR = (6.791, 14.735, 7.695)


Track 1: Harmonic SDR/SIR/SAR = (19.202, 28.068, 19.813) | Percussive SDR/SIR/SAR = (10.113, 22.064, 10.426)


Track 2: Harmonic SDR/SIR/SAR = (13.239, 21.653, 13.944) | Percussive SDR/SIR/SAR = (10.576, 24.406, 10.776)


Track 3: Harmonic SDR/SIR/SAR = (9.628, 15.985, 10.879) | Percussive SDR/SIR/SAR = (11.054, 22.734, 11.383)

Summary (mean over tracks):
Harmonic SDR: 15.389, SIR: 23.813, SAR: 16.145
Percussive SDR: 9.634, SIR: 20.985, SAR: 10.070


In [None]:
evaluation_class.evaluate_algorithm()

Harmonic and Percussive Ground Truth Audio:


Harmonic and Percussive Estimated Audio:


Energies - s_target: 1213.8512884875806, e_interf: 4888.651905160019, e_noise: 2.7456110331622456, e_art: 281.3339112467621
Energies - s_target: 18.394899782529734, e_interf: 140.88011489213378, e_noise: 0.34622483026834855, e_art: 281.3210103665562
Track Harmonic Estimation Metrics (SDR, SIR, SNR, SAR): (-6.052674750317205, -6.050236309157845, 33.468690039820366, 13.36481283234411)
Track Percussive Estimation Metrics (SDR, SIR, SNR, SAR): (-8.85218277206341, -8.841522710493198, 26.627894426105875, -2.4611148960862983)
