In [None]:
%load_ext autoreload
%autoreload 2
%reload_ext line_profiler

In [None]:
from functools import partial

import time
from tqdm.notebook import tqdm
import numpy as np
import matplotlib.pyplot as plt
# plt.rcParams['text.usetex'] = True
import plotly.express as px
import plotly.graph_objects as go

In [None]:
import jax
import jax.numpy as jnp

import diffrax
import equinox as eqx
import optax

from haiku import PRNGSequence

In [None]:
import exciting_environments as excenvs

from density_estimation import update_kde_grid, update_kde_grid_multiple_observations
from metrics import JSDLoss
from model_utils import simulate_ahead, simulate_ahead_with_env
import plotting_utils
from plotting_utils import plot_sequence, append_predictions_to_sequence_plot, plot_sequence_and_prediction
from signals import generate_constant_action, aprbs
from optimization_utils import loss_function, optimize, soft_penalty

---

In [None]:
# setup PRNG
key = jax.random.PRNGKey(seed=21)

data_key, model_key, key = jax.random.split(key, 3)
data_rng = PRNGSequence(data_key)

In [None]:
batch_size = 20
tau = 1e-2

env = excenvs.make(
    env_id='Pendulum-v0',
    batch_size=batch_size,
    tau=tau
)

### Test simulation:

- starting from the intial state/obs ($\mathbf{x}_0$ / $\mathbf{y}_0$)
- apply $N = 999$ actions $\mathbf{u}_0 \dots \mathbf{u}_N$ (**here**: random APRBS actions)
- which results in the state trajectory $\mathbf{x}_0 ... \mathbf{x}_N+1$ with $N+1 = 1000$ elements

In [None]:
obs, state = env.reset()
obs = obs.astype(jnp.float32)
state = state.astype(jnp.float32)
n_steps = 999

actions = aprbs(n_steps, batch_size, 200, 500, next(data_rng))

In [None]:
observations = jax.vmap(simulate_ahead_with_env, in_axes=(None, 0, 0, 0, 0, 0, 0))(
    env,
    obs,
    state,
    actions,
    env.env_state_normalizer,
    env.action_normalizer,
    env.static_params
)

print("actions.shape:", actions.shape)
print("observations.shape:", observations.shape)

print(" \n One of the trajectories:")
fig, axs = plot_sequence(
    observations=observations[0, ...],
    actions=actions[0, ...],
    tau=tau,
    obs_labels=[r"$\theta$", r"$\omega$"],
    action_labels=[r"$u$"],
);
plt.show()

### Optimize trajectory:

- use the environment to predict the coming states and use the (perfect) prediction to optimize the trajectories

In [None]:
x1, x2 = [
    jnp.linspace(env.env_observation_space.low, env.env_observation_space.high, 40),
    jnp.linspace(env.env_observation_space.low, env.env_observation_space.high, 40)
]

x = jnp.meshgrid(*[x1, x2])
x = jnp.stack([x for x in x], axis=-1)
x = x.reshape(-1, 2)
n_grid_points = x.shape[0]

start_n_measurments = jnp.array([0])

bandwidth = 0.05
p_est = jnp.zeros([batch_size, n_grid_points, 1])

target_distribution = jnp.ones(shape=(n_grid_points, 1))
target_distribution *= 1 / (env.env_observation_space.high - env.env_observation_space.low)**2

n_prediction_steps = 50
n_time_steps = 4_999  # results in 5000 observations per trajectory

In [None]:
grad_loss_function = jax.grad(loss_function, argnums=(3))  # derivatve w.r.t. the actions

In [None]:
observations = []
actions = []

obs, state = env.reset()
obs = obs.astype(jnp.float32)
state = state.astype(jnp.float32)

observations.append(obs)

proposed_actions = aprbs(n_prediction_steps, batch_size, 20, 50, next(data_rng))

for k in tqdm(range(n_time_steps)):
    
    proposed_actions, final_loss = optimize(
        grad_loss_function=grad_loss_function,
        proposed_actions=proposed_actions,
        model=env,
        init_obs=obs,
        init_state=state,
        p_est=p_est,
        x=x,
        start_n_measurments=start_n_measurments,
        bandwidth=bandwidth,
        tau=tau,
        target_distribution=target_distribution
    )

    if k % 1000 == 0 and k > 0:
        print(f"Loss iteration: {k}: {final_loss[0]}")
        fig, axs = plot_sequence_and_prediction(
            observations=jnp.stack(observations, axis=1)[0, :],
            actions=jnp.stack(actions, axis=1)[0, :],
            tau=tau,
            obs_labels=[r"$\theta$", r"$\omega$"],
            actions_labels=[r"$u$"],
            model=env,
            init_obs=obs[0, :],
            init_state=state[0, :],
            proposed_actions=proposed_actions[0, :]
        )
        plt.show()

    p_est = jax.vmap(update_kde_grid, in_axes=[0, None, 0, None, None])(
        p_est, x, obs, k, bandwidth
    )
    start_n_measurments = start_n_measurments + 1

    action = proposed_actions[:, 0, :]
    obs, _, _, _, state = env.step(action, state)
    proposed_actions = proposed_actions.at[:, :-1, :].set(proposed_actions[:, 1:, :])

    # store current trajectory
    actions.append(action)
    observations.append(obs)

