In [97]:
import torch
from torch import nn
import torch.nn.functional as F
from torchaudio import transforms
from torch.utils.data import Subset

import data
import matplotlib.pyplot as plt


DEVICE = "cpu"
TEST_DATA_PATH = "./data/"
MODEL_PATH = "./experiments/models/crn-model-512-50-b16.pt"

N_FFT = 512
RESAMPLE_SAMPLERATE = 16000

In [None]:
# Load test set
test_dataset = data.TestNoisySpeech(TEST_DATA_PATH)
_, _, input_samplerate = test_dataset.__getitem__(0)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=1, collate_fn=data.collate_fn)

print(f"Dataset length: {len(test_dataset)} examples")

In [99]:
# Define your model here
class YourModel:
    pass

model = YourModel()

In [None]:
# Load from state_dict
model.load_state_dict(torch.load(MODEL_PATH))

In [101]:
# Pre- and post-processors
class PreProcessor(torch.nn.Module):
    def __init__(
        self,
        input_samplerate=16000,
        resample_samplerate=16000,
        n_fft=480,
        power=None,
    ):
        super().__init__()
        self.output_size = n_fft // 2 + 1
        self.resample = transforms.Resample(input_samplerate, resample_samplerate)
        self.transform = transforms.Spectrogram(n_fft=n_fft, power=power, normalized=True, window_fn=torch.hann_window)
        
    def forward(self, waveform: torch.Tensor) -> torch.Tensor:
        resampled = self.resample(waveform)
        spec = self.transform(resampled)
        spec = spec.permute(0, 2, 1)
        return spec


class PostProcessor(torch.nn.Module):
    def __init__(
        self,
        output_samplerate = 16000,
        resample_samplerate = 16000,
        n_fft = 480
    ):
        super().__init__()
        n_fft = n_fft
        self.resample = transforms.Resample(resample_samplerate, output_samplerate)
        self.transform = transforms.InverseSpectrogram(n_fft=n_fft, normalized=True, window_fn=torch.hann_window)

    def forward(self, spec: torch.Tensor) -> torch.Tensor:
        spec = spec.permute(0, 2, 1)
        waveform = self.transform(spec)
        resampled = self.resample(waveform)
        return resampled
    
    
preprocessor = PreProcessor(
    input_samplerate=input_samplerate, 
    resample_samplerate=RESAMPLE_SAMPLERATE,
    n_fft=N_FFT,
    power=None,
)

postprocessor = PostProcessor(
    output_samplerate=input_samplerate, 
    resample_samplerate=RESAMPLE_SAMPLERATE,
    n_fft=N_FFT,
)

### MAE Calculation

In [113]:
# MAE loss function
def calculate_mae(
    dataloader, 
    model, 
    preprocessor,
):
    mae_losses = []

    with torch.no_grad(): 
        model.eval()
        for noisy_batch, clean_batch, _ in dataloader:
            noisy_spec = preprocessor(noisy_batch)
            clean_spec = preprocessor(clean_batch)

            est_spec, _ = model(noisy_spec)

            mae_loss = F.l1_loss(est_spec.abs(), clean_spec.abs())
            mae_losses.append(mae_loss)

    average_score = sum(mae_losses) / len(mae_losses)
    return average_score

In [None]:
# Compute score
score = calculate_mae(
    model=model,
    dataloader=test_loader,
    preprocessor=preprocessor,
)

score 

### Listening Test

In [110]:
# Define IDs of examples to investigate
idx_list = [0, 100, 200]

single_example_set = Subset(test_dataset, idx_list)
single_example_loader = torch.utils.data.DataLoader(single_example_set, batch_size=1, collate_fn=data.collate_fn)
iterable_loader = iter(single_example_loader)

In [None]:
# Inference + plot spectrogram + display audio
noisy_batch, clean_batch, sr = next(iterable_loader)
noisy_spec = preprocessor(noisy_batch).to(DEVICE)
clean_spec = preprocessor(clean_batch).to(DEVICE)

model.eval()
with torch.no_grad():
    enhanced_spec, _ = model(noisy_spec)

enhanced_batch = postprocessor(enhanced_spec.to('cpu'))
clean_audio = postprocessor(clean_spec.to('cpu'))
noisy_audio = postprocessor(noisy_spec.to('cpu'))

plt.figure(figsize=(8, 3))
plt.subplot(131)
plt.imshow(noisy_spec[0,:,:].to('cpu').abs().log().mT.numpy(),origin='lower', aspect="auto")
plt.subplot(132)
plt.imshow(enhanced_spec[0,:,:].to('cpu').abs().log().mT.detach().numpy(),origin='lower', aspect="auto")
plt.subplot(133)
plt.imshow(clean_spec[0,:,:].to('cpu').abs().log().mT.numpy(),origin='lower', aspect="auto")
plt.show()

import IPython
IPython.display.display(IPython.display.Audio(noisy_batch[0,:].detach().numpy(),rate=int(sr)))
IPython.display.display(IPython.display.Audio(enhanced_batch[0,:].detach().numpy(),rate=int(sr)))
IPython.display.display(IPython.display.Audio(clean_batch[0,:],rate=int(sr)))

### PESQ Score (doesn't work)

In [106]:
from torchmetrics.functional.audio.pesq import perceptual_evaluation_speech_quality

# PESQ eval function
def calculate_pesq(
    dataloader, 
    model, 
    preprocessor,
    postprocessor, 
):
    pesq_scores = []

    with torch.no_grad(): 
        model.eval()
        for noisy_batch, clean_batch, _ in dataloader:
            noisy_spec = preprocessor(noisy_batch)
            est_spec, _ = model(noisy_spec)
            est_batch = postprocessor(est_spec)

            pesq_score = perceptual_evaluation_speech_quality(est_batch, clean_batch, fs=16000, mode="wb")
            pesq_scores.append(pesq_score)

    average_score = sum(pesq_scores) / len(pesq_scores)
    return average_score

In [None]:
# Compute score
score = calculate_pesq(
    dataloader=test_loader,
    model=model,
    preprocessor=preprocessor,
    postprocessor=postprocessor,
)

score 