# Spectral Analysis for Modal Parameter Linear Estimate

## Setup

### Libraries
Install the `sample` package and its dependencies.
The extras will install dependencies for helper functions such as plots

In [None]:
import sys
!$sys.executable -m pip install -qU lim-sample[notebooks,plots]==2.0.0
from sample import __version__
from sample.vid import logo
print("SAMPLE version:", __version__)
logo(size_inches=6)

### Generate test audio
We will synthesize a modal-like sound with three modal frequencies using simple additive synthesis.  
Also, we will add a gaussian noise at -45 dB SNR to mimic a bad recording environment.  
Sampling frequency is 44100 Hz and the duration is 2 seconds.

In [None]:
from matplotlib import pyplot as plt
from librosa.display import waveshow, specshow
from IPython import display as ipd
from sample.sample import additive_synth
from sample.utils import dsp as dsp_utils
import functools
import numpy as np

@functools.wraps(ipd.Audio)
def play(*args, **kwargs):
  ipd.display(ipd.Audio(*args, **kwargs))

def resize(w=12, h=6):
  plt.gcf().set_size_inches([w, h])

ground_truth = {
  "freqs": [440, 1103, 1097],
  "decays": [1, 0.75, 2],
  "amps": [1, 0.8, 0.2],
}
ground_truth["amps"] = np.array(ground_truth["amps"]) / sum(ground_truth["amps"])

fs = 44100
x = additive_synth(np.arange(int(2 * fs)) / fs, **ground_truth)

# Add noise
np.random.seed(42)
x += np.random.randn(np.size(x)) * dsp_utils.db2a(-45)
x /= np.max(np.abs(x))

play(x, rate=fs)

waveshow(x, sr=fs, alpha=.5, zorder=100)
plt.grid()
resize()

## Interface
Using the SAMPLE model is simplified by a scikit-learn-like API

In [None]:
from sample import SAMPLE
sample = SAMPLE(
    sinusoidal_model__max_n_sines=10,
    sinusoidal_model__peak_threshold=-30,
    sinusoidal_model__save_intermediate=True
).fit(x)

## Sinusoidal Model
SAMPLE is based on Serra's *Spectral Modelling Synthesis* (SMS),
an analysis and synthesis system for musical sounds based
on the decomposition of the sound into a deterministic
sinusoidal and a stochastic component.

The main components of the sinusoidal analysis are the peak detection
and the peak continuation algorithms.

### STFT
The peak detection/continuation algorithm is based on an analysis of the Short-Time Fourier Transform. Zero-phase windowing is employed.

In [None]:
stft = np.array([
  mx
  for mx, _ in sample.sinusoidal_model.intermediate_["stft"]
]).T

specshow(stft, sr=fs, x_axis="time", y_axis="hz");
plt.ylim([0, 2000])
resize()

### Peak detection
The peak detection algorithm detects peaks in each STFT frame of the analysed
sound as a local maximum in the magnitude spectrum

In [None]:
mx, px = sample.sinusoidal_model.intermediate_["stft"][0]
f = fs * np.arange(mx.size) / sample.sinusoidal_model.w_.size
ploc, pmag, pph = sample.sinusoidal_model.intermediate_["peaks"][0]

ax = plt.subplot(121)
plt.fill_between(f, np.full(mx.shape, -120), mx, alpha=.1)
plt.plot(f, mx)
plt.scatter(ploc * fs / sample.sinusoidal_model.w_.size, pmag, c="C0")
plt.ylim([-60, plt.ylim()[1]])
plt.grid()
plt.title("magnitude")

plt.subplot(122, sharex=ax)
plt.plot(f, px)
plt.scatter(ploc * fs / sample.sinusoidal_model.w_.size, pph)
plt.ylim([np.min(px[f < 2000]), np.max(px[f < 2000])])
plt.grid()
plt.title("phase")
plt.xlim([0, 2000])
resize()

### Peak continuation
The peak continuation algorithm organizes the peaks into temporal tracks,
with every track representing the time-varying behaviour of a partial.
For every peak in a trajectory, the instantaneous frequency, magnitude
and phase are stored to allow further manipulation and resynthesis.

The general-purpose SMS method enables recycling of the peak tracks data structures: if one trajectory
becomes inactive, it can be later picked up when a newly detected partial arises.
Our implementation doesn't allow this.

