# Open-Unmix: Audio Source Separation

In [2]:
# Imports and setup

import os
import json
import glob

import numpy as np
import pandas as pd
import soundfile as sf
import matplotlib.pyplot as plt
import seaborn as sns
from scipy import stats

# np.float_ = np.float32  # musdb, museval

import musdb
import museval
import torch, torchaudio

from openunmix.predict import separate

from demucs.pretrained import get_model
from demucs.apply import apply_model
from demucs.audio import convert_audio

from asteroid.models import XUMX

In [3]:
# Device configuration
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(torch.cuda.is_available())
print(torch.cuda.get_device_name(0))

True
AMD Radeon RX 9070 XT


In [5]:
# Paths
path_to_folder = "."
musdb_root = os.path.join(path_to_folder, "musdb18")
estimates_base_path = os.path.join(path_to_folder, "temp_estimates")
output_base_path = os.path.join(path_to_folder, "temp_output")

# Ensure directories exist
os.makedirs(estimates_base_path, exist_ok=True)
os.makedirs(output_base_path, exist_ok=True)
!ls

 estimates  'music_separator copy 2.ipynb'   requirements.txt
 musdb18    'music_separator copy.ipynb'     temp_estimates
 musdb18hq   output			     temp_output


In [6]:
# Load MUSDB dataset
mus = musdb.DB(
    root=musdb_root,
    download=True,
)

In [None]:
model_name = "openunmix"

estimates_path = os.path.join(estimates_base_path, model_name)
output_path = os.path.join(output_base_path, model_name)

os.makedirs(estimates_path, exist_ok=True)
os.makedirs(output_path, exist_ok=True)

# Iterate over all tracks in MUSDB
for track in mus:
    print(f"[→] Separating: {track.name}")
    audio = torch.tensor(track.audio.T).float()  # shape (2, samples)
    rate = track.rate
    subset = track.subset
    estimates = separate(
        audio=audio,
        rate=rate,
        targets=["vocals"],
        residual=True,
        device=device,
    )
    estimates["accompaniment"] = estimates.pop("residual")
    cpu_estimates = {
        key: torch.squeeze(value).detach().cpu().numpy().T  # shape (samples, 2)
        for key, value in estimates.items()
    }
    scores = museval.eval_mus_track(track, cpu_estimates, output_dir=output_path)
    print(scores)

    # Create subdirectory for the subset if it doesn't exist
    subset_path = os.path.join(estimates_path, subset)
    os.makedirs(subset_path, exist_ok=True)

    for target, audio_np in estimates.items():
        audio_np = (
            torch.squeeze(audio_np).detach().cpu().numpy().T
        )  # shape (samples, 2)
        file_name = f"{track.name} - {target}.wav"
        out_path = os.path.join(subset_path, file_name)
        sf.write(out_path, audio_np, rate)

[→] Separating: A Classic Education - NightOwl
vocals          ==> SDR:   3.938  SIR:   5.421  ISR:  10.877  SAR:   6.708  
accompaniment   ==> SDR:  12.263  SIR:  18.199  ISR:  15.721  SAR:  14.834  

[→] Separating: ANiMAL - Clinic A
vocals          ==> SDR:   3.938  SIR:   5.421  ISR:  10.877  SAR:   6.708  
accompaniment   ==> SDR:  12.263  SIR:  18.199  ISR:  15.721  SAR:  14.834  

[→] Separating: ANiMAL - Clinic A
vocals          ==> SDR:   5.252  SIR:   8.576  ISR:  14.217  SAR:   7.013  
accompaniment   ==> SDR:  13.004  SIR:  22.343  ISR:  17.363  SAR:  14.997  

[→] Separating: ANiMAL - Easy Tiger
vocals          ==> SDR:   5.252  SIR:   8.576  ISR:  14.217  SAR:   7.013  
accompaniment   ==> SDR:  13.004  SIR:  22.343  ISR:  17.363  SAR:  14.997  

[→] Separating: ANiMAL - Easy Tiger
vocals          ==> SDR:   6.201  SIR:  12.574  ISR:  11.155  SAR:   6.372  
accompaniment   ==> SDR:  14.227  SIR:  20.120  ISR:  22.557  SAR:  15.535  

[→] Separating: ANiMAL - Rockshow
voca

In [None]:
model_name = "htdemucs"

