In [34]:
import os
import pandas as pd
from torch.utils.data import Dataset
import torchaudio
import librosa
import random
import numpy as np
from pesq import pesq
from tqdm import tqdm

In [14]:
def audio_loading(path,sampling_rate=16000):

    audio, fs = librosa.load(path, sr=None)
    if len(audio.shape) > 1:
        audio = librosa.to_mono(audio)

    if fs != sampling_rate:
        audio = librosa.resample(audio,fs,sampling_rate)

    return audio

def clip_audio(audio,clip_sec=3,sampling_rate=16000):
    window_size = clip_sec*sampling_rate
    start = random.randrange(0, len(audio)-window_size)
    return audio[start:start+window_size]

def clip_2_audios(audio_1,audio_2,clip_sec=3,sampling_rate=16000):
    window_size = clip_sec*sampling_rate
    start = random.randrange(0, len(audio_1)-window_size)
    return audio_1[start:start+window_size],audio_2[start:start+window_size]

def si_sdr(deg_audio,reference_audio):
    eps = np.finfo(deg_audio.dtype).eps
    reference = reference_audio.reshape(reference_audio.size, 1)
    estimate = deg_audio.reshape(deg_audio.size, 1)
    Rss = np.dot(reference.T, reference)

    a = (eps + np.dot(reference.T, estimate)) / (Rss + eps)

    e_true = a * reference
    e_res = estimate - e_true

    Sss = (e_true**2).sum()
    Snn = (e_res**2).sum()

    return 10 * np.log10((eps+ Sss)/(eps + Snn))

In [15]:
NISQA_TRAIN_SIM_file = pd.read_csv("/work/data/speech_metrics_eval/NISQA_Corpus/NISQA_TRAIN_SIM/NISQA_TRAIN_SIM_file.csv")

In [47]:
NISQA_TRAIN_SIM_file = pd.read_csv("/work/data/speech_metrics_eval/NISQA_Corpus/NISQA_VAL_SIM/NISQA_VAL_SIM_file.csv")

In [48]:
NISQA_TRAIN_SIM_file

Unnamed: 0,db,con,file,con_description,filename_deg,filename_ref,source,lang,votes,mos,...,bp_high,p50_q,bMode1,bMode2,bMode3,FER1,FER2,FER3,asl_in_level,asl_out_level
0,NISQA_VAL_SIM,1.0,1,simulated,c0001_3_1026_2_7_001-ch6-speaker_seg58.wav,3_1026_2_7_001-ch6-speaker_seg58.wav,AusTalk,en,5.0,2.800,...,,,,,,,,,,
1,NISQA_VAL_SIM,2.0,2,simulated,c0002_book_02113_chp_0007_reader_01566_6_seg.wav,book_02113_chp_0007_reader_01566_6_seg.wav,DNS,en,5.0,3.600,...,,,,,,,,,,
2,NISQA_VAL_SIM,3.0,3,simulated,c0003_book_09928_chp_0007_reader_03595_42_seg.wav,book_09928_chp_0007_reader_03595_42_seg.wav,DNS,en,5.0,3.800,...,,,,,,,,,,
3,NISQA_VAL_SIM,4.0,4,simulated,c0004_1_301_2_7_001-ch6-speaker_seg36.wav,1_301_2_7_001-ch6-speaker_seg36.wav,AusTalk,en,5.0,3.800,...,,,,,,,,,,
4,NISQA_VAL_SIM,5.0,5,simulated,c0005_3_997_2_7_001-ch6-speaker_seg18.wav,3_997_2_7_001-ch6-speaker_seg18.wav,AusTalk,en,8.0,1.625,...,,,,,,,,,,
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
2495,NISQA_VAL_SIM,2496.0,2496,simulated,c2496_nom_07049_02002667939_seg.wav,nom_07049_02002667939_seg.wav,UKIRE,en,5.0,2.800,...,,,10.0,,,,,,,-31.312625
2496,NISQA_VAL_SIM,2497.0,2497,simulated,c2497_book_04189_chp_0005_reader_10125_60_seg.wav,book_04189_chp_0005_reader_10125_60_seg.wav,DNS,en,5.0,1.600,...,,,1.0,,,0.084359,,,,-44.649299
2497,NISQA_VAL_SIM,2498.0,2498,simulated,c2498_book_03028_chp_0016_reader_09104_3_seg.wav,book_03028_chp_0016_reader_09104_3_seg.wav,DNS,en,5.0,2.000,...,,,2.0,,,0.158718,,,,-26.903808
2498,NISQA_VAL_SIM,2499.0,2499,simulated,c2499_nom_06136_00396141351_seg.wav,nom_06136_00396141351_seg.wav,UKIRE,en,7.0,1.000,...,,,2.0,,,0.158718,,,,-66.803607


