In [2]:
%load_ext autoreload
%autoreload 2
%load_ext line_profiler
from cs285.envs.pendulum.pendulum_env import PendulumEnv
from cs285.envs.dt_sampler import ConstantSampler
from cs285.infrastructure.replay_buffer import ReplayBufferTrajectories
from cs285.infrastructure.utils import sample_n_trajectories, RandomPolicy
from cs285.agents.ode_agent import ODEAgent
from cs285.agents.utils import save_leaves, load_leaves
from cs285.infrastructure import utils
from typing import Callable, Optional, Tuple, Sequence
import numpy as np
import gym
from cs285.infrastructure import pytorch_util as ptu
from tqdm import trange
import jax
import jax.numpy as jnp
import equinox as eqx
import diffrax
from diffrax import diffeqsolve, Dopri5
import optax
import pickle
from tqdm import trange

In [3]:
key = jax.random.PRNGKey(0)
from jax.lib import xla_bridge
print(xla_bridge.get_backend().platform)

gpu


In [13]:
dt_sampler = ConstantSampler(dt=0.05)
env = PendulumEnv(
    dt_sampler=dt_sampler
)
mpc_dt_sampler = ConstantSampler(dt=0.05)
agent_key, new_agent_key = jax.random.split(key)
neural_ode_name = "vanilla"
neural_ode_kwargs = {
    "ode_dt0": 0.005,
    "mlp_dynamics_setup": {
        "hidden_size":128,
        "num_layers":4,
        "activation":"tanh",
        "output_activation":"identity"
    }
}
optimizer_name = "adamw"
optimizer_kwargs = {"learning_rate": 1e-3}
mb_agent = ODEAgent(
    env=env,
    key=agent_key,
    neural_ode_name=neural_ode_name,
    neural_ode_kwargs=neural_ode_kwargs,
    optimizer_name=optimizer_name,
    optimizer_kwargs=optimizer_kwargs,
    ensemble_size=10,
    train_discount=1,
    mpc_horizon_steps=20,
    mpc_dt_sampler=mpc_dt_sampler,
    mpc_strategy="cem",
    mpc_discount=0.9,
    mpc_num_action_sequences=1000,
    cem_num_iters=4,
    cem_num_elites=5,
    cem_alpha=1,
)

In [None]:
replay_buffer = ReplayBufferTrajectories(seed=42)
trajs, _ = sample_n_trajectories(env, RandomPolicy(env=env), ntraj=1000, max_length=200, key=key)
replay_buffer.add_rollouts(trajs)

In [4]:
"""
with open("random_replay_buffer", "wb") as f:
    pickle.dump(replay_buffer, f)
"""

'\nwith open("random_replay_buffer", "wb") as f:\n    pickle.dump(replay_buffer, f)\n'

In [14]:
batch_size = 16
train_ep_len = 20
train_stride = 20
train_steps = 100

In [15]:
all_losses = []
for step in trange(train_steps, dynamic_ncols=True):
    step_losses = []
    for i in range(mb_agent.ensemble_size):
        traj = replay_buffer.sample_rollouts(batch_size=batch_size)
        obs = utils.split_arr(np.array(traj["observations"]), length=train_ep_len, stride=train_stride)
        acs = utils.split_arr(np.array(traj["actions"]), length=train_ep_len, stride=train_stride)
        dts = utils.split_arr(np.array(traj["dts"])[..., np.newaxis], length=train_ep_len, stride=train_stride).squeeze(-1)
        batch_size, num_splitted, train_ep_len, ob_dim = obs.shape
        ac_dim = acs.shape[-1]
        obs = jnp.array(obs).reshape(batch_size * num_splitted, train_ep_len, ob_dim)
        acs = jnp.array(acs).reshape(batch_size * num_splitted, train_ep_len, ac_dim)
        times = jnp.cumsum(dts, axis=-1).reshape(batch_size * num_splitted, train_ep_len)
        loss = mb_agent.batched_update(
            i=i,
            obs=obs, 
            acs=acs, 
            times=times
        )
        step_losses.append(loss)
    save_leaves(mb_agent, f"checkpoint/checkpoint{step}")
    all_losses.append(np.mean(step_losses))
    print(f"epoch {step}, loss {np.mean(step_losses)}")

  1%|          | 1/100 [00:20<33:09, 20.10s/it]

epoch 0, loss 18.39639196395874


  2%|▏         | 2/100 [00:39<32:20, 19.80s/it]

