In [13]:
%matplotlib inline
import matplotlib.pyplot as plt
from matplotlib.colors import LogNorm
import os, json, random, sys
sys.path.insert(0, "../")
import soundfile as sf
import torch
from torch.utils.data import DataLoader
from soundfile import write, read
import numpy as np
import IPython.display as ipd
import math
from utils import get_hparams
from functional import mel_spectrogram, stft, spec_to_mel
from models import get_wrapper
from utils.data_audio import AECDataset
from pypesq import pesq
from tqdm import tqdm
from librosa import resample
from models.dctcrn.default.losses import si_snr

os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID"   
os.environ["CUDA_VISIBLE_DEVICES"]="3"


device = 'cuda:0'

In [14]:
name = "with_adam_block2"
epoch = None

base_dir = os.path.join("../logs", name)
try:
    hps = get_hparams(os.path.join(base_dir, "config.json"), base_dir)
except FileNotFoundError:
    hps = get_hparams(os.path.join(base_dir, "config.yaml"), base_dir)
#hps.model_kwargs.viterbi_legacy = False

wrapper = get_wrapper(hps.model)(hps, device=device)
wrapper.load(epoch=300)
wrapper.eval()

Loading checkpoint file '../logs/with_adam_block2/00300.pth'...


# Dataset

In [15]:
hps.data.segment_size = 320000
dataset = AECDataset(hps.data, mode="valid")
print(dataset[3]['mix'].shape)
#dataset.files = dataset.files[:100]
#print(len(dataset))
dataloader = DataLoader(dataset, batch_size=5)

(320000,)


# Utils

In [16]:
def vad(arr, threshold=1e-4):
    start, stop = 0, 0
    for i in range(arr.size):
        if arr[i] > threshold:
            start = i
            break
    for i in reversed(range(arr.size)):
        if arr[i] > threshold:
            stop = i
            break
    
    return start, stop


def cal_ERLE(src_mix, src_est, start, stop):
    #pow_mix = np.sum(np.power(src_mix[:start], 2)) + np.sum(np.power(src_mix[stop:], 2))
    #pow_est = np.sum(np.power(src_est[:start], 2)) + np.sum(np.power(src_est[stop:], 2))
    pow_mix = np.sum(np.power(src_mix, 2))
    pow_est = np.sum(np.power(src_est, 2))
    erle = 10 * math.log(pow_mix/pow_est)

    return erle

# Batched Inference (Fast)

In [18]:
pesq_mean = 0.
ERLE_mean = 0.
print(name)
os.makedirs(os.path.join("../samples", name), exist_ok=True)
os.makedirs(os.path.join("../samples", name, 'single_talk/est'), exist_ok=True)
os.makedirs(os.path.join("../samples", name, 'single_talk/echo'), exist_ok=True)
for idx, batch in enumerate(dataloader):
    near = batch["near"]
    far = batch["far"]
    mix = batch["mix"]
    #echo = batch['echo']
    #wav_len = near.size(-1) // hps.model_kwargs.hop_size * hps.model_kwargs.hop_size
    #near = near[..., :wav_len]
    #far = far[..., :wav_len]
    #mix = mix[..., :wav_len]
    b = near.size(0)
    with torch.no_grad():    
        if hasattr(wrapper.model, "autoregressive_valid"):
            wav_out = wrapper.model.autoregressive_valid(far.view(b, -1).to(device), mix.view(b, -1).to(device))
        else:
            wav_out, echo_out = wrapper.model(far.view(b, -1).to(device), mix.view(b, -1).to(device))
        wav_out = wav_out.squeeze()
        echo_out = echo_out.squeeze()

    #wav_out = wav_out.clip(min=-1.0, max=1.0)
    for i in range(b):
        #sf.write('./dctcrn_sample_{}.wav'.format(i), wav_out[i].cpu().numpy(), samplerate=16000)
        #print('save')

        n = near[i]
        wo = wav_out[i]
        eo = echo_out[i]
        m = mix[i]
        n = n.cpu().numpy()
        m = m.cpu().numpy()
        wo = wo.cpu().numpy()
        eo = eo.cpu().numpy()
        l, r = vad(n)
        ERLE_mean += cal_ERLE(m[:160000], wo[:160000], l, r)
        pesq_mean += pesq(n[160000:], wo[160000:], 16000)
        #ERLE_mean += avg_erle
        sf.write(os.path.join('../samples/', name, 'est', 'near_est_'+str(idx*b + i+1) +'.wav'), wo, 16000)
        sf.write(os.path.join('../samples/', name, 'echo', 'echo_est_' + str(idx*b +i+1) + '.wav'), eo, 16000)
        print(f"\r{idx*b + i+1}/{len(dataset)} - {pesq_mean / (idx*b + i+1)}", end=" ", flush=True)
        #print(f"\r{idx*b + i+1}/{len(dataset)} - {ERLE_mean / (idx*b + i+1)}", end=" ", flush=True)

pesq_mean /= (idx*b + i+1)
ERLE_mean /= (idx*b + i+1)
print("pesq_NEST = {}, erle_FEST = {}".format(pesq_mean, ERLE_mean))

with_adam_block2
500/500 - 4.339033418178558  pesq_NEST = 4.339033418178558, erle_FEST = 67.61634806453888


# Noisy's PESQ

In [38]:
pesq_mean = 0.
print(name)

for idx, batch in enumerate(dataloader):
    near = batch["near"]
    mix = batch["mix"]
    for i in range(b):
        n = near[i]
        m = mix[i]
        pesq_mean += pesq(n.cpu().numpy(), m.cpu().numpy(), 16000)
        print(f"\r{idx*b + i+1}/{len(dataset)} - {pesq_mean / (idx*b + i+1)}", end=" ", flush=True)
pesq_mean /= (idx*b + i+1)
print("")

echofilter_like_full
81/500 - nan 68844826462902 



500/500 - nan 
