In [None]:
import os
os.environ['CUDA_VISIBLE_DEVICES'] = '0'
os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"]="false"

In [None]:
import matplotlib.pyplot as plt
import jax
import jax.numpy as jnp
gpus = jax.devices()
jax.config.update("jax_default_device", gpus[0])

import jax_dataclasses as jdc
from haiku import PRNGSequence

In [None]:
def plot_sequence(observations, actions, tau, obs_labels, action_labels, fig=None, axs=None, dotted=False):
    """Plots a given sequence of observations and actions."""

    if fig is None or axs is None:
        fig, axs = plt.subplots(nrows=1, ncols=3, figsize=(18, 6))

    t = jnp.linspace(0, observations.shape[0] - 1, observations.shape[0]) * tau

    for observation_idx in range(observations.shape[-1]):
        axs[0].plot(
            t,
            jnp.squeeze(observations[..., observation_idx]),
            "." if dotted else "-",
            markersize=1,
            label=obs_labels[observation_idx],
        )

    axs[0].title.set_text("observations, timeseries")
    axs[0].legend()
    axs[0].set_ylabel(r"x")
    axs[0].set_xlabel("t in seconds")

    if observations.shape[-1] == 2:
        axs[1].scatter(jnp.squeeze(observations[..., 0]), jnp.squeeze(observations[..., 1]), s=1)
        axs[1].title.set_text("observation plane")
        axs[1].set_ylabel(obs_labels[1])
        axs[1].set_xlabel(obs_labels[0])

    if actions is not None:
        for action_idx in range(actions.shape[-1]):
            axs[2].plot(t[:-1], jnp.squeeze(actions[..., action_idx]), label=action_labels[action_idx])
        axs[2].title.set_text("actions, timeseries")
        axs[2].legend()
        axs[2].set_ylabel(r"u$")
        axs[2].set_xlabel(r"t in seconds")

    for ax in axs:
        ax.grid(True)

    fig.tight_layout()
    return fig, axs

In [None]:
@jdc.pytree_dataclass
class StaticParams:
    """Dataclass containing the static parameters of the environment."""
    c: jax.Array
    s: jax.Array
    l: jax.Array
    a: jax.Array
    m: jax.Array

In [None]:
def msdc_step(state, F, params, tau):
    x1, x2 = tuple(state)

    d_x1 = x2.copy()
    d_x2 = (
        F.item()
        - x1 * (params.s * params.l) / (jnp.sqrt(x1**2 + params.a**2))
        - params.c * x2
    ) / params.m

    x1 = x1 + tau * d_x1
    x2 = x2 + tau * d_x2

    return jnp.array([x1, x2])

In [None]:
static_params = StaticParams(
    c=10, s=800, l=0.17, a=0.25, m=5
)
tau = 0.01

In [None]:
l = jnp.arange(21, 205)[..., None]
N = 2048

k = jnp.linspace(0, N-1, N)[None, ...]

f_s = 100
f_0 = f_s / N

In [None]:
# generate multisine
data_rng = PRNGSequence(jax.random.PRNGKey(seed=0))

phi_l = jax.random.uniform(key=next(data_rng), shape=l.shape, minval=0, maxval=jnp.pi * 2)

actions = jnp.sin(2 * jnp.pi * l * f_0 / f_s * k + phi_l)
actions = jnp.sum(actions, axis=0)[..., None]

actions =  actions * 8 / jnp.std(actions)

In [None]:
# run simulation starting from the origin

state = jnp.array([0, 0])

states = []

for action in actions:
    next_state = msdc_step(state, action, static_params, tau)
    state = next_state

    states.append(state)

states = jnp.stack(states)

In [None]:
plot_sequence(states, actions[:-1], tau, ["deflection in m", "velocity in m/s"], ["force in N"])