estimates_path = os.path.join(estimates_base_path, model_name)
output_path = os.path.join(output_base_path, model_name)

os.makedirs(estimates_path, exist_ok=True)
os.makedirs(output_path, exist_ok=True)

demucs_model = get_model("htdemucs")
demucs_model.to(device)

for track in mus:
    print(f"[→] Separating: {track.name}")

    # Prepare audio for demucs (expects shape: batch, channels, samples)
    audio = torch.tensor(track.audio.T).float().unsqueeze(0)  # shape (1, 2, samples)
    audio = audio.to(device)
    rate = track.rate
    subset = track.subset

    # Convert audio to model's expected sample rate if needed
    audio = convert_audio(
        audio, rate, demucs_model.samplerate, demucs_model.audio_channels
    )

    # Apply separation - FIXED: Remove the extra [None] indexing
    with torch.no_grad():
        sources = apply_model(demucs_model, audio, device=device)

    # Remove batch dimension from sources
    sources = sources.squeeze(0)  # Now shape: (n_sources, channels, samples)

    # HTDemucs returns: [drums, bass, other, vocals]
    # Convert to our desired format
    estimates = {}
    source_names = demucs_model.sources

    for i, source_name in enumerate(source_names):
        if source_name == "vocals":
            estimates["vocals"] = sources[i]
        elif source_name in ["drums", "bass", "other"]:
            if "accompaniment" not in estimates:
                estimates["accompaniment"] = sources[i]
            else:
                estimates["accompaniment"] += sources[i]

    # If no vocals found, create accompaniment from all non-vocal sources
    if "accompaniment" not in estimates:
        estimates["accompaniment"] = sources.sum(dim=0) - estimates.get("vocals", 0)

    # Convert back to original sample rate if needed
    for key in estimates:
        if demucs_model.samplerate != rate:
            estimates[key] = torchaudio.functional.resample(
                estimates[key], demucs_model.samplerate, rate
            )

    # Prepare estimates for evaluation
    cpu_estimates = {
        key: torch.squeeze(value).detach().cpu().numpy().T  # shape (samples, 2)
        for key, value in estimates.items()
    }

    # Evaluate with museval
    scores = museval.eval_mus_track(track, cpu_estimates, output_dir=output_path)
    print(scores)

    # Create subdirectory for the subset if it doesn't exist
    subset_path = os.path.join(estimates_path, subset)
    os.makedirs(subset_path, exist_ok=True)

    # Save separated audio files
    for target, audio_tensor in estimates.items():
        audio_np = (
            torch.squeeze(audio_tensor).detach().cpu().numpy().T
        )  # shape (samples, 2)
        file_name = f"{track.name} - {target}.wav"
        out_path = os.path.join(subset_path, file_name)
        sf.write(out_path, audio_np, rate)

In [None]:
# Initialize the X-UMX model
x_umx_model = XUMX.from_pretrained("JorisCos/asteroid-xumx")
x_umx_model.to(device)

model_name = "x-umx-asteroid"

estimates_path = os.path.join(estimates_base_path, model_name)
output_path = os.path.join(output_base_path, model_name)

os.makedirs(estimates_path, exist_ok=True)
os.makedirs(output_path, exist_ok=True)

for track in mus:
    print(f"[→] Separating: {track.name}")
    audio = torch.tensor(track.audio.T).float().unsqueeze(0)  # shape (1, 2, samples)
    audio = audio.to(device)
    rate = track.rate
    subset = track.subset

    # Perform separation using X-UMX from Asteroid
    with torch.no_grad():
        estimates = x_umx_model.separate(audio)

    # Convert estimates to CPU and numpy format
    cpu_estimates = {
        key: torch.squeeze(value).detach().cpu().numpy().T  # shape (samples, 2)
        for key, value in estimates.items()
    }

    # Evaluate with museval
    scores = museval.eval_mus_track(track, cpu_estimates, output_dir=output_path)
    print(scores)

    # Create subdirectory for the subset if it doesn't exist
    subset_path = os.path.join(estimates_path, subset)
    os.makedirs(subset_path, exist_ok=True)

    # Save separated audio files
    for target, audio_tensor in estimates.items():
        audio_np = (
            torch.squeeze(audio_tensor).detach().cpu().numpy().T
        )  # shape (samples, 2)
        file_name = f"{track.name} - {target}.wav"
        out_path = os.path.join(subset_path, file_name)
        sf.write(out_path, audio_np, rate)

