### To be solved here:
- given a perfect model of the system
- predict the state trajectory
- choose the action using MPC
- such that the $\mathcal{L} = \mathrm{JSD}(p(x), p^*(x))$ is minimized

---

In [None]:
%load_ext autoreload
%autoreload 2
%reload_ext line_profiler

In [None]:
from functools import partial

import time
from tqdm.notebook import tqdm
import numpy as np
import matplotlib.pyplot as plt
import plotly.express as px
import plotly.graph_objects as go

In [None]:
import jax
import jax.numpy as jnp

import diffrax
import equinox as eqx
import optax

from haiku import PRNGSequence

In [None]:
import exciting_environments as excenvs
from models import NeuralEulerODE

---

- find a suitable jax MPC library -> nah i dont think this is necessary
- This is just a "simple" optimization problem and is only somewhat loosely connected to MPC?
- First off I need a fully trained differentiable model of our system
- We will use the derivative of the model w.r.t. to the inputs to find the action(s) with the best predicted **endpoint** of "loss"
- We are only interested that the final excitation state is a good one
- Maybe this needs to be relaxed when the model is still bad and the predictions are not that trustworthy

## Familiarization with jax.grad and application in optimization for a given env

In [None]:
# setup PRNG
key = jax.random.PRNGKey(seed=21)

data_key, model_key, key = jax.random.split(key, 3)
data_rng = PRNGSequence(data_key)

In [None]:
batch_size = 1
n_steps = 1000
tau = 1e-3

env = excenvs.make(
    env_id='Pendulum-v0',
    batch_size=batch_size,
    tau=tau
)

model = excenvs.make(
    env_id='Pendulum-v0',
    batch_size=batch_size,
    tau=tau
)

In [None]:
env

In [None]:
model

In [None]:
@partial(jax.jit, static_argnums=(0, 1))
def gen_actions(n_steps, batch_size, key):
    actions = jax.random.uniform(key, shape=(batch_size, 1, 1), minval=-1, maxval=1)
    actions = jnp.repeat(actions, repeats=n_steps, axis=1)
    return actions

In [None]:
actions = gen_actions(n_steps, batch_size, next(data_rng))

In [None]:
@partial(jax.jit, static_argnums=(0, 3))
def simulate_episode(env, obs, state, n_steps, actions):

    batch_size, obs_dim = obs.shape
    observations = jnp.zeros([batch_size, n_steps, obs_dim])
    observations = observations.at[:, 0, :].set(obs)

    def body_fun(n, carry):
        obs, state, observations = carry

        action = actions[:, n, :]
        obs, reward, terminated, truncated, state = env.step(action, state)
        observations = observations.at[:, n, :].set(obs)

        return (obs, state, observations)

    obs, state, observations = jax.lax.fori_loop(lower=1, upper=n_steps, body_fun=body_fun, init_val=(obs, state, observations))

    return observations


def plot_episode(observations, actions, max_n=2):
    n_plots = min(max_n, observations.shape[0])

    for idx in range(n_plots):
        plt.plot(observations[idx, :, 0], label="theta")
        plt.plot(observations[idx, :, 1], label="omega")
        plt.grid()
        plt.title("observations, timeseries")
        plt.legend()
        plt.show()
    
    for idx in range(n_plots):
        plt.plot(observations[idx, :, 0], observations[idx, :, 1], 'b.')
        plt.grid()
        plt.title("observations, together")
        plt.show()
    
    for idx in range(n_plots):
        plt.plot(actions[idx, :, 0])
        plt.grid()
        plt.title("actions, timeseries")
        plt.show()

In [None]:
obs, state = env.reset()
obs = obs.astype(jnp.float32)
state = state.astype(jnp.float32)

# actions = aprbs(n_steps, batch_size, 5, 100, next(data_rng))
actions = gen_actions(n_steps, batch_size, next(data_rng))

observations = simulate_episode(
    env,
    obs,
    state,
    n_steps=n_steps,
    actions=actions
)

observations_model = simulate_episode(
    model, 
    obs,
    state,
    n_steps=n_steps,
    actions=actions
)

print(observations.shape)
print(observations_model.shape)
print(actions.shape)

In [None]:
jnp.all(observations == observations_model)

In [None]:
plot_episode(observations, actions, max_n=1)

In [None]:
plot_episode(observations_model, actions, max_n=1)

### Examine gradient properties of the model:

In [None]:
def wrap_model_for_theta(action, state, env_state_normalizer, action_normalizer, static_params):
    
    state = model._ode_exp_euler_step(
        state,
        action,
        env_state_normalizer,
        action_normalizer,
        static_params
    )
    
    return state[1]

