# Mel Spectrograms
Optimize the number of filters for mel-spectrograms

## Setup

### Libraries
Install the `sample` package and its dependencies.
The extras will install dependencies for helper functions such as plots

In [None]:
import sys
# !$sys.executable -m pip install -qU lim-sample[notebooks,plots]==1.4.0

### Generate test audio
We will synthesize a modal-like sound three partials, two of which interefere in a beat pattern

In [None]:
from matplotlib import pyplot as plt
from librosa.display import waveshow
from IPython import display as ipd
from sample import ipython as sip
from sample.evaluation import random
import functools
import numpy as np


@functools.wraps(sip.WebAudio)
def play(*args, **kwargs):
  ipd.display(sip.WebAudio(*args, **kwargs))


def resize(w=12, h=6):
  plt.gcf().set_size_inches([w, h])


x_, fs, feats = random.BeatsGenerator(seed=123, decay=0.5).audio()
play(x_, rate=fs)

waveshow(x_, sr=fs, alpha=.5, zorder=100)
plt.grid()
resize()

In [None]:
from sample.sample import additive_synth
from sample.utils.dsp import *
from sample import psycho

x = []
beat_mult = 0.4
for i in np.arange(-3, 3) * 12:
  feats_ = np.array(feats).copy()
  # Transpose carrier and modulant independently
  hff = np.mean(feats_[0, :2]) * np.power(2, i / 12)
  lff = np.diff(feats_[0, :2]) * beat_mult / 2
  feats_[0, 0] = hff - lff
  feats_[0, 1] = hff + lff
  xi = additive_synth(np.arange(x_.size) / fs, *feats_)
  x.append(xi)
x = np.array(x).flatten()
x += np.random.randn(*x.shape) * db2a(-45)
play(x, rate=fs)
waveshow(x, sr=fs, alpha=.5, zorder=100)
plt.grid()
resize()

## Define STFT Parameters

In [None]:
n = 1 << 11
olap = 0.75

stft_kws = dict(
    fs=fs,
    nfft=n,
    nperseg=n,
    noverlap=int(olap * n),
    window="blackman",
)

plot_kws = dict(cmap="afmhot",)

## Optimize

### Define helper Class

In [None]:
from sample import plots


class MelSpectrogramOptimizer:
  """Helper class for optimizing the number of filters for mel-spectrograms

  Args:
    n_filters (array): Options for number of filters
    flim (float, float): Frequency limits
    freq_transform (callable): Transformation from Hertz to
      another frequency unit
    freq_unit (str): Name of the transformed frequency unit"""

  def __init__(self,
               n_filters=tuple(range(3, 129)),
               flim=(18, 18000),
               freq_transform=lambda a: a,
               freq_unit: str = "Hz"):
    self.n_filters = n_filters
    self.flim = flim
    self.freq_transform = freq_transform
    self.freq_unit = freq_unit
    self.rmse_ = None
    self.ms_bandwidths_ = None
    self.er_bandwidths_ = None
    self.i_min_ = None

  def optimize(self):
    """Run opimization"""
    self.rmse_ = np.empty_like(self.n_filters, dtype=float)
    self.ms_bandwidths_ = []
    self.er_bandwidths_ = []
    for i, n in enumerate(self.n_filters):
      _, f = psycho.mel_triangular_filterbank(freqs=[],
                                              n_filters=n,
                                              flim=self.flim)
      f_ = self.freq_transform(f)
      self.ms_bandwidths_.append((f_[2:] - f_[:-2]) / 2)
      e = psycho.erb(f[1:-1])
      self.er_bandwidths_.append(
          self.freq_transform(f[1:-1] + e / 2) -
          self.freq_transform(f[1:-1] - e / 2))
      self.rmse_[i] = np.sqrt(
          np.mean(np.square(self.ms_bandwidths_[-1] - self.er_bandwidths_[-1])))
    self.i_min_ = np.argmin(self.rmse_)
    return self

  @property
  def n_min_(self):
    """Best number of filters"""
    return self.n_filters[self.i_min_]

  def plot_bandwidth_rmse(self, ax=None):
    """Plot the RMSE of the filter bandwidths"""
    if ax is None:
      ax = plt.gca()
    ax.plot(self.n_filters, self.rmse_)
    ax.scatter(self.n_min_, self.rmse_[self.i_min_])
    ax.text(self.n_min_,
            self.rmse_[self.i_min_],
            self.n_min_,
            verticalalignment="bottom")
    ax.set_ylabel(f"RMSE ({self.freq_unit})")
    ax.set_xlabel("#filters")
    ax.grid()
    ax.set_title("Bandwidth RMSE")

  def plot_bandwidths(self, ax=None, i=None):
    """Plot the filter bandwidths for the i-th option"""
    if ax is None:
      ax = plt.gca()
    if i is None:
      i = self.i_min_
    filter_i = np.arange(1, self.n_filters[i] - 1)
    ax.plot(filter_i, self.ms_bandwidths_[i], label="50% overlap")
    ax.plot(filter_i, self.er_bandwidths_[i], label="ERB")
    ax.set_ylabel(f"Bandwidth ({self.freq_unit})")
    ax.set_xlabel("filter")
    ax.fill_between(filter_i,
                    self.ms_bandwidths_[i],
                    self.er_bandwidths_[i],
                    facecolor="C3",
                    alpha=.125)
    ax.grid()
    ax.legend()
    ax.set_title(f"Bandwidths for {self.n_filters[i]} filters")

  def plot_melspectrogram(self,
                          x,
                          floor=1e-3,
                          plot_kws=None,
                          ax=None,
                          **kwargs):
    """Plot mel-specrogram with the best number of filters"""
    if plot_kws is None:
      plot_kws = {}
    freqs, times, melspec = psycho.mel_spectrogram(x,
                                                   n_filters=self.n_min_,
                                                   flim=self.flim,
                                                   **kwargs)
    ax = plots.tf_plot(
        complex2db(normalize(melspec), floor=floor),
        flim=psycho.hz2mel(freqs[[0, -1]]),
        tlim=times[[0, -1]],
        ax=ax,
        **plot_kws,
    )
    ax.set_xlabel("time (s)")
    ax.set_ylabel("frequency (Mel)")

