In [1]:
%matplotlib inline
import matplotlib.pyplot as plt
from matplotlib.colors import LogNorm
import os, json, random, sys
sys.path.insert(0, "../")

import torch
from torch.utils.data import DataLoader
from soundfile import write, read
import numpy as np
import IPython.display as ipd

from utils import get_hparams
from functional import mel_spectrogram, stft, spec_to_mel
from models import get_wrapper
from utils.data_audio import AECDataset
from pypesq import pesq
from tqdm import tqdm
from librosa import resample
from models.dctcrn.default.losses import si_snr

os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID"   
os.environ["CUDA_VISIBLE_DEVICES"]="0"
device = 'cpu'

In [2]:
name = "dctcrn/default_tanh_aux_coslr"
epoch = None

base_dir = os.path.join("../logs", name)
try:
    hps = get_hparams(os.path.join(base_dir, "config.json"), base_dir)
except FileNotFoundError:
    hps = get_hparams(os.path.join(base_dir, "config.yaml"), base_dir)
#hps.model_kwargs.viterbi_legacy = False

wrapper = get_wrapper(hps.model)(hps, device=device)
wrapper.load(epoch=epoch)
wrapper.eval()

Loading checkpoint file '../logs/dctcrn/default_tanh_aux_coslr/00100.pth'...


# Dataset

In [3]:
hps.data.segment_size = 160000
dataset = AECDataset(hps.data, mode="valid")
dataset.files = dataset.files[:100]
print(len(dataset))
dataloader = DataLoader(dataset, batch_size=100)

100


# Batched Inference (Fast)

In [5]:
pesq_mean = 0.
print(name)

for idx, batch in enumerate(dataloader):
    near = batch["near"]
    far = batch["far"]
    mix = batch["mix"]
    wav_len = near.size(-1) // hps.model_kwargs.hop_size * hps.model_kwargs.hop_size
    near = near[..., :wav_len]
    far = far[..., :wav_len]
    mix = mix[..., :wav_len]
    b = near.size(0)
    with torch.no_grad():
        if hasattr(wrapper.model, "autoregressive_valid"):
            wav_out, _ = wrapper.model.autoregressive_valid(far.view(b, -1).to(device), mix.view(b, -1).to(device))
        else:
            wav_out, _ = wrapper.model(far.view(b, -1).to(device), mix.view(b, -1).to(device))
        wav_out = wav_out.squeeze()
    #wav_out = wav_out.clip(min=-1.0, max=1.0)
    for i in range(b):
        n = near[i]
        wo = wav_out[i]
        pesq_mean += pesq(n.cpu().numpy(), wo.cpu().numpy(), 16000)
        print(f"\r{idx*b + i+1}/{len(dataset)} - {pesq_mean / (idx*b + i+1)}", end=" ", flush=True)
pesq_mean /= (idx*b + i+1)
print("")

dctcrn/default_tanh_aux_coslr
100/100 - 2.662963027358055 


# Noisy's PESQ

In [8]:
pesq_mean = 0.
print(name)

for idx, batch in enumerate(dataloader):
    near = batch["near"]
    mix = batch["mix"]
    for i in range(b):
        n = near[i]
        m = mix[i]
        pesq_mean += pesq(n.cpu().numpy(), m.cpu().numpy(), 16000)
        print(f"\r{idx*b + i+1}/{len(dataset)} - {pesq_mean / (idx*b + i+1)}", end=" ", flush=True)
pesq_mean /= (idx*b + i+1)
print("")

dctcrn/default_tanh_aux_coslr
100/100 - 1.8771302208304406 
