In [1]:
%load_ext autoreload
%autoreload 2
%load_ext line_profiler
from cs285.envs.pendulum.pendulum_env import PendulumEnv
from cs285.envs.dt_sampler import ConstantSampler
from cs285.infrastructure.replay_buffer import ReplayBufferTrajectories
from cs285.infrastructure.utils import sample_n_trajectories, RandomPolicy
from cs285.agents.ode_agent import ODEAgent
from cs285.agents.utils import save_leaves, load_leaves
from typing import Callable, Optional, Tuple, Sequence
import numpy as np
import gym
from cs285.infrastructure import pytorch_util as ptu
from tqdm import trange
import jax
import jax.numpy as jnp
import equinox as eqx
import diffrax
from diffrax import diffeqsolve, Dopri5
import optax

In [2]:
key = jax.random.PRNGKey(0)
from jax.lib import xla_bridge
print(xla_bridge.get_backend().platform)

cpu


In [6]:
dt_sampler = ConstantSampler(dt=0.05)
env = PendulumEnv(
    dt_sampler=dt_sampler
)
mpc_dt_sampler = ConstantSampler(dt=0.05)
agent_key, key = jax.random.split(key)
neural_ode_name = "vanilla"
neural_ode_kwargs = {
    "ode_dt0": 0.005,
    "mlp_dynamics_setup": {
        "hidden_size":128,
        "num_layers":4,
        "activation":"relu",
        "output_activation":"identity"
    }
}
optimizer_name = "adamw"
optimizer_kwargs = {"learning_rate": 1e-3}
mb_agent = ODEAgent(
    env=env,
    key=agent_key,
    neural_ode_name=neural_ode_name,
    neural_ode_kwargs=neural_ode_kwargs,
    optimizer_name=optimizer_name,
    optimizer_kwargs=optimizer_kwargs,
    ensemble_size=10,
    train_discount=1,
    mpc_horizon_steps=100,
    mpc_dt_sampler=mpc_dt_sampler,
    mpc_strategy="cem",
    mpc_discount=0.9,
    mpc_num_action_sequences=1000,
    cem_num_iters=4,
    cem_num_elites=5,
    cem_alpha=1,
)
replay_buffer = ReplayBufferTrajectories(seed=42)
trajs, _ = sample_n_trajectories(env, RandomPolicy(env=env), ntraj=10, max_length=200, key=key)
replay_buffer.add_rollouts(trajs)

batch_size = 64
for n in trange(2):
    for i in range(mb_agent.ensemble_size):
        obs, acs, times = jnp.empty(shape=(batch_size, 201, 3)), jnp.empty(shape=(batch_size, 201, 1)), jnp.empty(shape=(batch_size, 201))
        for m in range(batch_size):
            traj = replay_buffer.sample_rollout()
            obs.at[m].set(traj["observations"])
            acs.at[m].set(traj["actions"])
            times.at[m].set(jnp.cumsum(traj["dts"]))
        loss = mb_agent.batched_update(i=i, obs=obs, acs=acs, times=times)
# 24 seconds, might need profiler to see where I can imporve

trajs, _ = sample_n_trajectories(
    env=env,
    policy=mb_agent,
    ntraj=10,
    max_length=200,
    key=key
)

100%|██████████| 10/10 [00:00<00:00, 10.12it/s]
100%|██████████| 2/2 [00:47<00:00, 23.57s/it]
100%|██████████| 10/10 [00:16<00:00,  1.60s/it]


In [8]:
dt_sampler = ConstantSampler(dt=0.05)
env = PendulumEnv(
    dt_sampler=dt_sampler
)
mpc_dt_sampler = ConstantSampler(dt=0.05)
agent_key, key = jax.random.split(key)
neural_ode_name = "augmented"
neural_ode_kwargs = {
    "ode_dt0": 0.005,
    "mlp_dynamics_setup": {
        "hidden_size":128,
        "num_layers":4,
        "activation":"relu",
        "output_activation":"identity"
    },
    "aug_dim": 3
}
optimizer_name = "adamw"
optimizer_kwargs = {"learning_rate": 1e-3}
mb_agent = ODEAgent(
    env=env,
    key=agent_key,
    neural_ode_name=neural_ode_name,
    neural_ode_kwargs=neural_ode_kwargs,
    optimizer_name=optimizer_name,
    optimizer_kwargs=optimizer_kwargs,
    ensemble_size=10,
    train_discount=1,
    mpc_horizon_steps=100,
    mpc_dt_sampler=mpc_dt_sampler,
    mpc_strategy="cem",
    mpc_discount=0.9,
    mpc_num_action_sequences=1000,
    cem_num_iters=4,
    cem_num_elites=5,
    cem_alpha=1,
)
replay_buffer = ReplayBufferTrajectories(seed=42)
trajs, _ = sample_n_trajectories(env, RandomPolicy(env=env), ntraj=10, max_length=200, key=key)
replay_buffer.add_rollouts(trajs)

