In [None]:
%load_ext autoreload
%autoreload 2
%reload_ext line_profiler

In [None]:
from functools import partial

import time
from tqdm.notebook import tqdm
import numpy as np
import matplotlib.pyplot as plt
import plotly.express as px
import plotly.graph_objects as go

In [None]:
import jax
import jax.numpy as jnp

jax.config.update('jax_platform_name', 'cpu')

import diffrax
import equinox as eqx
import optax

from haiku import PRNGSequence

In [None]:
from cmaes import CMAwM
from scipy.stats.qmc import LatinHypercube

import exciting_environments as excenvs

import exciting_exciting_systems
# from exciting_exciting_systems.models.model_utils import simulate_ahead_with_env

from exciting_exciting_systems.evaluation.plotting_utils import plot_sequence

from exciting_exciting_systems.related_work.np_reimpl.env_utils import simulate_ahead_with_env
from exciting_exciting_systems.related_work.excitation_utils import generate_aprbs, soft_penalty
from exciting_exciting_systems.related_work.np_reimpl.pendulum import Pendulum
from exciting_exciting_systems.related_work.np_reimpl.metrics import MC_uniform_sampling_distribution_approximation

---

In [None]:
batch_size = 1
tau = 2e-2

env = Pendulum(
    batch_size=batch_size,
    tau=tau,
    max_torque=5
)

In [None]:
h = 100

action_parameters = np.concatenate([
    np.random.uniform(low=-1, high=1, size=(h,)).astype(np.float32),
    np.random.randint(low=2, high=50, size=(h,), dtype=np.int32)
])

actions = generate_aprbs(
    amplitudes=action_parameters[:h],
    durations=np.abs(action_parameters[h:].astype(np.int32))
)[None, :, None]

plt.plot(np.squeeze(actions))

In [None]:
obs, env_state = env.reset()
obs = obs.astype(np.float32)
env_state = env_state.astype(jnp.float32)

observations, _ = simulate_ahead_with_env(
    env,
    obs,
    env_state,
    actions,
)

print("actions.shape:", actions.shape)
print("observations.shape:", observations.shape)

print(" \n One of the trajectories:")
fig, axs = plot_sequence(
    observations=observations[0, ...],
    actions=actions[0, ...],
    tau=tau,
    obs_labels=[r"$\theta$", r"$\omega$"],
    action_labels=[r"$u$"],
);
plt.show()

In [None]:
n_amplitude_levels = 1000

# use latin hypercube sampling to generate the amplitude levels
amplitude_sampler = LatinHypercube(d=1)

amplitudes = amplitude_sampler.random(n=n_amplitude_levels) * 2 - 1

In [None]:
plt.plot(amplitudes)

## GOATS:
- What is the influence of the Lehmer coding on the result of the algorithm?
- What genetic algorithm is best suited here?

In [None]:
from exciting_exciting_systems.related_work.algorithms import excite_with_GOATs

In [None]:
def featurize_theta(obs):
    """The angle itself is difficult to properly interpret in the loss as angles
    such as 1.99 * pi and 0 are essentially the same. Therefore the angle is 
    transformed to sin(phi) and cos(phi) for comparison in the loss."""

    feat_obs = np.stack([np.sin(obs[..., 0] * np.pi), np.cos(obs[..., 0] * np.pi), obs[..., 1]], axis=-1)
    return feat_obs

In [None]:
observations, actions = excite_with_GOATs(
    n_amplitudes=100,
    env=env,
    bounds_duration=(1,50),
    population_size=20,
    n_generations=10,
    n_support_points=1600,
    featurize=featurize_theta,
    seed=0,
    verbose=True
)

In [None]:
fig, axs = plot_sequence(
    observations=observations[0, ...],
    actions=actions[0, ...],
    tau=tau,
    obs_labels=[r"$\theta$", r"$\omega$"],
    action_labels=[r"$u$"],
);
plt.show()

## sGOATS:

- essentially repeat GOATs for subsets of the data

In [None]:
from exciting_exciting_systems.related_work.algorithms import excite_with_sGOATs

#### without reuse:
perform ```n_amplitude_groups``` independent optimizations

In [None]:
all_observations = []
all_actions = []

all_observations, all_actions = excite_with_sGOATs(
    n_amplitudes=600,
    n_amplitude_groups=6,
    reuse_observations=False,
    all_observations=all_observations,
    all_actions=all_actions,
    env=env,
    bounds_duration=(1,50),
    population_size=20,
    n_generations=20,
    n_support_points=1600,
    featurize=featurize_theta,
    seed=0,
    verbose=True
)

observations = np.concatenate(all_observations)
actions = np.concatenate(all_actions)

In [None]:
fig, axs = plot_sequence(
    observations=observations,
    actions=actions,
    tau=tau,
    obs_labels=[r"$\theta$", r"$\omega$"],
    action_labels=[r"$u$"],
);
plt.show()

#### with reuse:
reuse the observations from a previous amplitude group for the current iteration

In [None]:
all_observations = []
all_actions = []

all_observations, all_actions = excite_with_sGOATs(
    n_amplitudes=600,
    n_amplitude_groups=6,
    reuse_observations=True,
    all_observations=all_observations,
    all_actions=all_actions,
    env=env,
    bounds_duration=(1,50),
    population_size=20,
    n_generations=20,
    n_support_points=1600,
    featurize=featurize_theta,
    seed=0,
    verbose=True
)

observations = np.concatenate(all_observations)
actions = np.concatenate(all_actions)

In [None]:
fig, axs = plot_sequence(
    observations=observations,
    actions=actions,
    tau=tau,
    obs_labels=[r"$\theta$", r"$\omega$"],
    action_labels=[r"$u$"],
);
plt.show()