In [None]:
def load_museval_results(base_path, model_names):
    """
    Load museval results from JSON files for multiple models
    """
    all_results = {}

    for model_name in model_names:
        model_path = os.path.join(base_path, model_name)
        results = []

        # Find all JSON files in the model directory
        json_files = glob.glob(os.path.join(model_path, "**/*.json"), recursive=True)

        for json_file in json_files:
            try:
                with open(json_file, "r") as f:
                    data = json.load(f)

                # Extract track name from filename or data
                track_name = os.path.basename(json_file).replace(".json", "")

                # Parse the results structure
                if "targets" in data:
                    for target in data["targets"]:
                        if target["name"] in ["vocals", "accompaniment"]:
                            for frame in target["frames"]:
                                results.append(
                                    {
                                        "model": model_name,
                                        "track": track_name,
                                        "target": target["name"],
                                        "sdr": frame["metrics"]["SDR"],
                                        "sir": frame["metrics"]["SIR"],
                                        "sar": frame["metrics"]["SAR"],
                                        "isr": (
                                            frame["metrics"]["ISR"]
                                            if "ISR" in frame["metrics"]
                                            else np.nan
                                        ),
                                    }
                                )
            except Exception as e:
                print(f"Error loading {json_file}: {e}")

        all_results[model_name] = results

    return all_results


def create_comparison_dataframe(results_dict):
    """
    Convert results dictionary to a pandas DataFrame for analysis
    """
    all_data = []
    for model_name, results in results_dict.items():
        all_data.extend(results)

    df = pd.DataFrame(all_data)
    return df


def calculate_statistics(df):
    """
    Calculate comprehensive statistics for each model and target
    """
    metrics = ["sdr", "sir", "sar", "isr"]
    stats_results = {}

    for model in df["model"].unique():
        stats_results[model] = {}
        model_data = df[df["model"] == model]

        for target in model_data["target"].unique():
            target_data = model_data[model_data["target"] == target]
            stats_results[model][target] = {}

            for metric in metrics:
                if (
                    metric in target_data.columns
                    and not target_data[metric].isna().all()
                ):
                    values = target_data[metric].dropna()
                    stats_results[model][target][metric] = {
                        "mean": values.mean(),
                        "median": values.median(),
                        "std": values.std(),
                        "min": values.min(),
                        "max": values.max(),
                        "q25": values.quantile(0.25),
                        "q75": values.quantile(0.75),
                        "count": len(values),
                    }

    return stats_results


def plot_comparison_boxplots(df, figsize=(16, 12)):
    """
    Create box plots comparing metrics across models
    """
    metrics = ["sdr", "sir", "sar", "isr"]
    targets = df["target"].unique()

    fig, axes = plt.subplots(2, 2, figsize=figsize)
    axes = axes.flatten()

    for i, metric in enumerate(metrics):
        if metric in df.columns:
            sns.boxplot(data=df, x="target", y=metric, hue="model", ax=axes[i])
            axes[i].set_title(f"{metric.upper()} Comparison")
            axes[i].set_ylabel(f"{metric.upper()} (dB)")
            axes[i].grid(True, alpha=0.3)
            axes[i].legend(title="Model")

    plt.tight_layout()
    return fig


def plot_distribution_comparison(df, figsize=(16, 10)):
    """
    Create distribution plots for each metric
    """
    metrics = ["sdr", "sir", "sar", "isr"]

    fig, axes = plt.subplots(2, 2, figsize=figsize)
    axes = axes.flatten()

    for i, metric in enumerate(metrics):
        if metric in df.columns:
            for target in df["target"].unique():
                for model in df["model"].unique():
                    data = df[(df["target"] == target) & (df["model"] == model)][
                        metric
                    ].dropna()
                    if len(data) > 0:
                        axes[i].hist(
                            data, alpha=0.6, label=f"{model} - {target}", bins=20
                        )

            axes[i].set_title(f"{metric.upper()} Distribution")
            axes[i].set_xlabel(f"{metric.upper()} (dB)")
            axes[i].set_ylabel("Frequency")
            axes[i].legend()
            axes[i].grid(True, alpha=0.3)

    plt.tight_layout()
    return fig


