In [1]:
import os
from tqdm import tqdm
import librosa as rosa

# path = "C:/Users/chris/Music/batchsize1_no_noise_speedperturb.zip"
sampling_rate = 8000

'''
return format:
    {"itemN": [(source, hat, hat_up),(source, hat, hat_up)] }
where: 
    ["itemN"][0] -> source1
    ["itemN"][1] -> source2

NOTE: assumes that all files are in a the same directory
'''
def read_all_audio(path):
   
    output = {}
    all_files = [f for f in os.listdir(path) if f.endswith(".wav") and "mix" not in f]
    # print(f"len(all_files)")
    for file_name in tqdm(all_files, total=len(all_files)) :
        filepath = os.path.join(path, file_name)
        split = file_name.split('_')
        item_name = split[0]
        file_info = split[1]
        arr_idx, tuple_idx = determine_idx(file_info)
        # print(split)
        if len(split) > 2:
            tuple_idx = 2
        try:
            output[item_name][arr_idx][tuple_idx], _ = rosa.load(filepath, sr=sampling_rate, mono=True)
        except KeyError:
            output[item_name] = [[-1,-1,-1],[-1,-1,-1]]
            output[item_name][arr_idx][tuple_idx], _ = rosa.load(filepath, sr=sampling_rate, mono=True)
    return output

def determine_idx(file_info):
    arr_idx = 0
    tuple_idx = 0
    if "source2" in file_info:
        arr_idx = 1
    if "hat" in file_info:
        tuple_idx = 1
    return arr_idx, tuple_idx

# audios = read_all_audio(path)

In [2]:
# nuwave SNR implementation
# def snr(pred, target):
#     return (20 *torch.log10(torch.norm(target, dim=-1) \
#             /torch.norm(pred -target, dim =-1).clamp(min =1e-8))).mean()

In [3]:
import pandas as pd
from pesq import pesq

def get_pesq(path):
    audios = read_all_audio(path)
    keys = list(audios.keys())
    rate = 8000

    metrics = []
    for key in keys:
        audio = audios[key]
        for i,x in enumerate(audio):
            data = []
            name = f"{key}_source{i+1}"
            data.append(name)
            data.append(pesq(rate, x[0], x[1], 'nb')) # hat
            data.append(pesq(rate, x[0], x[2], 'nb')) # up
            metrics.append(data)
        break

    df = pd.DataFrame(metrics, columns=["item", "hat", "up"])
    hat = df["hat"].mean()
    up = df["up"].mean()
    df.loc[len(df)] = ["MEAN", hat, up]
    print(f'hat: {hat} | up: {up}')
    # save_name = path.split("/")[-1]
    # df.to_csv(f'./PESQ_metrics/{save_name}.csv', index=False)
    df.tail()
configs = ["epoch_30", "no_noise_speedperturb","w_noise_speedperturb", "w_noise_wavedrop"]

for config in configs:
    path = f"results/{config}/audio_results"
    print(f"Config: {config}")
    get_pesq(path)

Config: epoch_30


100%|██████████| 14768/14768 [06:03<00:00, 40.60it/s] 


hat: 3.8756214380264282 | up: 3.8525179624557495
Config: no_noise_speedperturb


100%|██████████| 18000/18000 [10:50<00:00, 27.68it/s]


hat: 2.076050877571106 | up: 2.039978504180908
Config: w_noise_speedperturb


100%|██████████| 18000/18000 [07:35<00:00, 39.50it/s]


hat: 2.401853561401367 | up: 2.37525737285614
Config: w_noise_wavedrop


100%|██████████| 18000/18000 [06:54<00:00, 43.43it/s]


hat: 1.9234905242919922 | up: 1.9087774753570557


In [4]:
# !pip install pesq