In [None]:
%load_ext autoreload
%autoreload 2
%reload_ext line_profiler

In [None]:
from operator import itemgetter
from functools import partial
import os

os.environ["CUDA_VISIBLE_DEVICES"] = "0"

import time
from tqdm.notebook import tqdm
import numpy as np
import matplotlib.pyplot as plt
import matplotlib as mpl
mpl.rcParams['text.usetex'] = True
mpl.rcParams['text.latex.preamble']=r"\usepackage{bm}"
import plotly.express as px
import plotly.graph_objects as go

In [None]:
import jax
import jax.numpy as jnp
jax.config.update('jax_platform_name', 'cpu')
# jax.config.update("jax_enable_x64", True)
import diffrax
import equinox as eqx
import optax

from haiku import PRNGSequence

In [None]:
from exciting_exciting_systems.utils.signals import aprbs
from exciting_exciting_systems.evaluation.plotting_utils import (
    plot_sequence, append_predictions_to_sequence_plot, plot_sequence_and_prediction, plot_model_performance
)

In [None]:
import exciting_environments as excenvs

import exciting_exciting_systems
from exciting_exciting_systems.related_work.algorithms import excite_with_sGOATS, excite_with_GOATS, excite_with_iGOATS

In [None]:
seed=0
env_params = dict(
    batch_size=1,
    tau=2e-2,
    max_force=10,
    static_params={
        "mu_p": 0.002,
        "mu_c": 0.5,
        "l": 0.5,
        "m_p": 0.1,
        "m_c": 1,
        "g": 9.81,
    },
    physical_constraints={
        "deflection": 2.4,
        "velocity": 8,
        "theta": jnp.pi,
        "omega": 8,
    },
    env_solver=diffrax.Tsit5(),
)
env = excenvs.make(
    env_id="CartPole-v0",
    batch_size=env_params["batch_size"],
    action_constraints={"force": env_params["max_force"]},
    physical_constraints=env_params["physical_constraints"],
    static_params=env_params["static_params"],
    solver=env_params["env_solver"],
    tau=env_params["tau"],
)


h = 10
a = 10

alg_params = dict(
    prediction_horizon=h,
    application_horizon=a,
    bounds_amplitude=(-1, 1),
    bounds_duration=(1, 100),
    population_size=50,
    n_generations=25,
    featurize=lambda x: x,
    rng=None,
    compress_data=True,
    compression_target_N=500,
    rho_obs=1e3,
    rho_act=1e3,
    penalty_order=2,
    compression_feat_dim=-2,
    compression_dist_th=0.1,
)

exp_params = dict(
    n_timesteps=15_000,
    seed=int(seed),
    alg_params=alg_params,
    env_params=env_params,
)

# run excitation algorithm
observations, actions = excite_with_iGOATS(
    n_timesteps=exp_params["n_timesteps"],
    env=env,
    prediction_horizon=alg_params["prediction_horizon"],
    application_horizon=alg_params["application_horizon"],
    bounds_amplitude=alg_params["bounds_amplitude"],
    bounds_duration=alg_params["bounds_duration"],
    population_size=alg_params["population_size"],
    n_generations=alg_params["n_generations"],
    featurize=alg_params["featurize"],
    rng=np.random.default_rng(seed),
    compress_data=alg_params["compress_data"],
    compression_target_N=alg_params["compression_target_N"],
    rho_obs=alg_params["rho_obs"],
    rho_act=alg_params["rho_act"],
    penalty_order=alg_params["penalty_order"],
    compression_feat_dim=alg_params["compression_feat_dim"],
    compression_dist_th=alg_params["compression_dist_th"],
    plot_subsequences=True,
)


In [None]:
seed=0
env_params = dict(batch_size=1, tau=2e-2, max_torque=5, g=9.81, l=1, m=1, env_solver=diffrax.Tsit5())
env = excenvs.make(
    env_id="Pendulum-v0",
    batch_size=env_params["batch_size"],
    action_constraints={"torque": env_params["max_torque"]},
    static_params={"g": env_params["g"], "l": env_params["l"], "m": env_params["m"]},
    solver=env_params["env_solver"],
    tau=env_params["tau"],
)

h = 10
a = 10

alg_params = dict(
    prediction_horizon=h,
    application_horizon=a,
    bounds_amplitude=(-1, 1),
    bounds_duration=(10, 100),
    population_size=50,
    n_generations=25,
    featurize=lambda x: x,
    rng=None,
    compress_data=True,
    compression_target_N=500,
    rho_obs=1e3,
    rho_act=1e3,
    compression_feat_dim=-2,
    compression_dist_th=0.1,
)

exp_params = dict(
    n_timesteps=15000,
    seed=int(seed),
    alg_params=alg_params,
    env_params=env_params,
)

# run excitation algorithm
observations, actions = excite_with_iGOATS(
    n_timesteps=exp_params["n_timesteps"],
    env=env,
    prediction_horizon=alg_params["prediction_horizon"],
    application_horizon=alg_params["application_horizon"],
    bounds_amplitude=alg_params["bounds_amplitude"],
    bounds_duration=alg_params["bounds_duration"],
    population_size=alg_params["population_size"],
    n_generations=alg_params["n_generations"],
    featurize=alg_params["featurize"],
    rng=np.random.default_rng(seed),
    compress_data=alg_params["compress_data"],
    compression_target_N=alg_params["compression_target_N"],
    rho_obs=alg_params["rho_obs"],
    rho_act=alg_params["rho_act"],
    compression_feat_dim=alg_params["compression_feat_dim"],
    compression_dist_th=alg_params["compression_dist_th"],
    plot_subsequences=True,
)

In [None]:
seed=0