Moreover, two tracks that do not overlap in time but have approximately the same
average frequency can be considered as belonging to the same partial and merged into the same track.

In [None]:
from sample import plots
plots.sine_tracking_2d(sample.sinusoidal_model)
resize()

In [None]:
from sample import plots
plots.sine_tracking_3d(sample.sinusoidal_model)
resize(6, 6)

## Regression
Partials of a modal impact sound are characterized by exponentially decaying amplitudes.
Our model for modal partials is
$$x(t) = m\cdot e^{-2\frac{t}{d}}\cdot \sin{\left(2\pi f t + \phi\right)}$$

The magnitude in decibels is a linear funtion of time
$$m_{dB}(t) = 20\log_{10}{\left(m\cdot e^{-2\frac{t}{d}}\right)} = 20\log_{10}{m} - 40\frac{\log_{10}{e}}{d} \cdot t$$

$$k = - 40\frac{\log_{10}{e}}{d}$$
$$q = 20\log_{10}{m}$$

$$m_{dB}(t) = kt + q$$

We use linear regression to find an initial estimate of the parameters $k$ and $q$ from the magnitude tracks. Then, we refine the estimate by fitting a semi-linear *hinge* function. Amplitude is then doubled to compensate for the fact that we are looking at only half of the spectrum

In [None]:
t_x = np.arange(x.size) / fs
for i, ((f, d, a), t) in enumerate(zip(sample.param_matrix_.T, sample.sinusoidal_model.tracks_)):
    c = "C{}".format(i)
    t_t = (t["start_frame"] + np.arange(t["freq"].size)) * sample.sinusoidal_model.h / sample.sinusoidal_model.fs
    plt.plot(t_t, t["mag"] + 6.02, c=c, alpha=.33, linewidth=3)  # compensate for spectral halving
    plt.plot(t_x, 20*np.log10(a * np.exp(-2*t_x / d)), "--", c=c)

plt.title("fitted curves")
plt.grid()
plt.ylabel("magnitude (dB)")
plt.xlabel("time (s)")
plt.legend(["track", "fitted"])
resize(6, 6)

Frequency is simply estimated as the mean frequency of the peak track

# Resynthesize
Let's resynthesize the sound using the estimated parameters (via additive synthesis)

In [None]:
from librosa import stft, amplitude_to_db

x_hat = sample.predict(np.arange(x.size) / fs)

ax = plt.subplot(211)
x_dual = np.array([x, x_hat])
for l, xi in zip(("Original", "Resynthesis"), x_dual):
    waveshow(xi, sr=fs, alpha=.5, zorder=100, label=l, ax=ax)
plt.grid()
plt.legend()

X_db = amplitude_to_db(np.abs(stft(x)), ref=np.max)
ax = plt.subplot(223, sharex=ax)
specshow(X_db, ax=ax, sr=fs, x_axis="time", y_axis="hz")
ax.set_title("Original")

X_hat_db = amplitude_to_db(np.abs(stft(x_hat)), ref=np.max)
ax = plt.subplot(224, sharex=ax, sharey=ax)
specshow(X_hat_db, ax=ax, sr=fs, x_axis="time", y_axis="hz")
ax.set_title("Resynthesis")
ax.set_ylim([0, 2000])

resize(12, 12)
ipd.display(ipd.HTML("Original"))
play(x, rate=fs)
ipd.display(ipd.HTML("Resynthesis"))
play(x_hat, rate=fs)

# BeatsDROP
We can also apply a regression algorithm to disentangle beating partials!

In [None]:
from sample.beatsdrop import regression as beatsdrop_regression
import itertools

# Extract one sinusoidal track (the beating one)
track_i = np.argmax(sample.freqs_)
track = sample.sinusoidal_model.tracks_[track_i]
track_t = np.arange(
    len(track["mag"]
        )) * sample.sinusoidal_model.h / sample.sinusoidal_model.fs
track_a = track["mag"] + 6.02
track_f = track["freq"]
if sample.sinusoidal_model.reverse:
  track_a = np.flip(track_a)
  track_f = np.flip(track_f)
iok = np.isfinite(track_a)

track_alin = track_a.copy()
track_alin[iok] = dsp_utils.db2a(track_a[iok])

