In [5]:
import os
from tqdm import tqdm
import librosa as rosa

path = "C:/Users/chris/Music/batchsize1_no_noise_speedperturb.zip"
sampling_rate = 8000

'''
return format:
    {"itemN": [(source, hat, hat_up),(source, hat, hat_up)] }
where: 
    ["itemN"][0] -> source1
    ["itemN"][1] -> source2

NOTE: assumes that all files are in a the same directory
'''
def read_all_audio(path):
   
    output = {}
    all_files = [f for f in os.listdir(path) if f.endswith(".wav") and "mix" not in f]
    # print(f"len(all_files)")
    for file_name in tqdm(all_files, total=len(all_files)) :
        filepath = os.path.join(path, file_name)
        split = file_name.split('_')
        item_name = split[0]
        file_info = split[1]
        arr_idx, tuple_idx = determine_idx(file_info)
        # print(split)
        if len(split) > 2:
            tuple_idx = 2
        try:
            output[item_name][arr_idx][tuple_idx], _ = rosa.load(filepath, sr=sampling_rate, mono=True)
        except KeyError:
            output[item_name] = [[-1,-1,-1],[-1,-1,-1]]
            output[item_name][arr_idx][tuple_idx], _ = rosa.load(filepath, sr=sampling_rate, mono=True)
    return output

def determine_idx(file_info):
    arr_idx = 0
    tuple_idx = 0
    if "source2" in file_info:
        arr_idx = 1
    if "hat" in file_info:
        tuple_idx = 1
    return arr_idx, tuple_idx

audios = read_all_audio(path)

100%|██████████| 50/50 [00:01<00:00, 30.47it/s]


In [None]:
# nuwave SNR implementation
# def snr(pred, target):
#     return (20 *torch.log10(torch.norm(target, dim=-1) \
#             /torch.norm(pred -target, dim =-1).clamp(min =1e-8))).mean()

In [28]:
import pandas as pd
from pesq import pesq

keys = list(audios.keys())
rate = 8000

metrics = []
for key in keys:
    audio = audios[key]
    for i,x in enumerate(audio):
        data = []
        name = f"{key}_source{i+1}"
        data.append(name)
        data.append(pesq(rate, x[0], x[1], 'nb')) # hat
        data.append(pesq(rate, x[0], x[2], 'nb')) # up
        metrics.append(data)
    break

df = pd.DataFrame(metrics, columns=["item", "hat", "up"])
hat = df["hat"].mean()
up = df["up"].mean()
df.loc[len(df)] = ["MEAN", hat, up]
save_name = path.split("/")[-1]
df.to_csv(f'./PESQ_metrics/{save_name}.csv', index=False)
df.tail()

Unnamed: 0,item,hat,up
0,item0_source1,2.20389,2.211942
1,item0_source2,1.948212,1.868015
2,MEAN,2.076051,2.039979