In [None]:
grad_model = jax.grad(wrap_model_for_theta, argnums=(0))

In [None]:
action = jnp.linspace(-1, 1, 1)[:, None]
obs, state = env.reset()
obs = obs.astype(jnp.float32)
state = state.astype(jnp.float32)

In [None]:
grads = jax.vmap(grad_model)(action, state, model.env_state_normalizer, model.action_normalizer, model.static_params)
grads

In [None]:
jax.vmap(wrap_model_for_theta)(action, state, model.env_state_normalizer, model.action_normalizer, model.static_params)

### 1D-choice: Try to follow the given trajectory:

- find the actions that produce the given trajectory?
- This is exactly DPC?

In [None]:
def featurize_theta(obs):
    """The angle itself is difficult to properly """
    feat_obs = jnp.stack([jnp.sin(obs[:, 0] * jnp.pi), jnp.cos(obs[:, 0] * jnp.pi), obs[:, 1]], axis=-1)
    return feat_obs

@partial(jax.jit, static_argnums=(1, 4))
def loss_function(true_obs, model, init_obs, init_state, n_steps, proposed_action):
    proposed_actions = jnp.ones(shape=(true_obs.shape[0], true_obs.shape[1], 1)) * proposed_action

    pred_obs = simulate_episode(model, init_obs, init_state, n_steps, proposed_actions)

    feat_pred_obs = jax.vmap(featurize_theta)(pred_obs)
    feat_true_obs = jax.vmap(featurize_theta)(true_obs)

    return jnp.mean((feat_pred_obs - feat_true_obs)**2)

grad_loss_function = jax.grad(loss_function, argnums=(5))
hessian_loss_function = jax.grad(grad_loss_function, argnums=(5))

@partial(jax.jit, static_argnums=(1, 4))
def step(true_obs, model, init_obs, init_state, n_steps, proposed_action):
    return proposed_action - 1 / hessian_loss_function(true_obs, model, obs, state, n_steps, jnp.squeeze(proposed_action)) * grad_loss_function(true_obs, model, obs, state, n_steps, proposed_action)

In [None]:
obs, state = env.reset()
obs = obs.astype(jnp.float32)
state = state.astype(jnp.float32)

# actions = aprbs(n_steps, batch_size, 5, 100, next(data_rng))
actions = gen_actions(n_steps, batch_size, next(data_rng))

true_obs = simulate_episode(
    env,
    obs,
    state,
    n_steps=n_steps,
    actions=actions
)

In [None]:
proposed_action = gen_actions(1, batch_size, next(data_rng))
print("proposed: ", proposed_action)
print("true", actions[:, 0, :])

start = time.time()

proposed_action = step(true_obs, model, obs, state, n_steps, proposed_action)

end = time.time()

loss = loss_function(true_obs, model, obs, state, n_steps, proposed_action)

print(f"Loss: {loss}, Computation time: {end - start}")
print("proposed: ", proposed_action)
print("true", actions[:, 0, :])

#### use optax to optimize

- naturally this is a lot slower
- not sure yet if it makes sense to use something like this
- maybe it becomes necessary when model and actions are more complex?

In [None]:
import optax

In [None]:
def optimize(true_obs, model, init_obs, init_state, n_steps, initial_action):
    solver = optax.adabelief(learning_rate=1)
    action = initial_action

    opt_state = solver.init(action)
    
    for iter in range(1000):
        grad = grad_loss_function(true_obs, model, init_obs, init_state, n_steps, action)
        updates, opt_state = solver.update(grad, opt_state, action)
        action = optax.apply_updates(action, updates)

    return action

In [None]:
proposed_action = gen_actions(1, batch_size, next(data_rng))


print("proposed: ", proposed_action)
print("true", actions[:, 0, :])

start = time.time()

proposed_action = optimize(true_obs, model, obs, state, n_steps, proposed_action)

end = time.time()

loss = loss_function(true_obs, model, obs, state, n_steps, proposed_action)

print(f"Loss: {loss}, Computation time: {end - start}")
print("proposed: ", proposed_action)
print("true", actions[:, 0, :])

### Multi-dim choice: Try to follow the given trajectory

- optax likely already has some solvers?

In [None]:
def aprbs2(len, t_min, t_max, key):
    t = 0
    sig = []
    while t < len:
        steps_key, value_key, key = jax.random.split(key, 3)

        t_step = jax.random.randint(steps_key, shape=(1,), minval=t_min, maxval=t_max)
           
        sig.append(jnp.ones(t_step) * jax.random.uniform(value_key, shape=(1,), minval=-1, maxval=1))
        t += t_step.item()

    return jnp.hstack(sig)[:len]

