In [None]:
%load_ext autoreload
%autoreload 2
%reload_ext line_profiler

In [None]:
from functools import partial

import time
from tqdm.notebook import tqdm
import numpy as np
import matplotlib.pyplot as plt
import plotly.express as px
import plotly.graph_objects as go

In [None]:
import jax
import jax.numpy as jnp

import diffrax
import equinox as eqx
import optax

from haiku import PRNGSequence

In [None]:
import exciting_environments as excenvs

from exciting_exciting_systems.utils.density_estimation import build_grid_2d
from exciting_exciting_systems.models.model_utils import simulate_ahead_with_env
from exciting_exciting_systems.utils.signals import aprbs
from exciting_exciting_systems.evaluation.plotting_utils import plot_sequence

---

In [None]:
# setup PRNG
key = jax.random.PRNGKey(seed=3333) # 21)

data_key, opt_key, key = jax.random.split(key, 3)
data_rng = PRNGSequence(data_key)

In [None]:
batch_size = 1
tau = 2e-2 # 5e-2

env = excenvs.make(
    env_id='Pendulum-v0',
    batch_size=batch_size,
    tau=tau
)

In [None]:
obs, state = env.reset()
obs = obs.astype(jnp.float32)
state = state.astype(jnp.float32)
n_steps = 999

actions = aprbs(n_steps, batch_size, 1, 10, next(data_rng))

In [None]:
observations = jax.vmap(simulate_ahead_with_env, in_axes=(None, 0, 0, 0, 0, 0, 0))(
    env,
    obs,
    state,
    actions,
    env.env_state_normalizer,
    env.action_normalizer,
    env.static_params
)

print("actions.shape:", actions.shape)
print("observations.shape:", observations.shape)

print(" \n One of the trajectories:")
fig, axs = plot_sequence(
    observations=observations[0, ...],
    actions=actions[0, ...],
    tau=tau,
    obs_labels=[r"$\theta$", r"$\omega$"],
    action_labels=[r"$u$"],
);
plt.show()

---

In [None]:
import optuna

In [None]:
@jax.jit
def dummy_fitness_func(x):
    a = 2
    b = 100
    return (a - x[0])**2 + b * (x[1] - x[0]**2)**2


def objective(trial):
    x0 = trial.suggest_float('x0', -10, 10)
    x1 = trial.suggest_float('x1', -10, 10)
    return dummy_fitness_func(jnp.array([x0, x1]))

In [None]:
sampler = optuna.samplers.CmaEsSampler(popsize=50)
study = optuna.create_study(sampler=sampler)
study.optimize(objective, n_trials=500, show_progress_bar=True)

study.best_params

In [None]:
from cmaes import CMA, CMAwM

In [None]:
optimizer = CMA(mean=np.zeros(2), sigma=1.3, population_size=20)

for generation in range(100):
    solutions = []
    for i in range(optimizer.population_size):
        x = optimizer.ask()
        value = dummy_fitness_func(x)
        solutions.append((x, value))
        # if (generation % 10) == 0:
        #     print(f"#{generation} {value} (x1={x[0]}, x2 = {x[1]})")
    optimizer.tell(solutions)

In [None]:
print(x)

In [None]:
xs, values = [], []
for x, value in solutions:
    xs.append(x)
    values.append(value)

xs = np.stack(xs)
values = np.stack(values)

In [None]:
min_idx = np.argmin(values)

print("min location:", xs[min_idx])
print("min value:", values[min_idx])

---

In [None]:
# from exciting_exciting_systems.utils.metrics import audze_eglais

In [None]:
def generate_aprbs(amplitudes, durations):
    """Parameterizable aprbs. This is used to transform the aprbs parameters into a signal.

    Not jittable because duration cannot be traced. Seems to be an unsolvable problem for this algorithm.
    -> To me it just looks like a bad idea to have the duration be dependent on randomness.
    -> An alternative would be fixed full duration and have indexing dependent on the duration? Actually, this is not traceable as well?
    """
    signal = np.concatenate([np.ones(duration) * amplitude for (amplitude, duration) in zip(amplitudes, durations)])
    return signal

In [None]:
h = 20

action_parameters = jnp.concatenate([
    jax.random.uniform(
        key=next(data_rng),
        shape=(h,),
        minval=-1,
        maxval=1
    ),
    jax.random.randint(
        key=next(data_rng),
        shape=(h,),
        minval=2,
        maxval=50
    )
])

actions = generate_aprbs(
    amplitudes=action_parameters[:h],
    durations=action_parameters[h:].astype(jnp.int32)
)[None, :, None]

plt.plot(jnp.squeeze(actions))

In [None]:
def soft_penalty(a, a_max=1):
    """Computes penalty for the given input. Assumes symmetry in all dimensions."""
    penalty = np.sum(jax.nn.relu(jnp.abs(a) - a_max), axis=(-2, -1))
    return np.squeeze(penalty)


