In [3]:
import sys
import torch
import demucs
import museval
import librosa
import subprocess
import numpy as np
import pandas as pd
from pathlib import Path
from scipy.io import wavfile
from torch.utils.data import DataLoader

sys.path.append("../")
from asteroid.data import MUSDB18Dataset

In [4]:
DATA_DIR = Path("../../../data/musdb18")

TEST_AUDIO_DIR = Path("./audio_files")
SEPARATED_DIR = Path("./separated")

SEGMENT_SIZE = 1
RANDOM_TRACK_MIX = False
TARGETS = ["drums", "bass", "other", "vocals"]
N_SRC = len(TARGETS)
SAMPLE_RATE = 22050
SIZE = -1
TRAIN_BATCH_SIZE = 4
TEST_BATCH_SIZE = 4
NUM_WORKERS = 0
DEVICE = "cpu"

In [None]:
RESULT_DIR = Path("./results")
TMP_DIR = Path("./tmp")
INPUT_DIR = TMP_DIR/"input"
OUTPUT_DIR = TMP_DIR/"output"

RESULT_DIR.mkdir(parents=True, exist_ok=True)
INPUT_DIR.mkdir(parents=True, exist_ok=True)
OUTPUT_DIR.mkdir(parents=True, exist_ok=True)

test_dataset = MUSDB18Dataset(
        root=DATA_DIR.__str__(),
        targets=TARGETS,
        suffix=".mp4",
        split="test",
        subset=None,
        segment=4,
        samples_per_track=1,
        random_segments=True,
        random_track_mix=RANDOM_TRACK_MIX,
        sample_rate=SAMPLE_RATE,
        size=SIZE
    )
test_loader = DataLoader(test_dataset, batch_size=1, num_workers=0)

df_sdr = pd.DataFrame(data=[], columns=TARGETS)
df_isr = pd.DataFrame(data=[], columns=TARGETS)
df_sir = pd.DataFrame(data=[], columns=TARGETS)
df_sar = pd.DataFrame(data=[], columns=TARGETS)

for i, (x, y) in enumerate(test_loader):
    mix_p = INPUT_DIR/f"mix.wav"
    
    wavfile.write(str(mix_p), SAMPLE_RATE, x.view(-1).detach().numpy())
    
    command = f"python3 -m demucs.separate -o {str(OUTPUT_DIR)} -n mdx_extra '{str(mix_p.resolve())}'"
    try:
        subprocess.check_call(command, shell=True)
    except subprocess.CalledProcessError as e:
        print(f"command {e.cmd} exited with error code {e.returncode}.")
    
    
    pred = []
    for f in OUTPUT_DIR.rglob("*.wav"):
        s, _ = librosa.load(str(f), sr=SAMPLE_RATE, mono=True)
        pred.append(s)
    pred = np.array(pred)
    
    try:
        SDR, ISR, SIR, SAR, _ = museval.metrics.bss_eval(
                y[0],
                pred,
                compute_permutation=True,
                window=1*SAMPLE_RATE,
                hop=1*SAMPLE_RATE,
                framewise_filters=False,
                bsseval_sources_version=False
        )
        
        df_sdr = pd.concat([df_sdr, pd.DataFrame(data=[SDR.mean(axis=0).tolist()], columns=TARGETS)])
        df_isr = pd.concat([df_sdr, pd.DataFrame(data=[ISR.mean(axis=0).tolist()], columns=TARGETS)])
        df_sir = pd.concat([df_sdr, pd.DataFrame(data=[SIR.mean(axis=0).tolist()], columns=TARGETS)])
        df_sar = pd.concat([df_sdr, pd.DataFrame(data=[SAR.mean(axis=0).tolist()], columns=TARGETS)])
            
        print(SDR.mean(), ISR.mean(), SIR.mean(), SAR.mean())
    except:
        pass


df_sdr.to_csv(str(RESULT_DIR/f"df_sdr.csv"))
df_isr.to_csv(str(RESULT_DIR/f"df_isr.csv"))
df_sar.to_csv(str(RESULT_DIR/f"df_sar.csv"))
df_sir.to_csv(str(RESULT_DIR/f"df_sir.csv"))

res = []
for metric in ('sdr', 'isr', 'sar', 'sir'):
    df = eval(f"df_{metric}")
    r = [metric, *df.mean().values.tolist()]
    res.append(r)

df_results = pd.DataFrame(data=res, columns=["Metric", *TARGETS])
df_results.to_csv(str(RESULT_DIR/f"results.csv"))

In [11]:
df_sdr = pd.read_csv("results/df_sdr.csv", index_col=0)

In [12]:
df_sdr.mean()

drums     7.264383
bass      7.835729
other     7.389458
vocals    7.290056
dtype: float64