In [None]:
%load_ext autoreload
%autoreload 2
%reload_ext line_profiler

In [None]:
from functools import partial

import time
from tqdm.notebook import tqdm
import numpy as np
import matplotlib.pyplot as plt
# plt.rcParams['text.usetex'] = True
import plotly.express as px
import plotly.graph_objects as go

In [None]:
import jax
import jax.numpy as jnp

import diffrax
import equinox as eqx
import optax

from haiku import PRNGSequence

In [None]:
import exciting_environments as excenvs

from density_estimation import update_kde_grid, update_kde_grid_multiple_observations
from metrics import JSDLoss
from model_utils import simulate_ahead
from plotting_utils import plot_sequence
from signals import generate_constant_action, aprbs
from optimization_utils import loss_function, optimize, soft_penalty

---

In [None]:
# setup PRNG
key = jax.random.PRNGKey(seed=21)

data_key, model_key, key = jax.random.split(key, 3)
data_rng = PRNGSequence(data_key)

- 1d:

In [None]:
observations = jnp.linspace(-2, 2, 100)[..., None, None, None]
result = jax.vmap(soft_penalty)(observations)
plt.plot(jnp.squeeze(observations), result)

- multidim

In [None]:
observations = jnp.stack([jnp.linspace(-2, 2, 100) for i in range(5)], axis=1)
observations = observations[..., None, None, :]

In [None]:
observations.shape

In [None]:
result = jax.vmap(soft_penalty, in_axes=(0))(observations)
plt.plot(jnp.squeeze(observations[..., 0]), result)

- unreduced

In [None]:
penalty_unreduced = lambda a, a_max: jax.nn.relu(jnp.abs(a) - a_max)

In [None]:
observations = jnp.linspace(-2, 2, 100)
observations = observations[None, :, None]

result = penalty_unreduced(observations, a_max=1)
plt.plot(jnp.squeeze(observations), jnp.squeeze(result))