### Caution) We recommend to use GPU.

In [1]:
%matplotlib inline
import matplotlib.pyplot as plt
import IPython.display as ipd
import os
import sys
import concurrent
import random
import math
from pathlib import Path
import scipy.stats as st

os.environ["MKL_NUM_THREADS"]='1'
os.environ["NUMEXPR_NUM_THREADS"]='1'
os.environ["OMP_NUM_THREADS"]='1'

import torch
from torch.utils.data import DataLoader
from torchaudio.transforms import Resample
import numpy as np
from tqdm.notebook import tqdm
import librosa
import soundfile as sf

sys.path.insert(0, "../")
from utils import get_hparams
from models import get_wrapper
from functional import stft, spec_to_mel
from utils.data import get_dataset_dataloader
from utils.data.audio import Dataset, collate
from utils import HParams

os.environ["CUDA_VISIBLE_DEVICES"]='2'
device = 'cuda'     # 'cpu' or 'cuda'

# Load model

In [8]:
base_dir = "../logs/hil_audio"
hps = get_hparams(f"{base_dir}/config.yaml", base_dir)
wrapper = get_wrapper(hps.model)(hps, device=device)
wrapper.load()
wrapper.eval()

sr = hps.data.sampling_rate
hop_size = wrapper.hop_size

lookahead = getattr(hps.train, "delay", getattr(hps.train, "lookahead", 0))
hps.data.filelists["infer"] = f'../{hps.data.filelists["infer"]}'

n_params = 0
for n, p in wrapper.model.named_parameters():
    n_params += p.numel()
print(f"#params: {n_params/1000_000} M")
n_params = 0
for n, p in wrapper.model.decoder.named_parameters():
    n_params += p.numel()
print(f"Decoder #params: {n_params/1000_000} M")

Loading checkpoint file '../logs/encodec_disc_ablation/grad_none_1/00050.pth'...
#params: 9.577019 M
Decoder #params: 6.587246 M


# Speech Dataset

In [7]:
data_sr = 48_000
hp = HParams(**dict(
    data=dict(
        dataset="Dataset",
        wav_dir="/home/shahn/Datasets/VCTK-Corpus/wav48",
        data_dir = "",
        extension="",
        filelists=dict(
            pesq="/home/shahn/Documents/trainer/filelists/etc/VCTK_valid.txt"
        ),
        filter=dict(
            pesq=True
        ),
        normalize_method=None,
        channel=1,
        sampling_rate=data_sr,
    ),
    train=dict(),
    pesq=dict(
        batch_size=20,
        num_workers=0,
    )
))
resampler48khz = Resample(sr, 48000).to(device)
resampler16khz = Resample(sr, 16000).to(device)
resampler10khz = Resample(sr, 10000).to(device)
dataset, dataloader = get_dataset_dataloader(hp, mode="pesq", keys=["wav", "wav_len"])

                                                                                                                                                

pesq dataset filtered: 587/587




# Audio Dataset (Choose b/w Speech and Audio)

In [9]:
hp = HParams(**dict(
    data=dict(
        dataset="Dataset",
        wav_dir="/home/shahn/Datasets",
        data_dir = "",
        extension="",
        filelists=dict(
            pesq="/home/shahn/Documents/trainer/filelists/DNS/DNS_VCTK_jamendo_pesq_24khz.txt"
        ),
        filter=dict(
            pesq=True
        ),
        normalize_method=None,
        channel=1,
        sampling_rate=sr,
    ),
    train=dict(),
    pesq=dict(
        batch_size=5,
        num_workers=0,
    )
))
resampler48khz = Resample(sr, 48000).to(device)
resampler16khz = Resample(sr, 16000).to(device)
resampler10khz = Resample(sr, 10000).to(device)
dataset, dataloader = get_dataset_dataloader(hp, mode="pesq", keys=["wav", "wav_len"])

                                                                                               

pesq dataset filtered: 300/300




# Calculate Metrics using Multi Processing (Fast, but may crash your server!)

In [10]:
from pesq import pesq
from pystoi import stoi
from utils.measure_visqol import measure_visqol


SAMPLING_RATE = {
    "pesq": 16_000,
    "stoi": 10_000,
    "visqol": 16_000,
    "visqol_audio": 48_000,
}

def metric(ref, deg, wav_len, mode, idx=0) -> int:
    mode = mode.lower()
    mode_sr = SAMPLING_RATE[mode]
    wav_len = int(wav_len * mode_sr / sr)
    ref = ref[:wav_len]
    deg = deg[:wav_len]
    if mode == "pesq":
        return pesq(16000, ref, deg, "wb")
    elif mode == "pystoi" or mode == "stoi":
        return stoi(ref, deg, 10000)
    elif mode == "visqol":
        return measure_visqol(ref, deg, idx, "speech")
    elif mode == "visqol_audio":
        return measure_visqol(ref, deg, idx, "audio")

In [None]:
N = 4   # number of quantizers in the RVQ
pesq_list, stoi_list, visqol_list, va_list = [], [], [], []
max_items = 0
pesq_futures, visqol_futures, va_futures = [], [], []

calc_pesq, calc_stoi, calc_visqol, calc_va = False, False, False, True

