In [None]:
%load_ext autoreload
%autoreload 2
%reload_ext line_profiler

In [None]:
from functools import partial

import time
from tqdm.notebook import tqdm
import numpy as np
import matplotlib.pyplot as plt
import plotly.express as px
import plotly.graph_objects as go

In [None]:
import jax
import jax.numpy as jnp

import diffrax
import equinox as eqx
import optax

from haiku import PRNGSequence

In [None]:
import os
os.environ["TF_CPP_MIN_LOG_LEVEL"]="0"

In [None]:
import exciting_environments as excenvs
from models import NeuralEulerODE

### NODE training on pendulum env

In [None]:
# setup PRNG
key = jax.random.PRNGKey(seed=21)

data_key, model_key, key = jax.random.split(key, 3)
data_rng = PRNGSequence(data_key)

### Simulate from the env:

In [None]:
batch_size = 64
n_steps = 1000
tau = 1e-3

env = excenvs.make(
    env_id='Pendulum-v0',
    batch_size=batch_size,
    tau=tau
)

In [None]:
@partial(jax.jit, static_argnums=(0, 1))
def gen_actions(n_steps, batch_size, key):
    actions = jax.random.uniform(key, shape=(batch_size, 1, 1), minval=-1, maxval=1)
    actions = jnp.repeat(actions, repeats=n_steps, axis=1)
    return actions
                                 

def aprbs2(len, t_min, t_max, key):
    t = 0
    sig = []
    while t < len:
        steps_key, value_key, key = jax.random.split(key, 3)

        t_step = jax.random.randint(steps_key, shape=(1,), minval=t_min, maxval=t_max)
           
        sig.append(jnp.ones(t_step) * jax.random.uniform(value_key, shape=(1,), minval=-1, maxval=1))
        t += t_step.item()

    return jnp.hstack(sig)[:len]

def aprbs(n_steps, batch_size, t_min, t_max, key):
    actions = []
    for _ in range(batch_size):
        subkey, key = jax.random.split(key)
        actions.append(aprbs2(n_steps, t_min, t_max, subkey)[..., None])
    return jnp.stack(actions, axis=0)

In [None]:
# actions = aprbs(n_steps, batch_size, 5, 100, next(data_rng))
actions = gen_actions(n_steps, batch_size, next(data_rng))

In [None]:
@partial(jax.jit, static_argnums=(0, 3))
def simulate_epsiode(env, obs, state, n_steps, actions):

    batch_size, obs_dim = obs.shape
    observations = jnp.zeros([batch_size, n_steps, obs_dim])
    observations = observations.at[:, 0, :].set(obs)

    def body_fun(n, carry):
        obs, state, observations = carry

        action = actions[:, n, :]
        obs, reward, terminated, truncated, state = env.step(action, state)
        observations = observations.at[:, n, :].set(obs)

        return (obs, state, observations)

    obs, state, observations = jax.lax.fori_loop(lower=1, upper=n_steps, body_fun=body_fun, init_val=(obs, state, observations))

    # observations = jnp.stack(observations).swapaxes(0, 1)
    return observations

In [None]:
obs, state = env.reset()
obs = obs.astype(jnp.float32)
state = state.astype(jnp.float32)

# actions = aprbs(n_steps, batch_size, 5, 100, next(data_rng))
actions = gen_actions(n_steps, batch_size, next(data_rng))

observations = simulate_epsiode(
    env,
    obs,
    state,
    n_steps=n_steps,
    actions=actions
)
print(observations.shape)
print(actions.shape)

In [None]:
def plot_episode(observations, actions, max_n=2):
    n_plots = min(max_n, observations.shape[0])

    for idx in range(n_plots):
        plt.plot(observations[idx, :, 0], label="theta")
        plt.plot(observations[idx, :, 1], label="omega")
        plt.grid()
        plt.title("observations, timeseries")
        plt.legend()
        plt.show()
    
    for idx in range(n_plots):
        plt.plot(observations[idx, :, 0], observations[idx, :, 1], 'b.')
        plt.grid()
        plt.title("observations, together")
        plt.show()
    
    for idx in range(n_plots):
        plt.plot(actions[idx, :, 0])
        plt.grid()
        plt.title("actions, timeseries")
        plt.show()

In [None]:
plot_episode(observations, actions)

### Build NODE model

In [None]:
obs_dim = env.env_observation_space.shape[-1]
action_dim = env.action_space.shape[-1]

model = NeuralEulerODE(obs_dim=obs_dim, action_dim=action_dim, width_size=64, depth=1, key=model_key)