In [49]:
csv_path="/work/data/speech_metrics_eval/NISQA_Corpus/NISQA_VAL_SIM/NISQA_VAL_SIM_file.csv"
base_path="/work/data/speech_metrics_eval/NISQA_Corpus/"

In [50]:
len(NISQA_TRAIN_SIM_file)

2500

In [51]:
sampling_rate = 16000
clip_sec = 3
pesq_val_all = []
si_sdr_val_all =  []
for i in range(len(NISQA_TRAIN_SIM_file)):
    entry = NISQA_TRAIN_SIM_file.iloc[i]

    deg_file = os.path.join(base_path, entry.filepath_deg)
    ref_file = os.path.join(base_path, entry.filepath_ref)

    deg_audio = audio_loading(deg_file,sampling_rate)
    ref_audio = audio_loading(ref_file,sampling_rate)
    try:
        pesq_val = pesq(sampling_rate, ref_audio, deg_audio, 'wb')
    except:
        pesq_val = None

    try:
        si_sdr_val = si_sdr(deg_audio,ref_audio)
    except:
        si_sdr_val = None

    pesq_val_all.append(pesq_val)
    si_sdr_val_all.append(si_sdr_val)

In [36]:
len(si_sdr_val_all)

10000

In [None]:
NISQA_TRAIN_SIM_file["si_sdr"] = si_sdr_val_all

In [None]:
NISQA_TRAIN_SIM_file["pesq"] = pesq_val_all

In [None]:
NISQA_TRAIN_SIM_file.to_csv("/work/data/speech_metrics_eval/NISQA_Corpus/NISQA_VAL_SIM/NISQA_VAL_SIM_file_pesq_si_sdr.csv")

In [1]:
import pandas as pd

In [13]:
NISQA_TRAIN_SIM_file = pd.read_csv("/work/data/speech_metrics_eval/NISQA_Corpus/NISQA_TRAIN_SIM/NISQA_TRAIN_SIM_file_pesq_si_sdr.csv")

In [11]:
NISQA_TRAIN_SIM_file["pesq"].notna()

0        True
1        True
2        True
3        True
4        True
        ...  
2495     True
2496    False
2497     True
2498     True
2499     True
Name: pesq, Length: 2500, dtype: bool

In [16]:
NISQA_TRAIN_SIM_file[NISQA_TRAIN_SIM_file["mos"].notna()]

