In [1]:
import pyroomacoustics as pra

import os
import time
from tqdm import tqdm

import numpy as np
import pandas as pd
import torch
from torch.utils.data import Dataset, DataLoader
from torch.autograd import profiler
import torchaudio
from torchmetrics.audio import SpeechReverberationModulationEnergyRatio, ShortTimeObjectiveIntelligibility


from einops import rearrange

from src.dataset import SignalDataset, TRUNetDataset
from src.loss import loss_tot, loss_MR, loss_MR_w
from models.fspen import FullSubPathExtension 

from IPython.display import Audio

import matplotlib.pyplot as plt

In [2]:
# DATA_DIR = os.path.join("..", "data", "musan", "speech")
# TRAIN_DIR = os.path.join(DATA_DIR, "train_chunks")
# TEST_DIR = os.path.join(DATA_DIR, "test_chunks")

DATA_DIR = os.path.join("data", "wav48")
TRAIN_DIR = os.path.join(DATA_DIR, "clean_trainset_56spk_wav")

TEST_DIR = os.path.join(DATA_DIR, "clean_testset_wav")
# TRAIN_DIR = os.path.join(DATA_DIR, "data_thchs30", "train_1")
# TEST_DIR = os.path.join(DATA_DIR, "data_thchs30", "train_1")
# RIR_DIR = os.path.join("..", "data", "rirs_noises", "RIRS_NOISES", "real_rirs")
# RIR_DIR = os.path.join("..", "data", "RIRs", )
RIR_DIR = os.path.join("data", "rirs48", )
# NOISE_DIR = os.path.join("..", "data", "rirs_noises", "RIRS_NOISES", "real_rirs_isotropic_noises")
NOISE_DIR = os.path.join("data", "TAU-urban-acoustic")
CHKP_DIR = "checkpoints"

np.set_printoptions(precision=3)
torch.set_printoptions(precision=3)

In [3]:
N_FFTS = 512
HOP_LENGTH = 256
SR = 48_000

DEVICE = "cuda" # torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"It's {DEVICE} time!!!")

It's cuda time!!!


In [4]:
import yaml

from NISQA_s.src.core.model_torch import model_init
from NISQA_s.src.utils.process_utils import process

NISQA_PATH = "NISQA_s/config/nisqa_s.yaml"

with open(NISQA_PATH, 'r') as stream:
    nisqa_args = yaml.safe_load(stream)
nisqa_args["ms_n_fft"] = N_FFTS
nisqa_args["hop_length"] = HOP_LENGTH
nisqa_args["ms_win_length"] = N_FFTS
nisqa_args["ckp"] = nisqa_args["ckp"][3:]

nisqa, h0_nisqa, c0_nisqa = model_init(nisqa_args)



In [5]:
dataset = TRUNetDataset(DATA_DIR, sr=48_000, noise_dir=NOISE_DIR, rir_dir=RIR_DIR, snr=[5, 10], rir_proba=0.7, noise_proba=0.7, return_noise=False, return_rir=False, max_seq_len=48_000 * 5)

train_dataset, test_dataset = torch.utils.data.random_split(dataset, [0.95, 0.05])

In [6]:
from src.fspen_configs import TrainConfig, TrainConfigLarge, TrainConfigLarge1, TrainConfigRNN1, TrainConfigRNN2

configs = TrainConfigLarge()
print(sum(configs.bands_num_in_groups), configs.dual_path_extension["num_modules"])
fspen = FullSubPathExtension(configs=configs).to(DEVICE)

state_d = torch.load(os.path.join(CHKP_DIR, "fspen_chkp", "fspen_voicebank_subband_loss#2.pt"), weights_only=False)

32 2


In [7]:
fspen.load_state_dict(state_d["model_state_dict"])

<All keys matched successfully>

In [8]:
def vorbis_window(winlen, device="cuda"):
    sq = torch.sin(torch.pi/2*(torch.sin(torch.pi/winlen*(torch.arange(winlen)-0.5))**2)).float()
    return sq

In [9]:
PCS = torch.ones(257, device=DEVICE)      # Perceptual Contrast Stretching
PCS[0:3] = 1
PCS[3:6] = 1.070175439
PCS[6:9] = 1.182456140
PCS[9:12] = 1.287719298
PCS[12:138] = 1.4       # Pre Set
PCS[138:166] = 1.322807018
PCS[166:200] = 1.238596491
PCS[200:241] = 1.161403509
PCS[241:256] = 1.077192982

In [10]:
from scipy.io.wavfile import write

