In [None]:
%load_ext autoreload
%autoreload 2
%reload_ext line_profiler

In [None]:
from operator import itemgetter
from functools import partial

import time
from tqdm.notebook import tqdm
import numpy as np
import matplotlib.pyplot as plt
import matplotlib as mpl
mpl.rcParams['text.usetex'] = True
mpl.rcParams['text.latex.preamble']=r"\usepackage{bm}"
import plotly.express as px
import plotly.graph_objects as go

In [None]:
import jax
import jax.numpy as jnp
jax.config.update('jax_platform_name', 'cpu')
# jax.config.update("jax_enable_x64", True)
import diffrax
import equinox as eqx
import optax

from haiku import PRNGSequence

In [None]:
from exciting_exciting_systems.utils.signals import aprbs
from exciting_exciting_systems.evaluation.plotting_utils import (
    plot_sequence, append_predictions_to_sequence_plot, plot_sequence_and_prediction, plot_model_performance
)

In [None]:
import exciting_environments as excenvs

import exciting_exciting_systems
from exciting_exciting_systems.related_work.algorithms import excite_with_sGOATS, excite_with_GOATS, excite_with_iGOATS

In [None]:
env_params = dict(
    batch_size=1,
    tau=5e-1,
    max_height=3,
    max_inflow=0.2,
    base_area=jnp.pi,
    orifice_area=jnp.pi * 0.1**2,
    c_d=0.6,
    g=9.81,
    env_solver=diffrax.Euler(),
)
env = excenvs.make(
    "FluidTank-v0",
    physical_constraints=dict(height=env_params["max_height"]),
    action_constraints=dict(inflow=env_params["max_inflow"]),
    static_params=dict(
        base_area=env_params["base_area"],
        orifice_area=env_params["orifice_area"],
        c_d=env_params["c_d"],
        g=env_params["g"],
    ),
    tau=env_params["tau"],
    solver=env_params["env_solver"],
)

In [None]:
prediction_horizon = 4
application_horizon = 4

igoats_observations, igoats_actions = excite_with_iGOATS(
    n_timesteps=15000,
    env=env,
    prediction_horizon=prediction_horizon,
    application_horizon=application_horizon,
    bounds_amplitude=[-1, 1],
    bounds_duration=[1, 100],
    population_size=50,
    n_generations=50,
    featurize=lambda x: x,
    rng=np.random.default_rng(0),
    compress_data=True,
    compression_target_N=500,
    rho_obs=1e3,
    rho_act=1e3,
    compression_feat_dim=-2,
    compression_dist_th=0.1,
    plot_subsequences=True,
)


In [None]:
%debug

In [None]:
plot_sequence(igoats_observations, igoats_actions, env.tau, env.obs_description, env.action_description)

## work on MixedGA based iGOATS?

In [None]:
def featurize_theta(obs_action):
    """The angle itself is difficult to properly interpret in the loss as angles
    such as 1.99 * pi and 0 are essentially the same. Therefore the angle is 
    transformed to sin(phi) and cos(phi) for comparison in the loss."""

    feat_obs_action = np.stack([np.sin(obs_action[..., 0] * np.pi), np.cos(obs_action[..., 0] * np.pi)], axis=-1)
    feat_obs_action = np.concatenate([feat_obs_action, obs_action[..., 1:]], axis=-1)
    
    return feat_obs_action

In [None]:
def identity(x):
    return x

---

In [None]:
from exciting_exciting_systems.related_work.excitation_utils import compress_datapoints
from exciting_exciting_systems.related_work.np_reimpl.env_utils import simulate_ahead_with_env
from exciting_exciting_systems.related_work.np_reimpl.pendulum import Pendulum

In [None]:
data_rng = PRNGSequence(jax.random.PRNGKey(0))

batch_size = 1
tau = 2e-2

env = Pendulum(
    tau=tau,
    max_torque=8
)

In [None]:
import exciting_environments as excenvs

In [None]:
env = excenvs.make("FluidTank-v0")

In [None]:
obs, state = env.reset()

n_steps = 1000
actions = aprbs(n_steps, batch_size, 20, 100, next(data_rng))[0]

