In [None]:
import os
import csv
import librosa
import mir_eval
import librosa
import numpy as np
import argparse
import subprocess

In [None]:

# Go through each song in root_dir and separate its mixture files

root_dir = "/Users/kaimikkelsen/Downloads/test"
model_path = "Wave-U-Net-Pytorch/checkpoints/waveunet/model"


for dirpath, dirnames, filenames in os.walk(root_dir):
    # Check if mixture.wav exists in the current directory
    if "separated" in dirnames and "wav-u-net" in os.listdir(os.path.join(dirpath, "separated")):
        print(f"Skipping directory {dirpath} as separated/wav-u-net already exists.")
        continue
    
    if "mixture.wav" in filenames:
        print(dirpath)

        mixture_path = os.path.join(dirpath, "mixture.wav")
        
        output_dir = os.path.join(dirpath, "separated", "wav-u-net")

        #print(output_dir)
        os.makedirs(output_dir, exist_ok=True)
        
        # Define the command to run
        command = [
            "python3",
            "Wave-U-Net-Pytorch/predict.py",
            "--load_model",
            model_path,
            "--input",
            mixture_path,
            "--output",
            output_dir
        ]
        
        subprocess.run(command)


In [None]:

# Go through all reference and separated tracks and calculate SDR values

directory_path = "/Users/kaimikkelsen/Downloads/test"

def get_audio_files_wav_u_net(current_dir):
    current_dir = current_dir

    filenames = ['other.wav', 'vocals.wav', 'bass.wav', 'drums.wav']
    separated_filenames = ['mixture.wav_other.wav', 'mixture.wav_vocals.wav', 'mixture.wav_bass.wav', 'mixture.wav_drums.wav']

    file_paths = []

    for filename in filenames:
        files_found = glob.glob(os.path.join(current_dir, filename))
        file_paths.extend(files_found)

    mixture_dir = os.path.join(current_dir, 'separated', 'wav-u-net')

    mixture_file_paths = []

    for filename in separated_filenames:
        files_found = glob.glob(os.path.join(mixture_dir, filename))
        mixture_file_paths.extend(files_found)


    return file_paths, mixture_file_paths

def getSeparationMetrics(audio_ref_list, audio_list):
    reference_sources = np.array(audio_ref_list)
    estimated_sources = np.array(audio_list)
    (sdr, sir, sar, perm) = mir_eval.separation.bss_eval_sources(reference_sources, estimated_sources, False)
    return np.mean(sdr), np.mean(sir), np.mean(sar)

def write_to_csv(data, file_path):
    with open(file_path, mode='w', newline='') as file:
        writer = csv.writer(file)
        writer.writerow(["Directory", "Other", "Vocals", "Bass", "Drums"])
        for row in data:
            writer.writerow(row)


# Get a list of all directories in the specified directory
directories = [d for d in os.listdir(directory_path) if os.path.isdir(os.path.join(directory_path, d))]

has_run = False
data_to_write = []
counter = 0
for directory in directories:
    reference_files, separated_files = get_audio_files_wav_u_net(directory_path+"/"+directory)
    print(directory)
    print(f"{counter} of {len(directories)}")
    counter = counter + 1
    
    if(reference_files and separated_files and not has_run):

        row_data = [directory]
        for i in range(4):

            print(reference_files[i])
            print(separated_files[i])

            audio_ref, _ = librosa.load(reference_files[i])
            audio_sep, _ = librosa.load(separated_files[i])

            ref_list = [audio_ref]
            sep_list = [audio_sep]

            print(ref_list)

            sdr, sir, sar = getSeparationMetrics(ref_list, sep_list)
            print("SDR:", sdr)
            row_data.append(sdr)
        data_to_write.append(row_data)
        

csv_file_path = "/Users/kaimikkelsen/sdr_eval/wav-u-net-pytorch_output.csv"
write_to_csv(data_to_write, csv_file_path)
print("CSV file written successfully.")