In [None]:
%load_ext autoreload
%autoreload 2
%reload_ext line_profiler

In [None]:
from functools import partial

import time
from tqdm.notebook import tqdm
import numpy as np
import matplotlib.pyplot as plt
# plt.rcParams['text.usetex'] = True
import plotly.express as px
import plotly.graph_objects as go

In [None]:
import jax
import jax.numpy as jnp

import diffrax
import equinox as eqx
import optax

from haiku import PRNGSequence

In [None]:
import exciting_environments as excenvs

from density_estimation import update_kde_grid, update_kde_grid_multiple_observations
from metrics import JSDLoss
from model_utils import simulate_ahead
import plotting_utils
from plotting_utils import plot_sequence, append_predictions_to_sequence_plot, plot_sequence_and_prediction
from signals import generate_constant_action, aprbs
from optimization_utils import loss_function, optimize, soft_penalty

---

In [None]:
# setup PRNG
key = jax.random.PRNGKey(seed=21)

data_key, model_key, key = jax.random.split(key, 3)
data_rng = PRNGSequence(data_key)

In [None]:
batch_size = 20
tau = 1e-2 # 1e-3

env = excenvs.make(
    env_id='Pendulum-v0',
    batch_size=batch_size,
    tau=tau
)

model = excenvs.make(
    env_id='Pendulum-v0',
    batch_size=batch_size,
    tau=tau
)

### Test simulation:

In [None]:
obs, state = env.reset()
obs = obs.astype(jnp.float32)
state = state.astype(jnp.float32)
n_steps = 1_000

actions = aprbs(n_steps, batch_size, 200, 500, next(data_rng))

In [None]:
observations = jax.vmap(simulate_ahead, in_axes=(None, None, 0, 0, 0, None, None, None))(
    model,
    n_steps,
    obs,
    state,
    actions,
    model.env_state_normalizer[0, :],
    model.action_normalizer[0, :],
    {key: value[0, :] for (key, value) in model.static_params.items()}
)

print("observations.shape:", observations.shape)
print("actions.shape:", actions.shape)


print(" \n One of the trajectories:")
fig, axs = plot_sequence(
    observations=observations[0, ...],
    actions=actions[0, ...],
    tau=tau,
    obs_labels=[r"$\theta$", r"$\omega$"],
    action_labels=[r"$u$"],
);
plt.show()

### Optimize trajectory:

In [None]:
x1, x2 = [
    jnp.linspace(env.env_observation_space.low, env.env_observation_space.high, 40),
    jnp.linspace(env.env_observation_space.low, env.env_observation_space.high, 40)
]

x = jnp.meshgrid(*[x1, x2])
x = jnp.stack([x for x in x], axis=-1)
x = x.reshape(-1, 2)
n_grid_points = x.shape[0]

start_n_measurments = 0

bandwidth = 0.05
p_est = jnp.zeros([batch_size, n_grid_points, 1])

target_distribution = jnp.ones(shape=(n_grid_points, 1))
target_distribution *= 1 / (env.env_observation_space.high - env.env_observation_space.low)**2

n_prediction_steps = 50
n_time_steps = 5_000

In [None]:
grad_loss_function = jax.grad(loss_function, argnums=(0))

In [None]:
observations = []
actions = []

obs, state = env.reset()
obs = obs.astype(jnp.float32)
state = state.astype(jnp.float32)

proposed_actions = aprbs(n_prediction_steps, batch_size, 20, 50, next(data_rng))

for k in tqdm(range(n_time_steps)):
    
    proposed_actions, final_loss = optimize(
        grad_loss_function=grad_loss_function,
        proposed_actions=proposed_actions,
        model=model,
        init_obs=obs,
        init_state=state,
        n_steps=n_prediction_steps,
        p_est=p_est,
        x=x,
        start_n_measurments=start_n_measurments,
        bandwidth=bandwidth,
        target_distribution=target_distribution    
    )

    if k % 1000 == 0 and k > 0:
        print(f"Loss iteration: {k}: {final_loss}")
        fig, axs = plot_sequence_and_prediction(
            observations=jnp.stack(observations, axis=1)[0, :],
            actions=jnp.stack(actions, axis=1)[0, :],
            tau=tau,
            obs_labels=[r"$\theta$", r"$\omega$"],
            actions_labels=[r"$u$"],
            model=model,
            n_prediction_steps=n_prediction_steps,
            init_obs=obs[0, :],
            init_state=state[0, :],
            proposed_actions=proposed_actions[0, :]
        )
        plt.show()

    p_est = jax.vmap(update_kde_grid, in_axes=[0, None, 0, None, None])(
        p_est,
        x,
        obs,
        k,
        bandwidth
    )
    start_n_measurments += 1

    action = proposed_actions[:, 0, :]
    
    actions.append(action)
    observations.append(obs)

    obs, _, _, _, state = model.step(action, state)

    proposed_actions = proposed_actions.at[:, :-1, :].set(proposed_actions[:, 1:, :])