# Apply both variants regression
br = beatsdrop_regression.BeatRegression(fs=fs, lpf=20)
dbr = beatsdrop_regression.DualBeatRegression(fs=fs, lpf=20)
for b in (br, dbr):
  b.fit(t=track_t[iok], a=track_a[iok], f=track_f[iok])

# Plot
_, axs = plt.subplots(3, 2, sharex=True, sharey="row", figsize=np.array((16/9 * 2, 1 * 3)) * 6)

for i, b in enumerate((dbr, br)):
  am_, a0_, a1_, fm_ = b.predict(track_t, "am", "a0", "a1", "fm")
  np.true_divide(fm_, 2 * np.pi, out=fm_)

  # Amplitude modulation
  axs[0][i].plot(np.arange(x.size) / fs, x, c="C0", alpha=0.25, label="Signal")
  for a, kw in (
      (track_alin, dict(c="C0", label="Sinusoidal Track", zorder=102)),
      (am_, dict(linestyle="--", c="C1", label="Prediction", zorder=102)),
      (a0_, dict(c="C3", label="$A_1$", zorder=101)),
      (a1_, dict(c="C4", label="$A_2$", zorder=101)),
  ):
    a_ = np.copy(a)
    a_[np.less_equal(a, dsp_utils.db2a(-60))] = np.nan
    axs[0][i].plot(track_t, a, **kw)
    axs[1][i].plot(track_t, dsp_utils.a2db(a_), **kw)

  if i == 0:
    axs[0][i].set_ylabel("amplitude")
    axs[1][i].set_ylabel("amplitude (dB)")
  axs[0][0].set_title("BeatsDROP")
  axs[0][1].set_title("Baseline")

  # Frequency modulation
  axs[2][i].plot(track_t, track_f, c="C0", zorder=3, label="Sinusoidal Track")
  axs[2][i].plot(track_t, fm_, "--", c="C1", zorder=5, label="Prediction")
  axs[2][i].plot(track_t,
                 np.full_like(track_t, b.params_[2]),
                 c="C3",
                 label=r"$\nu_1$",
                 zorder=4)
  axs[2][i].plot(track_t,
                 np.full_like(track_t, b.params_[3]),
                 c="C4",
                 label=r"$\nu_2$",
                 zorder=4)
  if i == 0:
    axs[2][i].set_ylabel("frequency (Hz)")

for ax in itertools.chain.from_iterable(axs):
  ax.legend(loc="upper right")
  ax.grid()
  yl = ax.get_ylabel()
  if yl:
    yl = ax.set_ylabel(yl)
    yl.set_rotation(0)
    yl.set_horizontalalignment("left")
    yl.set_verticalalignment("bottom")
    ax.yaxis.set_label_coords(-0.05, 1.01)
for c in range(axs.shape[0]):
  axs[c, -1].set_xlabel("time (s)")

Let's now hear the refined result

In [None]:
from sample.beatsdrop.regression import BeatRegression

idxs = np.arange(len(sample.freqs_) - 1)
idxs[track_i:] += 1

ipd.display(ipd.HTML("Original"))
play(x, rate=fs)
_, axs = plt.subplots(3, 1, sharex=True, figsize=np.array((16 / 9, 1 * 3)) * 6)
waveshow(x, ax=axs[0], sr=fs, alpha=.5, zorder=100)
axs[0].set_title("Original")

br_params = {}
for i, (k, b, ax) in enumerate(zip(("Baseline", "BeatsDROP"), (br, dbr), axs[1:])):
  d = {}
  br_params[k] = d
  d["freqs"] = [*sample.freqs_[idxs], *b.params_[2:4]]
  d["decays"] = [*sample.decays_[idxs], *b.params_[4:6]]
  d["amps"] = [*sample.amps_[idxs], *b.params_[:2]]
  d["phases"] = [*np.full_like(idxs, 0), *b.params_[6:8]]
  
  x_b = additive_synth(np.arange(x.size) / fs, **d)
  ipd.display(ipd.HTML(k))
  play(x_b, rate=fs)
  waveshow(x_b, sr=fs, ax=ax, alpha=.5, zorder=100, color=f"C{i + 1}")
  ax.set_title(k)

for ax in axs:
  ax.grid()