In [1]:
%matplotlib inline
import matplotlib.pyplot as plt
from matplotlib.colors import LogNorm
import os, json, random, sys
sys.path.insert(0, "../")
import soundfile as sf
import torch
from torch.utils.data import DataLoader
from soundfile import write, read
import numpy as np
import IPython.display as ipd
import math
from utils import get_hparams
from functional import mel_spectrogram, stft, spec_to_mel
from models import get_wrapper
from pypesq import pesq
from tqdm import tqdm
import librosa
from librosa import resample
from models.dctcrn.default.losses import si_snr
import glob

#os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID"   
#os.environ["CUDA_VISIBLE_DEVICES"]="5"


#device = 'cuda:0'

In [2]:
class AECDataset(torch.utils.data.Dataset):
    def __init__(self, path=None):
        super().__init__()
        self.path = path
        self.files = glob.glob(path + '/**/*_mic.wav')
        self.sampling_rate = 16000

    def shuffle(self, seed: int):
        random.seed(seed)
        random.shuffle(self.files)
    
    def __len__(self):
        return len(self.files)
    
    def __getitem__(self, idx):
        _id = self.files[idx]

        far, sr = librosa.core.load(
            _id.replace('mic', 'lpb'),
            sr=None
        )
        assert sr == self.sampling_rate
        mix, sr = librosa.core.load(
            _id,
            sr=None
        )
        assert sr == self.sampling_rate

        res = {"far": far, "mix": mix}
        
        return res

In [7]:
name = "dctcrn_64000"
epoch = None

base_dir = os.path.join("../logs", name)
try:
    hps = get_hparams(os.path.join(base_dir, "config.json"), base_dir)
except FileNotFoundError:
    hps = get_hparams(os.path.join(base_dir, "config.yaml"), base_dir)
#hps.model_kwargs.viterbi_legacy = False

wrapper = get_wrapper(hps.model)(hps, device=device)
wrapper.load(epoch=800)
wrapper.eval()

Loading checkpoint file '../logs/dctcrn_64000/00800.pth'...


# Dataset

In [8]:
path = '/home/jhkim21/Data/AEC-Challenge/datasets/blind_test_set_interspeech2021'
dataset = AECDataset(path)
#dataset.files = dataset.files[:100]
#print(len(dataset))
dataloader = DataLoader(dataset, batch_size=1)

# Utils

In [9]:
def vad(arr, threshold=1e-4):
    start, stop = 0, 0
    for i in range(arr.size):
        if arr[i] > threshold:
            start = i
            break
    for i in reversed(range(arr.size)):
        if arr[i] > threshold:
            stop = i
            break
    
    return start, stop


def cal_ERLE(src_mix, src_est, start, stop):
    #pow_mix = np.sum(np.power(src_mix[:start], 2)) + np.sum(np.power(src_mix[stop:], 2))
    #pow_est = np.sum(np.power(src_est[:start], 2)) + np.sum(np.power(src_est[stop:], 2))
    pow_mix = np.sum(np.power(src_mix, 2))
    pow_est = np.sum(np.power(src_est, 2))
    erle = 10 * math.log(pow_mix/pow_est)

    return erle

# Batched Inference (Fast)

In [11]:
pesq_mean = 0.
ERLE_mean = 0.
print(name)

for idx, batch in enumerate(dataloader):
    near = batch["near"]
    far = batch["far"]
    mix = batch["mix"]
    #echo = batch['echo']
    #wav_len = near.size(-1) // hps.model_kwargs.hop_size * hps.model_kwargs.hop_size
    #near = near[..., :wav_len]
    #far = far[..., :wav_len]
    #mix = mix[..., :wav_len]
    b = near.size(0)
    print(b)
    with torch.no_grad():
        if hasattr(wrapper.model, "autoregressive_valid"):
            wav_out = wrapper.model.autoregressive_valid(far.view(b, -1).to(device), mix.view(b, -1).to(device))
        else:
            wav_out, _= wrapper.model(far.view(b, -1).to(device), mix.view(b, -1).to(device))
        wav_out = wav_out.squeeze()
        
    #wav_out = wav_out.clip(min=-1.0, max=1.0)
    for i in range(b):
        #sf.write('./dctcrn_sample_{}.wav'.format(i), wav_out[i].cpu().numpy(), samplerate=16000)
        #print('save')

        n = near[i]
        wo = wav_out[i]
        m = mix[i]
        n = n.cpu().numpy()
        m = m.cpu().numpy()
        wo = wo.cpu().numpy()
        l, r = vad(n)
        avg_erle = cal_ERLE(m[:160000], wo[:160000], l, r)
        pesq_mean += pesq(n[160000:], wo[160000:], 16000)
        ERLE_mean += avg_erle
        #print(f"\r{idx*b + i+1}/{len(dataset)} - {pesq_mean / (idx*b + i+1)}", end=" ", flush=True)
        print(f"\r{idx*b + i+1}/{len(dataset)} - {ERLE_mean / (idx*b + i+1)}", end=" ", flush=True)

pesq_mean /= (idx*b + i+1)
ERLE_mean /= (idx*b + i+1)
print("pesq_NEST = {}, erle_FEST = {}".format(pesq_mean, ERLE_mean))

dctcrn_64000
50
50/500 - 111.3869625094808  50
100/500 - 113.52666197321088 50
150/500 - 114.04197105133021 50
200/500 - 113.57572878223606 50
250/500 - 112.256661506488 7 50
300/500 - 112.12311683519523 50
350/500 - 111.73133904588842 50
400/500 - 111.7153748020256  50
450/500 - 111.81078295404313 50
500/500 - 111.43865097102537 pesq_NEST = 4.3976437516212465, erle_FEST = 111.43865097102537


# Noisy's PESQ

In [38]:
pesq_mean = 0.
print(name)

for idx, batch in enumerate(dataloader):
    near = batch["near"]
    mix = batch["mix"]
    for i in range(b):
        n = near[i]
        m = mix[i]
        pesq_mean += pesq(n.cpu().numpy(), m.cpu().numpy(), 16000)
        print(f"\r{idx*b + i+1}/{len(dataset)} - {pesq_mean / (idx*b + i+1)}", end=" ", flush=True)
pesq_mean /= (idx*b + i+1)
print("")

echofilter_like_full
81/500 - nan 68844826462902 



500/500 - nan 