epoch 1, loss 16.3913480758667


  3%|▎         | 3/100 [00:59<32:04, 19.84s/it]

epoch 2, loss 15.876043033599853


  4%|▍         | 4/100 [01:19<31:47, 19.87s/it]

epoch 3, loss 16.638383769989012


  5%|▌         | 5/100 [01:39<31:31, 19.91s/it]

epoch 4, loss 17.07912998199463


  6%|▌         | 6/100 [01:59<31:23, 20.03s/it]

epoch 5, loss 14.973362064361572


  7%|▋         | 7/100 [02:20<31:17, 20.19s/it]

epoch 6, loss 14.370085906982421


  8%|▊         | 8/100 [02:40<30:59, 20.21s/it]

epoch 7, loss 12.581250762939453


  9%|▉         | 9/100 [03:01<30:55, 20.39s/it]

epoch 8, loss 11.768706798553467


 10%|█         | 10/100 [03:22<30:46, 20.52s/it]

epoch 9, loss 11.928400421142578


 11%|█         | 11/100 [03:43<30:40, 20.68s/it]

epoch 10, loss 11.369459056854248


 12%|█▏        | 12/100 [04:04<30:24, 20.74s/it]

epoch 11, loss 10.592216491699219


 13%|█▎        | 13/100 [04:25<30:14, 20.86s/it]

epoch 12, loss 10.926023769378663


 14%|█▍        | 14/100 [04:46<30:05, 20.99s/it]

epoch 13, loss 10.953106594085693


 15%|█▌        | 15/100 [05:08<30:00, 21.18s/it]

epoch 14, loss 10.844387435913086


 16%|█▌        | 16/100 [05:29<29:44, 21.25s/it]

epoch 15, loss 10.101641368865966


 17%|█▋        | 17/100 [05:51<29:30, 21.33s/it]

epoch 16, loss 9.455598068237304


 18%|█▊        | 18/100 [06:12<29:22, 21.50s/it]

epoch 17, loss 10.078759670257568


 19%|█▉        | 19/100 [06:34<29:14, 21.66s/it]

epoch 18, loss 10.108087587356568


 20%|██        | 20/100 [06:56<28:51, 21.64s/it]

epoch 19, loss 9.524208402633667


 21%|██        | 21/100 [07:18<28:28, 21.62s/it]

epoch 20, loss 8.754386234283448


 22%|██▏       | 22/100 [07:40<28:21, 21.81s/it]

epoch 21, loss 9.43628225326538


 23%|██▎       | 23/100 [08:02<28:08, 21.93s/it]

epoch 22, loss 9.538161420822144


 24%|██▍       | 24/100 [08:24<27:57, 22.07s/it]

epoch 23, loss 8.022551727294921


 25%|██▌       | 25/100 [08:46<27:33, 22.04s/it]

epoch 24, loss 8.927814960479736


 26%|██▌       | 26/100 [09:08<27:10, 22.04s/it]

epoch 25, loss 8.836034774780273


 27%|██▋       | 27/100 [09:31<26:49, 22.05s/it]

epoch 26, loss 8.966203641891479


 28%|██▊       | 28/100 [09:53<26:27, 22.05s/it]

epoch 27, loss 7.633602333068848


 29%|██▉       | 29/100 [10:15<26:07, 22.08s/it]

epoch 28, loss 8.366142320632935


 30%|███       | 30/100 [10:37<25:41, 22.03s/it]

epoch 29, loss 8.111577987670898


 31%|███       | 31/100 [10:59<25:19, 22.03s/it]

epoch 30, loss 7.156509304046631


 32%|███▏      | 32/100 [11:21<24:59, 22.05s/it]

epoch 31, loss 7.191894292831421


 33%|███▎      | 33/100 [11:43<24:39, 22.09s/it]

epoch 32, loss 7.213384008407592


 34%|███▍      | 34/100 [12:05<24:16, 22.06s/it]

epoch 33, loss 7.319515371322632


 35%|███▌      | 35/100 [12:27<23:54, 22.07s/it]

epoch 34, loss 7.318087768554688


 36%|███▌      | 36/100 [12:49<23:30, 22.04s/it]

epoch 35, loss 6.486569213867187


 37%|███▋      | 37/100 [13:11<23:03, 21.95s/it]

epoch 36, loss 6.922734165191651


 38%|███▊      | 38/100 [13:33<22:45, 22.03s/it]

epoch 37, loss 6.457598924636841


 39%|███▉      | 39/100 [13:55<22:22, 22.01s/it]