observations, state = simulate_ahead_with_env(env, obs, state, actions)

plot_sequence(observations, actions, tau, env.obs_description, env.action_description)
plt.show()

In [None]:
actions.shape

In [None]:
rng = np.random.default_rng(seed=0)

In [None]:
all_observations = []
all_actions = []

all_observations, all_actions = excite_with_sGOATs(
    n_amplitudes=600,
    n_amplitude_groups=12,
    reuse_observations=True,
    all_observations=all_observations,
    all_actions=all_actions,
    env=env,
    bounds_duration=(1,50),
    population_size=50,
    n_generations=50,
    featurize=identity,
    rng=np.random.default_rng(seed=12),
    verbose=True
)

sgoats_observations = np.concatenate(all_observations)
sgoats_actions = np.concatenate(all_actions)

In [None]:
compressed_data, indices = compress_datapoints(sgoats_observations, N_c=500, feature_dimension=1)

In [None]:
N = 1000

plt.plot(np.linspace(0, sgoats_observations.shape[0]-1, sgoats_observations.shape[0]), sgoats_observations[..., 1])
plt.plot(np.linspace(0, sgoats_observations.shape[0]-1, sgoats_observations.shape[0])[indices], compressed_data[..., 1], 'r.')
plt.xlim(200, 1300)
plt.show()


In [None]:
print("sgoats actions.shape:", sgoats_actions.shape)
print("sgoats observations.shape:", sgoats_observations.shape)

fig, axs = plot_sequence(
    observations=sgoats_observations,
    actions=sgoats_actions[:-1, ...],
    tau=tau,
    obs_labels=[r"$\theta$", r"$\omega$"],
    action_labels=[r"$u$"],
);
plt.show()

In [None]:
from exciting_exciting_systems.utils.density_estimation import update_density_estimate_multiple_observations, DensityEstimate, build_grid_2d

points_per_dim = 50
bandwidth = 0.05


density_estimate = DensityEstimate(
    p=jnp.zeros([points_per_dim**2, 1]),
    x_g=build_grid_2d(low=-1, high=1, points_per_dim=points_per_dim),
    bandwidth=jnp.array([bandwidth]),
    n_observations=jnp.array([0])
)

sgoats_density_estimate = update_density_estimate_multiple_observations(
    density_estimate, sgoats_observations,
)
fig, axs, cax = exciting_exciting_systems.evaluation.plotting_utils.plot_2d_kde_as_contourf(
    sgoats_density_estimate.p, sgoats_density_estimate.x_g, [r"$\theta$", r"$\omega$"]
)

In [None]:
from exciting_exciting_systems.related_work.excitation_utils import latin_hypercube_sampling

In [None]:
all_amplitudes = latin_hypercube_sampling(d=env.action_space.shape[-1], n=600, seed=1)
amplitude_groups = np.split(all_amplitudes, 12, axis=0)

In [None]:
n_amplitude_groups = 12
n_amplitudes = 600

from exciting_exciting_systems.related_work.algorithms import generate_amplitude_groups

In [None]:
amplitude_groups = generate_amplitude_groups(n_amplitudes, n_amplitude_groups)
for amplitude_group in amplitude_groups:
    np.random.shuffle(amplitude_group)

In [None]:
for amplitude_group in amplitude_groups:
    plt.plot(amplitude_group, 'r.')
    plt.show()

In [None]:
for amplitude_group in amplitude_groups:
    plt.hist(amplitude_group, bins=20, range=(-1,1))
    plt.show()

## GOATS optimization problem as mixed integer permutation problem:

In [None]:
batch_size = 1
tau = 2e-2

env = Pendulum(
    batch_size=batch_size,
    tau=tau,
    max_torque=8
)

In [None]:
goats_observations, goats_actions = excite_with_GOATs(
    n_amplitudes=100,
    env=env,
    bounds_duration=(5,50),
    population_size=50,
    n_generations=300,
    featurize=featurize_theta,
    rng=np.random.default_rng(seed=120),
    verbose=True
)