with concurrent.futures.ProcessPoolExecutor(max_workers=32) as executor:
    for batch in tqdm(dataloader, desc="PESQ", leave=False, dynamic_ncols=True):
        wav_r = batch["wav"].to(device).unsqueeze(1)
        wav_lens = batch["wav_len"]
        batch_size = wav_r.size(0)

        batch_wav_len = wav_r.size(-1) // hop_size * hop_size
        wav_r = wav_r[..., :batch_wav_len]
        with torch.no_grad():
            wav_g, *_ = wrapper.model(wav_r, n=N)
            if lookahead > 0:
                wav_r = wav_r[..., :-lookahead]
                wav_g = wav_g[..., lookahead:]
        
        if calc_pesq or calc_visqol:
            wav_r_pesq = resampler16khz(wav_r).cpu().numpy()
            wav_g_pesq = resampler16khz(wav_g).cpu().numpy()
        if calc_stoi:
            wav_r_stoi = resampler10khz(wav_r).cpu().numpy()
            wav_g_stoi = resampler10khz(wav_g).cpu().numpy()
        if calc_va:
            wav_r_va = resampler48khz(wav_r).cpu().numpy()
            wav_g_va = resampler48khz(wav_g).cpu().numpy()

        for i in range(batch_size):
            file_idx = max_items + i
            if calc_pesq:
                pesq_futures.append(executor.submit(metric, wav_r_pesq[i, 0], wav_g_pesq[i, 0], wav_lens[i], "pesq"))
            if calc_visqol:
                visqol_futures.append(executor.submit(metric, wav_r_pesq[i, 0], wav_g_pesq[i, 0], wav_lens[i], "visqol", file_idx))
            if calc_stoi:
                stoi_list.append(metric(wav_r_stoi[i, 0], wav_g_stoi[i, 0], wav_lens[i], "stoi"))
            if calc_va:
                va_futures.append(executor.submit(metric, wav_r_va[i, 0], wav_g_va[i, 0], wav_lens[i], "visqol_audio", file_idx))
        max_items += batch_size
    
    if calc_pesq:
        for idx, future in tqdm(
            enumerate(concurrent.futures.as_completed(pesq_futures), start=1),
            desc='pesq', total=max_items, leave=False
        ):
            pesq_list.append(future.result())
    if calc_visqol:
        for idx, future in tqdm(
            enumerate(concurrent.futures.as_completed(visqol_futures), start=1),
            desc='visqol', total=max_items, leave=False
        ):
            visqol_list.append(future.result())
    if calc_va:
        for idx, future in tqdm(
            enumerate(concurrent.futures.as_completed(va_futures), start=1),
            desc='visqol_audio', total=max_items, leave=False
        ):
            va_list.append(future.result())

if calc_pesq:
    pesq_mean = sum(pesq_list) / max_items
    pesq_ci = pesq_mean - st.t.interval(confidence=0.95, df=len(pesq_list)-1, loc=pesq_mean, scale=st.sem(pesq_list))[0]
    print(f"\rPESQ: {pesq_mean} +- {pesq_ci}", flush=True)
if calc_stoi:
    stoi_mean = sum(stoi_list) / max_items
    stoi_ci = stoi_mean - st.t.interval(confidence=0.95, df=len(stoi_list)-1, loc=stoi_mean, scale=st.sem(stoi_list))[0]
    print(f"\rSTOI: {stoi_mean} +- {stoi_ci}", flush=True)
if calc_visqol:
    visqol_mean = sum(visqol_list) / max_items
    visqol_ci = visqol_mean - st.t.interval(confidence=0.95, df=len(visqol_list)-1, loc=visqol_mean, scale=st.sem(visqol_list))[0]
    print(f"\rViSQOL: {visqol_mean} +- {visqol_ci}", flush=True)
if calc_va:
    va_mean = sum(va_list) / max_items
    va_ci = va_mean - st.t.interval(confidence=0.95, df=len(va_list)-1, loc=va_mean, scale=st.sem(va_list))[0]
    print(f"ViSQOL Audio: {va_mean} +- {va_ci}")

PESQ:   0%|                                                               | 0/60 [00:00<?, ?it/s]

visqol_audio:   0%|          | 0/300 [00:00<?, ?it/s]

# Calculate Metrics using Single Process (Very slow, but we didn't experienced any crash)

In [12]:
N = 4
for name in ["grad_none_1"]:
    base_dir = f"../logs/encodec_disc_ablation/{name}"
    hps = get_hparams(f"{base_dir}/config.yaml", base_dir)
    if getattr(hps.model_kwargs, "act_norm", None) == "SyncBatchNorm":
        hps.model_kwargs.act_norm = "BatchNorm1d"
    wrapper = get_wrapper(hps.model)(hps, device=device)
    wrapper.load()
    wrapper.eval()
    
    va_list = []
    max_items = 0

    for batch in tqdm(dataloader, desc="PESQ", leave=False, dynamic_ncols=True):
        wav_r = batch["wav"].to(device).unsqueeze(1)
        wav_lens = batch["wav_len"]
        batch_size = wav_r.size(0)

        batch_wav_len = wav_r.size(-1) // hop_size * hop_size
        wav_r = wav_r[..., :batch_wav_len]
        with torch.no_grad():
            wav_g, *_ = wrapper.model(wav_r, n=N)
            if lookahead > 0:
                wav_r = wav_r[..., :-lookahead]
                wav_g = wav_g[..., lookahead:]

        wav_r_va = resampler48khz(wav_r).cpu().numpy()
        wav_g_va = resampler48khz(wav_g).cpu().numpy()

        for i in range(batch_size):
            file_idx = max_items + i
            va_list.append(metric(wav_r_va[i, 0], wav_g_va[i, 0], wav_lens[i], "visqol_audio", file_idx))
        max_items += batch_size

    va_mean = sum(va_list) / max_items
    va_ci = va_mean - st.t.interval(confidence=0.95, df=len(va_list)-1, loc=va_mean, scale=st.sem(va_list))[0]
    print(f"{name}: {va_mean} +- {va_ci}")

Loading checkpoint file '../logs/encodec_disc_ablation/grad_none_1/00050.pth'...


PESQ:   0%|                                                             | 0/60 [00:00<?, ?it/s]

grad_none_1: 4.148787836073998 +- 0.021267010126748254
