In [1]:
%load_ext autoreload
%autoreload 2
%load_ext line_profiler
from cs285.envs.pendulum.pendulum_env import PendulumEnv
from cs285.envs.dt_sampler import ConstantSampler, UniformSampler, ExponentialSampler
from cs285.infrastructure.replay_buffer import ReplayBufferTrajectories
from cs285.infrastructure.utils import sample_n_trajectories, RandomPolicy
from cs285.agents.ode_agent import ODEAgent
from cs285.agents.nueral_ode import Base_NeuralODE, NeuralODE_Vanilla, Pendulum_True_Dynamics, NeuralODE_Augmented, NeuralODE_Latent_MLP, ODE_RNN
from cs285.agents.utils import save_leaves, load_leaves
from cs285.infrastructure import utils
from typing import Callable, Optional, Tuple, Sequence
import numpy as np
import gym
from cs285.infrastructure import pytorch_util as ptu
from tqdm import trange
import jax
import jax.numpy as jnp
import equinox as eqx
import diffrax
from diffrax import diffeqsolve, Dopri5
import optax
import pickle
from tqdm import trange
import matplotlib.pyplot as plt

In [2]:
key = jax.random.PRNGKey(0)

In [3]:
def train(agent, i, replay_buffer, train_config, key):
    optim = agent.optims[i]
    opt_state = agent.optim_states[i]
    discount_array = train_config["discount"] ** jnp.arange(train_config["ep_len"])
    neural_ode = agent.neural_odes[i]

    @eqx.filter_jit
    @eqx.filter_value_and_grad
    def get_loss_grad(neural_ode, obs, acs, times):
        obs_pred = neural_ode.batched_pred(ob=obs[:, 0, :], acs=acs, times=times)
        l2_losses = jnp.sum((obs - obs_pred) ** 2, axis=-1) # (batch_size, ep_len)
        weighed_mse = jnp.mean(discount_array * l2_losses)
        return weighed_mse

    def get_data(sample_key):
        traj = replay_buffer.sample_rollouts(batch_size=train_config["batch_size"], key=sample_key)
        obs = utils.split_arr(np.array(traj["observations"]), length=train_config["ep_len"], stride=train_config["stride"])
        acs = utils.split_arr(np.array(traj["actions"]), length=train_config["ep_len"], stride=train_config["stride"])
        dts = utils.split_arr(np.array(traj["dts"])[..., np.newaxis], length=train_config["ep_len"], stride=train_config["stride"]).squeeze(-1)
        batch_size, num_splitted, train_ep_len, ob_dim = obs.shape
        ac_dim = acs.shape[-1]
        obs = jnp.array(obs).reshape(batch_size * num_splitted, train_ep_len, ob_dim)
        acs = jnp.array(acs).reshape(batch_size * num_splitted, train_ep_len, ac_dim)
        times = jnp.cumsum(dts, axis=-1).reshape(batch_size * num_splitted, train_ep_len)
        return obs, acs, times

    losses = []
    for step in trange(train_config["steps"]):
        sample_key, key = jax.random.split(key)
        obs, acs, times = get_data(sample_key)
        loss, grad = get_loss_grad(neural_ode, obs, acs, times)
        updates, opt_state = optim.update(grad, opt_state, neural_ode)
        neural_ode = eqx.apply_updates(neural_ode, updates)
        losses.append(loss.item())

    plt.plot(np.arange(len(losses)), losses)
    agent.neural_odes[i] = neural_ode
    agent.optims[i] = optim
    agent.optim_states[i] = opt_state
    return agent, losses
    

In [4]:
dt_sampler = ConstantSampler(0.05)
mpc_dt_sampler = ConstantSampler(0.05)
env = PendulumEnv(dt_sampler=dt_sampler)

agent_key, key = jax.random.split(key)
neural_ode_name = "vanilla"
neural_ode_kwargs = {
    "ode_dt0": 0.05,
    "mlp_dynamics_setup": {
        "hidden_size": 128,
        "num_layers": 4,
        "activation": "tanh",
        "output_activation": "identity",
    }
}
optimizer_name = "adamw"
optimizer_kwargs = {"learning_rate": 1e-3}
mb_agent = ODEAgent(
    env=env,
    key=agent_key,
    neural_ode_name=neural_ode_name,
    neural_ode_kwargs=neural_ode_kwargs,
    optimizer_name=optimizer_name,
    optimizer_kwargs=optimizer_kwargs,
    ensemble_size=1,
    train_discount=1,
    mpc_horizon_steps=20,
    mpc_dt_sampler=mpc_dt_sampler,
    mpc_strategy="random",
    mpc_discount=0.95,
    mpc_num_action_sequences=1000,
    cem_num_iters=4,
    cem_num_elites=5,
    cem_alpha=1,
)
with open("1000_random_replay_buffer", "rb") as f:
    replay_buffer = pickle.load(f)

In [5]:
train_config = {
    "batch_size": 64,
    "steps": 1000,
    "ep_len": 20,
    "stride": 1,
    "discount": 1.0,
}

In [6]:
train_key, key = jax.random.split(key)
train(mb_agent, 0, replay_buffer, train_config, key=train_key)

  0%|          | 0/1000 [01:24<?, ?it/s]


KeyboardInterrupt: 