In [None]:
import pandas as pd
import torchaudio
from pathlib import Path
import sys
import numpy as np
import random
import torch
import numpy as np

#set random seeds for reproducibility
SEED = 0
torch.manual_seed(SEED)
random.seed(SEED)
np.random.seed(SEED)

repo_root = Path.cwd().parent
sys.path.insert(0, str(repo_root / "src"))

model_path = repo_root / 'models' / 'GRU_VAD' /'tiny_vad_best.pth'

from utils.audio_dataset_loader import (
    load_ears_dataset,
    load_noizeus_dataset,
    create_audio_pairs,
    preprocess_audio
)
from utils.generate_and_save_spectrogram import generate_and_save_spectrogram
from utils.compute_and_save_speech_metrics import compute_and_save_speech_metrics
from utils.parse_and_merge_csvs import merge_csvs
from utils.delete_csvs import delete_csvs_in_directory as delete_csvs
from dsp_algorithms.wiener_as import wiener_filter
from dsp_algorithms.wiener_GTCRN import wiener_filter as wiener_gtcrn

In [None]:
# Load test datasets
print("Loading EARS test dataset...")
ears_files = load_ears_dataset(repo_root, mode="test")
print(f"Loaded {len(ears_files)} EARS files for test mode")

print("Loading NOIZEUS test dataset...")
noizeus_files = load_noizeus_dataset(repo_root)
print(f"Loaded {len(noizeus_files)} NOIZEUS files for test mode")

# Create audio pairs
paired_files = create_audio_pairs(noizeus_files, ears_files)
print(f"Created {len(paired_files)} audio pairs for processing")

noise_path, clean_path = paired_files[16]

snr_dB = 0

clean_waveform, noise_waveform, noisy_speech, clean_sr = preprocess_audio(
    clean_speech=clean_path, 
    noisy_audio=noise_path, 
    snr_db=snr_dB
)

In [None]:
print("\n2. Applying causal Wiener filtering...")
threshold = 0.5  # VAD decision threshold
enhanced_speech_wf_gru_vad, enhanced_fs = wiener_filter(
    noisy_audio=noisy_speech,
    fs=16000,
    frame_dur_ms=8
    #output_dir=output_dir_snr,
    #output_file=output_filename.replace('.wav', ''),
)

In [None]:
print("\n2. Applying causal Wiener filtering...")
enhanced_speech_as, enhanced_fs = wiener_filter(
    noisy_audio=noisy_speech,
    fs=16000,
    mu=0.98,
    a_dd=0.98,
    eta=0.15,
    frame_dur_ms=8
)

In [None]:
import IPython.display as ipd
import matplotlib.pyplot as plt

print("="*80)
print("SIDE-BY-SIDE COMPARISON")
print("="*80)

# Prepare audio signals for comparison
noisy_compare = noisy_speech.clone().squeeze(0)
wf_as_compare = enhanced_speech_as.clone().squeeze(0)
enhanced_output = enhanced_speech_wf_gru_vad.clone().squeeze(0)
wf_gru_compare = enhanced_speech_wf_gru_vad.clone().squeeze(0)
clean_compare = clean_waveform.clone().squeeze(0)

fs = 16000

print(f"\nAudio lengths before trimming:")
print(f"  noisy:              {len(noisy_compare)}")
print(f"  WF as:              {len(wf_as_compare)}")
print(f"  WF TinyGRU (thr=0.05): {len(wf_gru_compare)}")
print(f"  clean:              {len(clean_compare)}")

# Trim to shortest length
min_len = min(
    len(noisy_compare), 
    len(wf_as_compare),
    len(wf_gru_compare),
    len(clean_compare)
)

noisy_compare = noisy_compare[:min_len]
wf_as_compare = wf_as_compare[:min_len]
wf_gru_compare = wf_gru_compare[:min_len]
clean_compare = clean_compare[:min_len]

print(f"\nTrimmed to: {min_len} samples ({min_len/fs:.2f}s)")

# Create 4-panel comparison plot
fig, axes = plt.subplots(4, 1, figsize=(16, 12))
time_axis = np.arange(min_len) / fs

# 1. Noisy
axes[0].plot(time_axis, noisy_compare.numpy(), 'r', alpha=0.7, linewidth=0.8)
axes[0].set_title(f'1. Noisy Speech (SNR = {snr_dB}dB)', fontsize=12, fontweight='bold')
axes[0].set_ylabel('Amplitude')
axes[0].grid(True, alpha=0.3)
axes[0].set_xlim(0, min_len/fs)