In [None]:
goats_observations, goats_actions = excite_with_GOATs(
    n_amplitudes=100,
    env=env,
    bounds_duration=(5,50),
    population_size=50,
    n_generations=300,
    featurize=featurize_theta,
    rng=np.random.default_rng(seed=120),
    verbose=True
)

In [None]:
print("goats actions.shape:", goats_actions.shape)
print("goats observations.shape:", goats_observations.shape)

fig, axs = plot_sequence(
    observations=goats_observations,
    actions=goats_actions,
    tau=tau,
    obs_labels=[r"$\theta$", r"$\omega$"],
    action_labels=[r"$u$"],
);

plt.show()

- inspect compressed datapoints

In [None]:
from exciting_exciting_systems.related_work.excitation_utils import compress_datapoints

In [None]:
compressed_data, indices = compress_datapoints(goats_observations, N_c=500, feature_dimension=1)

In [None]:
plt.plot(np.linspace(0, goats_observations.shape[0]-1, goats_observations.shape[0]), goats_observations[:, 1])
plt.plot(np.linspace(0, goats_observations.shape[0]-1, goats_observations.shape[0])[indices], compressed_data[..., 1], 'r.')
plt.show()

## test genetic algorithm with permutation encoding on travelling salesperson problem (**TSP**):

In [None]:
from scipy.spatial.distance import cdist

from pymoo.algorithms.soo.nonconvex.ga import GA

from pymoo.core.problem import ElementwiseProblem



from pymoo.optimize import minimize
from pymoo.operators.sampling.rnd import IntegerRandomSampling
from pymoo.operators.crossover.sbx import SBX
from pymoo.operators.mutation.pm import PM
from pymoo.operators.repair.rounding import RoundingRepair
from pymoo.termination.default import DefaultSingleObjectiveTermination

from pymoo.operators.sampling.rnd import PermutationRandomSampling
from pymoo.operators.crossover.ox import OrderCrossover
from pymoo.operators.mutation.inversion import InversionMutation

In [None]:
from pymoo.problems.single.traveling_salesman import visualize
from pymoo.problems.single.traveling_salesman import create_random_tsp_problem

In [None]:
import numpy as np
from pymoo.core.repair import Repair

class StartFromZeroRepair(Repair):

    def _do(self, problem, X, **kwargs):
        print("raw:", X[0])
        
        decoded_X = np.array([problem.decode(x) for x in X])

        print("decoded:", decoded_X[0])
                
        I = np.where(decoded_X == 0)[1]

        for k in range(len(decoded_X)):
            i = I[k]
            decoded_X[k] = np.concatenate([decoded_X[k, i:], decoded_X[k, :i]])

        print("sorted:", decoded_X[0])

        encoded_X = np.array([problem.encode(decoded_x) for decoded_x in decoded_X])
        print("raw out:", encoded_X[0])
        return encoded_X

In [None]:
class TravelingSalesperson(ElementwiseProblem):

    def __init__(self, cities, **kwargs):
        """
        A two-dimensional traveling salesperson problem (TSP)

        Parameters
        ----------
        cities : numpy.array
            The cities with 2-dimensional coordinates provided by a matrix where where city is represented by a row.

        """       
        n_cities, _ = cities.shape

        self.cities = cities
        self.D = cdist(cities, cities)

        super(TravelingSalesperson, self).__init__(
            n_var=n_cities,
            n_obj=1,
            xl=np.zeros(n_cities),
            xu=np.ones(n_cities) * np.linspace(0, n_cities - 1, n_cities)[::-1],
            vtype=int,
            **kwargs
        )

    @staticmethod
    def decode(lehmer_code: list[int]) -> list[int]:
        """Decode Lehmer code to permutation.

        This function decodes Lehmer code represented as a list of integers to a permutation.

        Source: https://optuna.readthedocs.io/en/latest/faq.html#how-can-i-deal-with-permutation-as-a-parameter
        """
        n = len(lehmer_code)

        all_indices = list(range(n))
        output = []
        for k in lehmer_code:
            value = all_indices[k]
            output.append(value)
            all_indices.remove(value)
        return output


    @staticmethod
    def encode(x: list[int]) -> list[int]:
        """Source: https://gist.github.com/theepicsnail/ec3c9e91d881468fa4a822feb85a6e0e"""    
        for i in range(len(x)):
            for j in range(i+1, len(x)):
                if x[j] > x[i]:
                    x[j] -= 1
        return x
    
    def _evaluate(self, x, out, *args, **kwargs):
        out['F'] = self.get_route_length(x)

    def get_route_length(self, x):        
        x = self.decode(x)
        n_cities = len(x)
        dist = 0
        for k in range(n_cities - 1):
            i, j = x[k], x[k + 1]
            dist += self.D[i, j]

        last, first = x[-1], x[0]
        dist += self.D[last, first]  # back to the initial city
        return dist