In [None]:
for j in range(min(batch_size, 10)):  # plot at max 10 of the batches
    fig, axs = plot_sequence(
        observations=jnp.stack(observations, axis=1)[j, ...],
        actions=jnp.stack(actions, axis=1)[j, ...],
        tau=tau,
        obs_labels=[r"$\theta$", r"$\omega$"],
        action_labels=[r"$u$"],
    );
    plt.show()

In [None]:
p_est = jnp.zeros([n_grid_points, 1])
p_est = update_kde_grid_multiple_observations(p_est, x, jnp.stack(observations, axis=1)[1, ...], n_observations=0, bandwidth=bandwidth)
fig, axs, cax = plotting_utils.plot_2d_kde_as_contourf(p_est, x, [r"$\theta$", r"$\omega$"])
# fig.savefig("excited_pendulum_kde_contourf.png")

In [None]:
p_est = jnp.zeros([n_grid_points, 1])
p_est = update_kde_grid_multiple_observations(p_est, x, jnp.stack(observations, axis=1)[0, ...], n_observations=0, bandwidth=bandwidth)
fig, axs = plotting_utils.plot_2d_kde_as_surface(p_est, x, [r"$\theta$", r"$\omega$"])
fig.suptitle("Vanilla KDE")
# fig.savefig("excited_pendulum_kde_surface.png")
plt.show()

fig, axs = plotting_utils.plot_2d_kde_as_surface(jnp.abs(p_est - target_distribution), x, [r"$\theta$", r"$\omega$"])
fig.suptitle("Difference")

plt.show()

fig, axs, cax = plotting_utils.plot_2d_kde_as_contourf(jnp.abs(p_est - target_distribution), x, [r"$\theta$", r"$\omega$"])
plt.colorbar(cax)
fig.suptitle("Abs Difference")

plt.show()

### Learn NODE on optimized trajectory:

In [None]:
observations = jnp.stack(observations, axis=1)
print(observations.shape)

actions = jnp.stack(actions, axis=1)
print(actions.shape)

- batches:    
    - $u_b: (32, 50, 2)$
    - $x_b: (32, 50, 1)$

In [None]:
from models import NeuralEulerODE, NeuralEulerODEPendulum

In [None]:
def featurize_theta(obs):
    """The angle itself is difficult to properly interpret in the loss as angles
    such as 1.99 * pi and 0 are essentially the same. Therefore the angle is 
    transformed to sin(phi) and cos(phi) for comparison in the loss."""
    feat_obs = jnp.stack([jnp.sin(obs[..., 0] * jnp.pi), jnp.cos(obs[..., 0] * jnp.pi), obs[..., 1]], axis=-1)
    return feat_obs

@eqx.filter_value_and_grad
def grad_loss(model, true_obs, actions, tau):
  
    pred_obs = jax.vmap(simulate_ahead, in_axes=(None, 0, 0, None))(model, true_obs[:, 0, :], actions, tau)

    feat_pred_obs = jax.vmap(featurize_theta, in_axes=(0))(pred_obs)
    feat_true_obs = jax.vmap(featurize_theta, in_axes=(0))(true_obs)
    
    return jnp.mean((feat_pred_obs - feat_true_obs) ** 2)

@eqx.filter_jit
def make_step(model, observations, actions, tau, opt_state):
    loss, grads = grad_loss(model, observations, actions, tau)
    updates, opt_state = optim.update(grads, opt_state)
    model = eqx.apply_updates(model, updates)
    return loss, model, opt_state

In [None]:
def dataloader(observations_dataset, actions_dataset, batch_size, *, key):
    dataset_size = observations_dataset.shape[0]
    assert actions_dataset.shape[0] == dataset_size
    assert actions_dataset.shape[1] == observations_dataset.shape[1] - 1

    indices = jnp.arange(dataset_size)
    
    while True:
        perm = jax.random.permutation(key, indices)
        (key,) = jax.random.split(key, 1)
        start = 0
        end = batch_size

        while end < dataset_size:
            batch_perm = perm[start:end]
            yield tuple([observations_dataset[batch_perm], actions_dataset[batch_perm]])
            start = end
            end = start + batch_size

- training on the gathered trajectories:

In [None]:
obs_dim = env.env_observation_space.shape[-1]
action_dim = env.action_space.shape[-1]

model = NeuralEulerODEPendulum(obs_dim=obs_dim, action_dim=action_dim, width_size=128, depth=3, key=model_key)

batch_size = 128
sequence_length = n_prediction_steps

lr = 1e-4
n_iters = 5_000
print_every = 100

observations_dataset = observations.reshape(int(observations.shape[0] * observations.shape[1] / sequence_length), sequence_length, observations.shape[-1])
print("observations_datset.shape", observations_dataset.shape)

actions_dataset = jnp.concatenate([actions, jnp.ones((actions.shape[0], 1, actions.shape[-1]))], axis=1)
actions_dataset = actions_dataset.reshape(int(actions_dataset.shape[0] * actions_dataset.shape[1] / sequence_length), sequence_length, actions_dataset.shape[-1])
actions_dataset = actions_dataset[:, :-1, :]
print("actions_dataset.shape", actions_dataset.shape)