In [None]:
for j in range(min(batch_size, 10)):  # plot at max 10 of the batches
    fig, axs = plot_sequence(
        observations=jnp.stack(observations, axis=1)[j, ...],
        actions=jnp.stack(actions, axis=1)[j, ...],
        tau=tau,
        obs_labels=[r"$\theta$", r"$\omega$"],
        action_labels=[r"$u$"],
    );
    plt.show()

In [None]:
p_est = jnp.zeros([n_grid_points, 1])
p_est = update_kde_grid_multiple_observations(p_est, x, jnp.stack(observations, axis=1)[1, ...], n_observations=0, bandwidth=bandwidth)
fig, axs, cax = plotting_utils.plot_2d_kde_as_contourf(p_est, x, [r"$\theta$", r"$\omega$"])
fig.savefig("excited_pendulum_kde_contourf.png")

In [None]:
p_est = jnp.zeros([n_grid_points, 1])
p_est = update_kde_grid_multiple_observations(p_est, x, jnp.stack(observations, axis=1)[0, ...], n_observations=0, bandwidth=bandwidth)
fig, axs = plotting_utils.plot_2d_kde_as_surface(p_est, x, [r"$\theta$", r"$\omega$"])
# fig.savefig("excited_pendulum_kde_surface.png")

In [None]:
fig, axs = plotting_utils.plot_2d_kde_as_surface(jnp.abs(p_est - target_distribution), x, [r"$\theta$", r"$\omega$"])

In [None]:
fig, axs, cax = plotting_utils.plot_2d_kde_as_contourf(jnp.abs(p_est - target_distribution), x, [r"$\theta$", r"$\omega$"])
plt.colorbar(cax)

### Learn NODE on optimized trajectory:

In [None]:
observations = jnp.stack(observations, axis=1)
print(observations.shape)

actions = jnp.stack(actions, axis=1)
print(actions.shape)

- batches:    
    - $u_b$: (32, 200, 2)
    - $x_b$: (32, 200, 1)

In [None]:
from models import NeuralEulerODE

In [None]:
@eqx.filter_jit
def evaluate_model(model, obs, actions, n_steps, tau):  
    obs_dim = obs.shape[0]
    observations = jnp.zeros([n_steps, obs_dim])
    observations = observations.at[0, :].set(obs)

    def body_fun(n, carry):
        obs, observations = carry

        action = actions[n-1, :]
        #obs = jax.vmap(model, in_axes=(0, 0, None))(obs, action, tau)
        obs = model(obs, action, tau)
        obs = jnp.stack(
            [(((obs[..., 0] + 1) % 2) - 1), obs[..., 1]],
            axis=-1
        )

        observations = observations.at[n, :].set(obs)
        return (obs, observations)

    obs, observations = jax.lax.fori_loop(lower=1, upper=n_steps, body_fun=body_fun, init_val=(obs, observations))
    return observations

def featurize_theta(obs):
    """The angle itself is difficult to properly """
    feat_obs = jnp.stack([jnp.sin(obs[..., 0] * jnp.pi), jnp.cos(obs[..., 0] * jnp.pi), obs[..., 1]], axis=-1)
    return feat_obs

In [None]:
@eqx.filter_value_and_grad
def grad_loss(model, true_obs, actions, n_steps, tau):
    pred_obs = jax.vmap(evaluate_model, in_axes=(None, 0, 0, None, None))(model, true_obs[:, 0, :], actions, n_steps, tau)

    feat_pred_obs = jax.vmap(featurize_theta, in_axes=(0))(pred_obs)
    feat_true_obs = jax.vmap(featurize_theta, in_axes=(0))(true_obs)
    # feat_pred_obs = featurize_theta(pred_obs)
    # feat_true_obs = featurize_theta(true_obs)

    # feat_pred_obs = pred_obs
    # feat_true_obs = true_obs
    
    return jnp.mean((feat_pred_obs - feat_true_obs) ** 2)

