## own function for fbank (to ensure I understand the process)

### torchaudio

In [None]:
import os
import numpy as np
import matplotlib.pyplot as plt
import torch
import torchaudio
import soundfile as sf


# Verify the file path is correct
file_path = '/home/bosfab01/SpeakerVerificationBA/data/preprocessed/0a4b5c0f-facc-4d3b-8a41-bc9148d62d95/0_segment_0.flac'
try:
    audio_signal, sample_rate = sf.read(file_path)
except Exception as e:
    print(f"An error occurred while reading the file: {e}")
    raise

# Create time array for plotting
time = np.arange(len(audio_signal)) / sample_rate

# Convert the NumPy array to a PyTorch tensor
audio_tensor = torch.from_numpy(audio_signal)

# Ensure the tensor is in float32 format (required for most torchaudio operations)
audio_tensor = audio_tensor.float()
print("Data type of audio tensor:", audio_tensor.dtype)

# If your array is not in batch x channels x time format, adjust accordingly
# Assuming the audio signal is single-channel and not batched:
audio_tensor = audio_tensor.unsqueeze(0)
print("Shape of audio tensor:", audio_tensor.shape)

# Now call the fbank function
fbank_features = torchaudio.compliance.kaldi.fbank(
    audio_tensor, 
    sample_frequency=sample_rate, 
    htk_compat=True, 
    use_energy=False, 
    window_type='hanning', 
    num_mel_bins=128, 
    dither=0.0, 
    frame_shift=10
)

# call with fewer arguments
fbank_features_few = torchaudio.compliance.kaldi.fbank(
    audio_tensor, 
    sample_frequency=sample_rate, 
    window_type='hanning', 
    num_mel_bins=128
)

# Output the shape of the fbank features to confirm
print(f"Shape of fbank features: {fbank_features.shape}")

# compare the two
print(f"Are the two fbank features equal? {torch.allclose(fbank_features, fbank_features_few)}")

### own function for fbank

In [None]:
import math
from typing import Tuple
import torch
from torch import Tensor

# numeric_limits<float>::epsilon() 1.1920928955078125e-07
EPSILON = torch.tensor(torch.finfo(torch.float).eps)
# 1 milliseconds = 0.001 seconds
MILLISECONDS_TO_SECONDS = 0.001


def _get_epsilon(device, dtype):
    return EPSILON.to(device=device, dtype=dtype)


def _next_power_of_2(x: int) -> int:
    r"""Returns the smallest power of 2 that is greater than x"""
    return 1 if x == 0 else 2 ** (x - 1).bit_length()


