In [27]:
%load_ext autoreload
%autoreload 2
%load_ext line_profiler
from cs285.envs.pendulum.pendulum_env import PendulumEnv
from cs285.envs.dt_sampler import ConstantSampler
from cs285.infrastructure.replay_buffer import ReplayBufferTrajectories
from cs285.infrastructure.utils import sample_n_trajectories, RandomPolicy
from cs285.agents.ode_agent import ODEAgent_Vanilla
from typing import Callable, Optional, Tuple, Sequence
import numpy as np
import gym
from cs285.infrastructure import pytorch_util as ptu
from torchdiffeq import odeint
from tqdm import trange
import jax
import jax.numpy as jnp
import equinox as eqx
import diffrax
from diffrax import diffeqsolve, Dopri5
import optax

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload
The line_profiler extension is already loaded. To reload it, use:
  %reload_ext line_profiler


In [28]:
key = jax.random.PRNGKey(0)
from jax.lib import xla_bridge
print(xla_bridge.get_backend().platform)

gpu


In [29]:
dt_sampler = ConstantSampler(dt=0.05)
env = PendulumEnv(
    dt_sampler=dt_sampler
)
mpc_dt_sampler = ConstantSampler(dt=0.05)
agent_key, new_agent_key = jax.random.split(key)
mlp_dyanmics_setup = {
    "hidden_size":128,
    "num_layers":4,
    "activation":"relu",
    "output_activation":"identity"
}
optimizer_name = "adamw"
optimizer_kwargs = {"learning_rate": 1e-3}
mb_agent = ODEAgent_Vanilla(
    env=env,
    key=agent_key,
    mlp_dynamics_setup=mlp_dyanmics_setup,
    optimizer_name=optimizer_name,
    optimizer_kwargs=optimizer_kwargs,
    ensemble_size=10,
    train_timestep=0.005,
    train_discount=1,
    mpc_horizon_steps=100,
    mpc_dt_sampler=mpc_dt_sampler,
    mpc_strategy="cem",
    mpc_discount=0.9,
    mpc_num_action_sequences=1000,
    mpc_timestep=0.05,
    cem_num_iters=4,
    cem_num_elites=5,
    cem_alpha=1,
)
replay_buffer = ReplayBufferTrajectories(seed=42)
trajs, _ = sample_n_trajectories(env, RandomPolicy(env=env), ntraj=10, max_length=200, key=key)
replay_buffer.add_rollouts(trajs)

batch_size = 64
for n in trange(2):
    for i in range(mb_agent.ensemble_size):
        obs, acs, times = jnp.empty(shape=(batch_size, 201, 3)), jnp.empty(shape=(batch_size, 201, 1)), jnp.empty(shape=(batch_size, 201))
        for m in range(batch_size):
            traj = replay_buffer.sample_rollout()
            obs.at[m].set(traj["observations"])
            acs.at[m].set(traj["actions"])
            times.at[m].set(jnp.cumsum(traj["dts"]))
        loss = mb_agent.batched_update(i=i, obs=obs, acs=acs, times=times)
# well, on cpu and gpu it all takes around 16 seconds per iteration

trajs, _ = sample_n_trajectories(
    env=env,
    policy=mb_agent,
    ntraj=10,
    max_length=200,
    key=key
)

100%|██████████| 10/10 [00:01<00:00,  5.05it/s]
100%|██████████| 2/2 [00:32<00:00, 16.01s/it]
100%|██████████| 10/10 [00:16<00:00,  1.66s/it]


In [30]:
with open("test", "wb") as f:
    eqx.tree_serialise_leaves(f, mb_agent)

In [31]:
new_mb_agent = ODEAgent_Vanilla(
    env=env,
    key=new_agent_key,
    mlp_dynamics_setup=mlp_dyanmics_setup,
    optimizer_name=optimizer_name,
    optimizer_kwargs=optimizer_kwargs,
    ensemble_size=10,
    train_timestep=0.005,
    train_discount=1,
    mpc_horizon_steps=100,
    mpc_dt_sampler=mpc_dt_sampler,
    mpc_strategy="cem",
    mpc_discount=0.9,
    mpc_num_action_sequences=1000,
    mpc_timestep=0.05,
    cem_num_iters=4,
    cem_num_elites=5,
    cem_alpha=1,
)

