In [None]:
import torch as th
from music_diffusion.data import constants, wav_to_stft, stft_to_magnitude_phase
from torch.nn import functional as th_f

In [None]:
_MEL_BREAK_FREQUENCY_HERTZ = 700.0
_MEL_HIGH_FREQUENCY_Q = 1127.0


def mel_to_hertz(mel_values: th.Tensor) -> th.Tensor:
    return _MEL_BREAK_FREQUENCY_HERTZ * (
        th.exp(mel_values / _MEL_HIGH_FREQUENCY_Q) - 1.0
    )


def hertz_to_mel(frequencies_hertz: th.Tensor) -> th.Tensor:
    return _MEL_HIGH_FREQUENCY_Q * th.log(
        1.0 + (frequencies_hertz / _MEL_BREAK_FREQUENCY_HERTZ)
    )


def linear_to_mel_weight_matrix(
    num_mel_bins: int = constants.N_FFT // 2,
    num_spectrogram_bins: int = constants.N_FFT // 2,
    sample_rate: int = constants.SAMPLE_RATE,
    lower_edge_hertz: float = 0.,
    upper_edge_hertz: float = constants.SAMPLE_RATE // 2,
) -> th.Tensor:

    # HTK excludes the spectrogram DC bin.
    bands_to_zero = 1
    nyquist_hertz = sample_rate / 2.0
    linear_frequencies = th.linspace(0.0, nyquist_hertz, num_spectrogram_bins)[
        bands_to_zero:, None
    ]
    # spectrogram_bins_mel = hertz_to_mel(linear_frequencies)

    # Compute num_mel_bins triples of (lower_edge, center, upper_edge). The
    # center of each band is the lower and upper edge of the adjacent bands.
    # Accordingly, we divide [lower_edge_hertz, upper_edge_hertz] into
    # num_mel_bins + 2 pieces.
    band_edges_mel = th.linspace(
        hertz_to_mel(th.tensor(lower_edge_hertz)).item(),
        hertz_to_mel(th.tensor(upper_edge_hertz)).item(),
        num_mel_bins + 2,
    )

    lower_edge_mel = band_edges_mel[0:-2]
    center_mel = band_edges_mel[1:-1]
    upper_edge_mel = band_edges_mel[2:]

    freq_res = nyquist_hertz / float(num_spectrogram_bins)
    freq_th = 1.5 * freq_res
    for i in range(0, num_mel_bins):
        center_hz = mel_to_hertz(center_mel[i])
        lower_hz = mel_to_hertz(lower_edge_mel[i])
        upper_hz = mel_to_hertz(upper_edge_mel[i])
        if upper_hz - lower_hz < freq_th:
            rhs = 0.5 * freq_th / (center_hz + _MEL_BREAK_FREQUENCY_HERTZ)
            dm = _MEL_HIGH_FREQUENCY_Q * th.log(rhs + th.sqrt(1.0 + rhs**2))
            lower_edge_mel[i] = center_mel[i] - dm
            upper_edge_mel[i] = center_mel[i] + dm

    lower_edge_hz = mel_to_hertz(lower_edge_mel)[None, :]
    center_hz = mel_to_hertz(center_mel)[None, :]
    upper_edge_hz = mel_to_hertz(upper_edge_mel)[None, :]

    # Calculate lower and upper slopes for every spectrogram bin.
    # Line segments are linear in the mel domain, not Hertz.
    lower_slopes = (linear_frequencies - lower_edge_hz) / (
        center_hz - lower_edge_hz
    )
    upper_slopes = (upper_edge_hz - linear_frequencies) / (
        upper_edge_hz - center_hz
    )

    # Intersect the line segments with each other and zero.
    mel_weights_matrix = th.maximum(
        th.tensor(0.0), th.minimum(lower_slopes, upper_slopes)
    )

    # Re-add the zeroed lower bins we sliced out above.
    # [freq, mel]
    mel_weights_matrix = th_f.pad(
        mel_weights_matrix, [0, 0, bands_to_zero, 0], "constant"
    )
    return mel_weights_matrix

In [None]:
mat = linear_to_mel_weight_matrix(513, 513)

In [None]:
wav_p = "/media/samuel/M2_Sam/mozart_all_musics_16000Hz/01 - 6 German Dances, K509 - I. C Major - II. G Major - III. B Flat Major - IV. D Major - V. F Major - VI. F Major.flac"

stft = wav_to_stft(wav_p)


In [None]:
magn = th.abs(stft)
phase = th.angle(stft)

In [None]:
magn.size(), magn.max(), magn.min()

In [None]:
magn = (magn - magn.min()) / (magn.max() - magn.min())

print(magn.size(), magn.min(), magn.max())

In [None]:
mat.size()

In [None]:
magn_u = magn.T @ mat
magn_u = magn_u.T

In [None]:
magn_u.size(), magn_u.min(), magn_u.max()

In [None]:
magn_uu = mat @ magn_u
magn_uu = magn_uu.T

In [None]:
magn_uu.size(), magn_uu.min(), magn_uu.max()

In [None]:
from math import log10

def mel_filter_bank(fft_size, sample_rate):
    # Compute Mel filter bank
    num_filters = fft_size // 2
    mel_filters = th.zeros(num_filters, fft_size // 2 + 1)
    mel_points = th.linspace(0, (2595 * log10(1 + (sample_rate / 2) / 700)) / 2595, num_filters + 2)
    hz_points = 700 * (10**(mel_points * 2595 / 700) - 1)

    for i in range(1, num_filters + 1):
        mel_filters[i - 1, :] = ((th.arange(fft_size // 2 + 1) * sample_rate) / (fft_size - 1) >= hz_points[i - 1]) \
                                & ((th.arange(fft_size // 2 + 1) * sample_rate) / (fft_size - 1) <= hz_points[i + 1])

    return mel_filters

In [None]:
mat = mel_filter_bank(1024, 16000)

In [None]:
mat.size()

In [None]:
mat_2 = mat[:, :-1]

In [None]:
magn_s = th.matmul(magn.T, mat_2.T).T

In [None]:
import matplotlib.pyplot as plt

In [None]:
magn_s.min(), magn_s.max()

In [None]:
magn_s = (magn_s - magn_s.min()) / (magn_s.max() - magn_s.min())

In [None]:
magn_u = (magn_s.T @ th.pinverse(mat_2)).T

In [None]:
magn_u.max(), magn_u.min()