# Progressive Inference for Music Demixing
## ID: L12
Description: Using denoising diffusion approaches to train music demixing (MDX) models is
promising but requires retraining large and carefully tuned neural networks (Plaja-Roglans,
2022). Instead, we will explore a related yet different approach: can we improve separation
quality solely by scheduling the inference process using a diffusion-inspired strategy even
without retraining? By experimenting with existing MDX models (Spleeter by Deezer,Meta’s Demucs, ByteDance’s BS-Roformer, etc.), this project offers an exciting opportunity
to explore and possibly enhance the performance of state-of-the-art AI techniques.

### Resources
(Plaja-Roglans, 2022) https://ismir2022program.ismir.net/poster_262.html
Denoising Diffusion Probabilistic Models: https://arxiv.org/abs/2006.11239
MDX Challenge 2021: https://arxiv.org/abs/2108.13559
MDX Challenge 2023: https://arxiv.org/abs/2308.06979
Overview of state-of-the-art MDX models:
https://paperswithcode.com/sota/music-source-separation-on-musdb18-hq

### Preliminary step: Install and Import Libraries.

In [2]:
%pip install -r requirements.txt

Defaulting to user installation because normal site-packages is not writeable

[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m A new release of pip is available: [0m[31;49m25.0[0m[39;49m -> [0m[32;49m25.0.1[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m To update, run: [0m[32;49m/Library/Developer/CommandLineTools/usr/bin/python3 -m pip install --upgrade pip[0m
Note: you may need to restart the kernel to use updated packages.


In [3]:
%pip install torchviz

Defaulting to user installation because normal site-packages is not writeable
Collecting torchviz
  Using cached torchviz-0.0.3-py3-none-any.whl.metadata (2.1 kB)
Using cached torchviz-0.0.3-py3-none-any.whl (5.7 kB)
Installing collected packages: torchviz
Successfully installed torchviz-0.0.3

[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m A new release of pip is available: [0m[31;49m25.0[0m[39;49m -> [0m[32;49m25.0.1[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m To update, run: [0m[32;49m/Library/Developer/CommandLineTools/usr/bin/python3 -m pip install --upgrade pip[0m
Note: you may need to restart the kernel to use updated packages.


In [4]:
import IPython
import torch
import torchaudio
import os
import matplotlib.pyplot as plt
from IPython.display import Audio
from mir_eval import separation
from torchaudio.pipelines import HDEMUCS_HIGH_MUSDB_PLUS
from torchaudio.transforms import Fade

In [5]:
## from torchviz import make_dot

### 1. Load the Model

In [6]:
bundle = HDEMUCS_HIGH_MUSDB_PLUS
model = bundle.get_model()
# Use GPU if available
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model.to(device)

HDemucs(
  (freq_encoder): ModuleList(
    (0): _HEncLayer(
      (conv): Conv2d(4, 48, kernel_size=(8, 1), stride=(4, 1), padding=(2, 0))
      (norm1): Identity()
      (rewrite): Conv2d(48, 96, kernel_size=(1, 1), stride=(1, 1))
      (norm2): Identity()
      (dconv): _DConv(
        (layers): ModuleList(
          (0): Sequential(
            (0): Conv1d(48, 12, kernel_size=(3,), stride=(1,), padding=(1,))
            (1): GroupNorm(1, 12, eps=1e-05, affine=True)
            (2): GELU(approximate='none')
            (3): Conv1d(12, 96, kernel_size=(1,), stride=(1,))
            (4): GroupNorm(1, 96, eps=1e-05, affine=True)
            (5): GLU(dim=1)
            (6): _LayerScale()
          )
          (1): Sequential(
            (0): Conv1d(48, 12, kernel_size=(3,), stride=(1,), padding=(2,), dilation=(2,))
            (1): GroupNorm(1, 12, eps=1e-05, affine=True)
            (2): GELU(approximate='none')
            (3): Conv1d(12, 96, kernel_size=(1,), stride=(1,))
           

### 2. Dataset Loading

Download the dataset at: https://zenodo.org/records/3338373

Set the path of the downloaded dataset in DATASET_PATH

In [10]:
# Prompt the user for the absolute path of the dataset
DATASET_PATH = input("Please enter the absolute path for the dataset: ")
print(f"Dataset path set to: {DATASET_PATH}")

Dataset path set to: /Users/giorgiomagalini/Documents/_DocumentiUtente/Uni/PoliMi/CAPSTONE/Code


In [8]:
import shlex
# Prompt the user for the absolute path of the dataset
DATASET_PATH = input("Please enter the absolute path for the dataset: ")
quoted_path = shlex.quote(DATASET_PATH)  # Safely quote the path
print(f"Dataset path set to: {quoted_path}")

KeyboardInterrupt: Interrupted by user

In [11]:
# Load dataset and choose the length of the tracks for each song (in seconds)
DATASET_FOLDER = DATASET_PATH + "/musdb18hq/test"
SEGMENT = 30  # We'll keep exactly 30 seconds from each track

# !!! --> PAY ATTENTION TO THE SILENT PORTIONS OF THE TRACKS, THEY CAN CAUSE ERRORS IN THE SEPARATION PROCESS <-- !!!
# You should choose a portion of the track that contains all the instrument without the channel sums being zero
# NB: if you encounter a silent portion of the track, you will skip the evaluation for that stem!

track_folders = sorted(
    folder for folder in os.listdir(DATASET_FOLDER)
    if os.path.isdir(os.path.join(DATASET_FOLDER, folder))
)

# Dictionary to store {track_folder -> {stem_name -> waveform}}
dataset_dict = {}

# Each subfolder in musdb18hq/test corresponds to a track
for track_folder in track_folders:
    track_path = os.path.join(DATASET_FOLDER, track_folder)
    if not os.path.isdir(track_path):
        continue

    # Prepare a sub-dictionary for this track
    stems_dict = {}
    stem_names = ["mixture", "drums", "bass", "vocals", "other"]

    for stem_name in stem_names:
        file_path = os.path.join(track_path, f"{stem_name}.wav")
        
        if not os.path.isfile(file_path):
            print(f"Warning: file not found {file_path}")
            continue

        # Load full audio
        print(f"Loading {file_path}...")
        file_path = shlex.quote(file_path)
        waveform, sr = torchaudio.load(file_path)

        # Keep only the first 30s
        segment_samples = SEGMENT * sr
        waveform_segment = waveform[:, :segment_samples]

        stems_dict[stem_name] = waveform_segment

    dataset_dict[track_folder] = stems_dict

print("Loaded tracks:", list(dataset_dict.keys()))

Loading /Users/giorgiomagalini/Documents/_DocumentiUtente/Uni/PoliMi/CAPSTONE/Code/musdb18hq/test/AM Contra - Heart Peripheral/mixture.wav...


RuntimeError: Couldn't find appropriate backend to handle uri '/Users/giorgiomagalini/Documents/_DocumentiUtente/Uni/PoliMi/CAPSTONE/Code/musdb18hq/test/AM Contra - Heart Peripheral/mixture.wav' and format None.

### 2. Processing a Single Track

#### 2.1 Choose a Track to Process
Try analyze different tracks by changing the index of ***track_names[index]***

In [None]:
# Now you have a dictionary with track_folder as the key,
# and a sub-dict with "mixture", "drums", "bass", "vocals", "other" waveforms
track_names = list(dataset_dict.keys())


track_chosen = track_names[25]
print("Chosen track name:", track_chosen)

stems_available = list(dataset_dict[track_chosen].keys())
print("Stems:", stems_available)  # e.g. ['mixture', 'drums', 'bass', 'vocals', 'other']

# Check duration
mixture_waveform = dataset_dict[track_chosen]["mixture"]
duration_seconds = mixture_waveform.shape[1] / sample_rate
print(f"Duration (seconds): {duration_seconds}")

# Ensure we have all 5 stems
if len(stems_available) < 5:
    print("Warning: Not all stems found. This track might be incomplete.")

#### 2.2 Prepare the Data for Separation

In [None]:
mixture_waveform = mixture_waveform.to(device)

# Normalization across the channels
ref = mixture_waveform.mean(0)  # (2, samples) -> (samples,)

# Singal with zero mean and unit variance
mixture_norm = (mixture_waveform - ref.mean()) / ref.std()

In [None]:
def separate_sources(
    model,
    mix,
    segment=30,
    overlap=0.0,  # set to 0.0 to avoid chunk repetition
    device=None
):
    """
    Apply model to a given mixture. Use fade, and add segments together in order to add model segment by segment.

    Args:
        segment (int): segment length in seconds
        device (torch.device, str, or None): if provided, device on which to
            execute the computation, otherwise `mix.device` is assumed.
            When `device` is different from `mix.device`, only local computations will
            be on `device`, while the entire tracks will be stored on `mix.device`.
    """

    if device is None:
        device = mix.device
    else:
        device = torch.device(device)

    batch, channels, length = mix.shape

    # chunk_len for entire 30s, no overlap
    chunk_len = int(sample_rate * segment * (1 + overlap))  # effectively 30s if overlap=0
    start = 0
    end = chunk_len

    overlap_frames = int(overlap * sample_rate)
    fade = Fade(fade_in_len=0, fade_out_len=overlap_frames, fade_shape="linear")

    # Prepare final buffer
    final = torch.zeros(batch, len(model.sources), channels, length, device=device)

    while start < length - overlap_frames:
        chunk = mix[:, :, start:end]
        with torch.no_grad():
            out = model(chunk)
        out = fade(out)
        final[:, :, :, start:end] += out

        if start == 0:
            fade.fade_in_len = overlap_frames
            start += chunk_len - overlap_frames
        else:
            start += chunk_len
        end += chunk_len
        if end >= length:
            fade.fade_out_len = 0

    return final

#### 2.3 Run Separation

In [None]:
print("Separating 30-second track with no overlap...")
sources_tensor = separate_sources(
    model,
    mixture_norm[None],  # shape (1, channels, samples)
    segment=30,
    overlap=0.0,
    device=device
)[0]  # shape (4, channels, samples)

# Undo normalization
sources_tensor = sources_tensor * ref.std() + ref.mean()

# Build a dict {stem_name -> predicted_stem}
stem_names = model.sources  # ['drums', 'bass', 'other', 'vocals'] typically
predicted_stems = dict(zip(stem_names, list(sources_tensor)))


#### 2.4 Evaluations

In [None]:
def output_results(original_source: torch.Tensor, predicted_source: torch.Tensor, source: str):
    # Move to CPU
    original_np = original_source.detach().cpu().numpy()
    predicted_np = predicted_source.detach().cpu().numpy()

    # If shape is (C, T), that's fine for mir_eval if C=2
    # but let's ensure it's (2, T) not (T, 2)
    # Usually PyTorch waveforms are (channels, samples),
    # which is correct for bss_eval_sources.

    # Verify the energy of the reference(sum of the absolutes for each channel).
    energy = original_source.abs().sum(dim=1)
    print(f"{source} - Energy per channel: {energy}")
    
    # If one of the cheannel has an energy below the energy threshold (1e-3), skip the evaluation
    if (energy < 1e-3).any():
        print(f"Warning: {source} reference appears silent or nearly silent. Skipping evaluation for this stem.")
        return None  # oppure ritorna un valore di default o una stringa informativa
    sdr, sir, sar, _ = separation.bss_eval_sources(
        reference_sources=original_np,
        estimated_sources=predicted_np
    )

    print(f"--- {source} ---")
    print("SDR:", sdr.mean())
    print("SIR:", sir.mean())
    print("SAR:", sar.mean())
    print("----------------")

    return Audio(predicted_source, rate=sample_rate)

# Retrieve references from dataset_dict
drums_ref = dataset_dict[track_chosen]["drums"].to(device)
bass_ref = dataset_dict[track_chosen]["bass"].to(device)
vocals_ref = dataset_dict[track_chosen]["vocals"].to(device)
other_ref = dataset_dict[track_chosen]["other"].to(device)

# Predicted
drums_pred = predicted_stems["drums"]
bass_pred  = predicted_stems["bass"]
vocals_pred = predicted_stems["vocals"]
other_pred = predicted_stems["other"]

# Evaluate each stem
output_results(drums_ref, drums_pred, "Drums")
output_results(bass_ref, bass_pred, "Bass")
output_results(vocals_ref, vocals_pred, "Vocals")
output_results(other_ref, other_pred, "Other")

In [None]:
import numpy as np

# Collect evaluation metrics
metrics = {
    "Drums": output_results(drums_ref, drums_pred, "Drums"),
    "Bass": output_results(bass_ref, bass_pred, "Bass"),
    "Vocals": output_results(vocals_ref, vocals_pred, "Vocals"),
    "Other": output_results(other_ref, other_pred, "Other"),
}

# Prepare data for plotting
stems = list(metrics.keys())
sdr_values = [separation.bss_eval_sources(drums_ref.cpu().numpy(), drums_pred.cpu().numpy())[0].mean(),
              separation.bss_eval_sources(bass_ref.cpu().numpy(), bass_pred.cpu().numpy())[0].mean(),
              separation.bss_eval_sources(vocals_ref.cpu().numpy(), vocals_pred.cpu().numpy())[0].mean(),
              separation.bss_eval_sources(other_ref.cpu().numpy(), other_pred.cpu().numpy())[0].mean()]

# Plot the results
plt.figure(figsize=(10, 6))
plt.bar(stems, sdr_values, color='skyblue')
plt.xlabel('Stems')
plt.ylabel('SDR (Signal-to-Distortion Ratio)')
plt.title('Evaluation Results: SDR for Each Stem')
plt.show()