In [32]:
with open("test", "rb") as f:
    new_mb_agent = eqx.tree_deserialise_leaves(f, new_mb_agent)

In [33]:
new_mb_agent.ode_functions[0].mlp.layers[0].weight == mb_agent.ode_functions[0].mlp.layers[0].weight

Array([[False, False, False, False],
       [False, False, False, False],
       [False, False, False, False],
       [False, False, False, False],
       [False, False, False, False],
       [False, False, False, False],
       [False, False, False, False],
       [False, False, False, False],
       [False, False, False, False],
       [False, False, False, False],
       [False, False, False, False],
       [False, False, False, False],
       [False, False, False, False],
       [False, False, False, False],
       [False, False, False, False],
       [False, False, False, False],
       [False, False, False, False],
       [False, False, False, False],
       [False, False, False, False],
       [False, False, False, False],
       [False, False, False, False],
       [False, False, False, False],
       [False, False, False, False],
       [False, False, False, False],
       [False, False, False, False],
       [False, False, False, False],
       [False, False, False, False],
 

In [34]:
new_mb_agent.ode_functions[0].mlp.layers[0].weight

Array([[ 0.33654785, -0.39120317,  0.48370373,  0.4156152 ],
       [-0.06202388, -0.43098235, -0.40418017,  0.19395185],
       [ 0.4353912 ,  0.39089513, -0.30023468, -0.25820112],
       [ 0.25183928,  0.17674911,  0.44612384, -0.42434502],
       [-0.0796746 ,  0.34811842,  0.22427356, -0.16206837],
       [ 0.05441475, -0.02656889,  0.28011024,  0.48035014],
       [-0.24405074,  0.39190662,  0.10941935,  0.13282919],
       [ 0.19410288, -0.42802656, -0.38454342, -0.16508341],
       [-0.08832908,  0.31841052,  0.1397798 ,  0.03385365],
       [ 0.29156566, -0.29597306,  0.49948156,  0.27594852],
       [-0.09665036,  0.24842024,  0.29515135, -0.42700922],
       [-0.09207308,  0.49814785, -0.38990033,  0.4412793 ],
       [ 0.46399903,  0.20536566, -0.11606586,  0.19340885],
       [-0.145836  , -0.06244862,  0.20353329,  0.1559422 ],
       [ 0.21769631,  0.31571066,  0.28398633,  0.04734266],
       [-0.04182661,  0.4890033 ,  0.10447836,  0.45306945],
       [-0.36736226, -0.

In [35]:
mb_agent.ode_functions[0].mlp.layers[0].weight

Array([[-2.82259881e-01, -2.95423195e-02,  2.73427069e-01,
        -2.66601741e-01],
       [-3.74877751e-01,  1.64401859e-01, -3.15861881e-01,
        -3.15281212e-01],
       [-2.93381870e-01, -6.86187595e-02, -3.32467481e-02,
         2.62441218e-01],
       [-1.89717971e-02, -2.90948078e-02,  3.96218717e-01,
         1.51053160e-01],
       [ 4.72217441e-01, -1.00941524e-01,  1.02604851e-01,
         2.47997940e-01],
       [ 3.70253265e-01, -3.61286819e-01, -3.93150747e-01,
        -8.64373296e-02],
       [-3.28063190e-01,  3.33211601e-01,  2.97978222e-01,
        -1.63936347e-01],
       [-6.10761493e-02, -4.34538305e-01, -6.66882843e-02,
        -2.11347073e-01],
       [-4.73633766e-01, -1.10832199e-01, -2.28620231e-01,
        -1.95602387e-01],
       [ 1.20546799e-02,  3.93493474e-01,  1.55461639e-01,
        -4.59995270e-01],
       [-2.16246456e-01, -3.91764820e-01,  4.80449438e-01,
         3.20317090e-01],
       [-4.43247736e-01,  3.21073949e-01, -3.43698800e-01,
      