Unnamed: 0.1,Unnamed: 0,db,con,file,con_description,filename_deg,filename_ref,source,lang,votes,...,bMode1,bMode2,bMode3,FER1,FER2,FER3,asl_in_level,asl_out_level,si_sdr,pesq
0,0,NISQA_TRAIN_SIM,1.0,1,simulated,c00001_3_640_2_7_001-ch6-speaker_seg49.wav,3_640_2_7_001-ch6-speaker_seg49.wav,AusTalk,en,5.0,...,,,,,,,,,3.201192,1.052332
1,1,NISQA_TRAIN_SIM,2.0,2,simulated,c00002_1_319_2_7_001-ch6-speaker_seg63.wav,1_319_2_7_001-ch6-speaker_seg63.wav,AusTalk,en,5.0,...,,,,,,,,,8.855734,1.145391
2,2,NISQA_TRAIN_SIM,3.0,3,simulated,c00003_3_864_2_7_001-ch6-speaker_seg91.wav,3_864_2_7_001-ch6-speaker_seg91.wav,AusTalk,en,5.0,...,,,,,,,,,12.327173,1.204837
3,3,NISQA_TRAIN_SIM,4.0,4,simulated,c00004_2_1011_2_7_001-ch6-speaker_seg41.wav,2_1011_2_7_001-ch6-speaker_seg41.wav,AusTalk,en,5.0,...,,,,,,,,,10.273271,1.300005
4,4,NISQA_TRAIN_SIM,5.0,5,simulated,c00005_book_01337_chp_0002_reader_06379_5_seg.wav,book_01337_chp_0002_reader_06379_5_seg.wav,DNS,en,5.0,...,,,,,,,,,16.010752,1.472057
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
9995,9995,NISQA_TRAIN_SIM,9996.0,9996,simulated,c09996_mif_03397_00323124543_seg.wav,mif_03397_00323124543_seg.wav,UKIRE,en,8.0,...,1.0,,,0.270256,,,,-60.410822,-1.069621,1.128866
9996,9996,NISQA_TRAIN_SIM,9997.0,9997,simulated,c09997_sof_05223_01829932752_seg.wav,sof_05223_01829932752_seg.wav,UKIRE,en,5.0,...,1.0,,,0.114103,,,,-26.573146,2.950267,1.172490
9997,9997,NISQA_TRAIN_SIM,9998.0,9998,simulated,c09998_book_01746_chp_0010_reader_07461_22_seg...,book_01746_chp_0010_reader_07461_22_seg.wav,DNS,en,5.0,...,1.0,,,,,,,-38.366733,-8.409427,1.117868
9998,9998,NISQA_TRAIN_SIM,9999.0,9999,simulated,c09999_2_518_2_7_001-ch6-speaker_seg46.wav,2_518_2_7_001-ch6-speaker_seg46.wav,AusTalk,en,5.0,...,3.0,,,,,,,-63.386774,-4.228194,1.435990


In [42]:
# NISQA_TRAIN_SIM_file.to_csv("/work/data/speech_metrics_eval/NISQA_Corpus/NISQA_TRAIN_SIM/NISQA_TRAIN_SIM_file_pesq_si_sdr.csv")

In [31]:
print(pesq_val,si_sdr_val)

1.0523324012756348 3.20119172334671


In [25]:
np.concatenate((cliped_ref_audio,cliped_ref_audio))

array([-0.00202031, -0.00186867, -0.00242918, ...,  0.00154356,
        0.00142138,  0.00132616], dtype=float32)

In [26]:
pesq(sampling_rate, np.concatenate((cliped_ref_audio,cliped_ref_audio)), np.concatenate((cliped_ref_audio,cliped_ref_audio)), 'wb')

4.643888473510742

In [1]:
from models.multihead import Multihead_Wav2vec

In [2]:
import torch

In [4]:
chckpt = torch.load("/home/filip/speech_metrics_eval/training_checkpoints/multihead_wav2vec_2/checkpoint_1000.pt")

In [6]:
model = Multihead_Wav2vec()
model.load_state_dict(chckpt["model_state_dict"])
model.eval()

	Error importing 'hydra_plugins.hydra_colorlog'.
	Plugin is incompatible with this Hydra version or buggy.
	Recommended to uninstall or upgrade plugin.
		ImportError : cannot import name 'SearchPathPlugin' from 'hydra.plugins' (/work/miniconda3/envs/sayso_dev/lib/python3.9/site-packages/hydra/plugins/__init__.py)


Multihead_Wav2vec(
  (input_layer): Wav2Vec2Model(
    (feature_extractor): ConvFeatureExtractionModel(
      (conv_layers): ModuleList(
        (0): Sequential(
          (0): Conv1d(1, 512, kernel_size=(10,), stride=(5,), bias=False)
          (1): Dropout(p=0.0, inplace=False)
          (2): Fp32GroupNorm(512, 512, eps=1e-05, affine=True)
          (3): GELU()
        )
        (1): Sequential(
          (0): Conv1d(512, 512, kernel_size=(3,), stride=(2,), bias=False)
          (1): Dropout(p=0.0, inplace=False)
          (2): GELU()
        )
        (2): Sequential(
          (0): Conv1d(512, 512, kernel_size=(3,), stride=(2,), bias=False)
          (1): Dropout(p=0.0, inplace=False)
          (2): GELU()
        )
        (3): Sequential(
          (0): Conv1d(512, 512, kernel_size=(3,), stride=(2,), bias=False)
          (1): Dropout(p=0.0, inplace=False)
          (2): GELU()
        )
        (4): Sequential(
          (0): Conv1d(512, 512, kernel_size=(3,), stride=(2,), bias=

: 