def create_cities(n_cities, grid_width=100.0, grid_height=None, seed=None):
    if seed is not None:
        np.random.seed(seed)
    grid_height = grid_height if grid_height is not None else grid_width
    cities = np.random.random((n_cities, 2)) * [grid_width, grid_height]
    return cities

def visualize(problem, x, fig=None, ax=None, show=True, label=True):
    with plt.style.context('ggplot'):

        if fig is None or ax is None:
            fig, ax = plt.subplots()

        # plot cities using scatter plot
        ax.scatter(problem.cities[:, 0], problem.cities[:, 1], s=250)
        if label:
            # annotate cities
            for i, c in enumerate(problem.cities):
                ax.annotate(str(i), xy=c, fontsize=10, ha="center", va="center", color="white")

        # plot the line on the path
        for i in range(len(x)):
            current = x[i]
            next_ = x[(i + 1) % len(x)]
            ax.plot(problem.cities[[current, next_], 0], problem.cities[[current, next_], 1], 'r--')

        fig.suptitle("Route length: %.4f" % problem.get_route_length(problem.encode(x)))

        if show:
            fig.show()

In [None]:
cities = create_cities(20)

In [None]:
opt_algorithm = GA(
    pop_size=20,
    sampling=IntegerRandomSampling(),
    crossover=SBX(vtype=float, repair=RoundingRepair()),
    mutation=PM(vtype=float, repair=RoundingRepair()),
    eliminate_duplicates=True,
)

opt_problem = TravelingSalesperson(cities)

termination = DefaultSingleObjectiveTermination(period=200, n_max_gen=np.inf)

res = minimize(
    problem=opt_problem,
    algorithm=opt_algorithm,
    termination=termination,
    seed=0,
    save_history=False,
    verbose=True,
)

In [None]:
visualize(opt_problem, opt_problem.decode(res.X))

In [None]:
visualize(opt_problem, opt_problem.decode(opt_problem.encode(opt_problem.decode(res.X))))

---

In [None]:
from pymoo.problems.single.traveling_salesman import create_random_tsp_problem, TravelingSalesman
from pymoo.problems.single.traveling_salesman import visualize

In [None]:
from pymoo.core.mixed import MixedVariableGA, MixedVariableDuplicateElimination
from pymoo.core.variable import Integer

import mixed_GA
from mixed_GA import MixedVariableSampling, MixedVariableMating, Permutation

In [None]:
opt_algorithm = MixedVariableGA(
    pop_size=20,
    sampling=MixedVariableSampling(),
    mating=MixedVariableMating(eliminate_duplicates=MixedVariableDuplicateElimination())
)
opt_algorithm

In [None]:
class DummyProblem(ElementwiseProblem):
    """Build a mixed permutation integer problem!"""

    def __init__(self):
        variables = {"a": Permutation(bounds=(0,3)), "b": Permutation(bounds=(0,3)), "c": Permutation(bounds=(0,3)), "x0": Integer(bounds=(0,10))}
        super().__init__(vars=variables, n_obj=1)

        self.possible_values = [10, 2, 5]

        self.permutation_keys = []
        self.non_permutation_keys = []
        for key, value in variables.items():
            if isinstance(value, Permutation):
                self.permutation_keys.append(key)
            else:
                self.non_permutation_keys.append(key)    

    def _evaluate(self, x, out, *args, **kwargs):
        value = ((self.possible_values[x["a"]] - 2)**2 + self.possible_values[x["b"]]) / self.possible_values[x["c"]] + x["x0"]
        out["F"] = value

