In [1]:
import pyroomacoustics as pra

import os
from tqdm import tqdm

import numpy as np
import pandas as pd
import torch
from torch.utils.data import Dataset, DataLoader
from torch.autograd import profiler
import torchaudio
from torchmetrics.audio import SpeechReverberationModulationEnergyRatio, ShortTimeObjectiveIntelligibility


from einops import rearrange

from src.dataset import SignalDataset, TRUNetDataset
from src.loss import loss_tot, loss_MR, loss_MR_w
from models.fspen import FullSubPathExtension 

from IPython.display import Audio

from src.utils import model_eval, model_eval_fspen2x_ver3

import matplotlib.pyplot as plt

In [2]:
# DATA_DIR = os.path.join("data", "wav48")
# TRAIN_DIR = os.path.join(DATA_DIR, "clean_trainset_56spk_wav")

TEST_DIR = os.path.join("data", "DS_10283_2791", "clean_testset_wav")
TEST_NOISE_DIR = os.path.join("data", "DS_10283_2791", "noisy_testset_wav")
# RIR_DIR = os.path.join("data", "rirs48_medium_2", )
NOISE_DIR = os.path.join("data", "demand")

CHKP_DIR = "checkpoints"

np.set_printoptions(precision=3)
torch.set_printoptions(precision=3)

In [3]:
SEED = 1984

np.random.seed(SEED)
torch.manual_seed(SEED)

gen = torch.Generator()
gen.manual_seed(SEED)

<torch._C.Generator at 0x78c0218bd850>

In [4]:
N_FFTS = 1024
HOP_LENGTH = 512
SR = 48_000

DEVICE = "cpu" # torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"It's {DEVICE} time!!!")

It's cpu time!!!


In [5]:
rir_dict = {1: os.path.join("data", "rirs48_soft_2"), 2: os.path.join("data", "rirs48_medium_2"), 3: os.path.join("data", "rirs48_hard_2")}
dataset = TRUNetDataset(TEST_DIR, sr=48_000, noise_dir=NOISE_DIR, rir_dir=rir_dict, snr=[0, 5, 10, 15], rir_proba=0.9, noise_proba=1.0, rir_target=False, return_noise=False, return_rir=False)
dataset.set_epoch(99)

In [6]:
from src.fspen_configs import TrainConfig48kHzEnc2x_ver2

configs = TrainConfig48kHzEnc2x_ver2()
print(sum(configs.bands_num_in_groups), configs.dual_path_extension["num_modules"])
fspen = FullSubPathExtension(configs=configs)# .to(DEVICE)

state_d = torch.load(os.path.join(CHKP_DIR, "fspen_chkp", "TrainConfig48kHzEnc2x_ver2_hard#1.pt"), map_location="cpu",  weights_only=False)

64 3


In [7]:
fspen.load_state_dict(state_d["model_state_dict"])

<All keys matched successfully>

In [8]:
def vorbis_window(winlen, device="cuda"):
    sq = torch.sin(torch.pi/2*(torch.sin(torch.pi/winlen*(torch.arange(winlen)-0.5))**2)).float()
    return sq

In [9]:
import yaml

from NISQA_s.src.core.model_torch import model_init
from NISQA_s.src.utils.process_utils import process

NISQA_PATH = "NISQA_s/config/nisqa_s.yaml"

with open(NISQA_PATH, 'r') as stream:
    nisqa_args = yaml.safe_load(stream)
nisqa_args["ms_n_fft"] = 512
nisqa_args["hop_length"] = 256
nisqa_args["ms_win_length"] = 512
nisqa_args["ckp"] = nisqa_args["ckp"][3:]

nisqa, h0_nisqa, c0_nisqa = model_init(nisqa_args)



In [10]:
def pad_sequence(batch):
    if not batch:
        return torch.zeros(0), torch.zeros(0)

    input_signal, target_signal, noise, rir = zip(*batch)
        
    max_len_s = max(s.shape[-1] for s in input_signal)
    
    padded_input = torch.zeros(len(input_signal), max_len_s)
    padded_target = torch.zeros(len(target_signal), max_len_s)
    
    for i, s in enumerate(input_signal):
        padded_input[i, :s.shape[-1]] = s
        padded_target[i, :s.shape[-1]] = target_signal[i]

    return padded_input, padded_target


def collate_fn(batch):
    
    padded_input, padded_target = pad_sequence(batch)
        
    padded_input = padded_input.reshape(-1, padded_input.shape[-1])
    padded_target = padded_target.reshape(-1, padded_input.shape[-1])

    return padded_input, padded_target

In [11]:
test_dataloader = DataLoader(dataset, batch_size=1, shuffle=False, drop_last=False, collate_fn=collate_fn)

In [12]:
import time

