# Automatic Tuning of SAMPLE hyperparameters
In this notebook we will see how to automatically tune the hyperparameters of SAMPLE

## Setup

### Libraries
Install the `sample` package and its dependencies.
The extras will install dependencies for helper functions such as plots

In [None]:
# import sys
# !$sys.executable -m pip install -qU ..
from sample import __version__
from sample.vid import logo
print("SAMPLE version:", __version__)
logo(size_inches=6)

### Load audio
Download the test audio or load your own audio file. In this notebook, you can specify

   - a filename: to load the audio from file
   - a URL: to download the audio file from the web (only if fname is empty)
   - start time and length (in seconds): to cut the audio file

In [None]:
from matplotlib import pyplot as plt
from librosa.display import waveshow, specshow
from IPython import display as ipd
from sample.ipython import WebAudio
import numpy as np
import functools
import librosa
import requests
import os

@functools.wraps(WebAudio)
def play(*args, **kwargs):
  ipd.display(WebAudio(*args, **kwargs))

def resize(w=12, h=6):
  plt.gcf().set_size_inches([w, h])

fname = "" #@param {type: "string"}
url = "http://soundbible.com/grab.php?id=2190&type=wav" #@param {type: "string"}
start_time = 1.298 #@param {type: "number"}
time_length = 3 #@param {type: "number"}

if not fname:
  _fname = "_testaudio.wav"
  r = requests.get(url)
  with open(_fname, "wb") as f:
    f.write(r.content)
else:
  _fname = fname

x, fs = librosa.load(_fname, sr=None)

i_0 = int(start_time * fs)
i_1 = i_0 + int(time_length * fs)

x = x[i_0:i_1]

if not fname:
  os.remove(_fname)

waveshow(x, sr=fs, alpha=.5, zorder=100)
plt.grid()
resize()
play(x, rate=fs)

## Define optimization problem

Define fixed parameters, that will not be tuned by the optimizer

In [None]:
sample_opt_fixed=dict(
  max_n_modes=64,
  sinusoidal_model__reverse=True,
  sinusoidal_model__safe_sine_len=2,
  sinusoidal_model__overlap=0.5,
  sinusoidal_model__frequency_bounds=(50, 20e3),
  # sinusoidal_model__max_n_sines=48,
)

Define the space of parameters to tune

In [None]:
import skopt.space
sample_opt_space = dict(
  sinusoidal_model__log_n=skopt.space.Integer(
    6, 14, name="log2(n)"),
  sinusoidal_model__max_n_sines=skopt.space.Integer(
    16, 128, name="n sines"),
  sinusoidal_model__peak_threshold=skopt.space.Real(
    -120, -30, name="peak threshold"),
  sinusoidal_model__min_sine_dur=skopt.space.Real(
    0, 0.5, name="min duration"),
  # sinusoidal_model__overlap=skopt.space.Real(
  #   0, 2/3, name="overlap"),
)

Define cochleagram loss function

In [None]:
from sample.evaluation.metrics import CochleagramLoss
from sample.utils.dsp import complex2db
from functools import partial

cochleagram_loss = CochleagramLoss(fs=fs,
                                   normalize=True,
                                   analytical="ir",
                                   stride=int(fs * 0.008),
                                   postprocessing=partial(complex2db,
                                                          floor=-60,
                                                          floor_db=True))

## Optimize

In [None]:
from tqdm.notebook import tqdm_notebook
import sample.optimize


reset = True #@param {type:"boolean"}
n_minimizing_points = 32 #@param {type:"integer"}
n_initial_points = 64 - n_minimizing_points #@param {type:"integer"}
n_calls = n_minimizing_points + n_initial_points


if reset or "opt_res" not in locals():
  opt_res = None
sample_opt = sample.optimize.SAMPLEOptimizer(
  sample_kw=sample_opt_fixed,
  loss_fn=cochleagram_loss,
  **sample_opt_space,
)
tqdm_cbk = sample.optimize.TqdmCallback(
  sample_opt=sample_opt,
  n_calls=n_calls,
  n_initial_points=n_initial_points,
  tqdm_fn=tqdm_notebook,
)
opt_model, opt_res = sample_opt.gp_minimize(
  x=x, fs=fs,
  n_calls=n_calls,
  n_initial_points=n_initial_points,
  callback=tqdm_cbk,
  initial_point_generator="lhs",
  acq_func="LCB",
  state=opt_res,
  random_state=42,
  # verbose=True,
)

In [None]:
opt_model

### Listen back

In [None]:
from sample.psycho import cochleagram, hz2cams
from sample.plots import tf_plot

x_hat = np.clip(
    opt_model.predict(np.arange(x.size) / fs, phases="random", seed=1), -1, +1)

ax = plt.subplot(211)
x_dual = np.array([x, x_hat])
for l, xi in zip(("original", "resynthesis"), x_dual):
  play(xi, rate=fs, label=l)
  waveshow(xi, sr=fs, alpha=.5, zorder=100, label=l, ax=ax)
plt.grid()
plt.legend()

coch_x, cfreq = cochleagram(x,
                            filterbank=cochleagram_loss.filterbank,
                            **cochleagram_loss.postprocessing)
ax = plt.subplot(223, sharex=ax)
tf_plot(coch_x,
        ax=ax,
        tlim=(0, x.size / fs),
        flim=hz2cams(cfreq[[0, -1]]),
        cmap="Blues")
ax.set_title("original")

coch_x_hat, _ = cochleagram(x_hat,
                            filterbank=cochleagram_loss.filterbank,
                            **cochleagram_loss.postprocessing)
ax = plt.subplot(224, sharex=ax)
tf_plot(coch_x_hat,
        ax=ax,
        tlim=(0, x.size / fs),
        flim=hz2cams(cfreq[[0, -1]]),
        cmap="Oranges")
ax.set_title("resynthesis")

resize(12, 12)

In [None]:
import matplotlib.collections
import skopt.plots
import itertools

with plt.style.context("../scripts/beatsdrop/figures.mplstyle",
                       after_reset=True) as ctx:
  axs = skopt.plots.plot_objective(opt_res, levels=16, show_points=False)
  axs[0][0].set_xticks((7, 9, 11, 13))
  for ax in itertools.chain.from_iterable(axs):
    if ax.get_ylabel() == "Partial dependence":
      ax.grid(True)
      ax.set_ylabel("")
    if ax.get_xlabel() == "log2(n)":
      ax.set_xlabel(r"$\log_2{n}$")
  # Numbered axis
  it = itertools.chain.from_iterable(
      zip(itertools.repeat(i), enumerate(axsi)) for i, axsi in enumerate(axs))
  # Only lower triangle (no diagonal)
  it = filter(lambda t: t[0] > t[1][0], it)
  # Get children
  it = itertools.chain.from_iterable(map(lambda t: t[1][1].get_children(), it))
  # Only PathCollections
  it = filter(lambda c: isinstance(c, matplotlib.collections.PathCollection),
              it)
  # Not red (the minimum star)
  it = itertools.filterfalse(
      lambda c: (c.get_facecolor() == (1, 0, 0, 1)).all(), it)
  for c in it:
    c.set_rasterized(True)
  plt.savefig("partialdependence.pdf")