@eqx.filter_jit
def make_step(model, observations, actions, n_steps, tau, opt_state):
    loss, grads = grad_loss(model, observations, actions, n_steps, tau)
    updates, opt_state = optim.update(grads, opt_state)
    model = eqx.apply_updates(model, updates)
    return loss, model, opt_state

In [None]:
def dataloader(arrays, batch_size, *, key):
    dataset_size = arrays.shape[0]
    indices = jnp.arange(dataset_size)
    
    while True:
        perm = jax.random.permutation(key, indices)
        (key,) = jax.random.split(key, 1)
        start = 0
        end = batch_size

        while end < dataset_size:
            batch_perm = perm[start:end]
            yield arrays[batch_perm]
            start = end
            end = start + batch_size

In [None]:
obs_dim = env.env_observation_space.shape[-1]
action_dim = env.action_space.shape[-1]

model = NeuralEulerODE(obs_dim=obs_dim, action_dim=action_dim, width_size=128, depth=3, key=model_key)
model

batch_size = 128
sequence_length = n_prediction_steps

lr = 1e-4
n_iters = 5_000
print_every = 100

data = jnp.concatenate([observations, actions], axis=-1)
data = data.reshape((int(observations.shape[0] * observations.shape[1] / sequence_length), sequence_length, 3))

optim = optax.adabelief(lr)
opt_state = optim.init(eqx.filter(model, eqx.is_inexact_array))

loader_key, key = jax.random.split(key)

for (iter, data_batch) in zip(tqdm(range(n_iters)), dataloader(data, batch_size, key=loader_key)):

    observations_batch = data_batch[..., :obs_dim]
    actions_batch = data_batch[..., obs_dim:]

    #observations_batch = data[:64, :, :obs_dim]
    #actions_batch = data[:64, :, obs_dim:]

    # observations = 

    start = time.time()
    loss, model, opt_state = make_step(model, observations_batch, actions_batch, sequence_length, tau, opt_state)
    end = time.time()

    if (iter % print_every) == 0 or iter == n_iters - 1:
        print(f"iter: {iter}, Loss: {loss}, Computation time: {end - start}")

    if (iter % 1_000) == 0:

        ### ----
        model_observations = jax.vmap(evaluate_model, in_axes=(None, 0, 0, None, None))(model, observations[:, 0, :], actions, 1000, tau)
        # model_observations = jnp.stack(
        #     [(((model_observations[..., 0] + 1) % 2) - 1), model_observations[..., 1]],
        #     axis=-1
        # )
        fig, axs = plot_sequence(
            observations=model_observations[0, ...],
            actions=actions[0, :1000],
            tau=tau,
            obs_labels=[r"$\theta$", r"$\omega$"],
            action_labels=[r"$u$"],
        );
        plt.show()

        fig, axs = plot_sequence(
            observations=observations[0, :1000],
            actions=actions[0, :1000],
            tau=tau,
            obs_labels=[r"$\theta$", r"$\omega$"],
            action_labels=[r"$u$"],
        );
        plt.show()
        
        ### ----

In [None]:
model_observations = jax.vmap(evaluate_model, in_axes=(None, 0, 0, None, None))(model, observations[:, 0, :], actions, 1000, tau)

In [None]:
# obs = jnp.stack(
#     [(((obs[..., 0] + 1) % 2) - 1), obs[..., 1]],
#     axis=-1
# )

# model_observations = jnp.stack(
#     [(((model_observations[..., 0] + 1) % 2) - 1), model_observations[..., 1]],
#     axis=-1
# )

In [None]:
fig, axs = plot_sequence(
    observations=model_observations[0, ...],
    actions=actions[0, :1000],
    tau=tau,
    obs_labels=[r"$\theta$", r"$\omega$"],
    action_labels=[r"$u$"],
);

In [None]:
fig, axs = plot_sequence(
    observations=observations[0, :1000],
    actions=actions[0, :1000],
    tau=tau,
    obs_labels=[r"$\theta$", r"$\omega$"],
    action_labels=[r"$u$"],
);

---

In [None]:
loader_key, key = jax.random.split(key)
n_iters = 10


for (iter, data_batch) in zip(range(n_iters), dataloader(data, batch_size=batch_size, key=loader_key)):
    print(data_batch.shape)

### Use learned model to make optimize trajectories

In [None]:
model

In [None]:
env = excenvs.make(
    env_id='Pendulum-v0',
    batch_size=1,
    tau=tau
)

