# Automatic Tuning of SAMPLE hyperparameters
In this notebook we will see how to automatically tune the hyperparameters of SAMPLE

## Setup

### Libraries
Install the `sample` package and its dependencies.
The extras will install dependencies for helper functions such as plots

In [None]:
# import sys
# !$sys.executable -m pip install -qU ..

### Load audio
Download the test audio or load your own audio file. In this notebook, you can specify

   - a filename: to load the audio from file
   - a URL: to download the audio file from the web (only if fname is empty)
   - start time and length (in seconds): to cut the audio file

In [None]:
from matplotlib import pyplot as plt
from librosa.display import waveshow, specshow
from IPython import display as ipd
from sample.ipython import WebAudio
import numpy as np
import functools
import librosa
import requests
import os

@functools.wraps(WebAudio)
def play(*args, **kwargs):
  ipd.display(WebAudio(*args, **kwargs))

def resize(w=12, h=6):
  plt.gcf().set_size_inches([w, h])

fname = "/home/marco/Downloads/Cowbell2.wav" #@param {type: "string"}
url = "http://soundbible.com/grab.php?id=2190&type=wav" #@param {type: "string"}
start_time = 2.86 #@param {type: "number"}
time_length = 1.5 #@param {type: "number"}

if not fname:
  _fname = "_testaudio.wav"
  r = requests.get(url)
  with open(_fname, "wb") as f:
    f.write(r.content)
else:
  _fname = fname

x, fs = librosa.load(_fname, sr=None)

i_0 = int(start_time * fs)
i_1 = i_0 + int(time_length * fs)

x = x[i_0:i_1]

if not fname:
  os.remove(_fname)

waveshow(x, sr=fs, alpha=.5, zorder=100)
plt.grid()
resize()
play(x, rate=fs)

## Define optimization problem

Define fixed parameters, that will not be tuned by the optimizer

In [None]:
sample_opt_fixed=dict(
  max_n_modes=64,
  sinusoidal_model__reverse=True,
  sinusoidal_model__safe_sine_len=2,
  sinusoidal_model__frequency_bounds=(50, 20e3),
)

Define the space of parameters to tune

In [None]:
import skopt.space
sample_opt_space = dict(
  sinusoidal_model__log_n=skopt.space.Integer(
    6, 15, name="log2(n)"),
  sinusoidal_model__max_n_sines=skopt.space.Integer(
    16, 128, name="n sines"),
  sinusoidal_model__peak_threshold=skopt.space.Real(
    -120, -30, name="peak threshold"),
  sinusoidal_model__min_sine_dur=skopt.space.Real(
    0, 0.5, name="min duration"),
  sinusoidal_model__overlap=skopt.space.Real(
    0, 2/3, name="overlap"),
)

## Optimize

In [None]:
from tqdm.notebook import tqdm_notebook
import multiprocessing as mp
import sample.optimize
import warnings


n_minimizing_points = 16 #@param {type:"integer"}
n_initial_points = 128 #@param {type:"integer"}
n_calls = n_minimizing_points + n_initial_points


if "opt_res" not in locals():
  opt_res = None
with mp.Pool(processes=6) as pool:
  sample_opt = sample.optimize.SAMPLEOptimizer(
    sample_kw=sample_opt_fixed,
    loss_kw=dict(pool=pool),
    **sample_opt_space,
  )
  tqdm_cbk = sample.optimize.TqdmCallback(
    sample_opt=sample_opt,
    n_calls=n_calls,
    n_initial_points=n_initial_points,
    tqdm_fn=tqdm_notebook,
  )
  with warnings.catch_warnings():
    warnings.simplefilter("ignore")
    opt_model, opt_res = sample_opt.gp_minimize(
      x=x, fs=fs,
      n_calls=n_calls,
      n_initial_points=n_initial_points,
      callback=tqdm_cbk,
      initial_point_generator="lhs",
      acq_func="LCB",
      state=opt_res,
      # verbose=True,
    )

In [None]:
opt_model

### Listen back

In [None]:
from librosa import stft, amplitude_to_db
x_hat = np.clip(opt_model.predict(np.arange(x.size) / fs, phases="random", seed=1), -1, +1)

ax = plt.subplot(211)
x_dual = np.array([x, x_hat])
for l, xi in zip(("original", "resynthesis"), x_dual):
  play(xi, rate=fs, label=l)
  waveshow(xi, sr=fs, alpha=.5, zorder=100, label=l, ax=ax)
plt.grid()
plt.legend()

X_db = amplitude_to_db(np.abs(stft(x)), ref=np.max)
ax = plt.subplot(223, sharex=ax)
specshow(X_db, ax=ax, sr=fs, x_axis="time", y_axis="hz")
ax.set_title("original")

X_hat_db = amplitude_to_db(np.abs(stft(x_hat)), ref=np.max)
ax = plt.subplot(224, sharex=ax, sharey=ax)
specshow(X_hat_db, ax=ax, sr=fs, x_axis="time", y_axis="hz")
ax.set_title("resynthesis")
ax.set_ylim([0, 20000])

resize(12, 12)

In [None]:
import skopt.plots
import itertools

axs = skopt.plots.plot_objective(opt_res, levels=16, show_points=False)

for ax in itertools.chain.from_iterable(axs):
  if ax.get_ylabel() == "Partial dependence":
    ax.grid(True)
    ax.set_ylabel("")
# plt.savefig("partial_dependence_sample.svg");