# Evaluate PESQ Score and Average SNR on Test Dataset

In [5]:
import sys
import os
sys.path.append("../")

if "notebook" in os.getcwd():
    os.chdir("../")

In [6]:
import torch
from pesq import pesq
from scipy.signal import resample
from torchmetrics.functional import signal_noise_ratio
from torch.utils.data import DataLoader, random_split
import numpy as np
from tqdm import tqdm


from src.models.waveform.cicada_clean_unet_att import CicadaCleanUNetModel
from src.models.waveform.cicada_unet_att import CicadaUNetAttModel
from src.models.waveform.cicada_unet import CicadaUNetModel
from src.data.waveform_data import WaveformDataset

torch.manual_seed(42)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

MODEL_CKPT = "ckpts/cicadence_unet_final.pt"
model = CicadaUNetModel()
model.load_state_dict(torch.load(MODEL_CKPT, map_location=device))
model.eval()

NOISY_WAVE_PATH = "data/processed/28spk/combined_noisy_waves.pt"
CLEAN_WAVE_PATH = "data/processed/28spk/combined_clean_waves.pt"
SR = 48000

# Ensure correct device
model.to(device)
batch_size = 32

total_pesq = 0.0
total_snr = 0.0
num_samples = 0


ENCODERS: 5
S and K [1, 14, 27, 45, 84, 164], [7, 7, 7, 7, 7]


In [7]:
data = WaveformDataset(NOISY_WAVE_PATH, CLEAN_WAVE_PATH)

train_size = int(0.8 * len(data))
val_size = int(0.15 * len(data))
test_size = len(data) - train_size - val_size  # Ensure all samples are used

train_set, val_set, test_set = random_split(data, [train_size, val_size, test_size])
print(f"Train: {len(train_set)}, Val: {len(val_set)}, Test: {len(test_set)}")


train_loader = DataLoader(train_set, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(val_set, batch_size=batch_size, shuffle=False)
test_loader = DataLoader(test_set, batch_size=batch_size, shuffle=False)

Train: 4952, Val: 928, Test: 310


In [8]:
with torch.no_grad():
    print("Evaluating SNR and PESQ: ")
    for noisy, clean in tqdm(test_loader):
        noisy = noisy.to(device)
        clean = clean.to(device)

        output = model(noisy)
        output_np = output.cpu().numpy()
        clean_np = clean.cpu().numpy()

        for est, ref in zip(output_np, clean_np):
            est = est.squeeze()
            ref = ref.squeeze()

            min_len = min(len(est), len(ref))
            est = est[:min_len]
            ref = ref[:min_len]
            
            est_pesq = resample(est, int(len(est) * 16000 / SR))
            ref_pesq = resample(ref, int(len(ref) * 16000 / SR))

            pesq_score = pesq(16000, ref_pesq, est_pesq, 'wb') #Wide band PESQ needs 16KHz data
            snr_score = signal_noise_ratio(torch.tensor(est), torch.tensor(ref)).item()

            total_pesq += pesq_score
            total_snr += snr_score
            num_samples += 1

# Calculate averages
avg_pesq = total_pesq / num_samples
avg_snr = total_snr / num_samples

print(f"Model Type: {model.__class__.__name__}")  
print(f"Average PESQ: {avg_pesq:.3f}")
print(f"Average SNR: {avg_snr:.3f} dB")


Evaluating SNR and PESQ: 


100%|██████████| 10/10 [04:13<00:00, 25.32s/it]

Model Type: CicadaUNetModel
Average PESQ: 1.449
Average SNR: 9.633 dB



