In [22]:
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [23]:
import functools
import jax
import chex
import jax.numpy as jnp
import flax
import soundfile
from tqdm.auto import tqdm
from jax.scipy.signal import stft
from evosax import DES, ParameterReshaper
from synthax.synth import Voice
from synthax.config import SynthConfig
from synthax.parameter import ModuleParameter
from IPython.display import Audio

In [24]:
target_audio, sr = soundfile.read("bird.wav")
target_audio = jnp.atleast_2d(jnp.asarray(target_audio))

In [25]:
n_iter = 1000
pop_size = 100
prng_key = jax.random.PRNGKey(0)

In [26]:
synth_cfg = SynthConfig(batch_size=pop_size, sample_rate=sr)
synth = Voice(
    prng_key,
    synth_cfg,
    False
)

In [27]:
class InstanceFinder(flax.traverse_util.Traversal):
    def __init__(self, target_type):
        self.target_type = target_type
        self.instances = []

    def node(self, node):
        if isinstance(node, self.target_type):
            self.instances.append(node)
        return node

In [28]:
from synthax.io import write_synthspec, read_synthspec

In [29]:
# def spectrogram(signals, n_fft, win_length, hop_length, window_fn, sample_rate):
#     freqs, times, spec = stft(
#         signals,
#         fs=sample_rate,
#         window=window_fn,
#         nperseg=win_length,
#         noverlap=hop_length,
#         nfft=n_fft,
#         boundary=None,
#         return_onesided=True,
#         axis=-1
#     )
    
#     magnitude_spec = jnp.abs(spec)
#     return magnitude_spec

# spec_func = jax.jit(
#     functools.partial(
#         spectrogram,
#         n_fft=512,
#         win_length=400,
#         hop_length=160,
#         window_fn="hann",
#         sample_rate=sr
#     )
# )


In [30]:
@jax.jit
def mse(y_true, y_pred):
    return jnp.mean(jnp.square(y_true - y_pred), axis=list(range(1, len(y_true.shape))))

In [31]:
params = synth.init(prng_key)
make_sound = jax.jit(synth.apply)

TypeError: Shapes must be 1D sequences of concrete values of integer type, got (100, Traced<ShapedArray(int32[])>with<DynamicJaxprTrace(level=1/0)>).
If using `jit`, try using `static_argnums` or applying `jit` to smaller subfunctions.
The error occurred while tracing the function fix_length at /home/gridsan/nsingh1/PapayaResearch/synthax/synthax/functional.py:42 for jit. This concrete value was not available in Python because it depends on the value of the argument length.

In [21]:
d = flax.traverse_util.flatten_dict(params)
keys = d.keys()
values = d.values()
init_params = jnp.vstack(list(values))

NameError: name 'params' is not defined

In [None]:
strategy = DES(popsize=pop_size, num_dims=init_params.shape[0])
es_params = strategy.default_params
state = strategy.initialize(prng_key, es_params)

In [None]:
# target_spec = spec_func(target_audio)
# target_spec = jnp.broadcast_to(target_spec, (pop_size, *target_spec.shape[1:]))

In [None]:
best_params = None

pbar = tqdm(range(n_iter))

ask = jax.jit(strategy.ask)
tell = jax.jit(strategy.tell)


@jax.jit
def make_params(synth_params):
    synth_params = jnp.clip(synth_params, 0, 1)
    fparams = dict(zip(keys, synth_params.T))
    updated_params = flax.traverse_util.unflatten_dict(fparams)
    return updated_params

# for i in pbar:
#     prng_key, rng_gen = jax.random.split(prng_key, 2)
#     synth_params, state = ask(rng_gen, state, es_params)
#     updated_params = make_params(synth_params)
#     audio = make_sound(updated_params)
#     # spec = spec_func(audio)
    
#     f = mse(target_audio, audio)
#     state = tell(synth_params, f, state, es_params)
    
#     best_f = state.best_fitness
#     best_params = state.best_member
#     pbar.set_postfix({"best": best_f})

In [None]:
def make_batched(x, batch_size):
    return jnp.broadcast_to(x, (batch_size, *x.shape[1:]))

In [None]:
target_audiob = make_batched(target_audio, pop_size)

@functools.partial(jax.jit, static_argnums=(1,))
def run_es_loop(rng, num_steps):
    es_params = strategy.default_params
    state = strategy.initialize(rng, es_params)

    def es_step(state_input, tmp):
        rng, state = state_input
        rng, rng_iter = jax.random.split(rng)
        synth_params, state = strategy.ask(rng_iter, state, es_params)
        synth_params = jnp.clip(synth_params, 0, 1)
        
        fparams = dict(zip(keys, synth_params.T))
        updated_params = flax.traverse_util.unflatten_dict(fparams)
        audio = make_sound(updated_params)
        
        f = mse(target_audiob, audio)
        
        state = strategy.tell(synth_params, f, state, es_params)
        return [rng, state], jnp.hstack([state.best_fitness, state.best_member])

    [rng, state], scan_f = jax.lax.scan(
        es_step,
        [rng, state],
        [jnp.zeros(num_steps)]
    )
    f = scan_f[:, 1]
    v = scan_f[:, 1:]
    i = jnp.argmin(f)
    return f[i], v[i]

In [None]:
# %timeit run_es_loop(prng_key, n_iter)

In [None]:
best_f, best_params = run_es_loop(prng_key, n_iter)

In [None]:
synth_cfg1 = SynthConfig(batch_size=1)
synth1 = Voice(
    prng_key,
    synth_cfg1,
    False
)

synth_params1 = jnp.clip(jnp.atleast_2d(best_params).T, 0, 1)
fparams1 = dict(zip(keys, synth_params1))
updated_params1 = flax.traverse_util.unflatten_dict(fparams1)
apply1 = jax.jit(synth1.apply)

audio1 = apply1(updated_params1)

In [None]:
Audio(audio1[0] / jnp.abs(audio1[0]).max(), rate=sr)

In [None]:
best_params