def featurize_theta(obs):
    """The angle itself is difficult to properly interpret in the loss as angles
    such as 1.99 * pi and 0 are essentially the same. Therefore the angle is 
    transformed to sin(phi) and cos(phi) for comparison in the loss."""

    feat_obs = np.stack([np.sin(obs[..., 0] * np.pi), np.cos(obs[..., 0] * np.pi), obs[..., 1]], axis=-1)
    return feat_obs


def audze_eglais(data_points: jnp.ndarray) -> jnp.ndarray:
    """From [Smits+Nelles2024]. The maximin-desing penalizes points that
    are too close in the point distribution.

    TODO: There has to be a more efficient way to do this.
    """
    N = data_points.shape[0]
    distance_matrix = np.linalg.norm(data_points[:, None, :] - data_points[None, ...], axis=-1)
    # distances = distance_matrix[jax.numpy.triu_indices(N, k=1)]
    distances = distance_matrix[np.triu_indices(N, k=1)]
    
    return 2 / (N * (N-1)) * np.sum(1 / distances**2)

In [None]:
def fitness_function(
    env,
    obs,
    state,
    action_parameters,
    h
):
    actions = generate_aprbs(
        amplitudes=action_parameters[:h],
        durations=action_parameters[h:].astype(np.int32)
    )[None, :, None]

    max_signal_length = h * 50
    diff_to_max = max_signal_length - actions.shape[1]

    padded_actions = jnp.concatenate([actions, jnp.zeros((1, diff_to_max, 1))], axis=1)
    
    padded_observations = jax.vmap(simulate_ahead_with_env, in_axes=(None, 0, 0, 0, 0, 0, 0))(
        env,
        obs,
        state,
        padded_actions,
        env.env_state_normalizer,
        env.action_normalizer,
        env.static_params
    )
    padded_observations = np.array(padded_observations)
    actions = np.array(actions)
    
    padded_feat_observations = featurize_theta(padded_observations[0])

    observations = padded_observations[:, :-diff_to_max, :]
    feat_observations = padded_feat_observations[:-diff_to_max, :]
    actions = padded_actions[:, :-diff_to_max, :]

    
    ae_score = audze_eglais(
        data_points=feat_observations,
    )
    #ae_score = audze_eglais(feat_observations)
    
    rho_obs = 1e8
    rho_act = 1e8
    penalty_terms = rho_obs * soft_penalty(a=observations, a_max=1) + rho_act * soft_penalty(a=actions, a_max=1)
    
    return jnp.squeeze(ae_score).item() + penalty_terms.item()

In [None]:
h = 20

continuous_dim = h
discrete_dim = h

population_size = 21
n_generations = 100

bounds = np.concatenate(
    [
        np.tile([-1, 1], (continuous_dim, 1)),
        np.tile([1, 50], (discrete_dim, 1)),
    ]
)

steps = np.concatenate([np.zeros(continuous_dim), np.ones(discrete_dim)])

In [None]:
optimizer = CMAwM(
    mean=np.hstack([np.zeros(h), np.ones(h) * 25]),
    sigma=2.0,
    population_size=population_size,
    bounds=bounds,
    steps=steps
)

obs, env_state = env.reset()
obs = obs.astype(jnp.float32)
env_state = state.astype(jnp.float32)

for generation in tqdm(range(n_generations)):
    solutions = []
    for i in range(optimizer.population_size):        
        x_for_eval, x_for_tell = optimizer.ask()
        value = fitness_function(
            env,
            obs,
            env_state,
            x_for_eval,
            h
        )
        
        solutions.append((x_for_tell, value))
        # if (generation % 10) == 0:
        #     print(f"#{generation} {value} (x1={x[0]}, x2 = {x[1]})")
    optimizer.tell(solutions)

In [None]:
xs, values = [], []
for x, value in solutions:
    xs.append(x)
    values.append(value)

xs = np.stack(xs)
values = np.stack(values)

In [None]:
min_idx = np.argmin(values)

print("min location:", xs[min_idx])
print("min value:", values[min_idx])

In [None]:
x_for_eval, _ = optimizer.ask()

In [None]:
actions = generate_aprbs(
    amplitudes=x_for_eval[:h],
    durations=x_for_eval[h:].astype(np.int32)
)[None, :, None]

obs, env_state = env.reset()
obs = obs.astype(jnp.float32)
env_state = env_state.astype(jnp.float32)
observations = jax.vmap(simulate_ahead_with_env, in_axes=(None, 0, 0, 0, 0, 0, 0))(
    env,
    obs,
    env_state,
    actions,
    env.env_state_normalizer,
    env.action_normalizer,
    env.static_params
)

fig, axs = plot_sequence(
    observations=observations[0, ...],
    actions=actions[0, ...],
    tau=tau,
    obs_labels=[r"$\theta$", r"$\omega$"],
    action_labels=[r"$u$"],
);
plt.show()

- good enough for this week...
- i think the jit compilation eats a whole lot of time :(
- **test without any jitting** -> jax without jitting is slow as hell. What about pure numpy?