def eval(model, input_sig, gt=None, hidden_state=None):
    model.eval()
    start_time = time.time()
    window = vorbis_window(512).to(DEVICE)
    spec = torch.stft(
        input_sig,
        n_fft=N_FFTS,
        hop_length=HOP_LENGTH,
        # onesided=True,
        win_length=512,
        window=window,
        return_complex=True,
        normalized=True,
        center=False
    ) 

    # input_spec_ = spec
    spec_pcs = PCS[:, None, None] * torch.transpose(torch.log1p(torch.abs(spec)), 1, 0)
    spec_pcs = torch.transpose(spec_pcs, 1, 0)
    spec_pcs = torch.polar(spec_pcs, spec.angle())
    input_spec = torch.permute(torch.view_as_real(spec_pcs), dims=(0, 2, 3, 1))
    abs_spectrum = spec_pcs.abs()
    batch, frames, channels, frequency = input_spec.shape

    abs_spectrum = torch.permute(abs_spectrum, dims=(0, 2, 1))
    abs_spectrum = torch.reshape(abs_spectrum, shape=(batch, frames, 1, frequency))
    # h0 = [[torch.randn(1, batch * 32, 16 // 8,  ) for _ in range(8)]
    #                for _ in range(3)]

    if hidden_state is None:
        hidden_state = [[torch.zeros(1, batch * 32, 16 // 8, device=input_spec.device) for _ in range(8)] for _ in range(2)]

    output, new_hidden_state = model(input_spec, abs_spectrum, hidden_state)

    output = torch.permute(output, dims=(0, 3, 1, 2))
    output = torch.view_as_complex(output)

    window = vorbis_window(N_FFTS).to(DEVICE)
    # print(output_d.shape, input_spec.shape)
    out_wave = torch.istft(output, n_fft=N_FFTS, hop_length=HOP_LENGTH, win_length=N_FFTS,
                        window=window,
                        # onesided=True,
                        return_complex=False,
                        normalized=True,
                        center=False)
    end_time = time.time()
    # print(out_wave.shape, input_sig.shape)
    # print(input_spec.shape, output.shape)
    # out_wave = out_wave / torch.max(torch.abs(out_wave))
    nisqa_score, _, _ = process(out_wave.detach().cpu(), 48_000, nisqa, h0_nisqa, c0_nisqa, nisqa_args)
    # print(out_wave.shape)
    # out_wave_1 = out_wave.reshape(-1)
    # gt_1 = gt.reshape(-1)
    # input_sig_1 = input_sig.reshape(-1)
    # write('input_stream.wav', SR, input_sig_1.cpu().detach().numpy())
    # write('output_full_stream.wav', SR, out_wave_1.cpu().detach().numpy())
    # write('gt_stream.wav', SR, gt_1.cpu().detach().numpy())
    return nisqa_score, (input_sig.shape[-1] / SR) / (end_time - start_time), new_hidden_state


def check_stream_inference(model, dataset, window_size = 1 * SR):
    model.eval()
    result_nisqa_full = []
    result_rtf_full = []
    result_nisqa_chunk = []
    result_rtf_chunk = [] 
    with torch.no_grad():
        for i in tqdm(range(len(dataset))):
            input_sig, gt, _, _ = dataset[i]
            input_sig = input_sig.to(DEVICE)
            nisqa_score_full, rtf_full, _ = eval(model, input_sig, gt)
            # print(nisqa_score_full)
            result_nisqa_full.append(nisqa_score_full)
            result_rtf_full.append(rtf_full)
            # print(nisqa_score_full)
            hidden_state = [[torch.zeros(1, 1 * 32, 16 // 8, device=input_sig.device) for _ in range(8)] for _ in range(2)]
            nisqa_score_chunk_mean = []
            rtf_chunk_mean = []
            for j in range(0, input_sig.shape[-1], window_size):
                chunk = input_sig[..., j:j+window_size]
                
                nisqa_score_chunk, rtf_chunk, hidden_state = eval(model, chunk, hidden_state=hidden_state)
                nisqa_score_chunk_mean.append(nisqa_score_chunk)
                rtf_chunk_mean.append(rtf_chunk)

            # print(torch.stack(nisqa_score_chunk_mean).mean(dim=0).shape)
            result_nisqa_chunk.append(torch.stack(nisqa_score_chunk_mean).mean(dim=0))
            result_rtf_chunk.append(torch.tensor(rtf_chunk_mean).mean())

    print(f"Mean nisqa for full audio: ", torch.stack(result_nisqa_full).mean(dim=0))
    print(f"Mean rtf for full audio: ", torch.tensor(result_rtf_full).mean(dim=0))
    print("---" * 10)
    print("Mean nisqa for \"stream\" audio: ", torch.stack(result_nisqa_chunk).mean(dim=0))
    print("Mean rtf for \"stream\" audio: ", torch.tensor(result_rtf_chunk).mean(dim=-1))

    return result_nisqa_full, result_rtf_full, result_nisqa_chunk, result_rtf_chunk

In [14]:
check_stream_inference(fspen, dataset)

  0%|          | 16/44242 [00:03<2:43:48,  4.50it/s]


KeyboardInterrupt: 

In [12]:
SR // 1000 * 20

960

In [13]:
# _ = check_stream_inference(fspen, dataset, window_size=SR // 1000 * 20)