In [None]:
@eqx.filter_jit
# I guess eqx and jax dont like each other here?
def evaluate_model(model, obs, actions, n_steps, tau):  
    batch_size, obs_dim = obs.shape
    observations = jnp.zeros([batch_size, n_steps, obs_dim])
    observations = observations.at[:, 0, :].set(obs)

    def body_fun(n, carry):
        obs, observations = carry

        action = actions[:, n, :]
        obs = jax.vmap(model, in_axes=(0, 0, None))(obs, action, tau)
        # obs = jnp.stack(
        #     [(((obs[..., 0] + 1) % 2) - 1), obs[..., 1]],
        #     axis=-1
        # )

        observations = observations.at[:, n, :].set(obs)
        return (obs, observations)

    obs, observations = jax.lax.fori_loop(lower=1, upper=n_steps, body_fun=body_fun, init_val=(obs, observations))
    return observations

def featurize_theta(obs):
    """The angle itself is difficult to properly """
    feat_obs = jnp.stack([jnp.sin(obs[:, 0] * jnp.pi), jnp.cos(obs[:, 0] * jnp.pi), obs[:, 1]], axis=-1)
    return feat_obs

In [None]:
observations_model = evaluate_model(model, observations[:, 0, :], actions, n_steps, tau=tau)
plot_episode(observations_model, actions)

In [None]:
plot_episode(
    observations=jax.vmap(featurize_theta, in_axes=(0))(observations_model),
    actions=actions
)

In [None]:
plot_episode(observations, actions)

In [None]:
plot_episode(
    observations=jax.vmap(featurize_theta, in_axes=(0))(observations),
    actions=actions
)

### Train NODE:

In [None]:
@eqx.filter_value_and_grad
def grad_loss(model, true_obs, actions, n_steps, tau):
    pred_obs = evaluate_model(model, true_obs[:, 0, :], actions, n_steps, tau)

    feat_pred_obs = jax.vmap(featurize_theta, in_axes=(0))(pred_obs)
    feat_true_obs = jax.vmap(featurize_theta, in_axes=(0))(true_obs)
    
    return jnp.mean((feat_pred_obs - feat_true_obs) ** 2)

@eqx.filter_jit
def make_step(model, observations, actions, n_steps, tau, opt_state):
    loss, grads = grad_loss(model, observations, actions, n_steps, tau)
    updates, opt_state = optim.update(grads, opt_state)
    model = eqx.apply_updates(model, updates)
    return loss, model, opt_state

In [None]:
batch_size = 512
n_steps = 1_000
tau = 1e-3

obs_dim = env.env_observation_space.shape[-1]
action_dim = env.action_space.shape[-1]

model = NeuralEulerODE(obs_dim=obs_dim, action_dim=action_dim, width_size=64, depth=3, key=model_key)
lr = 1e-3


env = excenvs.make(
    env_id='Pendulum-v0',
    batch_size=batch_size,
    tau=tau
)

optim = optax.adabelief(lr)
opt_state = optim.init(eqx.filter(model, eqx.is_inexact_array))

n_episodes = 5000
print_every = 10

for episode in tqdm(range(n_episodes)):

    obs, state = env.reset()
    obs = obs.astype(jnp.float32)
    state = state.astype(jnp.float32)

    # actions = aprbs(n_steps, batch_size, 5, 100, next(data_rng))
    actions = gen_actions(n_steps, batch_size, next(data_rng))
    observations = simulate_epsiode(
        env,
        obs,
        state,
        n_steps=n_steps,
        actions=actions
    )

    start = time.time()
    loss, model, opt_state = make_step(model, observations, actions, n_steps, tau, opt_state)
    end = time.time()
    if (episode % print_every) == 0 or episode == n_episodes - 1:
        print(f"Episode: {episode}, Loss: {loss}, Computation time: {end - start}")

- fix gpu stuff
- look through code, put components in pyfiles
- rethink eval plots
- generate actions more efficiently
    - can this be jitted somehow? Its slowing down training immensely

In [None]:
observations_model = evaluate_model(model, observations[:, 0, :], actions, n_steps, tau=tau)
plot_episode(observations_model, actions)

In [None]:
plot_episode(observations, actions)

In [None]:
obs, state = env.reset()
obs = obs.astype(jnp.float32)
state = state.astype(jnp.float32)

# actions = aprbs(n_steps, batch_size, 5, 100, next(data_rng))
actions = gen_actions(n_steps, batch_size, next(data_rng))
observations = simulate_epsiode(
    env,
    obs,
    state,
    n_steps=n_steps,
    actions=actions
)

observations_model = evaluate_model(model, observations[:, 0, :], actions, n_steps, tau=tau)
plt.plot(observations_model[22, :, 0], label="model")
plt.plot(observations[22, :, 0], label="sim")
plt.legend()
plt.grid()
plt.show()

plt.plot(observations_model[22, :, 1], label="model")
plt.plot(observations[22, :, 1], label="sim")
plt.legend()
plt.grid()
plt.show()