In [1]:
import sys

sys.path.append("../..")

In [2]:
import torch
from torch.utils.data import DataLoader

from rat_models import GTCRN_RAT
from models.gtcrn import GTCRN
from utils import count_parameters
from train import VCTKDatasetFromList, validate_epoch, HybridLoss

In [3]:
models = {
    "gtcrn": (GTCRN, "checkpoints/gtcrn/epoch_80.pt"),
    "gtcrn_best": (GTCRN, "checkpoints/gtcrn/best.pt"),
    "gtcrn_rat_chunk_8": (GTCRN_RAT, "checkpoints/gtcrn_rat/chunk_8/epoch_40.pt"),
    "gtcrn_rat_chunk_8_best": (GTCRN_RAT, "checkpoints/gtcrn_rat/chunk_16/best.pt"),
}
for model_name, (model_cls, checkpoint_path) in models.items():
    model = model_cls()
    checkpoint = torch.load(checkpoint_path, map_location="cpu")
    model.load_state_dict(checkpoint["model_state_dict"])
    model.eval()
    models[model_name] = model
    print(f"Loaded model {model_name}, params: {count_parameters(model):,}")

Loaded model gtcrn, params: 23,669
Loaded model gtcrn_best, params: 23,669
Loaded model gtcrn_rat_chunk_8, params: 24,757
Loaded model gtcrn_rat_chunk_8_best, params: 24,757


In [4]:
base_dir = "../../VoiceBank+DEMAND/"
test_dataset = VCTKDatasetFromList(
    file_list=base_dir + "test.txt",
    clean_dir= base_dir + "wavs_clean",
    noisy_dir= base_dir + "wavs_noisy",
    segment_len=None,
)
test_loader = DataLoader(test_dataset, batch_size=1, shuffle=False)

Loaded 824 files from ../../VoiceBank+DEMAND/test.txt


In [5]:
for model_name, model in models.items():
    val_loss, metrics = validate_epoch(model, test_loader, HybridLoss(), "cpu")
    print(f"{model_name} PESQ: {metrics['pesq']:.3f}")
    print(f"{model_name} SI-SNR: {metrics['sisnr']:.2f} dB")


Validation:   0%|          | 0/824 [00:00<?, ?it/s]

gtcrn PESQ: 2.793
gtcrn SI-SNR: 18.64 dB


Validation:   0%|          | 0/824 [00:00<?, ?it/s]

gtcrn_best PESQ: 2.804
gtcrn_best SI-SNR: 18.64 dB


Validation:   0%|          | 0/824 [00:00<?, ?it/s]

gtcrn_rat_chunk_8 PESQ: 2.858
gtcrn_rat_chunk_8 SI-SNR: 18.61 dB


Validation:   0%|          | 0/824 [00:00<?, ?it/s]

gtcrn_rat_chunk_8_best PESQ: 2.726
gtcrn_rat_chunk_8_best SI-SNR: 18.64 dB
