In [1]:
import sys

sys.path.append("../..")

In [2]:
import torch
from torch.utils.data import DataLoader

from rat_models import GTCRN_RAT, MPNet_RAT
from models.gtcrn import GTCRN
from models import MPNet
from experiments.min_gru import MinMPNet
from utils import count_parameters, VCTKDatasetFromList, load_config, mag_pha_stft, mag_pha_istft, compute_pesq, compute_sisnr
from train.gtcrn import validate_epoch, HybridLoss

In [3]:
mpnet_config = load_config("../../models/mpnet/config.json")

models_config = {
    "GTCRN": (GTCRN, {}, "checkpoints/gtcrn/best.pt"),
    "GTCRN_RAT (chunk=4)": (GTCRN_RAT, {"chunk_size": 4}, "checkpoints/gtcrn_rat/chunk_4/best.pt"),
    "GTCRN_RAT (chunk=8)": (GTCRN_RAT, {"chunk_size": 8}, "checkpoints/gtcrn_rat/chunk_8/best.pt"),
    "GTCRN_RAT (chunk=16)": (GTCRN_RAT, {"chunk_size": 16}, "checkpoints/gtcrn_rat/chunk_16/best.pt"),
    "GTCRN_RAT (chunk=32)": (GTCRN_RAT, {"chunk_size": 32}, "checkpoints/gtcrn_rat/chunk_32/best.pt"),
    "GTCRN_RAT (chunk=64)": (GTCRN_RAT, {"chunk_size": 64}, "checkpoints/gtcrn_rat/chunk_64/best.pt"),
    "MPNet": (MPNet, {"h": mpnet_config}, "checkpoints/mpnet/best.pt"),
    "MinMPNet": (MinMPNet, {"h": mpnet_config}, "checkpoints/min_mpnet/best_pesq.pt"),
    "MPNet_RAT (chunk=4)": (MPNet_RAT, {"h": mpnet_config, "chunk_size": 4}, "checkpoints/mpnet_rat/chunk_4/best.pt"),
    "MPNet_RAT (chunk=8)": (MPNet_RAT, {"h": mpnet_config, "chunk_size": 8}, "checkpoints/mpnet_rat/chunk_8/best.pt"),
    "MPNet_RAT (chunk=16)": (MPNet_RAT, {"h": mpnet_config, "chunk_size": 16}, "checkpoints/mpnet_rat/chunk_16/best.pt"),
    "MPNet_RAT (chunk=32)": (MPNet_RAT, {"h": mpnet_config, "chunk_size": 32}, "checkpoints/mpnet_rat/chunk_32/best.pt"),
    "MPNet_RAT (chunk=64)": (MPNet_RAT, {"h": mpnet_config, "chunk_size": 64}, "checkpoints/mpnet_rat/chunk_64/best.pt"),
}

import os

models = {}
for name, (model_cls, kwargs, ckpt_path) in models_config.items():
    model = model_cls(**kwargs)
    if os.path.isfile(ckpt_path):
        checkpoint = torch.load(ckpt_path, map_location="cpu", weights_only=True)
        model.load_state_dict(checkpoint["model_state_dict"])
    model.eval()
    models[name] = model
    print(f"Loaded {name}, params: {count_parameters(model):,}")

Loaded GTCRN, params: 23,669
Loaded GTCRN_RAT (chunk=4), params: 24,757
Loaded GTCRN_RAT (chunk=8), params: 24,757
Loaded GTCRN_RAT (chunk=16), params: 24,757
Loaded GTCRN_RAT (chunk=32), params: 24,757
Loaded GTCRN_RAT (chunk=64), params: 24,757
Loaded MPNet, params: 2,263,372
Loaded MinMPNet, params: 1,333,580
Loaded MPNet_RAT (chunk=4), params: 1,062,732
Loaded MPNet_RAT (chunk=8), params: 1,062,732
Loaded MPNet_RAT (chunk=16), params: 1,062,732
Loaded MPNet_RAT (chunk=32), params: 1,062,732
Loaded MPNet_RAT (chunk=64), params: 1,062,732


In [4]:
base_dir = "../../VoiceBank+DEMAND/"
test_dataset = VCTKDatasetFromList(
    file_list=base_dir + "test.txt",
    clean_dir=base_dir + "wavs_clean",
    noisy_dir=base_dir + "wavs_noisy",
    segment_len=None,
)
test_loader = DataLoader(test_dataset, batch_size=1, shuffle=False)

test_dataset_audio = VCTKDatasetFromList(
    file_list=base_dir + "test.txt",
    clean_dir=base_dir + "wavs_clean",
    noisy_dir=base_dir + "wavs_noisy",
    segment_len=None,
    return_audio=True,
)
test_loader_audio = DataLoader(test_dataset_audio, batch_size=1, shuffle=False)

Loaded 824 files from ../../VoiceBank+DEMAND/test.txt
Loaded 824 files from ../../VoiceBank+DEMAND/test.txt


In [None]:
for model_name, model in models.items():
    if "mpnet" in model_name.lower():
        h = mpnet_config
        model = model.to("cpu")
        pesq_scores, sisnr_scores = [], []
        with torch.no_grad():
            for noisy_audio, clean_audio in test_loader_audio:
                noisy_audio = noisy_audio.unsqueeze(0)
                clean_audio = clean_audio.unsqueeze(0)
                noisy_mag, noisy_pha, _ = mag_pha_stft(
                    noisy_audio, h.n_fft, h.hop_size, h.win_size, getattr(h, "compress_factor", 1.0)
                )
                mag_g, pha_g, _ = model(noisy_mag, noisy_pha)
                audio_g = mag_pha_istft(
                    mag_g, pha_g, h.n_fft, h.hop_size, h.win_size, getattr(h, "compress_factor", 1.0)
                )
                pesq_scores.append(compute_pesq(clean_audio[0].cpu(), audio_g[0].cpu()))
                sisnr_scores.append(compute_sisnr(clean_audio, audio_g).item())
        pesq_mean = sum(pesq_scores) / len(pesq_scores) if pesq_scores else 0.0
        sisnr_mean = sum(sisnr_scores) / len(sisnr_scores) if sisnr_scores else 0.0
        print(f"{model_name} PESQ: {pesq_mean:.3f}, SI-SNR: {sisnr_mean:.2f} dB")
    else:
        val_loss, metrics = validate_epoch(model, test_loader, HybridLoss(), "cpu", verbose=False)
        print(f"{model_name} PESQ: {metrics['pesq']:.3f}, SI-SNR: {metrics['sisnr']:.2f} dB")