def aprbs(n_steps, batch_size, t_min, t_max, key):
    actions = []
    for _ in range(batch_size):
        subkey, key = jax.random.split(key)
        actions.append(aprbs2(n_steps, t_min, t_max, subkey)[..., None])
    return jnp.stack(actions, axis=0)

In [None]:
obs, state = env.reset()
obs = obs.astype(jnp.float32)
state = state.astype(jnp.float32)

actions = aprbs(n_steps, batch_size, 200, 500, next(data_rng))
# actions = gen_actions(n_steps, batch_size, next(data_rng))

true_obs = simulate_episode(
    env,
    obs,
    state,
    n_steps=n_steps,
    actions=actions
)

In [None]:
plot_episode(true_obs, actions, max_n=1)

In [None]:
def featurize_theta(obs):
    """The angle itself is difficult to properly """
    feat_obs = jnp.stack([jnp.sin(obs[:, 0] * jnp.pi), jnp.cos(obs[:, 0] * jnp.pi), obs[:, 1]], axis=-1)
    return feat_obs

@partial(jax.jit, static_argnums=(1, 4))
def loss_function(true_obs, model, init_obs, init_state, n_steps, proposed_actions):

    pred_obs = simulate_episode(model, init_obs, init_state, n_steps, proposed_actions)

    feat_pred_obs = jax.vmap(featurize_theta)(pred_obs)
    feat_true_obs = jax.vmap(featurize_theta)(true_obs)        
    
    return jnp.mean((feat_pred_obs - feat_true_obs)**2)

grad_loss_function = jax.grad(loss_function, argnums=(5))
hessian_loss_function = jax.grad(grad_loss_function, argnums=(5))

@partial(jax.jit, static_argnums=(1, 4))
def step(true_obs, model, init_obs, init_state, n_steps, proposed_actions, alpha):
    return proposed_actions - alpha * grad_loss_function(true_obs, model, obs, state, n_steps, proposed_actions)

def optimize(true_obs, model, init_obs, init_state, n_steps, proposed_actions):
    solver = optax.adabelief(learning_rate=1)

    opt_state = solver.init(proposed_actions)
    
    for iter in tqdm(range(100)):
        grad = grad_loss_function(true_obs, model, init_obs, init_state, n_steps, proposed_actions)
        updates, opt_state = solver.update(grad, opt_state, proposed_actions)
        proposed_actions = optax.apply_updates(proposed_actions, updates)

    return proposed_actions

In [None]:
pred_obs = simulate_episode(
    env,
    obs,
    state,
    n_steps=n_steps,
    actions=proposed_actions
)
plot_episode(
    pred_obs,
    proposed_actions,
    max_n=1
)

In [None]:
proposed_actions = aprbs(n_steps, batch_size, 200, 500, next(data_rng))

In [None]:
print("loss before:", loss_function(true_obs, model, obs, state, n_steps, proposed_actions))
start = time.time()
# for n in range(50_000):
#     proposed_actions = step(true_obs, model, obs, state, n_steps, proposed_actions, alpha=1000)

proposed_actions_after_opt = optimize(true_obs, model, obs, state, n_steps, proposed_actions)

end = time.time()

print(f"Computation time: {end - start}")
print("loss after:", loss_function(true_obs, model, obs, state, n_steps, proposed_actions_after_opt))

print(jnp.mean(proposed_actions - actions**2))

In [None]:
print("loss after:", loss_function(true_obs, model, obs, state, n_steps, proposed_actions))
print("loss after:", loss_function(true_obs, model, obs, state, n_steps, proposed_actions_after_opt))

In [None]:
print(jnp.mean(proposed_actions_after_opt - actions**2))

In [None]:
pred_obs = simulate_episode(
    env,
    obs,
    state,
    n_steps=n_steps,
    actions=proposed_actions_after_opt
)
plot_episode(
    pred_obs,
    proposed_actions_after_opt,
    max_n=1
)

In [None]:
plot_episode(
    true_obs,
    actions,
    max_n=1
)

In [None]:
def hessian(f):
    return jax.jacfwd(jax.jacrev(f))

In [None]:
jnp.linalg.inv(hessian(loss_function)(true_obs, model, obs, state, n_steps, proposed_actions))

In [None]:
# huh, I guess the vmapping needs to be moved up closer to the optimization routine
# not really sure where to start with this and also very hungry, continue tomorrow
# Also not really sure if hessian is the ways to go