### Optmize bandwidth in Hz

In [None]:
melspec_opt_hz = MelSpectrogramOptimizer().optimize()
_, axs = plt.subplots(1, 2)
melspec_opt_hz.plot_bandwidth_rmse(ax=axs[0])
melspec_opt_hz.plot_bandwidths(ax=axs[1])
resize()

In [None]:
melspec_opt_hz.plot_melspectrogram(x, stft_kws=stft_kws, plot_kws=plot_kws)

### Optimize bandwidth in Cams
(Units on the ERB-rate scale)

In [None]:
melspec_opt_cams = MelSpectrogramOptimizer(
    freq_transform=psycho.hz2cams,
    freq_unit="Cams",
).optimize()
_, axs = plt.subplots(1, 2)
melspec_opt_cams.plot_bandwidth_rmse(ax=axs[0])
melspec_opt_cams.plot_bandwidths(ax=axs[1])
resize()

In [None]:
melspec_opt_cams.plot_melspectrogram(x, stft_kws=stft_kws, plot_kws=plot_kws)

### Optimize bandwidth in Barks

In [None]:
melspec_opt_bark = MelSpectrogramOptimizer(
    freq_transform=psycho.hz2bark,
    freq_unit="Bark",
).optimize()
_, axs = plt.subplots(1, 2)
melspec_opt_bark.plot_bandwidth_rmse(ax=axs[0])
melspec_opt_bark.plot_bandwidths(ax=axs[1])
resize()

In [None]:
melspec_opt_bark.plot_melspectrogram(x, stft_kws=stft_kws, plot_kws=plot_kws)

### Optimize bandwidth in Mel

In [None]:
melspec_opt_mel = MelSpectrogramOptimizer(
    freq_transform=psycho.hz2mel,
    freq_unit="Mel",
).optimize()
_, axs = plt.subplots(1, 2)
melspec_opt_mel.plot_bandwidth_rmse(ax=axs[0])
melspec_opt_mel.plot_bandwidths(ax=axs[1])
resize()

In [None]:
melspec_opt_mel.plot_melspectrogram(x, stft_kws=stft_kws, plot_kws=plot_kws)

## Compare
Compare the optmized mel-spectrograms with the cochleagram, the CQT, the VQT, and the IIRT.  
Also, compare to a mel-spectrogram with filter bandwidths equal to the ERBs at the center frequencies