In [None]:
opt_problem = DummyProblem()

In [None]:
res = minimize(
    opt_problem,
    opt_algorithm,
    seed=1,
    verbose=True
)

In [None]:
from multitasking_tsp import StartFromZeroRepair, create_cities, visualize, MultitaskingTravellingSalespersonProblem

In [None]:
opt_algorithm = MixedVariableGA(
    pop_size=20,
    sampling=mixed_GA.MixedVariableSampling(),
    mating=mixed_GA.MixedVariableMating(eliminate_duplicates=MixedVariableDuplicateElimination()),
    # repair=StartFromZeroRepair(),
)

In [None]:
# cities = create_cities(20)
opt_problem = MultitaskingTravellingSalespersonProblem(cities)

In [None]:
opt_problem.non_permutation_keys

In [None]:
termination = DefaultSingleObjectiveTermination(period=200, n_max_gen=np.inf)

res = minimize(
    opt_problem,
    opt_algorithm,
    termination,
    seed=1,
)

In [None]:
visualize(opt_problem, np.array(itemgetter(*opt_problem.permutation_keys)(res.X)))

In [None]:
res.X

---

In [None]:
from scipy.spatial.distance import cdist

from pymoo.algorithms.soo.nonconvex.ga import GA

from pymoo.core.problem import ElementwiseProblem

from pymoo.optimize import minimize
from pymoo.operators.sampling.rnd import IntegerRandomSampling
from pymoo.operators.crossover.sbx import SBX
from pymoo.operators.mutation.pm import PM
from pymoo.operators.repair.rounding import RoundingRepair
from pymoo.termination.default import DefaultSingleObjectiveTermination

from pymoo.operators.sampling.rnd import PermutationRandomSampling
from pymoo.operators.crossover.ox import OrderCrossover
from pymoo.operators.mutation.inversion import InversionMutation

from pymoo.problems.single.traveling_salesman import TravelingSalesman, visualize

In [None]:
class StartFromZeroRepair(Repair):

    def _do(self, problem, X, **kwargs):
        I = np.where(X == 0)[1]

        for k in range(len(X)):
            i = I[k]
            X[k] = np.concatenate([X[k, i:], X[k, :i]])

        return X

In [None]:
opt_algorithm = GA(
    pop_size=20,
    sampling=PermutationRandomSampling(),
    mutation=InversionMutation(),
    crossover=OrderCrossover(),
    repair=StartFromZeroRepair(),
    eliminate_duplicates=True
)

In [None]:
opt_problem = TravelingSalesman(cities)

termination = DefaultSingleObjectiveTermination(period=200, n_max_gen=np.inf)

res = minimize(
    opt_problem,
    opt_algorithm,
    termination,
    seed=1,
)

In [None]:
visualize(opt_problem, res.X)

In [None]:
opt_problem.vars

In [None]:
from exciting_exciting_systems.related_work.excitation_utils import latin_hypercube_sampling, generate_aprbs, simulate_ahead_with_env, compress_datapoints, audze_eglais, soft_penalty