In [None]:
def plot_sequence_and_prediction(
        observations,
        actions,
        tau,
        obs_labels,
        actions_labels,
        model,
        n_prediction_steps,
        init_obs,
        init_state,
        proposed_actions
    ):
    """Plots the current trajectory and appends the predictions from the optimization."""
    
    fig, axs = plot_sequence(
        observations=observations,
        actions=actions,
        tau=tau,
        obs_labels=obs_labels,
        action_labels=actions_labels,
    )

    pred_observations = simulate_ahead(
        model=model,
        n_steps=n_prediction_steps,
        obs=init_obs,
        state=init_state,
        actions=proposed_actions,
        env_state_normalizer=0, # model.env_state_normalizer[0, :],
        action_normalizer=0, #model.action_normalizer[0, :],
        static_params=0, #{key: value[0, :] for (key, value) in model.static_params.items()}
    )

    fig, axs = append_predictions_to_sequence_plot(
        fig=fig,
        axs=axs,
        starting_step=observations.shape[0],
        pred_observations=pred_observations,
        proposed_actions=proposed_actions,
        tau=tau,
        obs_labels=obs_labels,
        action_labels=actions_labels,
    )

    return fig, axs

In [None]:
@eqx.filter_jit
def simulate_ahead(
    model: excenvs.core_env.CoreEnvironment,  # typehint for the time being...
    n_steps: int,
    obs: jnp.ndarray,
    state: jnp.ndarray,
    actions: jnp.ndarray,
    tau
) -> jnp.ndarray:
    """Uses the given model to look ahead and simulate future observations.
    
    Args:
        model: The model to use in the simulation
        n_steps: The number of steps to simulate into the future
        obs: The current observations from which to start the simulation
        state: The current state from which to start the simulation
        actions: The actions to apply in each step of the simulation

    Returns:
        observations: The gathered observations
    """

    obs_dim = obs.shape[0]
    observations = jnp.zeros([n_steps, obs_dim])
    observations = observations.at[0, :].set(obs)

    # if isinstance(model, excenvs.core_env.CoreEnvironment):
    #     step = lambda action, state: model.step(action, state)
    # else:
    #     step = lambda action, state: model(action, state)

    def body_fun(n, carry):
        obs, state, observations = carry

        action = actions[n-1, :]

        obs = model(obs, action, tau)

        # obs, _, _, _, state = step(action, state)
        observations = observations.at[n, :].set(obs)

        return (obs, state, observations)

    obs, state, observations = jax.lax.fori_loop(lower=1, upper=n_steps, body_fun=body_fun, init_val=(obs, state, observations))

    return observations

@eqx.filter_jit
def loss_function(
        actions,
        model,
        init_obs,
        init_state,
        n_steps,
        p_est,
        x,
        start_n_measurments,
        bandwidth,
        target_distribution
    ):

    print("Recompiling you idiot")
    
    observations = simulate_ahead(
        model=model,
        n_steps=n_steps,
        obs=init_obs,
        state=init_state,
        actions=actions,
        tau = 1e-2
    )

    p_est = update_kde_grid_multiple_observations(p_est, x, observations, start_n_measurments, bandwidth)    
    
    loss = JSDLoss(
        p=p_est,
        q=target_distribution
    )

    # TODO: pull this automatically, maybe penalty_kwargs or something
    rho_obs = 1e4
    rho_act = 1e4
    penalty_terms = rho_obs * soft_penalty(a=observations, a_max=1) + rho_act * soft_penalty(a=actions, a_max=1)

    return loss + penalty_terms


def optimize(
        grad_loss_function,
        proposed_actions,
        model,
        init_obs,
        init_state,
        n_steps,
        p_est,
        x,
        start_n_measurments,
        bandwidth,
        target_distribution
    ):

    solver = optax.adabelief(learning_rate=1e-1)
    opt_state = solver.init(proposed_actions)

    for iter in range(5):
        grad = grad_loss_function(
            proposed_actions,
            model,
            init_obs,
            init_state,
            n_steps,
            p_est,
            x,
            start_n_measurments,
            bandwidth,
            target_distribution
        )
        updates, opt_state = solver.update(grad, opt_state, proposed_actions)
        proposed_actions = optax.apply_updates(proposed_actions, updates)

    return proposed_actions, 0

In [None]:
x1, x2 = [
    jnp.linspace(env.env_observation_space.low, env.env_observation_space.high, 40),
    jnp.linspace(env.env_observation_space.low, env.env_observation_space.high, 40)
]