def check_stream_inference(model, loader, window_size = 1 * SR, device="cpu"):
    model.eval()

    result_nisqa_full = []
    result_rtf_full = []
    result_nisqa_chunk = []
    result_rtf_chunk = []
    with torch.no_grad():
        for signal, target in tqdm(loader):
            signal = signal.to(device)
            target = target.to(device)
            window = vorbis_window(N_FFTS).to(device)
    
            start_time = time.time()
            spec = torch.stft(
                signal,
                n_fft=N_FFTS,
                hop_length=HOP_LENGTH,
                # onesided=True,
                win_length=N_FFTS,
                window=window,
                return_complex=True,
                normalized=True,
                center=True
            ) 
            
            output, _ = model_eval(model, spec, device)

            window = vorbis_window(N_FFTS).to(device)
            output = torch.istft(output, n_fft=N_FFTS, hop_length=HOP_LENGTH, win_length=N_FFTS,
                                   window=window,
                                   # onesided=True,
                                   return_complex=False,
                                   normalized=True,
                                   center=True)
            
            end_time = time.time()
            
            result_rtf_full.append((signal.shape[-1] / SR) / (end_time - start_time))
            nisqa_score, _, _ = process(output.detach().cpu(), SR, nisqa, h0_nisqa, c0_nisqa, nisqa_args)
            result_nisqa_full.append(nisqa_score)

            for j in range(0, signal.shape[-1], window_size):
                chunk = signal[..., j:j+window_size]
                
                if chunk.shape[-1] < 10_000:
                    continue

                start_time = time.time()
                spec = torch.stft(
                    chunk,
                    n_fft=N_FFTS,
                    hop_length=HOP_LENGTH,
                    # onesided=True,
                    win_length=N_FFTS,
                    window=window,
                    return_complex=True,
                    normalized=True,
                    center=True
                )

                output, _ = model_eval(model, spec, device)

                window = vorbis_window(N_FFTS).to(device)
                output = torch.istft(output, n_fft=N_FFTS, hop_length=HOP_LENGTH, win_length=N_FFTS,
                                    window=window,
                                    # onesided=True,
                                    return_complex=False,
                                    normalized=True,
                                    center=True)
                
                end_time = time.time()


                result_rtf_chunk.append((chunk.shape[-1] / SR) / (end_time - start_time))
                # print(output.shape)
                nisqa_score, _, _ = process(output.detach().cpu(), SR, nisqa, h0_nisqa, c0_nisqa, nisqa_args)

                result_nisqa_chunk.append(nisqa_score)
                

    print(f"Mean nisqa for full audio: ", torch.stack(result_nisqa_full).mean(dim=0))
    print(f"Mean rtf for full audio: ", torch.tensor(result_rtf_full).mean(dim=0), 1 / torch.tensor(result_rtf_full).mean(dim=0))
    print("---" * 10)
    print("Mean nisqa for \"stream\" audio: ", torch.stack(result_nisqa_chunk).mean(dim=0))
    print("Mean rtf for \"stream\" audio: ", torch.tensor(result_rtf_chunk).mean(dim=-1), 1 / torch.tensor(result_rtf_chunk).mean(dim=-1))

    return result_nisqa_full, result_rtf_full, result_nisqa_chunk, result_rtf_chunk

In [13]:
_ = check_stream_inference(fspen, test_dataloader, device="cpu")

100%|██████████| 824/824 [10:24<00:00,  1.32it/s]

Mean nisqa for full audio:  tensor([[3.152, 3.883, 3.319, 3.291, 3.419]])
Mean rtf for full audio:  tensor(10.644) tensor(0.094)
------------------------------
Mean nisqa for "stream" audio:  tensor([[2.667, 3.705, 3.253, 2.981, 2.879]])
Mean rtf for "stream" audio:  tensor(7.771) tensor(0.129)





In [14]:
from torchmetrics.audio.pesq import PerceptualEvaluationSpeechQuality
from torch_stoi import NegSTOILoss

srmr = SpeechReverberationModulationEnergyRatio(fs=48_000, norm=True)
pesq = PerceptualEvaluationSpeechQuality(fs=16_000, mode="wb").to("cuda")
stoi = NegSTOILoss(SR, use_vad=False, do_resample=False).to("cuda")

In [15]:
from torchaudio.transforms import Resample


def get_metrics(model, loader, device="cpu"):
    model.eval()
    
    model = model.to(device)
    
    nisqa_scores = []
    pesq_scores = []
    stoi_scores = []
    srmr_scores = []
    with torch.no_grad():
        for signal, target in tqdm(loader):
            signal = signal.to(device)
            target = target.to(device)
            window = vorbis_window(N_FFTS).to(device)
    
            spec = torch.stft(
                signal,
                n_fft=N_FFTS,
                hop_length=HOP_LENGTH,
                # onesided=True,
                win_length=N_FFTS,
                window=window,
                return_complex=True,
                normalized=True,
                center=True
            ) 
            
            output, _ = model_eval(model, spec, device)

            window = vorbis_window(N_FFTS).to(device)
            output = torch.istft(output, n_fft=N_FFTS, hop_length=HOP_LENGTH, win_length=N_FFTS,
                                   window=window,
                                   # onesided=True,
                                   return_complex=False,
                                   normalized=True,
                                   center=True)
            
            min_l = min(output.shape[-1], signal.shape[-1])
            nisqa_score, _, _ = process(output.detach().cpu(), SR, nisqa, h0_nisqa, c0_nisqa, nisqa_args)
            stoi_score = stoi(output[..., :min_l], target[..., :min_l])
            srmr_score = srmr(output.detach().cpu())
            
            resampler = Resample(SR, 16_000)
            output = resampler(output.cpu()).cuda()
            target = resampler(target.cpu()).cuda()
            min_l = min(output.shape[-1], target.shape[-1])

            try:
                pesq_score = pesq(output[..., :min_l], target[..., :min_l])
            except Exception as e:
                # print(min_l)
                # out_wave_ = output.reshape(-1)
                # target_ = target.reshape(-1)
                # write('exception_out.wav', SR, out_wave_.cpu().detach().numpy())
                # write('exception_in.wav', SR, target_.cpu().detach().numpy())
                continue

            nisqa_scores.append(nisqa_score[0])
            srmr_scores.append(srmr_score)
            stoi_scores.append(stoi_score.cpu())
            pesq_scores.append(pesq_score.cpu())

    result = {"nisqa": nisqa_scores, "stoi": stoi_scores, "srmr": srmr_scores, "pesq": pesq_scores}
        
    return result

