In [1]:
%load_ext autoreload
%autoreload 2

In [27]:
import functools
import jax
import jax.numpy as jnp
import flax
import soundfile
import audax.core.functional
from tqdm.auto import tqdm
from evosax import DES
from synthax.synth import Voice
from synthax.config import SynthConfig
from IPython.display import Audio

In [43]:
target_audio, sr = soundfile.read("bird.wav")
target_audio = jnp.atleast_2d(jnp.asarray(target_audio)[:132300])

In [58]:
n_iter = 1000
pop_size = 100
prng_key = jax.random.PRNGKey(0)
synth_cfg = SynthConfig(batch_size=pop_size)
synth = Voice(
    prng_key,
    synth_cfg,
    False
)

In [59]:
# NFFT = 512
# WIN_LEN = 400
# HOP_LEN = 160
# SR = sr

# window = jnp.hanning(WIN_LEN)
# spec_func = jax.jit(
#     functools.partial(
#         audax.core.functional.spectrogram,
#         pad=0,
#         window=window,
#         n_fft=NFFT,
#         hop_length=HOP_LEN,
#         win_length=WIN_LEN,
#         power=2.,
#         normalized=False,
#         center=True,
#         onesided=True
#     )
# )

In [71]:
@jax.jit
def mse(y_true, y_pred):
    return jnp.mean(jnp.square(jnp.broadcast_to(y_true, y_pred.shape) - y_pred), axis=-1)

In [72]:
params = synth.init(prng_key)
make_sound = jax.jit(synth.apply)

In [73]:
d = flax.traverse_util.flatten_dict(params)
keys = d.keys()
values = d.values()
init_params = jnp.vstack(list(values))

In [74]:
strategy = DES(popsize=pop_size, num_dims=init_params.shape[0])
es_params = strategy.default_params
state = strategy.initialize(prng_key, es_params)

In [None]:
best_params = None

pbar = tqdm(range(n_iter))
for i in pbar:
    prng_key, rng_gen = jax.random.split(prng_key, 2)
    synth_params, state = strategy.ask(rng_gen, state, es_params)
    synth_params = jnp.clip(synth_params, 0, 1)
    fparams = dict(zip(keys, synth_params.T))
    updated_params = flax.traverse_util.unflatten_dict(fparams)
    
    audio = make_sound(updated_params)
    
    f = mse(target_audio, audio)
    state = strategy.tell(synth_params, f, state, es_params)
    
    best_f = state.best_fitness
    best_params = state.best_member
    pbar.set_postfix({"best": best_f})


  0%|          | 0/1000 [00:00<?, ?it/s][A
  0%|          | 0/1000 [00:11<?, ?it/s, best=3.4028235e+38][A
  0%|          | 1/1000 [00:11<3:15:16, 11.73s/it, best=3.4028235e+38][A
  0%|          | 1/1000 [00:11<3:15:16, 11.73s/it, best=3.4028235e+38][A
  0%|          | 1/1000 [00:11<3:15:16, 11.73s/it, best=3.4028235e+38][A
  0%|          | 1/1000 [00:11<3:15:16, 11.73s/it, best=3.4028235e+38][A
  0%|          | 1/1000 [00:11<3:15:16, 11.73s/it, best=0.013687441]  [A
  0%|          | 1/1000 [00:11<3:15:16, 11.73s/it, best=0.013686053][A
  0%|          | 1/1000 [00:11<3:15:16, 11.73s/it, best=0.013685351][A
  0%|          | 1/1000 [00:11<3:15:16, 11.73s/it, best=0.013683513][A
  0%|          | 1/1000 [00:11<3:15:16, 11.73s/it, best=0.013683513][A
  1%|          | 9/1000 [00:11<15:47,  1.05it/s, best=0.013683513]  [A
  1%|          | 9/1000 [00:11<15:47,  1.05it/s, best=0.013683513][A
  1%|          | 9/1000 [00:11<15:47,  1.05it/s, best=0.013683513][A
  1%|          | 9/10

In [None]:
synth_cfg1 = SynthConfig(batch_size=1)
synth1 = Voice(
    prng_key,
    synth_cfg1,
    False
)

synth_params1 = jnp.clip(jnp.atleast_2d(best_params).T, 0, 1)
fparams1 = dict(zip(keys, synth_params1))
updated_params1 = flax.traverse_util.unflatten_dict(fparams1)
apply1 = jax.jit(synth1.apply)

audio1 = apply1(updated_params1)
Audio(audio1[0], rate=sr)

In [None]:
Audio(audio[0], rate=48000)

In [None]:
%timeit -n 1000 apply(params)

In [None]:

fparams = dict(zip(keys, all_params))
updated_params = flax.traverse_util.unflatten_dict(fparams)
audio = apply(updated_params)

In [None]:
Audio(audio[0], rate=48000)