env_params = dict(
    batch_size=1,
    tau=5,
    max_height=3,
    max_inflow=0.2,
    base_area=jnp.pi,
    orifice_area=jnp.pi * 0.1**2,
    c_d=0.6,
    g=9.81,
    env_solver=diffrax.Tsit5(),
)
env = excenvs.make(
    "FluidTank-v0",
    physical_constraints=dict(height=env_params["max_height"]),
    action_constraints=dict(inflow=env_params["max_inflow"]),
    static_params=dict(
        base_area=env_params["base_area"],
        orifice_area=env_params["orifice_area"],
        c_d=env_params["c_d"],
        g=env_params["g"],
    ),
    tau=env_params["tau"],
    solver=env_params["env_solver"],
)

h = 10
a = 10

alg_params = dict(
    prediction_horizon=h,
    application_horizon=a,
    bounds_amplitude=(-1, 1),
    bounds_duration=(5, 50),
    population_size=50,
    n_generations=25,
    featurize=lambda x: x,
    rng=None,
    compress_data=True,
    compression_target_N=500,
    rho_obs=1e3,
    rho_act=1e3,
    compression_feat_dim=-2,
    compression_dist_th=0.1,
)

exp_params = dict(
    n_timesteps=15000,
    seed=int(seed),
    alg_params=alg_params,
    env_params=env_params,
)

# run excitation algorithm
observations, actions = excite_with_iGOATS(
    n_timesteps=exp_params["n_timesteps"],
    env=env,
    prediction_horizon=alg_params["prediction_horizon"],
    application_horizon=alg_params["application_horizon"],
    bounds_amplitude=alg_params["bounds_amplitude"],
    bounds_duration=alg_params["bounds_duration"],
    population_size=alg_params["population_size"],
    n_generations=alg_params["n_generations"],
    featurize=alg_params["featurize"],
    rng=np.random.default_rng(seed),
    compress_data=alg_params["compress_data"],
    compression_target_N=alg_params["compression_target_N"],
    rho_obs=alg_params["rho_obs"],
    rho_act=alg_params["rho_act"],
    compression_feat_dim=alg_params["compression_feat_dim"],
    compression_dist_th=alg_params["compression_dist_th"],
    plot_subsequences=True,
)

In [None]:
env_params = dict(batch_size=1, tau=2e-2, max_torque=5, g=9.81, l=1, m=1, env_solver=diffrax.Tsit5())
env = excenvs.make(
    env_id="Pendulum-v0",
    batch_size=env_params["batch_size"],
    action_constraints={"torque": env_params["max_torque"]},
    static_params={"g": env_params["g"], "l": env_params["l"], "m": env_params["m"]},
    solver=env_params["env_solver"],
    tau=env_params["tau"],
)

alg_params = dict(
    n_amplitudes=360,
    n_amplitude_groups=36,
    reuse_observations=True,
    bounds_duration=(10, 100),
    population_size=50,
    n_generations=25,
    featurize=lambda x: x,
    compress_data=True,
    compression_target_N=500,
    compression_dist_th=0.1,
    compression_feature_dim=-2,
    rho_obs=1e3,
    rho_act=1e3,
)

In [None]:
seed = 0

exp_params = dict(
    seed=int(seed),
    alg_params=alg_params,
    env_params=env_params,
)

# setup PRNG
rng = np.random.default_rng(seed=seed)

# run excitation algorithm
observations, actions = excite_with_sGOATS(
    n_amplitudes=alg_params["n_amplitudes"],
    n_amplitude_groups=alg_params["n_amplitude_groups"],
    reuse_observations=alg_params["reuse_observations"],
    env=env,
    bounds_duration=alg_params["bounds_duration"],
    population_size=alg_params["population_size"],
    n_generations=alg_params["n_generations"],
    featurize=alg_params["featurize"],
    compress_data=alg_params["compress_data"],
    compression_target_N=alg_params["compression_target_N"],
    compression_dist_th=alg_params["compression_dist_th"],
    compression_feat_dim=alg_params["compression_feature_dim"],
    rho_obs=alg_params["rho_obs"],
    rho_act=alg_params["rho_act"],
    rng=np.random.default_rng(seed=exp_params["seed"]),
    verbose=False,
    plot_every_subsequence=True,
)

In [None]:
env_params = dict(
    batch_size=1,
    tau=5e-1,
    max_height=3,
    max_inflow=0.2,
    base_area=jnp.pi,
    orifice_area=jnp.pi * 0.1**2,
    c_d=0.6,
    g=9.81,
    env_solver=diffrax.Euler(),
)
env = excenvs.make(
    "FluidTank-v0",
    physical_constraints=dict(height=env_params["max_height"]),
    action_constraints=dict(inflow=env_params["max_inflow"]),
    static_params=dict(
        base_area=env_params["base_area"],
        orifice_area=env_params["orifice_area"],
        c_d=env_params["c_d"],
        g=env_params["g"],
    ),
    tau=env_params["tau"],
    solver=env_params["env_solver"],
)

In [None]:
prediction_horizon = 4
application_horizon = 4

igoats_observations, igoats_actions = excite_with_iGOATS(
    n_timesteps=15000,
    env=env,
    prediction_horizon=prediction_horizon,
    application_horizon=application_horizon,
    bounds_amplitude=[-1, 1],
    bounds_duration=[1, 100],
    population_size=50,
    n_generations=50,
    featurize=lambda x: x,
    rng=np.random.default_rng(0),
    compress_data=True,
    compression_target_N=500,
    rho_obs=1e3,
    rho_act=1e3,
    compression_feat_dim=-2,
    compression_dist_th=0.1,
    plot_subsequences=True,
)


In [None]:
%debug

In [None]:
plot_sequence(igoats_observations, igoats_actions, env.tau, env.obs_description, env.action_description)