batch_size = 64
for n in trange(2):
    for i in range(mb_agent.ensemble_size):
        obs, acs, times = jnp.empty(shape=(batch_size, 201, 3)), jnp.empty(shape=(batch_size, 201, 1)), jnp.empty(shape=(batch_size, 201))
        for m in range(batch_size):
            traj = replay_buffer.sample_rollout()
            obs.at[m].set(traj["observations"])
            acs.at[m].set(traj["actions"])
            times.at[m].set(jnp.cumsum(traj["dts"]))
        loss = mb_agent.batched_update(i=i, obs=obs, acs=acs, times=times)
# 24 seconds, might need profiler to see where I can imporve

trajs, _ = sample_n_trajectories(
    env=env,
    policy=mb_agent,
    ntraj=10,
    max_length=200,
    key=key
)

100%|██████████| 10/10 [00:00<00:00, 11.19it/s]
100%|██████████| 2/2 [00:47<00:00, 23.99s/it]
100%|██████████| 10/10 [00:16<00:00,  1.66s/it]


In [10]:
dt_sampler = ConstantSampler(dt=0.05)
env = PendulumEnv(
    dt_sampler=dt_sampler
)
mpc_dt_sampler = ConstantSampler(dt=0.05)
agent_key, key = jax.random.split(key)
mlp_setup = {
    "hidden_size":128,
    "num_layers":4,
    "activation":"relu",
    "output_activation":"identity"
}
neural_ode_name = "latent_mlp"
neural_ode_kwargs = {
    "ode_dt0": 0.005,
    "ac_latent_dim": 4,
    "ob_latent_dim": 4,
    "mlp_dynamics_setup": mlp_setup,
    "mlp_ob_encoder_setup": mlp_setup,
    "mlp_ob_decoder_setup": mlp_setup,
    "mlp_ac_encoder_setup": mlp_setup,
}
optimizer_name = "adamw"
optimizer_kwargs = {"learning_rate": 1e-3}
mb_agent = ODEAgent(
    env=env,
    key=agent_key,
    neural_ode_name=neural_ode_name,
    neural_ode_kwargs=neural_ode_kwargs,
    optimizer_name=optimizer_name,
    optimizer_kwargs=optimizer_kwargs,
    ensemble_size=10,
    train_discount=1,
    mpc_horizon_steps=100,
    mpc_dt_sampler=mpc_dt_sampler,
    mpc_strategy="cem",
    mpc_discount=0.9,
    mpc_num_action_sequences=1000,
    cem_num_iters=4,
    cem_num_elites=5,
    cem_alpha=1,
)
replay_buffer = ReplayBufferTrajectories(seed=42)
trajs, _ = sample_n_trajectories(env, RandomPolicy(env=env), ntraj=10, max_length=200, key=key)
replay_buffer.add_rollouts(trajs)

batch_size = 64
for n in trange(2):
    for i in range(mb_agent.ensemble_size):
        obs, acs, times = jnp.empty(shape=(batch_size, 201, 3)), jnp.empty(shape=(batch_size, 201, 1)), jnp.empty(shape=(batch_size, 201))
        for m in range(batch_size):
            traj = replay_buffer.sample_rollout()
            obs.at[m].set(traj["observations"])
            acs.at[m].set(traj["actions"])
            times.at[m].set(jnp.cumsum(traj["dts"]))
        loss = mb_agent.batched_update(i=i, obs=obs, acs=acs, times=times)
# 24 seconds, might need profiler to see where I can imporve

trajs, _ = sample_n_trajectories(
    env=env,
    policy=mb_agent,
    ntraj=10,
    max_length=200,
    key=key
)

100%|██████████| 10/10 [00:00<00:00, 11.62it/s]
100%|██████████| 2/2 [01:01<00:00, 30.82s/it]
100%|██████████| 10/10 [00:21<00:00,  2.13s/it]