def plot_per_track_comparison(df, metric="sdr", figsize=(16, 8)):
    """
    Create per-track comparison plots
    """
    # Calculate mean metric per track for each model
    track_means = (
        df.groupby(["track", "model", "target"])[metric]
        .mean()
        .unstack(["model", "target"])
    )

    fig, axes = plt.subplots(1, 2, figsize=figsize)

    # Vocals comparison
    if ("openunmix", "vocals") in track_means.columns and (
        "htdemucs",
        "vocals",
    ) in track_means.columns:
        axes[0].scatter(
            track_means[("openunmix", "vocals")],
            track_means[("htdemucs", "vocals")],
            alpha=0.7,
        )
        axes[0].plot(
            [track_means.min().min(), track_means.max().max()],
            [track_means.min().min(), track_means.max().max()],
            "r--",
            alpha=0.8,
        )
        axes[0].set_xlabel(f"OpenUnmix Vocals {metric.upper()} (dB)")
        axes[0].set_ylabel(f"HTDemucs Vocals {metric.upper()} (dB)")
        axes[0].set_title(f"Per-Track Vocals {metric.upper()} Comparison")
        axes[0].grid(True, alpha=0.3)

    # Accompaniment comparison
    if ("openunmix", "accompaniment") in track_means.columns and (
        "htdemucs",
        "accompaniment",
    ) in track_means.columns:
        axes[1].scatter(
            track_means[("openunmix", "accompaniment")],
            track_means[("htdemucs", "accompaniment")],
            alpha=0.7,
            color="orange",
        )
        axes[1].plot(
            [track_means.min().min(), track_means.max().max()],
            [track_means.min().min(), track_means.max().max()],
            "r--",
            alpha=0.8,
        )
        axes[1].set_xlabel(f"OpenUnmix Accompaniment {metric.upper()} (dB)")
        axes[1].set_ylabel(f"HTDemucs Accompaniment {metric.upper()} (dB)")
        axes[1].set_title(f"Per-Track Accompaniment {metric.upper()} Comparison")
        axes[1].grid(True, alpha=0.3)

    plt.tight_layout()
    return fig


def statistical_significance_test(df):
    """
    Perform statistical significance tests between models
    """
    results = {}
    metrics = ["sdr", "sir", "sar", "isr"]
    models = df["model"].unique()

    if len(models) == 2:
        model1, model2 = models

        for target in df["target"].unique():
            results[target] = {}

            for metric in metrics:
                if metric in df.columns:
                    data1 = df[(df["model"] == model1) & (df["target"] == target)][
                        metric
                    ].dropna()
                    data2 = df[(df["model"] == model2) & (df["target"] == target)][
                        metric
                    ].dropna()

                    if len(data1) > 0 and len(data2) > 0:
                        # Perform Wilcoxon rank-sum test (Mann-Whitney U test)
                        statistic, p_value = stats.mannwhitneyu(
                            data1, data2, alternative="two-sided"
                        )

                        # Effect size (Cohen's d approximation)
                        pooled_std = np.sqrt(
                            (
                                (len(data1) - 1) * data1.var()
                                + (len(data2) - 1) * data2.var()
                            )
                            / (len(data1) + len(data2) - 2)
                        )
                        effect_size = (
                            (data1.mean() - data2.mean()) / pooled_std
                            if pooled_std > 0
                            else 0
                        )

                        results[target][metric] = {
                            "statistic": statistic,
                            "p_value": p_value,
                            "significant": p_value < 0.05,
                            "effect_size": effect_size,
                            f"{model1}_mean": data1.mean(),
                            f"{model2}_mean": data2.mean(),
                            f"{model1}_median": data1.median(),
                            f"{model2}_median": data2.median(),
                        }

    return results