x = jnp.meshgrid(*[x1, x2])
x = jnp.stack([x for x in x], axis=-1)
x = x.reshape(-1, 2)
n_grid_points = x.shape[0]

start_n_measurments = 0

bandwidth = 0.05
p_est = jnp.zeros([n_grid_points, 1])

target_distribution = jnp.ones(shape=(n_grid_points, 1))
target_distribution *= 1 / (env.env_observation_space.high - env.env_observation_space.low)**2

n_prediction_steps = 10
n_time_steps = 5_000

In [None]:
grad_loss_function = jax.grad(loss_function, argnums=(0))

In [None]:
observations = []
actions = []

obs, state = env.reset()
obs = obs.astype(jnp.float32)[0, ...]
state = state.astype(jnp.float32)[0, ...]

proposed_actions = aprbs(n_prediction_steps, batch_size, 20, 50, next(data_rng))[0, ...]

for k in tqdm(range(n_time_steps)):
    
    proposed_actions, final_loss = optimize(
        grad_loss_function=grad_loss_function,
        proposed_actions=proposed_actions,
        model=model,
        init_obs=obs,
        init_state=state,
        n_steps=n_prediction_steps,
        p_est=p_est,
        x=x,
        start_n_measurments=start_n_measurments,
        bandwidth=bandwidth,
        target_distribution=target_distribution    
    )

    if k % 100 == 0 and k > 0:
        print(f"Loss iteration: {k}: {final_loss}")
        fig, axs = plot_sequence_and_prediction(
            observations=jnp.stack(observations, axis=0),
            actions=jnp.stack(actions, axis=0),
            tau=tau,
            obs_labels=[r"$\theta$", r"$\omega$"],
            actions_labels=[r"$u$"],
            model=model,
            n_prediction_steps=n_prediction_steps,
            init_obs=obs[:],
            init_state=state[:],
            proposed_actions=proposed_actions[:]
        )
        plt.show()

    p_est = update_kde_grid(
        p_est,
        x,
        obs,
        k,
        bandwidth
    )
    start_n_measurments += 1

    action = proposed_actions[0, :]
    
    actions.append(action)
    observations.append(obs)

    obs, _, _, _, state = env.step(action[None, :], state[None, :])

    obs = jnp.squeeze(obs)
    state = jnp.squeeze(state)

    proposed_actions = proposed_actions.at[:-1, :].set(proposed_actions[1:, :])

In [None]:
fig, axs = plot_sequence(
    observations=jnp.stack(observations),
    actions=jnp.stack(actions),
    tau=tau,
    obs_labels=[r"$\theta$", r"$\omega$"],
    action_labels=[r"$u$"],
);

In [None]:
jnp.stack(observations, axis=0).shape

### Debug jit recompilation problems:

In [None]:
@eqx.filter_jit
def loss_function(
        actions,
        model,
        init_obs,
        init_state,
        n_steps,
        p_est,
        x,
        start_n_measurments,
        bandwidth,
        target_distribution
    ):

    observations = simulate_ahead(
        model=model,
        n_steps=n_steps,
        obs=init_obs,
        state=init_state,
        actions=actions,
        tau=1e-2
    )

    p_est = update_kde_grid_multiple_observations(p_est, x, observations, start_n_measurments, bandwidth)    
    
    loss = JSDLoss(
        p=p_est,
        q=target_distribution
    )

    # TODO: pull this automatically, maybe penalty_kwargs or something
    rho_obs = 1e4
    rho_act = 1e4
    penalty_terms = rho_obs * soft_penalty(a=observations, a_max=1) + rho_act * soft_penalty(a=actions, a_max=1)

    return loss + penalty_terms

In [None]:
obs, state = env.reset()
obs = obs.astype(jnp.float32)[0, ...]
state = state.astype(jnp.float32)[0, ...]


for i in range(1000):
    start = time.time()
    jax.grad(loss_function, argnums=(0))(
        proposed_actions,
        model,
        obs,
        state,
        10,
        p_est,
        x,
        0,
        bandwidth,
        target_distribution
    )
    
    end = time.time()
    if (i % 100) == 0:
        print(end - start)

In [None]:
@eqx.filter_jit
@eqx.debug.assert_max_traces(max_traces=1)
def dummy_function(obs, model, tau):
    print("doing it again baby")
    obs = model(obs, action, tau)
    return obs

In [None]:
obs, state = env.reset()
obs = obs.astype(jnp.float32)[0, ...]

for i in range(1000):
    obs = dummy_function(obs, model, tau)