# 2. Wiener AS
axes[1].plot(time_axis, wf_as_compare.numpy(), 'orange', alpha=0.7, linewidth=0.8)
axes[1].set_title('2. Wiener Filter (AS method)', fontsize=12, fontweight='bold')
axes[1].set_ylabel('Amplitude')
axes[1].grid(True, alpha=0.3)
axes[1].set_xlim(0, min_len/fs)

# 3. Wiener + TinyGRU VAD
axes[2].plot(time_axis, wf_gru_compare.numpy(), 'b', alpha=0.7, linewidth=0.8)
axes[2].set_title('3. Wiener + TinyGRU VAD (threshold=0.5)', fontsize=12, fontweight='bold')
axes[2].set_ylabel('Amplitude')
axes[2].grid(True, alpha=0.3)
axes[2].set_xlim(0, min_len/fs)

# 4. Clean reference
axes[3].plot(time_axis, clean_compare.numpy(), 'g', alpha=0.7, linewidth=0.8)
axes[3].set_title('4. Clean Speech (Reference)', fontsize=12, fontweight='bold')
axes[3].set_ylabel('Amplitude')
axes[3].set_xlabel('Time (s)')
axes[3].grid(True, alpha=0.3)
axes[3].set_xlim(0, min_len/fs)

plt.tight_layout()
plt.show()

# Audio playback comparison
print("\n" + "="*80)
print("AUDIO PLAYBACK COMPARISON")
print("="*80)

print("NOISY SPEECH (0dB SNR):")
display(ipd.Audio(noisy_compare.numpy(), rate=fs))

print("\nWIENER FILTER (AS method):")
display(ipd.Audio(wf_as_compare.numpy(), rate=fs))

print("\nWIENER + TinyGRUVAD (threshold=0.5):")
display(ipd.Audio(wf_gru_compare.numpy(), rate=fs))

print("\nCLEAN SPEECH (Reference):")
display(ipd.Audio(clean_compare.numpy(), rate=fs))

In [None]:
import IPython.display as ipd
from dsp_algorithms.wiener_GTCRN import wiener_filter as wiener_gtcrn

# Use the audio waveforms from earlier
noisy_waveform = noisy_speech.squeeze(0)
clean_waveform_trimmed = clean_waveform.squeeze(0)[:len(noisy_waveform)]
sample_rate = 16000

# Apply Wiener filter to noisy
print("Applying Wiener GTCRN filter...")
enhanced_audio, _ = wiener_gtcrn(noisy_waveform, sample_rate)

# Generate and save spectrograms using the utility function
print("Generating and saving spectrograms...")
filter_type = "wiener_gtcrn"
output_path = f"spectrograms/{filter_type}"

generate_and_save_spectrogram(
    waveform=noisy_waveform,
    sample_rate=sample_rate,
    output_image_path=output_path,
    output_file_name="noisy_spectrogram",
    title="Noise Mel Spectrogram",
    n_mels=128,
    hop_length=512,
    n_fft=1024,
    colormap='plasma',
    include_metadata_in_filename=False
)

ipd.display(ipd.Image(filename=f'{output_path}/noisy_spectrogram.png'))

generate_and_save_spectrogram(
    waveform=clean_waveform_trimmed,
    sample_rate=sample_rate,
    output_image_path=output_path,
    output_file_name="clean_spectrogram",
    title="Clean Speech Mel Spectrogram",
    n_mels=128,
    hop_length=512,
    n_fft=1024,
    colormap='plasma',
    include_metadata_in_filename=False
)

ipd.display(ipd.Image(filename=f'{output_path}/clean_spectrogram.png'))

generate_and_save_spectrogram(
    waveform=enhanced_audio,
    sample_rate=sample_rate,
    output_image_path=output_path,
    output_file_name="enhanced_spectrogram",
    title="Enhanced Speech Mel Spectrogram (Wiener GTCRN)",
    n_mels=128,
    hop_length=512,
    n_fft=1024,
    colormap='plasma',
    include_metadata_in_filename=False
)

ipd.display(ipd.Image(filename=f'{output_path}/enhanced_spectrogram.png'))

print("Spectrograms generated, saved, and displayed.")