def _get_strided(waveform: Tensor, window_size: int, window_shift: int, snip_edges: bool) -> Tensor:
    r"""Given a waveform (1D tensor of size ``num_samples``), it returns a 2D tensor (m, ``window_size``)
    representing how the window is shifted along the waveform. Each row is a frame.

    Args:
        waveform (Tensor): Tensor of size ``num_samples``
        window_size (int): Frame length
        window_shift (int): Frame shift
        snip_edges (bool): If True, end effects will be handled by outputting only frames that completely fit
            in the file, and the number of frames depends on the frame_length.  If False, the number of frames
            depends only on the frame_shift, and we reflect the data at the ends.

    Returns:
        Tensor: 2D tensor of size (m, ``window_size``) where each row is a frame
    """
    assert waveform.dim() == 1
    num_samples = waveform.size(0)
    strides = (window_shift * waveform.stride(0), waveform.stride(0))

    if snip_edges:
        if num_samples < window_size:
            return torch.empty((0, 0), dtype=waveform.dtype, device=waveform.device)
        else:
            m = 1 + (num_samples - window_size) // window_shift
    else:
        reversed_waveform = torch.flip(waveform, [0])
        m = (num_samples + (window_shift // 2)) // window_shift
        pad = window_size // 2 - window_shift // 2
        pad_right = reversed_waveform
        if pad > 0:
            # torch.nn.functional.pad returns [2,1,0,1,2] for 'reflect'
            # but we want [2, 1, 0, 0, 1, 2]
            pad_left = reversed_waveform[-pad:]
            waveform = torch.cat((pad_left, waveform, pad_right), dim=0)
        else:
            # pad is negative so we want to trim the waveform at the front
            waveform = torch.cat((waveform[-pad:], pad_right), dim=0)

    sizes = (m, window_size)
    return waveform.as_strided(sizes, strides)


def _feature_window_function(
    window_size: int,
    device: torch.device,
    dtype: int,
) -> Tensor:
    r"""Returns a window function with the given type and size"""
    return torch.hann_window(window_size, periodic=False, device=device, dtype=dtype)

def _get_log_energy(strided_input: Tensor, epsilon: Tensor, energy_floor: float) -> Tensor:
    r"""Returns the log energy of size (m) for a strided_input (m,*)"""
    device, dtype = strided_input.device, strided_input.dtype
    log_energy = torch.max(strided_input.pow(2).sum(1), epsilon).log()  # size (m)
    if energy_floor == 0.0:
        return log_energy
    return torch.max(log_energy, torch.tensor(math.log(energy_floor), device=device, dtype=dtype))


def _get_waveform_and_window_properties(
    waveform: Tensor,
    channel: int,
    sample_frequency: float,
    frame_shift: float,
    frame_length: float,
    round_to_power_of_two: bool,
    preemphasis_coefficient: float,
) -> Tuple[Tensor, int, int, int]:
    r"""Gets the waveform and window properties"""
    channel = max(channel, 0)
    assert channel < waveform.size(0), "Invalid channel {} for size {}".format(channel, waveform.size(0))
    waveform = waveform[channel, :]  # size (n)
    window_shift = int(sample_frequency * frame_shift * MILLISECONDS_TO_SECONDS)
    window_size = int(sample_frequency * frame_length * MILLISECONDS_TO_SECONDS)
    padded_window_size = _next_power_of_2(window_size) if round_to_power_of_two else window_size

    assert 2 <= window_size <= len(waveform), "choose a window size {} that is [2, {}]".format(
        window_size, len(waveform)
    )
    assert 0 < window_shift, "`window_shift` must be greater than 0"
    assert padded_window_size % 2 == 0, (
        "the padded `window_size` must be divisible by two." " use `round_to_power_of_two` or change `frame_length`"
    )
    assert 0.0 <= preemphasis_coefficient <= 1.0, "`preemphasis_coefficient` must be between [0,1]"
    assert sample_frequency > 0, "`sample_frequency` must be greater than zero"
    return waveform, window_shift, window_size, padded_window_size


def _get_window(
    waveform: Tensor,
    padded_window_size: int,
    window_size: int,
    window_shift: int,
    snip_edges: bool,
    raw_energy: bool,
    energy_floor: float,
    remove_dc_offset: bool,
    preemphasis_coefficient: float,
) -> Tuple[Tensor, Tensor]:
    r"""Gets a window and its log energy

    Returns:
        (Tensor, Tensor): strided_input of size (m, ``padded_window_size``) and signal_log_energy of size (m)
    """
    device, dtype = waveform.device, waveform.dtype
    epsilon = _get_epsilon(device, dtype)

    # size (m, window_size)
    strided_input = _get_strided(waveform, window_size, window_shift, snip_edges)

    if remove_dc_offset:
        # Subtract each row/frame by its mean
        row_means = torch.mean(strided_input, dim=1).unsqueeze(1)  # size (m, 1)
        strided_input = strided_input - row_means

    if raw_energy:
        # Compute the log energy of each row/frame before applying preemphasis and
        # window function
        signal_log_energy = _get_log_energy(strided_input, epsilon, energy_floor)  # size (m)

    if preemphasis_coefficient != 0.0:
        # strided_input[i,j] -= preemphasis_coefficient * strided_input[i, max(0, j-1)] for all i,j
        offset_strided_input = torch.nn.functional.pad(strided_input.unsqueeze(0), (1, 0), mode="replicate").squeeze(
            0
        )  # size (m, window_size + 1)
        strided_input = strided_input - preemphasis_coefficient * offset_strided_input[:, :-1]

    # Apply window_function to each row/frame
    window_function = _feature_window_function(window_size, device, dtype).unsqueeze(
        0
    )  # size (1, window_size)
    strided_input = strided_input * window_function  # size (m, window_size)

    # Pad columns with zero until we reach size (m, padded_window_size)
    if padded_window_size != window_size:
        padding_right = padded_window_size - window_size
        strided_input = torch.nn.functional.pad(
            strided_input.unsqueeze(0), (0, padding_right), mode="constant", value=0
        ).squeeze(0)

    # Compute energy after window function (not the raw one)
    if not raw_energy:
        signal_log_energy = _get_log_energy(strided_input, epsilon, energy_floor)  # size (m)

    return strided_input, signal_log_energy


def _subtract_column_mean(tensor: Tensor, subtract_mean: bool) -> Tensor:
    # subtracts the column mean of the tensor size (m, n) if subtract_mean=True
    # it returns size (m, n)
    if subtract_mean:
        col_means = torch.mean(tensor, dim=0).unsqueeze(0)
        tensor = tensor - col_means
    return tensor



def inverse_mel_scale_scalar(mel_freq: float) -> float:
    return 700.0 * (math.exp(mel_freq / 1127.0) - 1.0)


def inverse_mel_scale(mel_freq: Tensor) -> Tensor:
    return 700.0 * ((mel_freq / 1127.0).exp() - 1.0)


def mel_scale_scalar(freq: float) -> float:
    return 1127.0 * math.log(1.0 + freq / 700.0)


def mel_scale(freq: Tensor) -> Tensor:
    return 1127.0 * (1.0 + freq / 700.0).log()


def get_mel_banks(
    num_bins: int,
    window_length_padded: int,
    sample_freq: float,
    low_freq: float,
    high_freq: float,
    vtln_low: float,
    vtln_high: float,
) -> Tuple[Tensor, Tensor]:
    """
    Returns:
        (Tensor, Tensor): The tuple consists of ``bins`` (which is
        melbank of size (``num_bins``, ``num_fft_bins``)) and ``center_freqs`` (which is
        center frequencies of bins of size (``num_bins``)).
    """
    assert num_bins > 3, "Must have at least 3 mel bins"
    assert window_length_padded % 2 == 0
    num_fft_bins = window_length_padded / 2
    nyquist = 0.5 * sample_freq

    if high_freq <= 0.0:
        high_freq += nyquist

    assert (
        (0.0 <= low_freq < nyquist) and (0.0 < high_freq <= nyquist) and (low_freq < high_freq)
    ), "Bad values in options: low-freq {} and high-freq {} vs. nyquist {}".format(low_freq, high_freq, nyquist)

    # fft-bin width [think of it as Nyquist-freq / half-window-length]
    fft_bin_width = sample_freq / window_length_padded
    mel_low_freq = mel_scale_scalar(low_freq)
    mel_high_freq = mel_scale_scalar(high_freq)

    # divide by num_bins+1 in next line because of end-effects where the bins
    # spread out to the sides.
    mel_freq_delta = (mel_high_freq - mel_low_freq) / (num_bins + 1)

    if vtln_high < 0.0:
        vtln_high += nyquist

    bin = torch.arange(num_bins).unsqueeze(1)
    left_mel = mel_low_freq + bin * mel_freq_delta  # size(num_bins, 1)
    center_mel = mel_low_freq + (bin + 1.0) * mel_freq_delta  # size(num_bins, 1)
    right_mel = mel_low_freq + (bin + 2.0) * mel_freq_delta  # size(num_bins, 1)

    center_freqs = inverse_mel_scale(center_mel)  # size (num_bins)
    # size(1, num_fft_bins)
    mel = mel_scale(fft_bin_width * torch.arange(num_fft_bins)).unsqueeze(0)

    # size (num_bins, num_fft_bins)
    up_slope = (mel - left_mel) / (center_mel - left_mel)
    down_slope = (right_mel - mel) / (right_mel - center_mel)

    # left_mel < center_mel < right_mel so we can min the two slopes and clamp negative values
    bins = torch.max(torch.zeros(1), torch.min(up_slope, down_slope))
    
    return bins, center_freqs


def fbank_own(
    waveform: Tensor,
    channel: int = -1,
    energy_floor: float = 1.0,
    frame_length: float = 25.0,
    frame_shift: float = 10.0,
    high_freq: float = 0.0,
    low_freq: float = 20.0,
    min_duration: float = 0.0,
    num_mel_bins: int = 23,
    preemphasis_coefficient: float = 0.97,
    raw_energy: bool = True,
    remove_dc_offset: bool = True,
    round_to_power_of_two: bool = True,
    sample_frequency: float = 16000.0,
    snip_edges: bool = True,
    subtract_mean: bool = False,
    use_log_fbank: bool = True,
    use_power: bool = True,
    vtln_high: float = -500.0,
    vtln_low: float = 100.0,
) -> Tensor:
    r"""Create a fbank from a raw audio signal. This matches the input/output of Kaldi's
    compute-fbank-feats.

    Args:
        waveform (Tensor): Tensor of audio of size (c, n) where c is in the range [0,2)
        blackman_coeff (float, optional): Constant coefficient for generalized Blackman window. (Default: ``0.42``)
        channel (int, optional): Channel to extract (-1 -> expect mono, 0 -> left, 1 -> right) (Default: ``-1``)
        dither (float, optional): Dithering constant (0.0 means no dither). If you turn this off, you should set
            the energy_floor option, e.g. to 1.0 or 0.1 (Default: ``0.0``)
        energy_floor (float, optional): Floor on energy (absolute, not relative) in Spectrogram computation.  Caution:
            this floor is applied to the zeroth component, representing the total signal energy.  The floor on the
            individual spectrogram elements is fixed at std::numeric_limits<float>::epsilon(). (Default: ``1.0``)
        frame_length (float, optional): Frame length in milliseconds (Default: ``25.0``)
        frame_shift (float, optional): Frame shift in milliseconds (Default: ``10.0``)
        high_freq (float, optional): High cutoff frequency for mel bins (if <= 0, offset from Nyquist)
         (Default: ``0.0``)
        htk_compat (bool, optional): If true, put energy last.  Warning: not sufficient to get HTK compatible features
         (need to change other parameters). (Default: ``False``)
        low_freq (float, optional): Low cutoff frequency for mel bins (Default: ``20.0``)
        min_duration (float, optional): Minimum duration of segments to process (in seconds). (Default: ``0.0``)
        num_mel_bins (int, optional): Number of triangular mel-frequency bins (Default: ``23``)
        preemphasis_coefficient (float, optional): Coefficient for use in signal preemphasis (Default: ``0.97``)
        raw_energy (bool, optional): If True, compute energy before preemphasis and windowing (Default: ``True``)
        remove_dc_offset (bool, optional): Subtract mean from waveform on each frame (Default: ``True``)
        round_to_power_of_two (bool, optional): If True, round window size to power of two by zero-padding input
            to FFT. (Default: ``True``)
        sample_frequency (float, optional): Waveform data sample frequency (must match the waveform file, if
            specified there) (Default: ``16000.0``)
        snip_edges (bool, optional): If True, end effects will be handled by outputting only frames that completely fit
            in the file, and the number of frames depends on the frame_length.  If False, the number of frames
            depends only on the frame_shift, and we reflect the data at the ends. (Default: ``True``)
        subtract_mean (bool, optional): Subtract mean of each feature file [CMS]; not recommended to do
            it this way.  (Default: ``False``)
        use_energy (bool, optional): Add an extra dimension with energy to the FBANK output. (Default: ``False``)
        use_log_fbank (bool, optional):If true, produce log-filterbank, else produce linear. (Default: ``True``)
        use_power (bool, optional): If true, use power, else use magnitude. (Default: ``True``)
        vtln_high (float, optional): High inflection point in piecewise linear VTLN warping function (if
            negative, offset from high-mel-freq (Default: ``-500.0``)
        vtln_low (float, optional): Low inflection point in piecewise linear VTLN warping function (Default: ``100.0``)
        vtln_warp (float, optional): Vtln warp factor (only applicable if vtln_map not specified) (Default: ``1.0``)
        window_type (str, optional): Type of window ('hamming'|'hanning'|'povey'|'rectangular'|'blackman')
         (Default: ``'povey'``)

    Returns:
        Tensor: A fbank identical to what Kaldi would output. The shape is (m, ``num_mel_bins + use_energy``)
        where m is calculated in _get_strided
    """
    device, dtype = waveform.device, waveform.dtype

    waveform, window_shift, window_size, padded_window_size = _get_waveform_and_window_properties(
        waveform, channel, sample_frequency, frame_shift, frame_length, round_to_power_of_two, preemphasis_coefficient
    )

    if len(waveform) < min_duration * sample_frequency:
        # signal is too short
        return torch.empty(0, device=device, dtype=dtype)

    # strided_input, size (m, padded_window_size) and signal_log_energy, size (m)
    strided_input, signal_log_energy = _get_window(
        waveform,
        padded_window_size,
        window_size,
        window_shift,
        snip_edges,
        raw_energy,
        energy_floor,
        remove_dc_offset,
        preemphasis_coefficient,
    )

    # size (m, padded_window_size // 2 + 1)
    spectrum = torch.fft.rfft(strided_input).abs()
    if use_power:
        spectrum = spectrum.pow(2.0)

    # size (num_mel_bins, padded_window_size // 2)
    mel_energies, _ = get_mel_banks(
        num_mel_bins, padded_window_size, sample_frequency, low_freq, high_freq, vtln_low, vtln_high
    )
    mel_energies = mel_energies.to(device=device, dtype=dtype)

    # pad right column with zeros and add dimension, size (num_mel_bins, padded_window_size // 2 + 1)
    mel_energies = torch.nn.functional.pad(mel_energies, (0, 1), mode="constant", value=0)

    # sum with mel fiterbanks over the power spectrum, size (m, num_mel_bins)
    mel_energies = torch.mm(spectrum, mel_energies.T)
    if use_log_fbank:
        # avoid log of zero (which should be prevented anyway by dithering)
        mel_energies = torch.max(mel_energies, _get_epsilon(device, dtype)).log()

    mel_energies = _subtract_column_mean(mel_energies, subtract_mean)
    return mel_energies

In [None]:
# call fbank_own

import os
import numpy as np
import matplotlib.pyplot as plt
import torch
import torchaudio
import soundfile as sf


# Verify the file path is correct
file_path = '/home/bosfab01/SpeakerVerificationBA/data/preprocessed/0a4b5c0f-facc-4d3b-8a41-bc9148d62d95/0_segment_0.flac'
try:
    audio_signal, sample_rate = sf.read(file_path)
except Exception as e:
    print(f"An error occurred while reading the file: {e}")
    raise

# Create time array for plotting
time = np.arange(len(audio_signal)) / sample_rate

# Convert the NumPy array to a PyTorch tensor
audio_tensor = torch.from_numpy(audio_signal)

# Ensure the tensor is in float32 format (required for most torchaudio operations)
audio_tensor = audio_tensor.float()
print("Data type of audio tensor:", audio_tensor.dtype)

# If your array is not in batch x channels x time format, adjust accordingly
# Assuming the audio signal is single-channel and not batched:
audio_tensor = audio_tensor.unsqueeze(0)
print("Shape of audio tensor:", audio_tensor.shape)

# Now call the fbank function
fbank_features_own = fbank_own(
    audio_tensor, 
    sample_frequency=sample_rate, 
    num_mel_bins=128, 
    frame_shift=10
)

In [None]:
# plot both

# plot the fbank features from the own implementation
plt.figure(figsize=(10, 1.5))
plt.imshow(fbank_features_own.T, aspect='auto', origin='lower', cmap='viridis')
plt.title('Own fbank Features')

# plot the fbank features
plt.figure(figsize=(10, 1.5))
plt.imshow(fbank_features.T, aspect='auto', origin='lower', cmap='viridis')
plt.title('Kaldi fbank Features')



In [None]:
# compare the 90th frame
plt.plot(fbank_features[90, :], label='Kaldi')
plt.plot(fbank_features_own[90, :], label='Own')
plt.legend()


### remove more

In [None]:
import math
from typing import Tuple
import torch
from torch import Tensor

# numeric_limits<float>::epsilon() 1.1920928955078125e-07
EPSILON = torch.tensor(torch.finfo(torch.float).eps)
# 1 milliseconds = 0.001 seconds
MILLISECONDS_TO_SECONDS = 0.001


def _get_epsilon(device, dtype):
    return EPSILON.to(device=device, dtype=dtype)


def _next_power_of_2(x: int) -> int:
    return 1 if x == 0 else 2 ** (x - 1).bit_length()


def _get_strided(waveform: Tensor, window_size: int, window_shift: int, snip_edges: bool) -> Tensor:
    assert waveform.dim() == 1
    num_samples = waveform.size(0)
    strides = (window_shift * waveform.stride(0), waveform.stride(0))

    if snip_edges:
        if num_samples < window_size:
            return torch.empty((0, 0), dtype=waveform.dtype, device=waveform.device)
        else:
            m = 1 + (num_samples - window_size) // window_shift
    else:
        reversed_waveform = torch.flip(waveform, [0])
        m = (num_samples + (window_shift // 2)) // window_shift
        pad = window_size // 2 - window_shift // 2
        pad_right = reversed_waveform
        if pad > 0:
            # torch.nn.functional.pad returns [2,1,0,1,2] for 'reflect'
            # but we want [2, 1, 0, 0, 1, 2]
            pad_left = reversed_waveform[-pad:]
            waveform = torch.cat((pad_left, waveform, pad_right), dim=0)
        else:
            # pad is negative so we want to trim the waveform at the front
            waveform = torch.cat((waveform[-pad:], pad_right), dim=0)

    sizes = (m, window_size)
    return waveform.as_strided(sizes, strides)


def _feature_window_function(
    window_size: int,
    device: torch.device,
    dtype: int,
) -> Tensor:
    return torch.hann_window(window_size, periodic=False, device=device, dtype=dtype)

def _get_log_energy(strided_input: Tensor, epsilon: Tensor, energy_floor: float) -> Tensor:
    device, dtype = strided_input.device, strided_input.dtype
    log_energy = torch.max(strided_input.pow(2).sum(1), epsilon).log()  # size (m)
    if energy_floor == 0.0:
        return log_energy
    return torch.max(log_energy, torch.tensor(math.log(energy_floor), device=device, dtype=dtype))


def _get_waveform_and_window_properties(
    waveform: Tensor,
    channel: int,
    sample_frequency: float,
    frame_shift: float,
    frame_length: float,
    round_to_power_of_two: bool,
    preemphasis_coefficient: float,
) -> Tuple[Tensor, int, int, int]:
    r"""Gets the waveform and window properties"""
    channel = max(channel, 0)
    assert channel < waveform.size(0), "Invalid channel {} for size {}".format(channel, waveform.size(0))
    waveform = waveform[channel, :]  # size (n)
    window_shift = int(sample_frequency * frame_shift * MILLISECONDS_TO_SECONDS)
    window_size = int(sample_frequency * frame_length * MILLISECONDS_TO_SECONDS)
    padded_window_size = _next_power_of_2(window_size) if round_to_power_of_two else window_size

    assert 2 <= window_size <= len(waveform), "choose a window size {} that is [2, {}]".format(
        window_size, len(waveform)
    )
    assert 0 < window_shift, "`window_shift` must be greater than 0"
    assert padded_window_size % 2 == 0, (
        "the padded `window_size` must be divisible by two." " use `round_to_power_of_two` or change `frame_length`"
    )
    assert 0.0 <= preemphasis_coefficient <= 1.0, "`preemphasis_coefficient` must be between [0,1]"
    assert sample_frequency > 0, "`sample_frequency` must be greater than zero"
    return waveform, window_shift, window_size, padded_window_size


def _get_window(
    waveform: Tensor,
    padded_window_size: int,
    window_size: int,
    window_shift: int,
    snip_edges: bool,
    raw_energy: bool,
    energy_floor: float,
    preemphasis_coefficient: float,
) -> Tuple[Tensor, Tensor]:
    
    device, dtype = waveform.device, waveform.dtype
    epsilon = _get_epsilon(device, dtype)

    # size (m, window_size)
    strided_input = _get_strided(waveform, window_size, window_shift, snip_edges)

    # Subtract each row/frame by its mean
    row_means = torch.mean(strided_input, dim=1).unsqueeze(1)  # size (m, 1)
    strided_input = strided_input - row_means

    if raw_energy:
        # Compute the log energy of each row/frame before applying preemphasis and window function
        signal_log_energy = _get_log_energy(strided_input, epsilon, energy_floor)  # size (m)

    if preemphasis_coefficient != 0.0:
        # strided_input[i,j] -= preemphasis_coefficient * strided_input[i, max(0, j-1)] for all i,j
        offset_strided_input = torch.nn.functional.pad(strided_input.unsqueeze(0), (1, 0), mode="replicate").squeeze(
            0
        )  # size (m, window_size + 1)
        strided_input = strided_input - preemphasis_coefficient * offset_strided_input[:, :-1]

    # Apply window_function to each row/frame
    window_function = _feature_window_function(window_size, device, dtype).unsqueeze(
        0
    )  # size (1, window_size)
    strided_input = strided_input * window_function  # size (m, window_size)

    # Pad columns with zero until we reach size (m, padded_window_size)
    if padded_window_size != window_size:
        padding_right = padded_window_size - window_size
        strided_input = torch.nn.functional.pad(
            strided_input.unsqueeze(0), (0, padding_right), mode="constant", value=0
        ).squeeze(0)

    # Compute energy after window function (not the raw one)
    if not raw_energy:
        signal_log_energy = _get_log_energy(strided_input, epsilon, energy_floor)  # size (m)

    return strided_input, signal_log_energy


def _subtract_column_mean(tensor: Tensor, subtract_mean: bool) -> Tensor:
    # subtracts the column mean of the tensor size (m, n) if subtract_mean=True
    # it returns size (m, n)
    if subtract_mean:
        col_means = torch.mean(tensor, dim=0).unsqueeze(0)
        tensor = tensor - col_means
    return tensor



def inverse_mel_scale_scalar(mel_freq: float) -> float:
    return 700.0 * (math.exp(mel_freq / 1127.0) - 1.0)


def inverse_mel_scale(mel_freq: Tensor) -> Tensor:
    return 700.0 * ((mel_freq / 1127.0).exp() - 1.0)


def mel_scale_scalar(freq: float) -> float:
    return 1127.0 * math.log(1.0 + freq / 700.0)


def mel_scale(freq: Tensor) -> Tensor:
    return 1127.0 * (1.0 + freq / 700.0).log()


def get_mel_banks(
    num_bins: int,
    window_length_padded: int,
    sample_freq: float,
    low_freq: float,
    high_freq: float,
) -> Tuple[Tensor, Tensor]:
    
    num_fft_bins = window_length_padded / 2
    nyquist = 0.5 * sample_freq

    if high_freq <= 0.0:
        high_freq += nyquist

    # fft-bin width [think of it as Nyquist-freq / half-window-length]
    fft_bin_width = sample_freq / window_length_padded
    mel_low_freq = mel_scale_scalar(low_freq)
    mel_high_freq = mel_scale_scalar(high_freq)

    # divide by num_bins+1 in next line because of end-effects where the bins
    # spread out to the sides.
    mel_freq_delta = (mel_high_freq - mel_low_freq) / (num_bins + 1)

    bin = torch.arange(num_bins).unsqueeze(1)
    left_mel = mel_low_freq + bin * mel_freq_delta  # size(num_bins, 1)
    center_mel = mel_low_freq + (bin + 1.0) * mel_freq_delta  # size(num_bins, 1)
    right_mel = mel_low_freq + (bin + 2.0) * mel_freq_delta  # size(num_bins, 1)

    center_freqs = inverse_mel_scale(center_mel)  # size (num_bins)
    # size(1, num_fft_bins)
    mel = mel_scale(fft_bin_width * torch.arange(num_fft_bins)).unsqueeze(0)

    # size (num_bins, num_fft_bins)
    up_slope = (mel - left_mel) / (center_mel - left_mel)
    down_slope = (right_mel - mel) / (right_mel - center_mel)

    # left_mel < center_mel < right_mel so we can min the two slopes and clamp negative values
    bins = torch.max(torch.zeros(1), torch.min(up_slope, down_slope))
    
    return bins, center_freqs


def fbank_own(
    waveform: Tensor,
    channel: int = -1,
    energy_floor: float = 1.0,
    frame_length: float = 25.0,
    frame_shift: float = 10.0,
    high_freq: float = 0.0,
    low_freq: float = 20.0,
    min_duration: float = 0.0,
    num_mel_bins: int = 128,
    preemphasis_coefficient: float = 0.97,
    raw_energy: bool = True,
    round_to_power_of_two: bool = True,
    sample_frequency: float = 16000.0,
    snip_edges: bool = True,
    subtract_mean: bool = False,
    use_log_fbank: bool = True,
    use_power: bool = True,
) -> Tensor:

    device, dtype = waveform.device, waveform.dtype

    waveform, window_shift, window_size, padded_window_size = _get_waveform_and_window_properties(
        waveform, channel, sample_frequency, frame_shift, frame_length, round_to_power_of_two, preemphasis_coefficient
    )

    if len(waveform) < min_duration * sample_frequency:
        # signal is too short
        return torch.empty(0, device=device, dtype=dtype)

    # strided_input, size (m, padded_window_size) and signal_log_energy, size (m)
    strided_input, signal_log_energy = _get_window(
        waveform,
        padded_window_size,
        window_size,
        window_shift,
        snip_edges,
        raw_energy,
        energy_floor,
        preemphasis_coefficient,
    )

    # size (m, padded_window_size // 2 + 1)
    spectrum = torch.fft.rfft(strided_input).abs()
    if use_power:
        spectrum = spectrum.pow(2.0)

    # size (num_mel_bins, padded_window_size // 2)
    mel_energies, _ = get_mel_banks(
        num_mel_bins, padded_window_size, sample_frequency, low_freq, high_freq
    )
    mel_energies = mel_energies.to(device=device, dtype=dtype)

    # pad right column with zeros and add dimension, size (num_mel_bins, padded_window_size // 2 + 1)
    mel_energies = torch.nn.functional.pad(mel_energies, (0, 1), mode="constant", value=0)

    # sum with mel fiterbanks over the power spectrum, size (m, num_mel_bins)
    mel_energies = torch.mm(spectrum, mel_energies.T)
    if use_log_fbank:
        # avoid log of zero (which should be prevented anyway by dithering)
        mel_energies = torch.max(mel_energies, _get_epsilon(device, dtype)).log()

    mel_energies = _subtract_column_mean(mel_energies, subtract_mean)
    return mel_energies

### remove as much as possible

In [None]:
import math
from typing import Tuple
import torch
from torch import Tensor


def fbank_own(
    waveform: Tensor,
) -> Tensor:
    device, dtype = waveform.device, waveform.dtype

    # shape is [c, n] (=[1, n] in case of mono) = [1, 160000] in our case
    waveform = torch.squeeze(waveform)
    # now shape is [n] = [160000] in our case


    def get_window(
        waveform: Tensor,
    ) -> Tuple[Tensor, Tensor]:
        device, dtype = waveform.device, waveform.dtype

        strides = (160, 1)
        sizes = (998, 400)

        strided_input = waveform.as_strided(sizes, strides) # size (998, 400)

        # Subtract each row/frame by its mean
        row_means = torch.mean(strided_input, dim=1).unsqueeze(1)  # size (998, 1)
        strided_input = strided_input - row_means # size (998, 400)

        # strided_input[i,j] -= preemphasis_coefficient * strided_input[i, max(0, j-1)] for all i,j
        offset_strided_input = torch.nn.functional.pad(strided_input.unsqueeze(0), (1, 0), mode="replicate").squeeze(0)  # size (998, 400 + 1)
        strided_input = strided_input - 0.97 * offset_strided_input[:, :-1] # size (998, 400)

        # Apply window_function to each row/frame
        window_function = torch.hann_window(400, periodic=False, device=device, dtype=dtype).unsqueeze(0)  # size (1, 400)
        strided_input = strided_input * window_function  # size (998, 400)

        strided_input = torch.nn.functional.pad(strided_input.unsqueeze(0), (0, 112), mode="constant", value=0).squeeze(0) # 512 - 400 = 112

        return strided_input

    # strided_input, size (m, padded_window_size) and signal_log_energy, size (m)
    strided_input = get_window(waveform) # size (998, 512)

    # size (m, padded_window_size // 2 + 1)
    spectrum = torch.fft.rfft(strided_input).abs() # size (998, 256 + 1)

    spectrum = spectrum.pow(2.0)


    def get_mel_banks(
        num_bins: int
    ) -> Tuple[Tensor, Tensor]:

        def inverse_mel_scale_scalar(mel_freq: float) -> float:
            return 700.0 * (math.exp(mel_freq / 1127.0) - 1.0)

        def inverse_mel_scale(mel_freq: Tensor) -> Tensor:
            return 700.0 * ((mel_freq / 1127.0).exp() - 1.0)

        def mel_scale_scalar(freq: float) -> float:
            return 1127.0 * math.log(1.0 + freq / 700.0)

        def mel_scale(freq: Tensor) -> Tensor:
            return 1127.0 * (1.0 + freq / 700.0).log()
        
        num_fft_bins = 256 # window_length_padded / 2 = 512 / 2
        nyquist_freq= 8000.0

        low_freq = 20.0
        high_freq = nyquist_freq

        # fft-bin width [think of it as Nyquist-freq / half-window-length]
        fft_bin_width = 31.25 # 16000 / window_length_padded = 16000 / 512
        mel_low_freq = mel_scale_scalar(low_freq) # 31.748578341466644
        mel_high_freq = mel_scale_scalar(high_freq) # 2840.0377117383778

        # divide by num_bins+1 in next line because of end-effects where the bins spread out to the sides.
        mel_freq_delta = (mel_high_freq - mel_low_freq) / (num_bins + 1) # 21.769683204627217

        bin = torch.arange(num_bins).unsqueeze(1)
        left_mel = mel_low_freq + bin * mel_freq_delta  # size(num_bins, 1) = (128, 1)
        center_mel = mel_low_freq + (bin + 1.0) * mel_freq_delta  # size(num_bins, 1) = (128, 1)
        right_mel = mel_low_freq + (bin + 2.0) * mel_freq_delta  # size(num_bins, 1) = (128, 1)

        center_freqs = inverse_mel_scale(center_mel)  # size (num_bins) = (128)
        mel = mel_scale(fft_bin_width * torch.arange(num_fft_bins)).unsqueeze(0) # size(1, num_fft_bins) = size (1, 256)

        # size (num_bins, num_fft_bins)
        up_slope = (mel - left_mel) / (center_mel - left_mel) # size (128, 256)
        down_slope = (right_mel - mel) / (right_mel - center_mel) # size (128, 256)

        # left_mel < center_mel < right_mel so we can min the two slopes and clamp negative values
        bins = torch.max(torch.zeros(1), torch.min(up_slope, down_slope)) # size (128, 256)
        
        return bins

    # size (num_mel_bins, padded_window_size // 2)
    mel_energies = get_mel_banks(128) # torch.Size([128, 256])
    mel_energies = mel_energies.to(device=device, dtype=dtype) # torch.Size([128, 256])

    # pad right column with zeros and add dimension, size (num_mel_bins, padded_window_size // 2 + 1)
    mel_energies = torch.nn.functional.pad(mel_energies, (0, 1), mode="constant", value=0) # torch.Size([128, 257])

    # sum with mel fiterbanks over the power spectrum, size (m, num_mel_bins)
    mel_energies = torch.mm(spectrum, mel_energies.T) # (998, 256 + 1) x (257, 128) = torch.Size([998, 128])
    
    # avoid log of zero (which should be prevented anyway by dithering)
    mel_energies = torch.max(mel_energies, torch.tensor(torch.finfo(torch.float).eps).to(device=device, dtype=dtype)).log() # torch.Size([998, 128])

    return mel_energies

In [None]:
import os
import numpy as np
import matplotlib.pyplot as plt
import torch
import torchaudio
import soundfile as sf


# Verify the file path is correct
file_path = '/home/bosfab01/SpeakerVerificationBA/data/preprocessed/0a4b5c0f-facc-4d3b-8a41-bc9148d62d95/0_segment_0.flac'
try:
    audio_signal, sample_rate = sf.read(file_path)
except Exception as e:
    print(f"An error occurred while reading the file: {e}")
    raise

# Create time array for plotting
time = np.arange(len(audio_signal)) / sample_rate

# Convert the NumPy array to a PyTorch tensor
audio_tensor = torch.from_numpy(audio_signal)

# Ensure the tensor is in float32 format (required for most torchaudio operations)
audio_tensor = audio_tensor.float()
# Data type of audio tensor: torch.float32
# Shape of audio tensor: torch.Size([160000])

# If your array is not in batch x channels x time format, adjust accordingly
# Assuming the audio signal is single-channel and not batched:
audio_tensor_batch = audio_tensor.unsqueeze(0)
# Shape of audio tensor: torch.Size([1, 160000])


# Call the fbank_own function
fbank_features_own = fbank_own(
    waveform=audio_tensor_batch,
)

# Now call the fbank function
fbank_features_torch = torchaudio.compliance.kaldi.fbank(
    audio_tensor_batch, 
    sample_frequency=sample_rate, 
    htk_compat=True, 
    use_energy=False, 
    window_type='hanning', 
    num_mel_bins=128, 
    dither=0.0, 
    frame_shift=10
)

# Output the shape of the fbank features to confirm
# Shape of fbank features: torch.Size([998, 128])

# Assuming you have already read the audio file into `audio_signal` and it's a 1D array
# Initial shape of audio signal: (160000,)






# plot both

# plot the fbank features from the own implementation
plt.figure(figsize=(10, 1.5))
plt.imshow(fbank_features_own.T, aspect='auto', origin='lower', cmap='viridis')
plt.title('Own fbank Features')

# plot the fbank features
plt.figure(figsize=(10, 1.5))
plt.imshow(fbank_features_torch.T, aspect='auto', origin='lower', cmap='viridis')
plt.title('Kaldi fbank Features')

plt.show()

# compare the 90th frame
plt.plot(fbank_features_torch[90, :], label='Kaldi')
plt.plot(fbank_features_own[90, :], label='Own')
plt.legend()

### convert to numpy python

In [None]:
import numpy as np

def fbank_own(waveform, sample_rate=16000, num_mel_bins=128, frame_shift=10, frame_length=25):
    # Waveform is now a 1D numpy array: shape = [160000]
    
    def get_window(waveform, sample_rate, frame_shift, frame_length):
        # Stride and size configuration to simulate torch's as_strided
        # Assuming waveform length is n = 160000, frame_shift = 10, frame_length = 25
        signal_length = waveform.shape[0]
        stride = sample_rate*frame_shift // 1000  # 160
        window_size = sample_rate*frame_length // 1000  # 400
        number_of_frames = (signal_length - window_size) // stride + 1  # 998 frames
        
        # Create an array of indices for each strided window
        indices = np.lib.stride_tricks.as_strided(
            np.arange(signal_length),
            shape=(number_of_frames, window_size),
            strides=(waveform.strides[0]*stride, waveform.strides[0])
        )
        strided_input = waveform[indices]  # shape = [998, 400]
        
        # Subtract each row/frame by its mean
        row_means = np.mean(strided_input, axis=1, keepdims=True)  # shape = [998, 1]
        strided_input -= row_means  # shape = [998, 400]
        
        # Pre-emphasis filtering
        preemphasis_coefficient = 0.97
        strided_input[:, 1:] -= preemphasis_coefficient * strided_input[:, :-1]
        
        # Apply Hanning window to each row/frame
        window_function = np.hanning(window_size)  # shape = [400]
        strided_input *= window_function  # shape = [998, 400]
        
        # Zero-pad each frame to the next power of two for FFT
        padded_window_size = 1 if window_size == 0 else 2 ** (window_size - 1).bit_length()
        strided_input = np.pad(strided_input, ((0, 0), (0, padded_window_size - window_size)), 'constant')  # shape = [998, 512]
        
        return strided_input, padded_window_size
    
    strided_input, padded_window_size = get_window(waveform, sample_rate, frame_shift, frame_length)  # shape = [998, 512]
    
    # Compute the power spectrum
    spectrum = np.abs(np.fft.rfft(strided_input, n=padded_window_size))**2  # shape = [998, 257]

    def get_mel_banks(num_bins, padded_window_size, sample_rate):
        num_fft_bins = padded_window_size // 2
        nyquist_freq = sample_rate / 2.0

        low_freq = 20.0
        high_freq = nyquist_freq

        fft_bin_width = nyquist_freq / num_fft_bins
        
        # Mel scale conversion
        def mel_scale(freq):
            return 1127.0 * np.log(1.0 + freq / 700.0)
        
        def inverse_mel_scale(mel_freq):
            return 700.0 * (np.exp(mel_freq / 1127.0) - 1.0)
        
        mel_low_freq = mel_scale(low_freq)
        mel_high_freq = mel_scale(high_freq)
        mel_freq_delta = (mel_high_freq - mel_low_freq) / (num_bins + 1)
        
        mel_bins = np.zeros((num_bins, num_fft_bins + 1))
        
        for i in range(num_bins):
            left_mel = mel_low_freq + i * mel_freq_delta
            center_mel = left_mel + mel_freq_delta
            right_mel = center_mel + mel_freq_delta
            
            for j in range(num_fft_bins + 1):
                freq = j * fft_bin_width
                mel_freq = mel_scale(freq)
                
                if left_mel < mel_freq < right_mel:
                    if mel_freq <= center_mel:
                        mel_bins[i, j] = (mel_freq - left_mel) / (center_mel - left_mel)
                    else:
                        mel_bins[i, j] = (right_mel - mel_freq) / (right_mel - center_mel)
        
        return mel_bins  # shape = [128, 257]

    mel_energies = get_mel_banks(num_mel_bins, padded_window_size, sample_rate)  # shape = [128, 257]
    
    # Filter bank energies
    filter_bank_energies = np.dot(spectrum, mel_energies.T)  # shape = [998, 128]
    
    # Log energies
    filter_bank_energies = np.log(np.maximum(filter_bank_energies, 1.19209e-07))
    
    return filter_bank_energies  # shape = [998, 128]

In [None]:
import os
import numpy as np
import matplotlib.pyplot as plt
import torch
import torchaudio
import soundfile as sf


# Verify the file path is correct
file_path = '/home/bosfab01/SpeakerVerificationBA/data/preprocessed/0a4b5c0f-facc-4d3b-8a41-bc9148d62d95/0_segment_0.flac'
try:
    audio_signal, sample_rate = sf.read(file_path)
except Exception as e:
    print(f"An error occurred while reading the file: {e}")
    raise

# Create time array for plotting
time = np.arange(len(audio_signal)) / sample_rate

# Convert the NumPy array to a PyTorch tensor
audio_tensor = torch.from_numpy(audio_signal)

# Ensure the tensor is in float32 format (required for most torchaudio operations)
audio_tensor = audio_tensor.float()
# Data type of audio tensor: torch.float32
# Shape of audio tensor: torch.Size([160000])

# If your array is not in batch x channels x time format, adjust accordingly
# Assuming the audio signal is single-channel and not batched:
audio_tensor_batch = audio_tensor.unsqueeze(0)
# Shape of audio tensor: torch.Size([1, 160000])


# Call the fbank_own function
fbank_features_own = fbank_own(
    waveform=audio_signal, num_mel_bins=10, frame_shift=100, frame_length=250
)

# Now call the fbank function
fbank_features_torch = torchaudio.compliance.kaldi.fbank(
    audio_tensor_batch, 
    sample_frequency=sample_rate, 
    htk_compat=True, 
    use_energy=False, 
    window_type='hanning', 
    num_mel_bins=10, 
    dither=0.0, 
    frame_shift=100,
    frame_length=250
)

# Output the shape of the fbank features to confirm
# Shape of fbank features: torch.Size([998, 128])

# Assuming you have already read the audio file into `audio_signal` and it's a 1D array
# Initial shape of audio signal: (160000,)






# plot both

# plot the fbank features from the own implementation
plt.figure(figsize=(10, 1.5))
plt.imshow(fbank_features_own.T, aspect='auto', origin='lower', cmap='viridis')
plt.title('Own fbank Features')

# plot the fbank features
plt.figure(figsize=(10, 1.5))
plt.imshow(fbank_features_torch.T, aspect='auto', origin='lower', cmap='viridis')
plt.title('Kaldi fbank Features')

plt.show()

# compare the 90th frame
plt.plot(fbank_features_torch[90, :], label='Kaldi')
plt.plot(fbank_features_own[90, :], label='Own')
plt.legend()

In [None]:
torch.finfo(torch.float)

### visualizing the mel frequency filter bank

In [None]:
import numpy as np
import matplotlib.pyplot as plt

def get_mel_banks(M, W_pad, f_s):
    f_nyq = f_s / 2.0

    f_low = 20.0
    f_high = f_nyq

    fft_bin_width = f_nyq / (W_pad // 2)
    
    # Mel scale conversion
    def m_fun(f):
        return 1127.0 * np.log(1.0 + f / 700.0)
    
    def f_fun(m):
        return 700.0 * (np.exp(m / 1127.0) - 1.0)
    
    m_low = m_fun(f_low)
    m_high = m_fun(f_high)
    m_delta = (m_high - m_low) / (M + 1)
    
    H = np.zeros((M, W_pad // 2 + 1))
    
    for i in range(M):
        m_left = m_low + i * m_delta
        m_center = m_left + m_delta
        m_right = m_center + m_delta
        
        for j in range(W_pad // 2 + 1):
            f = j * fft_bin_width
            m = m_fun(f)
            
            if m_left < m < m_right:
                if m <= m_center:
                    H[i, j] = (m - m_left) / (m_center - m_left)
                else:
                    H[i, j] = (m_right - m) / (m_right - m_center)
    print("shape of the H matrix:", H.shape)
    return H 

# Parameters for visualization
M = 8
W_pad = 512
f_s = 16000

# Get the Mel filter bank matrix
H = get_mel_banks(M, W_pad, f_s)

# Plotting the matrix
plt.figure(figsize=(10, 4))
plt.imshow(H, aspect='auto', origin='lower', cmap='hot', interpolation='nearest')
plt.colorbar(label='Filter bank coefficient')
plt.xlabel('FFT bins')
plt.ylabel('Mel bins')
plt.title('Mel Filter Bank')
plt.show()

In [None]:

import numpy as np
import matplotlib.pyplot as plt


def get_mel_banks(num_bins, padded_window_size, sample_rate):
    num_fft_bins = padded_window_size // 2
    nyquist_freq = sample_rate / 2.0

    low_freq = 20.0
    high_freq = nyquist_freq

    fft_bin_width = nyquist_freq / num_fft_bins
    
    # Mel scale conversion
    def mel_scale(freq):
        return 1127.0 * np.log(1.0 + freq / 700.0)
    
    def inverse_mel_scale(mel_freq):
        return 700.0 * (np.exp(mel_freq / 1127.0) - 1.0)
    
    mel_low_freq = mel_scale(low_freq)
    mel_high_freq = mel_scale(high_freq)
    mel_freq_delta = (mel_high_freq - mel_low_freq) / (num_bins + 1)
    
    mel_bins = np.zeros((num_bins, num_fft_bins + 1))
    
    for i in range(num_bins):
        left_mel = mel_low_freq + i * mel_freq_delta
        center_mel = left_mel + mel_freq_delta
        right_mel = center_mel + mel_freq_delta
        
        for j in range(num_fft_bins + 1):
            freq = j * fft_bin_width
            mel_freq = mel_scale(freq)
            
            if left_mel < mel_freq < right_mel:
                if mel_freq <= center_mel:
                    mel_bins[i, j] = (mel_freq - left_mel) / (center_mel - left_mel)
                else:
                    mel_bins[i, j] = (right_mel - mel_freq) / (right_mel - center_mel)
    
    return mel_bins  # shape = [128, 257]


sample_rate = 16000  # Example sample rate
padded_window_size = 512  # Example window size (power of 2 for FFT)
num_bins = 8  # Number of Mel bins

mel_filters = get_mel_banks(num_bins, padded_window_size, sample_rate)
print("shape of mel_filters:", mel_filters.shape)


import math
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.collections import PolyCollection

def polygon_under_graph(x, y):
    """
    Construct the vertex list which defines the polygon filling the space under
    the (x, y) line graph. This assumes x is in ascending order.
    """
    return [(x[0], 0.), *zip(x, y), (x[-1], 0.)]

ax = plt.figure(figsize=(6.2, 3.5)).add_subplot(projection='3d')

x = np.linspace(0, padded_window_size // 2, num=padded_window_size // 2 + 1)
verts = [polygon_under_graph(x, mel_filters[i]) for i in range(num_bins)]

mel_bins = np.arange(0, num_bins, 1)
print("mel_bins:", mel_bins)

facecolors = plt.get_cmap('hot')(np.linspace(0.2, 0.65, len(verts)))

poly = PolyCollection(verts, facecolors=facecolors, alpha=.7)
ax.add_collection3d(poly, zs=mel_bins, zdir='y')

ax.set(xlim=(0, padded_window_size // 2), ylim=(-0.5, num_bins - 0.5), zlim=(0, 1), xlabel='Frequency Bin', ylabel='Mel Bin')


plt.show()

In [None]:

import numpy as np
import matplotlib.pyplot as plt


def get_mel_banks(num_bins, padded_window_size, sample_rate):
    num_fft_bins = padded_window_size // 2
    nyquist_freq = sample_rate / 2.0

    low_freq = 20.0
    high_freq = nyquist_freq

    fft_bin_width = nyquist_freq / num_fft_bins
    
    # Mel scale conversion
    def mel_scale(freq):
        return 1127.0 * np.log(1.0 + freq / 700.0)
    
    def inverse_mel_scale(mel_freq):
        return 700.0 * (np.exp(mel_freq / 1127.0) - 1.0)
    
    mel_low_freq = mel_scale(low_freq)
    mel_high_freq = mel_scale(high_freq)
    mel_freq_delta = (mel_high_freq - mel_low_freq) / (num_bins + 1)
    
    mel_bins = np.zeros((num_bins, num_fft_bins + 1))
    
    for i in range(num_bins):
        left_mel = mel_low_freq + i * mel_freq_delta
        center_mel = left_mel + mel_freq_delta
        right_mel = center_mel + mel_freq_delta
        
        for j in range(num_fft_bins + 1):
            freq = j * fft_bin_width
            mel_freq = mel_scale(freq)
            
            if left_mel < mel_freq < right_mel:
                if mel_freq <= center_mel:
                    mel_bins[i, j] = (mel_freq - left_mel) / (center_mel - left_mel)
                else:
                    mel_bins[i, j] = (right_mel - mel_freq) / (right_mel - center_mel)
    
    return mel_bins  # shape = [128, 257]


sample_rate = 16000  # Example sample rate
padded_window_size = 512  # Example window size (power of 2 for FFT)
num_bins = 8  # Number of Mel bins

mel_filters = get_mel_banks(num_bins, padded_window_size, sample_rate)
print("shape of mel_filters:", mel_filters.shape)


import math
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.collections import PolyCollection

def polygon_under_graph(x, y):
    """
    Construct the vertex list which defines the polygon filling the space under
    the (x, y) line graph. This assumes x is in ascending order.
    """
    return [(x[0], 0.), *zip(x, y), (x[-1], 0.)]

ax = plt.figure(figsize=(6.2, 3.5)).add_subplot(projection='3d')

frequency_bin = np.linspace(0, padded_window_size // 2, num=padded_window_size // 2 + 1)

verts = [polygon_under_graph(frequency_bin, mel_filters[i]) for i in range(num_bins)]

mel_bins = np.arange(0, num_bins, 1)
print("mel_bins:", mel_bins)

edgecolors = plt.get_cmap('hot')(np.linspace(0.2, 0.65, len(verts)))

poly = PolyCollection(verts, edgecolors=edgecolors, alpha=0.8, facecolors='white')

ax.add_collection3d(poly, zs=mel_bins, zdir='y')

ax.set(xlim=(0, padded_window_size // 2), ylim=(-0.5, num_bins - 0.5), zlim=(0, 1), xlabel='Frequency Bin', ylabel='Mel Bin')


plt.show()