In [None]:
class GoatsProblem(ElementwiseProblem):
    """pymoo-API optimization problem for the GOATs and sGOATs algorithms.

    Optimizes amplitude permutations and durations of each specific amplitude.
    The amplitude levels are chosen beforehand.

    TODO: arbitrary observation and input dimensions
    """

    def __init__(
        self,
        amplitudes,
        env,
        obs,
        env_state,
        featurize,
        bounds_duration=(1, 50),
        starting_observations=None,
        starting_actions=None,
        compress_data=True,
        target_N=100,
    ):

        n_amplitudes = amplitudes.shape[0]

        self.env = env
        self.obs = obs
        self.env_state = env_state
        self.featurize = featurize

        amplitude_variables = {f"a_{number}": Permutation(bounds=(0, n_amplitudes)) for number in range(n_amplitudes)}
        duration_variables = {f"d_{number}": Integer(bounds=bounds_duration) for number in range(n_amplitudes)}

        self.permutation_keys = tuple(amplitude_variables.keys())
        self.non_permutation_keys = tuple(duration_variables.keys())

        all_vars = dict(amplitude_variables, **duration_variables)

        super().__init__(
            vars=all_vars,
            n_obj=1,
        )

        self.amplitudes = amplitudes
        self.n_amplitudes = n_amplitudes
        if starting_observations is not None:
            self.starting_observations = featurize(starting_observations)
        else:
            self.starting_observations = None
        self.starting_actions = starting_actions
        self.compress_data = compress_data
        self.target_N = target_N

    def _evaluate(self, x, out, *args, **kwargs):
        indices = np.array(itemgetter(*self.permutation_keys)(x))
        durations = np.array(itemgetter(*self.non_permutation_keys)(x))

        applied_amplitudes = self.amplitudes[indices]

        actions = generate_aprbs(amplitudes=applied_amplitudes, durations=durations)[:, None]

        observations, _ = simulate_ahead_with_env(
            self.env,
            self.obs,
            self.env_state,
            actions,
        )

        feat_observations = self.featurize(observations)
        if self.starting_observations is not None:
            assert (
                self.starting_actions is not None
            ), "There are starting observations, but no corresponding starting actions!"
            feat_observations = np.concatenate([feat_observations, self.starting_observations])
            all_actions = np.concatenate([actions, self.starting_actions])
        else:
            all_actions = actions

        feat_datapoints = np.concatenate([feat_observations[:-1, ...], all_actions], axis=-1)

        if self.compress_data:
            feat_datapoints, indices = compress_datapoints(feat_datapoints, N_c=self.target_N, feature_dimension=2)

        # N = observations.shape[0]
        # plt.plot(np.linspace(0, N - 1, N), feat_observations[:N, 2])
        # plt.plot(np.linspace(0, N - 1, N)[indices], compressed_feat_datapoints[..., 2], "r.")
        # plt.show()

        score = audze_eglais(feat_datapoints)

        N = observations.shape[0]

        rho_obs = 1
        rho_act = 1
        penalty_terms = rho_obs * soft_penalty(a=observations, a_max=1) + rho_act * soft_penalty(a=actions, a_max=1)

        out["F"] = 1 * score + penalty_terms.item()

In [None]:
obs, env_state = env.reset()
obs = obs.astype(np.float32)[0]
env_state = env_state.astype(np.float32)[0]

opt_algorithm = MixedVariableGA(
    pop_size=20,
    sampling=MixedVariableSampling(),
    mating=MixedVariableMating(eliminate_duplicates=MixedVariableDuplicateElimination())
)

In [None]:
seed = 0
n_amplitudes = 600

In [None]:
opt_problem = GoatsProblem(
    amplitudes=latin_hypercube_sampling(d=env.action_space.shape[-1], n=n_amplitudes, seed=seed),
    env=env,
    obs=obs,
    env_state=env_state,
    featurize=featurize_theta
)

In [None]:
res = minimize(
    problem=opt_problem,
    algorithm=opt_algorithm,
    seed=seed,
    save_history=False,
    verbose=True,
)

In [None]:
indices = np.array(itemgetter(*opt_problem.permutation_keys)(res.X))
applied_amplitudes = opt_problem.amplitudes[indices]

applied_durations = np.array(itemgetter(*opt_problem.non_permutation_keys)(res.X))

actions = generate_aprbs(amplitudes=applied_amplitudes, durations=applied_durations)[:, None]

observations, last_env_state = simulate_ahead_with_env(
    env,
    obs,
    env_state,
    actions,
)

In [None]:
goats_actions = actions
goats_observations = observations

In [None]:
print("goats actions.shape:", goats_actions.shape)
print("goats observations.shape:", goats_observations.shape)

fig, axs = plot_sequence(
    observations=goats_observations,
    actions=goats_actions,
    tau=tau,
    obs_labels=[r"$\theta$", r"$\omega$"],
    action_labels=[r"$u$"],
);

plt.show()