In [11]:
dt_sampler = ConstantSampler(dt=0.05)
env = PendulumEnv(
    dt_sampler=dt_sampler
)
mpc_dt_sampler = ConstantSampler(dt=0.05)
agent_key, key = jax.random.split(key)
mlp_setup = {
    "hidden_size":128,
    "num_layers":4,
    "activation":"relu",
    "output_activation":"identity"
}
neural_ode_name = "ode_rnn"
neural_ode_kwargs = {
    "ode_dt0": 0.005,
    "latent_dim": 4,
    "rnn_type": "lstm",
    "mlp_dynamics_setup": mlp_setup,
    "mlp_ob_encoder_setup": mlp_setup,
    "mlp_ob_decoder_setup": mlp_setup,
}
optimizer_name = "adamw"
optimizer_kwargs = {"learning_rate": 1e-3}
mb_agent = ODEAgent(
    env=env,
    key=agent_key,
    neural_ode_name=neural_ode_name,
    neural_ode_kwargs=neural_ode_kwargs,
    optimizer_name=optimizer_name,
    optimizer_kwargs=optimizer_kwargs,
    ensemble_size=10,
    train_discount=1,
    mpc_horizon_steps=100,
    mpc_dt_sampler=mpc_dt_sampler,
    mpc_strategy="cem",
    mpc_discount=0.9,
    mpc_num_action_sequences=1000,
    cem_num_iters=4,
    cem_num_elites=5,
    cem_alpha=1,
)
replay_buffer = ReplayBufferTrajectories(seed=42)
trajs, _ = sample_n_trajectories(env, RandomPolicy(env=env), ntraj=10, max_length=200, key=key)
replay_buffer.add_rollouts(trajs)

batch_size = 64
for n in trange(2):
    for i in range(mb_agent.ensemble_size):
        obs, acs, times = jnp.empty(shape=(batch_size, 201, 3)), jnp.empty(shape=(batch_size, 201, 1)), jnp.empty(shape=(batch_size, 201))
        for m in range(batch_size):
            traj = replay_buffer.sample_rollout()
            obs.at[m].set(traj["observations"])
            acs.at[m].set(traj["actions"])
            times.at[m].set(jnp.cumsum(traj["dts"]))
        loss = mb_agent.batched_update(i=i, obs=obs, acs=acs, times=times)
# 24 seconds, might need profiler to see where I can imporve

trajs, _ = sample_n_trajectories(
    env=env,
    policy=mb_agent,
    ntraj=10,
    max_length=200,
    key=key
)

100%|██████████| 10/10 [00:01<00:00,  7.26it/s]
100%|██████████| 2/2 [01:27<00:00, 43.55s/it]
100%|██████████| 10/10 [00:35<00:00,  3.55s/it]


In [8]:
dt_sampler = ConstantSampler(dt=0.05)
env = PendulumEnv(
    dt_sampler=dt_sampler
)
mpc_dt_sampler = ConstantSampler(dt=0.05)
agent_key, key = jax.random.split(key)
mlp_setup = {
    "hidden_size":128,
    "num_layers":4,
    "activation":"relu",
    "output_activation":"identity"
}
neural_ode_name = "ode_rnn"
neural_ode_kwargs = {
    "ode_dt0": 0.005,
    "latent_dim": 4,
    "rnn_type": "gru",
    "mlp_dynamics_setup": mlp_setup,
    "mlp_ob_encoder_setup": mlp_setup,
    "mlp_ob_decoder_setup": mlp_setup,
}
optimizer_name = "adamw"
optimizer_kwargs = {"learning_rate": 1e-3}
mb_agent = ODEAgent(
    env=env,
    key=agent_key,
    neural_ode_name=neural_ode_name,
    neural_ode_kwargs=neural_ode_kwargs,
    optimizer_name=optimizer_name,
    optimizer_kwargs=optimizer_kwargs,
    ensemble_size=10,
    train_discount=1,
    mpc_horizon_steps=100,
    mpc_dt_sampler=mpc_dt_sampler,
    mpc_strategy="cem",
    mpc_discount=0.9,
    mpc_num_action_sequences=1000,
    cem_num_iters=4,
    cem_num_elites=5,
    cem_alpha=1,
)
replay_buffer = ReplayBufferTrajectories(seed=42)
trajs, _ = sample_n_trajectories(env, RandomPolicy(env=env), ntraj=10, max_length=200, key=key)
replay_buffer.add_rollouts(trajs)

batch_size = 64
for n in trange(2):
    for i in range(mb_agent.ensemble_size):
        obs, acs, times = jnp.empty(shape=(batch_size, 201, 3)), jnp.empty(shape=(batch_size, 201, 1)), jnp.empty(shape=(batch_size, 201))
        for m in range(batch_size):
            traj = replay_buffer.sample_rollout()
            obs.at[m].set(traj["observations"])
            acs.at[m].set(traj["actions"])
            times.at[m].set(jnp.cumsum(traj["dts"]))
        loss = mb_agent.batched_update(i=i, obs=obs, acs=acs, times=times)
# 24 seconds, might need profiler to see where I can imporve

trajs, _ = sample_n_trajectories(
    env=env,
    policy=mb_agent,
    ntraj=10,
    max_length=200,
    key=key
)

100%|██████████| 10/10 [00:01<00:00,  6.96it/s]
100%|██████████| 2/2 [01:23<00:00, 41.93s/it]
100%|██████████| 10/10 [00:36<00:00,  3.65s/it]