epoch 38, loss 6.299310779571533


 40%|████      | 40/100 [14:17<21:59, 21.99s/it]

epoch 39, loss 6.0391669273376465


 41%|████      | 41/100 [14:39<21:37, 21.98s/it]

epoch 40, loss 5.722709155082702


 42%|████▏     | 42/100 [15:01<21:13, 21.96s/it]

epoch 41, loss 5.23240385055542


 43%|████▎     | 43/100 [15:23<20:51, 21.95s/it]

epoch 42, loss 5.578409576416016


 44%|████▍     | 44/100 [15:45<20:35, 22.07s/it]

epoch 43, loss 5.614706969261169


 45%|████▌     | 45/100 [16:07<20:16, 22.12s/it]

epoch 44, loss 5.4187705516815186


 46%|████▌     | 46/100 [16:29<19:52, 22.08s/it]

epoch 45, loss 5.428317141532898


 47%|████▋     | 47/100 [16:51<19:30, 22.09s/it]

epoch 46, loss 5.429158782958984


 48%|████▊     | 48/100 [17:13<19:07, 22.07s/it]

epoch 47, loss 5.447885918617248


 49%|████▉     | 49/100 [17:36<18:49, 22.16s/it]

epoch 48, loss 5.8711151599884035


 50%|█████     | 50/100 [17:58<18:27, 22.14s/it]

epoch 49, loss 6.064206790924072


 51%|█████     | 51/100 [18:20<18:06, 22.18s/it]

epoch 50, loss 5.317134833335876


 52%|█████▏    | 52/100 [18:43<17:49, 22.28s/it]

epoch 51, loss 5.234978246688843


 53%|█████▎    | 53/100 [19:04<17:20, 22.14s/it]

epoch 52, loss 5.01282570362091


 54%|█████▍    | 54/100 [19:27<17:07, 22.34s/it]

epoch 53, loss 5.791096353530884


 55%|█████▌    | 55/100 [19:50<16:48, 22.40s/it]

epoch 54, loss 5.307946729660034


 56%|█████▌    | 56/100 [20:12<16:25, 22.41s/it]

epoch 55, loss 5.700955438613891


 57%|█████▋    | 57/100 [20:35<16:06, 22.47s/it]

epoch 56, loss 5.29694435596466


 58%|█████▊    | 58/100 [20:58<15:48, 22.57s/it]

epoch 57, loss 4.767285799980163


 59%|█████▉    | 59/100 [21:20<15:26, 22.61s/it]

epoch 58, loss 4.5603265285491945


 60%|██████    | 60/100 [21:43<15:05, 22.63s/it]

epoch 59, loss 5.280098867416382


 61%|██████    | 61/100 [22:05<14:40, 22.58s/it]

epoch 60, loss 5.050508451461792


 62%|██████▏   | 62/100 [22:28<14:17, 22.57s/it]

epoch 61, loss 4.6597275018692015


 63%|██████▎   | 63/100 [22:51<13:54, 22.56s/it]

epoch 62, loss 5.187973403930664


 64%|██████▍   | 64/100 [23:13<13:27, 22.44s/it]

epoch 63, loss 5.272667121887207


 65%|██████▌   | 65/100 [23:35<13:05, 22.45s/it]

epoch 64, loss 4.813701701164246


 66%|██████▌   | 66/100 [23:58<12:46, 22.55s/it]

epoch 65, loss 4.336658310890198


 67%|██████▋   | 67/100 [24:21<12:27, 22.66s/it]

epoch 66, loss 4.273366475105286


 68%|██████▊   | 68/100 [24:44<12:04, 22.65s/it]

epoch 67, loss 4.801747107505799


 69%|██████▉   | 69/100 [25:06<11:37, 22.51s/it]

epoch 68, loss 4.332536959648133


 70%|███████   | 70/100 [25:28<11:14, 22.48s/it]

epoch 69, loss 4.538830733299255


 71%|███████   | 71/100 [25:50<10:49, 22.41s/it]

epoch 70, loss 4.213283228874206


 72%|███████▏  | 72/100 [26:13<10:27, 22.41s/it]

epoch 71, loss 4.910214042663574


 73%|███████▎  | 73/100 [26:35<10:06, 22.48s/it]

epoch 72, loss 4.5139923095703125


 74%|███████▍  | 74/100 [26:58<09:44, 22.47s/it]

epoch 73, loss 4.528792357444763


 75%|███████▌  | 75/100 [27:20<09:21, 22.45s/it]

