In [1]:
import os
import sys
import warnings

ffmpeg_bin_path = r"C:\Users\arezk\AppData\Local\Microsoft\WinGet\Packages\Gyan.FFmpeg_Microsoft.Winget.Source_8wekyb3d8bbwe\ffmpeg-8.0.1-full_build\bin"
os.environ["PATH"] += os.pathsep + ffmpeg_bin_path

warnings.filterwarnings("ignore")
sys.path.append('..')

In [2]:
import torch
import numpy as np
from tqdm import tqdm
import musdb
import museval
import os
from glob import glob
import soundfile as sf
import IPython.display as ipd
from src.utils import separate_vocals
from src.model import Unet
from scipy.io import wavfile

In [3]:
# Paths
test_data_path = '../data/musdb18/'
output_folder = '../outputs/'

# Load the model
model = Unet()
model_path = '../checkpoints/musdb18_V2/model_last.pth'
model.load(model_path)

# Create output folder
os.makedirs(output_folder, exist_ok=True)


Loading model from ../checkpoints/musdb18_V2/model_last.pth
Model loaded successfully!


In [4]:
# Load MUSDB18 test set
mus_test = musdb.DB(root=test_data_path, is_wav=False, subsets='test')
print(f"Found {len(mus_test.tracks)} test tracks to process")

Found 50 test tracks to process


In [None]:
# Process each track
for i, track in enumerate(tqdm(mus_test.tracks, desc="Processing tracks")):
    track_name = track.name.replace(' ', '_').replace('/', '_')
    output_path = os.path.join(output_folder, f'{track_name}_vocal.wav')
    
    # Skip if already processed
    if os.path.exists(output_path):
        continue
    
    
    # Get audio from track
    audio = track.audio  # (n_samples, 2) stereo
    sr = track.rate      # Usually 44100
    
    
    # Separate vocals
    vocal_audio_predicted = separate_vocals(model, audio, sr=sr)
    
    
    # Save audio
    sf.write(output_path, vocal_audio_predicted, sr)
    
    del vocal_audio_predicted

print(f"\ All {len(mus_test.tracks)} tracks processed and saved to {output_folder}")

Processing tracks: 100%|██████████| 50/50 [10:53<00:00, 13.06s/it]

\ All 50 tracks processed and saved to ../outputs/





In [22]:
# display an example 
print(output_path)
# display(ipd.Audio(output_path, rate=44100))

../outputs/Zeno_-_Signs_vocal.wav


In [5]:
# Load MUSDB18 test set
mus_test = musdb.DB(root='../data/musdb18', is_wav=False, subsets='test')
separated_vocals_folder = '../outputs/'

all_scores = []

# Loop through all test tracks
for i, track in enumerate(mus_test.tracks):
    track_name = track.name.replace(' ', '_').replace('/', '_')
    
        
    # Load separated vocal
    vocal_file = os.path.join(separated_vocals_folder, f'{track_name}_vocal.wav')
    
    if not os.path.exists(vocal_file):
        print(f"   Warning: File not found - {vocal_file}")
        continue
    
    # Load separated vocal
    sr, vocal_estimate = wavfile.read(vocal_file)
    mixture_shape = track.audio.shape
    
 
    
    # Make stereo if mono
    if vocal_estimate.ndim == 1:
        vocal_estimate = np.stack([vocal_estimate, vocal_estimate], axis=1)
    
    
    # Match lengths - crop to shorter
    min_length = min(len(vocal_estimate), len(track.audio))
    vocal_estimate = vocal_estimate[:min_length]
    track_audio = track.audio[:min_length]
        
    accompaniment_estimate = track_audio - vocal_estimate
    
    # Prepare estimates dict
    estimates = {
        'vocals': vocal_estimate,
        'accompaniment': accompaniment_estimate
    }
    
    # Evaluate using museval
    scores = museval.eval_mus_track(track, estimates, output_dir='../eval_2')

    all_scores.append(scores)
    


    # Print results for this track

    vocals_scores = scores.df[scores.df['target'] == 'vocals']

    median_sdr = vocals_scores[vocals_scores['metric'] == 'SDR']['score'].median()
    median_sir = vocals_scores[vocals_scores['metric'] == 'SIR']['score'].median()
    median_sar = vocals_scores[vocals_scores['metric'] == 'SAR']['score'].median()

    print(f"  SDR: {median_sdr:.2f} dB")
    print(f"  SIR: {median_sir:.2f} dB")
    print(f"  SAR: {median_sar:.2f} dB")
    print()

    if i == 6 : 
        break; 


  SDR: -87.61 dB
  SIR: 11.86 dB
  SAR: 1.14 dB
  SDR: -87.90 dB
  SIR: 14.47 dB
  SAR: 0.26 dB
  SDR: -88.81 dB
  SIR: 14.24 dB
  SAR: 2.41 dB
  SDR: -88.54 dB
  SIR: 4.91 dB
  SAR: 2.13 dB
  SDR: -80.76 dB
  SIR: -6.88 dB
  SAR: -0.55 dB
  SDR: -89.34 dB
  SIR: 16.01 dB
  SAR: 1.63 dB
  SDR: -83.06 dB
  SIR: 15.58 dB
  SAR: 0.44 dB


In [21]:
# Calculate average scores
all_sdr = []
all_sir = []
all_sar = []

for score in all_scores:
    vocals_df = score.df[score.df['target'] == 'vocals']
    all_sdr.append(vocals_df[vocals_df['metric'] == 'SDR']['score'].median())
    all_sir.append(vocals_df[vocals_df['metric'] == 'SIR']['score'].median())
    all_sar.append(vocals_df[vocals_df['metric'] == 'SAR']['score'].median())

print("AVERAGE RESULTS")
print(f"Average SDR: {np.mean(all_sdr):.2f} ± {np.std(all_sdr):.2f} dB")
print(f"Average SIR: {np.mean(all_sir):.2f} ± {np.std(all_sir):.2f} dB")
print(f"Average SAR: {np.mean(all_sar):.2f} ± {np.std(all_sar):.2f} dB")

AVERAGE RESULTS
Average SDR: -86.90 ± 2.99 dB
Average SIR: 10.17 ± 7.25 dB
Average SAR: 1.17 ± 0.97 dB
