In [1]:
import torch
import torchaudio
import matplotlib.pyplot as plt
import numpy as np
import pytorch_lightning as pl
import os
import sys
import random
sys.path.append('../')
sys.path.append(os.path.abspath(os.path.join(os.pardir, 'pdmdns_solution')))
from models import Model
from torchmetrics.functional.audio import scale_invariant_signal_noise_ratio

In [2]:
# ===========================
# Helper to plot spectrogram
# ===========================
def plot_spectrogram(waveform, title, sample_rate, ax):
    spectrogram = torchaudio.transforms.MelSpectrogram(sample_rate)(waveform)
    spec_db = torchaudio.transforms.AmplitudeToDB()(spectrogram)
    ax.imshow(spec_db.squeeze().cpu(), origin="lower", aspect="auto", cmap="viridis")
    ax.set_title(title)
    ax.axis("off")
    
# ===========================
# Helper: STFT spectrogram plot
# ===========================
def plot_spectrogram(waveform, title, sample_rate, ax):
    if waveform.ndim > 1:
        waveform = waveform[0]  # Take first channel if stereo

    spec = torch.stft(waveform, n_fft=N_FFT, hop_length=HOP_LENGTH,
                      return_complex=True, center=True, window=torch.hann_window(N_FFT))
    magnitude = spec.abs().cpu().numpy()

    time_bins = magnitude.shape[1]
    freq_bins = magnitude.shape[0]

    times = np.arange(time_bins) * HOP_LENGTH / sample_rate
    freqs = np.linspace(0, sample_rate // 2, freq_bins)

    ax.pcolormesh(times, freqs/1000, 20 * np.log10(magnitude + 1e-6), shading='auto', cmap='inferno')
    ax.set_title(title)
    ax.set_ylabel('Frequency (KHz)')
    #ax.set_xlabel('Time (s)')

# ===========================
# Load audio
# ===========================
def load_audio(path):
    waveform, sr = torchaudio.load(path)
    if sr != SAMPLE_RATE:
        waveform = torchaudio.functional.resample(waveform, sr, SAMPLE_RATE)
    return waveform

# ===========================
# Load models
# ===========================
def load_models():
    MODELS = {}
    for name, ckpt_path in CHECKPOINT_PATHS.items():
        print(f"Loading model: {name}")
        model = Model.load_from_checkpoint(ckpt_path)
        model.eval()
        MODELS[name] = model

    return MODELS
    
# ===========================
# Apply each model
# ===========================

def denoise(noisy):
    denoised_outputs = {}
    for name, model in MODELS.items():
        with torch.no_grad():
            if isinstance(noisy, list):
                denoised = [model(noisy_.to(DEVICE)).cpu() for noisy_ in noisy]
            else:
                denoised = model(noisy.to(DEVICE)).cpu()
        denoised_outputs[name] = denoised

    return denoised_outputs

In [3]:
# ===========================
# Configuration
# ===========================
CHECKPOINT_PATHS = {
    'PDMDNS_stateless': '../saved_checkpoints/base/checkpoint_epoch=0199_val_si_snr=14.5784_0.ckpt',
    'PDMDNS_stateful': '../saved_checkpoints/stateful/checkpoint_epoch=0194_val_si_snr=14.5371_8.ckpt'
}
NOISY_AUDIO_PATH = '../data/noisy/noisy_fileid_'
CLEAN_AUDIO_PATH = '../data/clean/clean_fileid_'
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
SAMPLE_RATE = 16000  # Make sure your models are trained on this sample rate
N_FFT = 512
HOP_LENGTH = 128
NUM_FILES = len(os.listdir('../data/noisy'))
MODELS = load_models()

Loading model: PDMDNS_stateless
DenoisingNet(
  (net_in): DELAYAvg(n_input=16, delay_max=1024, rand=False, avgpooling=AvgPool2d(kernel_size=(1, 128), stride=(1, 128), padding=0)
  (net): ModuleList(
    (0): ModuleList(
      (0): Sequential(
        (0): SubBlock(
          (conv): Conv2d(16, 8, kernel_size=(1, 3), stride=(1, 1), padding=same, bias=False)
          (activation): NeuronClass(
            (neuron): HRNeuron(
              (neurons): ParaLIF(
                spike_mode=T, n_neuron=8, recurrent=False, fire=True, recurrent_fire=True, learn_threshold=True, learn_tau=False
                (spike_fn): SpikingFunction(spike_mode: Threshold)
              )
            )
          )
        )
        (1): SubBlock(
          (conv): Conv2d(24, 8, kernel_size=(1, 3), stride=(1, 1), padding=same, bias=False)
          (activation): NeuronClass(
            (neuron): HRNeuron(
              (neurons): ParaLIF(
                spike_mode=T, n_neuron=8, recurrent=False, fire=True, r

C:\Users\yars2201\AppData\Local\anaconda3\envs\pdmse\Lib\site-packages\pytorch_lightning\utilities\migration\utils.py:56: The loaded checkpoint was produced with Lightning v2.5.1, which is newer than your current Lightning version: v2.4.0


In [4]:
import os, random
import torch
import torchaudio
import soundfile as sf
import IPython.display as ipd
import ipywidgets as widgets
from IPython.display import display, clear_output
import matplotlib.pyplot as plt

# State dict to store current data
state = {}

# Output display widget
output = widgets.Output()

# === Duration selector ===
duration_options = [5, 10, 15, 20, None]
duration_dropdown = widgets.Dropdown(
    options=[(str(d) + "s" if d else "Full", d) for d in duration_options],
    value=None,
    description="Max Duration:"
)

def save_audio(waveform, filename, save_dir):
    if waveform.ndim == 1:
        waveform = waveform.unsqueeze(0)
    torchaudio.save(os.path.join(save_dir, filename), waveform.cpu(), SAMPLE_RATE)

def save_current_audios(index, save_dir="exported_audio"):
    os.makedirs(save_dir, exist_ok=True)
    save_audio(state["noisy"], f"{index}_noisy___sisnr_{state['noisy_sisnr']:.2f}.wav", save_dir)
    save_audio(state["clean"], f"{index}_clean.wav", save_dir)
    for model_name, denoised in state["denoised_outputs"].items():
        save_audio(denoised, f"{index}_denoised_{model_name}___sisnr_{state['noisy_sisnr']+state['denoised_sisnr'][model_name]:.2f}.wav", save_dir)

def crop_random(audios, duration, sample_rate):
    max_len = int(duration * sample_rate)
    start = random.randint(0, audios[0].shape[-1] - max_len)
    return [audio[..., start:start + max_len] for audio in audios]

def show_audio(index):
    with output:
        clear_output(wait=True)

        # Load raw audio
        noisy = load_audio(NOISY_AUDIO_PATH + f"{index}.wav")
        clean = load_audio(CLEAN_AUDIO_PATH + f"{index}.wav")

        # Crop if needed
        duration_limit = duration_dropdown.value
        if duration_limit is not None:
            noisy, clean = crop_random([noisy, clean], duration_limit, SAMPLE_RATE)

        state["noisy"] = noisy
        state["clean"] = clean

        # Denoise and crop results
        denoised_outputs = denoise(noisy)
        state["denoised_outputs"] = denoised_outputs

        # Compute SI-SNR and improvements
        noisy_sisnr = scale_invariant_signal_noise_ratio(clean, noisy).item()
        denoised_sisnr = {}
        for model_name, denoised in denoised_outputs.items():
            denoised_sisnr[model_name] = (
                scale_invariant_signal_noise_ratio(clean, denoised).item() - noisy_sisnr
            )

        state["noisy_sisnr"] = noisy_sisnr
        state["denoised_sisnr"] = denoised_sisnr

        # === LEFT SIDE: Info and audio ===
        left_box = widgets.Output()
        with left_box:
            display(widgets.HTML(value=f"<b>Noisy Audio</b> SI-SNR: {noisy_sisnr:.2f} dB"))
            display(ipd.Audio(noisy, rate=SAMPLE_RATE))

            display(widgets.HTML(value="<b>Clean Audio</b>"))
            display(ipd.Audio(clean, rate=SAMPLE_RATE))

            for model_name, denoised in denoised_outputs.items():
                display(widgets.HTML(
                    value=f"<b>Denoised ({model_name})</b> SI-SNR Improvement: {denoised_sisnr[model_name]:.2f} dB"))
                display(ipd.Audio(denoised, rate=SAMPLE_RATE))

        # === RIGHT SIDE: Spectrograms ===
        right_box = widgets.Output()
        with right_box:
            n_models = len(denoised_outputs)
            fig, axs = plt.subplots(2 + n_models, 1, figsize=(6, 1.5 * (2 + n_models)))

            plot_spectrogram(clean[0], 'a) Clean Audio', SAMPLE_RATE, axs[0])
            axs[0].set_xticks([])

            plot_spectrogram(noisy[0], 'b) Noisy Audio', SAMPLE_RATE, axs[1])
            axs[1].set_xticks([])

            for i, (name, denoised) in enumerate(denoised_outputs.items(), start=2):
                plot_spectrogram(denoised[0], f"{'abcd'[i]}) Denoised by {name}", SAMPLE_RATE, axs[i])
                if i < 1 + n_models:
                    axs[i].set_xticks([])

            axs[-1].set_xlabel("Time (s)")
            plt.tight_layout()
            plt.show()

        # Combine
        display(widgets.HBox([left_box, right_box]))

# === Navigation and interaction ===
index_widget = widgets.IntText(value=0, min=0, max=NUM_FILES - 1, description='Audio ID:', disabled=True)

button_next = widgets.Button(description="Next", icon="arrow-right")
def on_next(b):
    index_widget.value = random.randint(0, NUM_FILES - 1)
    show_audio(index_widget.value)
button_next.on_click(on_next)

download_button = widgets.Button(description="Download Current Audio", button_style='success', icon="arrow-down")
def on_download(b):
    save_current_audios(index_widget.value)
    with output:
        print(f"✅ Exported audio files for ID {index_widget.value} to ./exported_audio/")
download_button.on_click(on_download)

# === Final layout ===
nav_buttons = widgets.HBox([button_next, download_button])
display(widgets.VBox([duration_dropdown, nav_buttons, index_widget, output]))
show_audio(index_widget.value)

VBox(children=(Dropdown(description='Max Duration:', options=(('5s', 5), ('10s', 10), ('15s', 15), ('20s', 20)…