In [1]:
%load_ext autoreload
%autoreload 2
from cs285.envs.pendulum.pendulum_env import PendulumEnv
from cs285.envs.dt_sampler import ConstantSampler
from cs285.infrastructure.replay_buffer import ReplayBufferTransitions
from cs285.infrastructure.utils import sample_n_trajectories, RandomPolicy
from cs285.agents.model_based_agent import ModelBasedAgent
from typing import Callable, Optional, Tuple, Sequence
import numpy as np
import gym
from tqdm import trange
import jax
import jax.numpy as jnp
import equinox as eqx
import diffrax
from diffrax import diffeqsolve, Dopri5
import optax

In [2]:
key = jax.random.PRNGKey(0)
from jax.lib import xla_bridge
print(xla_bridge.get_backend().platform)

cpu


In [3]:
dt_sampler = ConstantSampler(dt=0.05)
mpc_dt_sampler = ConstantSampler(dt=0.05)
env = PendulumEnv(
    dt_sampler=dt_sampler
)
num_layers = 4
hidden_size = 128
lr=0.001
agent_key, key = jax.random.split(key)

mb_agent = ModelBasedAgent(
    env=env,
    key=agent_key,
    hidden_size=hidden_size,
    num_layers=num_layers,
    activation="relu",
    output_activation="identity",
    lr=lr,
    ensemble_size=10,
    mpc_horizon_steps=100,
    mpc_strategy="random",
    mpc_discount=1.0,
    mpc_num_action_sequences=100,
    mpc_dt_sampler=mpc_dt_sampler,
    cem_num_iters=4,
    cem_num_elites=5,
    cem_alpha=1,
    mode="vanilla"
)
replay_buffer = ReplayBufferTransitions()
trajs, _ = sample_n_trajectories(env, RandomPolicy(env=env), ntraj=10, max_length=200, key=key)
for traj in trajs:
    replay_buffer.batched_insert(
    observations=traj["observation"],
    actions=traj["action"],
    rewards=traj["reward"],
    next_observations=traj["next_observation"],
    dones=traj["done"],
    dts=traj["dt"]
)

mb_agent.update_statistics(
    obs=replay_buffer.observations,
    acs=replay_buffer.actions,
    next_obs = replay_buffer.next_observations
)

for n in trange(10):
    for i in range(mb_agent.ensemble_size):
        batch = replay_buffer.sample(64)
        mb_agent.batched_update(i, batch["observations"], batch["actions"], batch["next_observations"], batch["dts"])

trajs, _ = sample_n_trajectories(
    env=env,
    policy=mb_agent,
    ntraj=10,
    max_length=200,
    key=key
)

100%|██████████| 10/10 [00:01<00:00,  9.86it/s]
100%|██████████| 10/10 [00:09<00:00,  1.03it/s]
100%|██████████| 10/10 [23:04<00:00, 138.40s/it]


In [4]:
dt_sampler = ConstantSampler(dt=0.05)
mpc_dt_sampler = ConstantSampler(dt=0.05)
env = PendulumEnv(
    dt_sampler=dt_sampler
)
num_layers = 4
hidden_size = 128
lr=0.001
agent_key, key = jax.random.split(key)

mb_agent = ModelBasedAgent(
    env=env,
    key=agent_key,
    hidden_size=hidden_size,
    num_layers=num_layers,
    activation="relu",
    output_activation="identity",
    lr=lr,
    ensemble_size=10,
    mpc_horizon_steps=100,
    mpc_strategy="random",
    mpc_discount=1.0,
    mpc_num_action_sequences=100,
    mpc_dt_sampler=mpc_dt_sampler,
    cem_num_iters=4,
    cem_num_elites=5,
    cem_alpha=1,
    mode="mul_dt"
)
replay_buffer = ReplayBufferTransitions()
trajs, _ = sample_n_trajectories(env, RandomPolicy(env=env), ntraj=10, max_length=200, key=key)
for traj in trajs:
    replay_buffer.batched_insert(
    observations=traj["observation"],
    actions=traj["action"],
    rewards=traj["reward"],
    next_observations=traj["next_observation"],
    dones=traj["done"],
    dts=traj["dt"]
)

mb_agent.update_statistics(
    obs=replay_buffer.observations,
    acs=replay_buffer.actions,
    next_obs = replay_buffer.next_observations
)

for n in trange(10):
    for i in range(mb_agent.ensemble_size):
        batch = replay_buffer.sample(64)
        mb_agent.batched_update(i, batch["observations"], batch["actions"], batch["next_observations"], batch["dts"])

trajs, _ = sample_n_trajectories(
    env=env,
    policy=mb_agent,
    ntraj=10,
    max_length=200,
    key=key
)

100%|██████████| 10/10 [00:00<00:00, 10.60it/s]
100%|██████████| 10/10 [00:09<00:00,  1.03it/s]
100%|██████████| 10/10 [23:03<00:00, 138.38s/it]


In [5]:
dt_sampler = ConstantSampler(dt=0.05)
mpc_dt_sampler = ConstantSampler(dt=0.05)
env = PendulumEnv(
    dt_sampler=dt_sampler
)
num_layers = 4
hidden_size = 128
lr=0.001
agent_key, key = jax.random.split(key)

mb_agent = ModelBasedAgent(
    env=env,
    key=agent_key,
    hidden_size=hidden_size,
    num_layers=num_layers,
    activation="relu",
    output_activation="identity",
    lr=lr,
    ensemble_size=10,
    mpc_horizon_steps=100,
    mpc_strategy="random",
    mpc_discount=1.0,
    mpc_num_action_sequences=100,
    mpc_dt_sampler=mpc_dt_sampler,
    cem_num_iters=4,
    cem_num_elites=5,
    cem_alpha=1,
    mode="dt_in"
)
replay_buffer = ReplayBufferTransitions()
trajs, _ = sample_n_trajectories(env, RandomPolicy(env=env), ntraj=10, max_length=200, key=key)
for traj in trajs:
    replay_buffer.batched_insert(
    observations=traj["observation"],
    actions=traj["action"],
    rewards=traj["reward"],
    next_observations=traj["next_observation"],
    dones=traj["done"],
    dts=traj["dt"]
)

mb_agent.update_statistics(
    obs=replay_buffer.observations,
    acs=replay_buffer.actions,
    next_obs = replay_buffer.next_observations
)

for n in trange(10):
    for i in range(mb_agent.ensemble_size):
        batch = replay_buffer.sample(64)
        mb_agent.batched_update(i, batch["observations"], batch["actions"], batch["next_observations"], batch["dts"])

trajs, _ = sample_n_trajectories(
    env=env,
    policy=mb_agent,
    ntraj=10,
    max_length=200,
    key=key
)

100%|██████████| 10/10 [00:00<00:00, 10.58it/s]
100%|██████████| 10/10 [00:09<00:00,  1.07it/s]
100%|██████████| 10/10 [27:31<00:00, 165.12s/it]