def generate_comparison_report(output_base_path, model_names=["openunmix", "htdemucs"]):
    """
    Main function to generate comprehensive comparison report
    """
    print("Loading museval results...")
    results_dict = load_museval_results(output_base_path, model_names)

    if not any(results_dict.values()):
        print(
            "No results found. Please check your paths and ensure museval has been run."
        )
        return

    df = create_comparison_dataframe(results_dict)
    print(f"Loaded {len(df)} data points from {len(df['track'].unique())} tracks")

    # Calculate statistics
    print("\nCalculating statistics...")
    stats_results = calculate_statistics(df)

    # Print summary statistics
    print("\n" + "=" * 80)
    print("SUMMARY STATISTICS")
    print("=" * 80)

    for model in stats_results:
        print(f"\n{model.upper()} Results:")
        print("-" * 40)

        for target in stats_results[model]:
            print(f"\n  {target.capitalize()}:")
            for metric in ["sdr", "sir", "sar", "isr"]:
                if metric in stats_results[model][target]:
                    data = stats_results[model][target][metric]
                    print(
                        f"    {metric.upper()}: {data['mean']:.2f} ± {data['std']:.2f} dB "
                        f"(median: {data['median']:.2f}, n={data['count']})"
                    )

    # Statistical significance tests
    print("\n" + "=" * 80)
    print("STATISTICAL SIGNIFICANCE TESTS")
    print("=" * 80)

    sig_results = statistical_significance_test(df)

    for target in sig_results:
        print(f"\n{target.capitalize()} Comparison:")
        print("-" * 30)

        for metric in sig_results[target]:
            data = sig_results[target][metric]
            significance = (
                "***"
                if data["p_value"] < 0.001
                else (
                    "**"
                    if data["p_value"] < 0.01
                    else "*" if data["p_value"] < 0.05 else "ns"
                )
            )

            model1_name = [k for k in data.keys() if k.endswith("_mean")][0].replace(
                "_mean", ""
            )
            model2_name = [k for k in data.keys() if k.endswith("_mean")][1].replace(
                "_mean", ""
            )

            print(
                f"  {metric.upper()}: {model1_name}={data[f'{model1_name}_mean']:.2f} vs "
                f"{model2_name}={data[f'{model2_name}_mean']:.2f} dB, "
                f"p={data['p_value']:.4f} {significance}, effect_size={data['effect_size']:.2f}"
            )

    # Generate plots
    print("\nGenerating plots...")

    # Box plots
    fig1 = plot_comparison_boxplots(df)
    plt.suptitle("Model Comparison - Box Plots", y=1.02, fontsize=16)
    plt.show()

    # Distribution plots
    fig2 = plot_distribution_comparison(df)
    plt.suptitle("Model Comparison - Distributions", y=1.02, fontsize=16)
    plt.show()

    # Per-track comparisons for each metric
    for metric in ["sdr", "sir", "sar"]:
        if metric in df.columns:
            fig3 = plot_per_track_comparison(df, metric=metric)
            plt.suptitle(f"Per-Track {metric.upper()} Comparison", y=1.02, fontsize=16)
            plt.show()

    print("\nComparison analysis complete!")
    return df, stats_results, sig_results

In [None]:
df, stats_results, significance_results = generate_comparison_report(
    output_base_path=output_base_path, model_names=["openunmix", "htdemucs"]
)

In [None]:
from asteroid.models import XUMX

# Initialize the X-UMX model
x_umx_model = XUMX.from_pretrained("JorisCos/asteroid-xumx")
x_umx_model.to(device)

model_name = "x-umx-asteroid"

estimates_path = os.path.join(estimates_base_path, model_name)
output_path = os.path.join(output_base_path, model_name)

os.makedirs(estimates_path, exist_ok=True)
os.makedirs(output_path, exist_ok=True)

for track in mus:
    print(f"[→] Separating: {track.name}")
    audio = torch.tensor(track.audio.T).float().unsqueeze(0)  # shape (1, 2, samples)
    audio = audio.to(device)
    rate = track.rate
    subset = track.subset

    # Perform separation using X-UMX from Asteroid
    with torch.no_grad():
        estimates = x_umx_model.separate(audio)

    # Convert estimates to CPU and numpy format
    cpu_estimates = {
        key: torch.squeeze(value).detach().cpu().numpy().T  # shape (samples, 2)
        for key, value in estimates.items()
    }

    # Evaluate with museval
    scores = museval.eval_mus_track(track, cpu_estimates, output_dir=output_path)
    print(scores)

    # Create subdirectory for the subset if it doesn't exist
    subset_path = os.path.join(estimates_path, subset)
    os.makedirs(subset_path, exist_ok=True)

    # Save separated audio files
    for target, audio_tensor in estimates.items():
        audio_np = (
            torch.squeeze(audio_tensor).detach().cpu().numpy().T
        )  # shape (samples, 2)
        file_name = f"{track.name} - {target}.wav"
        out_path = os.path.join(subset_path, file_name)
        sf.write(out_path, audio_np, rate)