epoch 74, loss 4.011380410194397


 76%|███████▌  | 76/100 [27:43<08:59, 22.48s/it]

epoch 75, loss 4.20252366065979


 77%|███████▋  | 77/100 [28:05<08:36, 22.45s/it]

epoch 76, loss 4.126343846321106


 78%|███████▊  | 78/100 [28:28<08:15, 22.51s/it]

epoch 77, loss 4.897342348098755


 79%|███████▉  | 79/100 [28:51<07:55, 22.63s/it]

epoch 78, loss 3.9494580030441284


 80%|████████  | 80/100 [29:13<07:32, 22.61s/it]

epoch 79, loss 3.8960768938064576


 81%|████████  | 81/100 [29:36<07:07, 22.49s/it]

epoch 80, loss 3.552157199382782


 82%|████████▏ | 82/100 [29:58<06:44, 22.46s/it]

epoch 81, loss 4.473735618591308


 83%|████████▎ | 83/100 [30:20<06:21, 22.44s/it]

epoch 82, loss 3.594724416732788


 84%|████████▍ | 84/100 [30:42<05:57, 22.36s/it]

epoch 83, loss 4.911491966247558


 85%|████████▌ | 85/100 [31:05<05:35, 22.33s/it]

epoch 84, loss 4.257541704177856


 86%|████████▌ | 86/100 [31:27<05:11, 22.28s/it]

epoch 85, loss 3.819335412979126


 87%|████████▋ | 87/100 [31:50<04:51, 22.40s/it]

epoch 86, loss 3.621919131278992


 88%|████████▊ | 88/100 [32:12<04:27, 22.32s/it]

epoch 87, loss 3.696516275405884


 89%|████████▉ | 89/100 [32:34<04:05, 22.30s/it]

epoch 88, loss 4.1080464124679565


 90%|█████████ | 90/100 [32:57<03:43, 22.37s/it]

epoch 89, loss 3.0477606058120728


 91%|█████████ | 91/100 [33:19<03:21, 22.37s/it]

epoch 90, loss 3.999513053894043


 92%|█████████▏| 92/100 [33:41<02:59, 22.44s/it]

epoch 91, loss 4.125225019454956


 93%|█████████▎| 93/100 [34:04<02:36, 22.39s/it]

epoch 92, loss 4.097827291488647


 94%|█████████▍| 94/100 [34:26<02:14, 22.42s/it]

epoch 93, loss 4.4026193857193


 95%|█████████▌| 95/100 [34:49<01:52, 22.40s/it]

epoch 94, loss 3.6277783870697022


 96%|█████████▌| 96/100 [35:12<01:30, 22.59s/it]

epoch 95, loss 3.360838508605957


 97%|█████████▋| 97/100 [35:34<01:07, 22.59s/it]

epoch 96, loss 3.2330355763435366


 98%|█████████▊| 98/100 [35:56<00:44, 22.47s/it]

epoch 97, loss 3.2696501493453978


 99%|█████████▉| 99/100 [36:19<00:22, 22.48s/it]

epoch 98, loss 3.5318663239479067


100%|██████████| 100/100 [36:41<00:00, 22.02s/it]

epoch 99, loss 3.576762342453003





In [None]:
# problem: low gpu utilization rate, need investigation

In [None]:
save_leaves(mb_agent, "vanilla_trained_on_500_random_steps")

In [8]:
mb_agent = load_leaves(mb_agent, "checkpoint_161")

(ScaleByAdamState(count=Array(164, dtype=int32), mu=NeuralODE_Vanilla(
   mlp=MLP(
     layers=(
       Linear(
         weight=f32[128,4],
         bias=f32[128],
         in_features=4,
         out_features=128,
         use_bias=True
       ),
       Linear(
         weight=f32[128,128],
         bias=f32[128],
         in_features=128,
         out_features=128,
         use_bias=True
       ),
       Linear(
         weight=f32[128,128],
         bias=f32[128],
         in_features=128,
         out_features=128,
         use_bias=True
       ),
       Linear(
         weight=f32[128,128],
         bias=f32[128],
         in_features=128,
         out_features=128,
         use_bias=True
       ),
       Linear(
         weight=f32[3,128],
         bias=f32[3],
         in_features=128,
         out_features=3,
         use_bias=True
       )
     ),
     activation=None,
     final_activation=None,
     use_bias=True,
     use_final_bias=True,
     in_size=4,
     out_size=3,
  