In [None]:
import torch
import torchaudio
import matplotlib.pyplot as plt
import numpy as np
from melbanks import LogMelFilterBanks

In [None]:
# Load example audio file from torchaudio
sample_speech, sr = torchaudio.load(torchaudio.utils.download_asset("tutorial-assets/Lab41-SRI-VOiCES-src-sp0307-ch127535-sg0042.wav"))
# Resample to 16kHz if needed
if sr != 16000:
    resampler = torchaudio.transforms.Resample(sr, 16000)
    sample_speech = resampler(sample_speech)
    sr = 16000

# Parameters for both implementations
params = {
    'n_fft': 400,
    'hop_length': 160,
    'n_mels': 80,
    'power': 2.0
}

In [None]:
torch_melspec = torchaudio.transforms.MelSpectrogram(
    sample_rate=sr,
    **params
)

custom_melspec = LogMelFilterBanks(
    samplerate=sr,
    **params
)

torch_output = torch.log(torch_melspec(sample_speech) + 1e-6)  # Add log for comparison
custom_output = custom_melspec(sample_speech)

In [None]:
assert torch_output.shape == custom_output.shape
assert torch.allclose(torch_output, custom_output)

In [None]:
fig, (ax1, ax2, ax3) = plt.subplots(3, 1, figsize=(12, 12))

im1 = ax1.imshow(torch_output[0].numpy(), aspect='auto', origin='lower')
ax1.set_title('Torchaudio Implementation')
ax1.set_ylabel('Mel Frequency Bin')
plt.colorbar(im1, ax=ax1)

im2 = ax2.imshow(custom_output[0].numpy(), aspect='auto', origin='lower')
ax2.set_title('Custom Implementation')
ax2.set_ylabel('Mel Frequency Bin')
plt.colorbar(im2, ax=ax2)

difference = torch_output[0] - custom_output[0]
im3 = ax3.imshow(difference.numpy(), aspect='auto', origin='lower')
ax3.set_title('Difference (Torchaudio - Custom)')
ax3.set_ylabel('Mel Frequency Bin')
ax3.set_xlabel('Time Frame')
plt.colorbar(im3, ax=ax3)

plt.tight_layout()
plt.show()