In [None]:
%load_ext autoreload
%autoreload 2
%reload_ext line_profiler

In [None]:
import os
os.environ['CUDA_VISIBLE_DEVICES'] = '0'
os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"]="false"

import pathlib
from functools import partial

import time
from tqdm.notebook import tqdm
import numpy as np
import matplotlib.pyplot as plt
import matplotlib as mpl
mpl.rcParams['text.usetex'] = True
mpl.rcParams['text.latex.preamble']=r"\usepackage{bm}"
import plotly.express as px
import plotly.graph_objects as go

In [None]:
import jax
import jax.numpy as jnp
# jax.config.update("jax_enable_x64", True)
# gpus = jax.devices()
# jax.config.update("jax_default_device", gpus[0])
jax.config.update('jax_platform_name', 'cpu')

import diffrax
import equinox as eqx
import optax

from haiku import PRNGSequence

In [None]:
import exciting_environments as excenvs

import exciting_exciting_systems
from exciting_exciting_systems.models import NeuralEulerODEPendulum, NeuralODEPendulum, NeuralEulerODE, NeuralEulerODECartpole
from exciting_exciting_systems.models.model_utils import simulate_ahead_with_env
from exciting_exciting_systems.models.model_training import ModelTrainer
from exciting_exciting_systems.excitation import loss_function, Exciter

from exciting_exciting_systems.utils.density_estimation import (
    update_density_estimate_single_observation, update_density_estimate_multiple_observations, DensityEstimate, select_bandwidth
)
from exciting_exciting_systems.utils.signals import aprbs
from exciting_exciting_systems.evaluation.plotting_utils import (
    plot_sequence, append_predictions_to_sequence_plot, plot_sequence_and_prediction, plot_model_performance
)
from exciting_exciting_systems.evaluation.experiment_utils import (
    get_experiment_ids, load_experiment_results, quick_eval, evaluate_experiment_metrics, evaluate_algorithm_metrics, evaluate_metrics
)

---

In [None]:
# setup PRNG
key = jax.random.PRNGKey(seed=2)

data_key, model_key, loader_key, expl_key, key = jax.random.split(key, 5)
data_rng = PRNGSequence(data_key)

In [None]:
from exciting_environments.pmsm.pmsm_env import PMSM, PMSM_Physical

In [None]:
batch_size=1

env = PMSM(
    pmsm_physical = PMSM_Physical(
        control_state="torque",
        deadtime=0,
        batch_size=batch_size,
        saturated=True,
        params = {
            "p": 3,
            "r_s": 15e-3,
            "l_d": 0.37e-3,
            "l_q": 1.2e-3,
            "psi_p": 65.6e-3,
            "u_dc": 400,
            "i_n": 250,
            "max_omega_el": 3000 / 60 * 2 * jnp.pi,
        }
    ),
    gamma=0.85,
    batch_size=batch_size,
    static_params = {
        "p_omega": 0.00005,
        "p_reference": 0.0002,
        "p_reset": 1.0,
        "i_lim_multiplier": 1.2,
        "constant_omega": True,
        "omega_ramp_min": 20000,
        "omega_ramp_max": 25000,
    }
)

In [None]:
act=jnp.repeat(jnp.array([0.03,0.03])[:,None],BATCH_SIZE,axis=1).T
act.shape

In [None]:
obs, state = env.reset()

n_steps = 99
actions = jnp.concatenate([aprbs(n_steps, batch_size, 1, 10, next(data_rng)), aprbs(n_steps, batch_size, 1, 10, next(data_rng))], axis=-1)

observations = [obs[..., 0:2]]

for i in range(actions.shape[1]):
   
    obs, state = new_motor_env.vmap_step(state, actions[:, i,:])
    observations.append(obs[...,0:2])

In [None]:
plot_sequence(np.concatenate(observations), np.concatenate(actions), env.tau, obs_labels=env.obs_description[:2], action_labels=['u_d', 'u_q'])