optim = optax.adabelief(lr)
opt_state = optim.init(eqx.filter(model, eqx.is_inexact_array))

loader_key, key = jax.random.split(key)

In [None]:
for (iter, (observations_batch, actions_batch)) in zip(range(n_iters), dataloader(observations_dataset, actions_dataset, batch_size=batch_size, key=loader_key)):
    pred_obs = jax.vmap(simulate_ahead, in_axes=(None, 0, 0, None))(model, observations_batch[:, 0, :], actions_batch, tau)
    
    start = time.time()
    loss, model, opt_state = make_step(model, observations_batch, actions_batch, tau, opt_state)
    end = time.time()

    if (iter % print_every) == 0 or iter == n_iters - 1:
        print(f"iter: {iter}, Loss: {loss}, Computation time: {end - start}")

- eval on the gathered trajectories:

In [None]:
print(observations.shape)
print(actions.shape)

In [None]:
from plotting_utils import plot_model_performance

In [None]:
fig, axs = plot_model_performance(
    model=model,
    true_observations=observations[0, :1000],
    actions=actions[0, :999],
    tau=tau,
    obs_labels=[r"$\theta$", r"$\omega$"],
    action_labels=[r"$u$"],
);
plt.plot()

In [None]:
# test dataloader and datasets

loader_key, key = jax.random.split(key)
n_iters = 10

for (iter, (observations_batch, actions_batch)) in zip(range(n_iters), dataloader(observations_dataset, actions_dataset, batch_size=batch_size, key=loader_key)):
    print(observations_batch.shape)
    print(actions_batch.shape)

### Use learned model to make optimize trajectories

In [None]:
model

In [None]:
env = excenvs.make(
    env_id='Pendulum-v0',
    batch_size=1,
    tau=tau
)

In [None]:
x1, x2 = [
    jnp.linspace(env.env_observation_space.low, env.env_observation_space.high, 40),
    jnp.linspace(env.env_observation_space.low, env.env_observation_space.high, 40)
]

x = jnp.meshgrid(*[x1, x2])
x = jnp.stack([x for x in x], axis=-1)
x = x.reshape(-1, 2)
n_grid_points = x.shape[0]

start_n_measurments = jnp.array([0])

bandwidth = 0.05
p_est = jnp.zeros([1, n_grid_points, 1])

target_distribution = jnp.ones(shape=(n_grid_points, 1))
target_distribution *= 1 / (env.env_observation_space.high - env.env_observation_space.low)**2

n_prediction_steps = 50
n_time_steps = 5_000

In [None]:
grad_loss_function = jax.grad(loss_function, argnums=(3))

In [None]:
observations = []
actions = []

obs, state = env.reset()
obs = obs.astype(jnp.float32)
state = state.astype(jnp.float32)

observations.append(obs)

proposed_actions = aprbs(n_prediction_steps, 1, 20, 50, next(data_rng))

for k in tqdm(range(n_time_steps)):
    
    proposed_actions, final_loss = optimize(
        grad_loss_function=grad_loss_function,
        proposed_actions=proposed_actions,
        model=model,
        init_obs=obs,
        init_state=state,
        p_est=p_est,
        x=x,
        start_n_measurments=start_n_measurments,
        bandwidth=bandwidth,
        tau=tau,
        target_distribution=target_distribution
    )

    if k % 1000 == 0 and k > 0:
        print(f"Loss iteration: {k}: {final_loss[0]}")
        fig, axs = plot_sequence_and_prediction(
            observations=jnp.stack(observations, axis=1)[0, :],
            actions=jnp.stack(actions, axis=1)[0, :],
            tau=tau,
            obs_labels=[r"$\theta$", r"$\omega$"],
            actions_labels=[r"$u$"],
            model=env,
            init_obs=obs[0, :],
            init_state=state[0, :],
            proposed_actions=proposed_actions[0, :]
        )
        plt.show()

    p_est = jax.vmap(update_kde_grid, in_axes=[0, None, 0, None, None])(
        p_est, x, obs, k, bandwidth
    )
    start_n_measurments = start_n_measurments + 1

    action = proposed_actions[:, 0, :]
    obs, _, _, _, state = env.step(action, state)
    proposed_actions = proposed_actions.at[:, :-1, :].set(proposed_actions[:, 1:, :])

    # store current trajectory
    actions.append(action)
    observations.append(obs)

In [None]:
fig, axs = plot_sequence(
    observations=jnp.stack(observations),
    actions=jnp.stack(actions),
    tau=tau,
    obs_labels=[r"$\theta$", r"$\omega$"],
    action_labels=[r"$u$"],
);

## What has been done in this notebook:
- use the perfect model to compute trajectories with informative data
- use the gathered data to fit a model
- use the fitted model in prediction to produce trajectories

### Results:
Works decently but the optimization sometimes gets stuck and does the same thing over and over. Just stays in the origin. Not sure if this is just a bug or what is going on here... **To be debugged**