In [16]:
metrics = get_metrics(fspen, test_dataloader, device="cuda")

100%|██████████| 824/824 [11:00<00:00,  1.25it/s]


In [17]:
print("NISQA:", torch.vstack(metrics["nisqa"]).mean(dim=0))
print("PESQ:", torch.vstack(metrics["pesq"]).mean(dim=0))
print("SRMR:", torch.vstack(metrics["srmr"]).mean(dim=0))
print("STOI:", -torch.vstack(metrics["stoi"]).mean(dim=0))

NISQA: tensor([3.117, 3.867, 3.311, 3.282, 3.380])
PESQ: tensor([1.772])
SRMR: tensor([1.879])
STOI: tensor([0.837])


In [18]:
input_sig, gt, gt_noise, gt_rir = dataset[0]

In [19]:
from thop import profile

window = vorbis_window(N_FFTS)

input_spec = torch.stft(
            input_sig[..., :SR],
            n_fft=N_FFTS,
            hop_length=HOP_LENGTH,
            # onesided=True,
            win_length=N_FFTS,
            window=window,
            return_complex=True,
            normalized=True,
            center=True
        )

input_spec = input_spec.to("cpu")

abs_spectrum = input_spec.abs()
input_spec_ = torch.permute(torch.view_as_real(input_spec), dims=(0, 2, 3, 1))
batch, frames, channels, frequency = input_spec_.shape
abs_spectrum = torch.permute(abs_spectrum, dims=(0, 2, 1))
abs_spectrum = torch.reshape(abs_spectrum, shape=(batch, frames, 1, frequency))
h0 = [[torch.zeros(1, batch * 64, 16 // 8, device=input_spec.device) for _ in range(8)] for _ in range(3)]

# output, hid_out = fspen(input_spec_, abs_spectrum, h0)

flops, params = profile(fspen.cpu(), inputs=(input_spec_, abs_spectrum, h0))

[INFO] Register count_convNd() for <class 'torch.nn.modules.conv.Conv1d'>.
[INFO] Register count_normalization() for <class 'torch.nn.modules.batchnorm.BatchNorm1d'>.
[INFO] Register zero_ops() for <class 'torch.nn.modules.activation.ReLU'>.
[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
[INFO] Register zero_ops() for <class 'torch.nn.modules.container.Sequential'>.
[INFO] Register count_gru() for <class 'torch.nn.modules.rnn.GRU'>.
[INFO] Register count_normalization() for <class 'torch.nn.modules.normalization.LayerNorm'>.
[INFO] Register count_convNd() for <class 'torch.nn.modules.conv.ConvTranspose1d'>.


In [20]:
print("Flops: ", flops)
print("Params: ", params)

Flops:  252744192.0
Params:  49600.0


In [21]:
fspen = FullSubPathExtension(configs=configs)
fspen.load_state_dict(state_d["model_state_dict"])

<All keys matched successfully>

In [22]:
input_sig, gt, gt_noise, gt_rir = dataset[8]

In [23]:
window = vorbis_window(N_FFTS)

spec = torch.stft(
            input_sig,
            n_fft=N_FFTS,
            hop_length=HOP_LENGTH,
            # onesided=True,
            win_length=N_FFTS,
            window=window,
            return_complex=True,
            normalized=True,
            center=True
        )

output, _ = model_eval(fspen, spec, DEVICE)

window = vorbis_window(N_FFTS)
# print(output_d.shape, input_spec.shape)
out_wave = torch.istft(output, n_fft=N_FFTS, hop_length=HOP_LENGTH, win_length=N_FFTS,
                       window=window,
                       # onesided=True,
                       return_complex=False,
                       normalized=True,
                       center=True)

out_wave = out_wave.reshape(-1)

In [24]:
from scipy.io.wavfile import write

write('input_sig_full.wav', SR, input_sig.cpu().detach().numpy()[0])
write('gt_full.wav', SR, gt.cpu().detach().numpy()[0])
write('output_full.wav', SR, out_wave.cpu().detach().numpy())