In [None]:
coch_, coch_freqs = psycho.cochleagram(x,
                                      fs=fs,
                                      normalize=True,
                                      analytical="ir")
coch = coch_[:, ::128]

In [None]:
import librosa

semitone_min = -45
nbins = 12 * np.log2(20e3/(440 * (2**(semitone_min/12)))).astype(int)
cqt_semitones = np.arange(nbins) + semitone_min
cqt_freqs = np.power(2, cqt_semitones / 12) * 440
cqt = librosa.cqt(x,
                  sr=fs,
                  fmin=cqt_freqs[0],
                  hop_length=128,
                  n_bins=nbins,
                  bins_per_octave=12)

In [None]:
vqt = librosa.vqt(x,
                  sr=fs,
                  fmin=cqt_freqs[0],
                  hop_length=128,
                  n_bins=nbins,
                  bins_per_octave=12)

In [None]:
iirt = librosa.iirt(x,
                    center_freqs=cqt_freqs,
                    sample_rates=np.full(cqt_freqs.shape, fs),
                    sr=fs,
                    win_length=32)

In [None]:
mel_freqs, mel_times, melspec = psycho.mel_spectrogram(x,
                                                       n_filters=64,
                                                       flim=(20, 20000),
                                                       stft_kws=stft_kws,
                                                       bandwidth=psycho.erb)

In [None]:
c = 5
r = 2
_, axs = plt.subplots(r, c, sharex=True)
axs = np.array(axs).flatten()

plots.tf_plot(
    complex2db(normalize(coch), floor=1e-3),
    flim=psycho.hz2cams(coch_freqs[[0, -1]]),
    tlim=(0, x.size / fs),
    ax=axs[0],
    **plot_kws,
)
axs[0].set_ylabel("frequency (Cams)")
axs[0].set_title("Cochleagram\n")

plots.tf_plot(
    complex2db(normalize(cqt), floor=1e-3),
    flim=cqt_semitones[[0, -1]] + 69,
    tlim=(0, x.size / fs),
    ax=axs[1],
    **plot_kws,
)
axs[1].set_ylabel("frequency (MIDI Pitch)")
axs[1].set_title("Constant-Q Spectrogram\n")

plots.tf_plot(
    complex2db(normalize(vqt), floor=1e-3),
    flim=cqt_semitones[[0, -1]] + 69,
    tlim=(0, x.size / fs),
    ax=axs[2],
    **plot_kws,
)
axs[2].set_ylabel("frequency (MIDI Pitch)")
axs[2].set_title("Variable-Q Spectrogam\n")

plots.tf_plot(
    complex2db(normalize(iirt), floor=1e-3),
    flim=cqt_semitones[[0, -1]] + 69,
    tlim=(0, x.size / fs),
    ax=axs[3],
    **plot_kws,
)
axs[3].set_ylabel("frequency (MIDI Pitch)")
axs[3].set_title(f"Semitone IIR Spectrogram\n")


axs[4].remove()

plots.tf_plot(
    complex2db(normalize(melspec), floor=1e-3),
    flim=psycho.hz2mel(mel_freqs[[0, -1]]),
    tlim=mel_times[[0, -1]],
    ax=axs[5],
    **plot_kws,
)
axs[5].set_xlabel("time (s)")
axs[5].set_ylabel("frequency (Mel)")
axs[5].set_title(f"Mel-Spectrogram (ERB)\n")

optimizers = [
    melspec_opt_hz,
    melspec_opt_cams,
    melspec_opt_bark,
    melspec_opt_mel,
]
for ax, opt in zip(axs[6:], optimizers):
  if opt is None:
    continue
  ax.sharey(axs[c])
  opt.plot_melspectrogram(x, stft_kws=stft_kws, plot_kws=plot_kws, ax=ax)
  ax.set_title(f"Mel-Spectrogram ({opt.n_min_} filters)"
               "\n"
               f"optimized for bandwidth in {opt.freq_unit}")
  ax.tick_params("y", labelleft=False)
  ax.set_ylabel("")

for ax in axs[2:4]:
  ax.sharey(axs[1])
  ax.tick_params("y", labelleft=False)
  ax.set_ylabel("")

for ax in axs:
   ax.set_facecolor("k")

resize(*(5 * np.array([c, 3 / 4 * r])))