In [153]:
%load_ext autoreload
%autoreload 2
%load_ext line_profiler
from cs285.envs.pendulum.pendulum_env import PendulumEnv
from cs285.envs.dt_sampler import ConstantSampler
from cs285.infrastructure.replay_buffer import ReplayBufferTrajectories
from cs285.infrastructure.utils import sample_n_trajectories, RandomPolicy
from typing import Callable, Optional, Tuple, Sequence
import numpy as np
import torch.nn as nn
import torch
import gym
from cs285.infrastructure import pytorch_util as ptu
from torchdiffeq import odeint
from tqdm import trange
import jax
import jax.numpy as jnp
import equinox as eqx
import diffrax
from diffrax import diffeqsolve, Dopri5
import optax

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload
The line_profiler extension is already loaded. To reload it, use:
  %reload_ext line_profiler


In [139]:
key = jax.random.PRNGKey(0)

In [129]:
device="cuda"

In [172]:
class NeuralODE_jax(eqx.Module):
    _str_to_activation = {
        "relu": jax.nn.relu,
        "tanh": jax.nn.tanh,
        "leaky_relu": jax.nn.leaky_relu,
        "sigmoid": jax.nn.sigmoid,
        "selu": jax.nn.selu,
        "softplus": jax.nn.softplus,
        "identity": lambda x: x,
    }
    mlp: eqx.nn.MLP
    def __init__(
            self,
            hidden_size,
            num_layers,
            ob_dim,
            ac_dim,
            key,
            activation="relu",
            output_activation="identity",
        ):
        super().__init__()
        activation = self._str_to_activation[activation]
        output_activation = self._str_to_activation[output_activation]
        # hidden_size is an integer
        self.mlp = eqx.nn.MLP(in_size=ob_dim+ac_dim,
                              out_size=ob_dim,
                              width_size=hidden_size,
                              depth=num_layers,
                              activation=activation,
                              final_activation=output_activation,
                              key=key)

    @jax.jit
    def __call__(self, t, y, args):
        # args is a dictionary that contains times and actions
        times = args["times"] # (ep_len,)
        actions = args["actions"] # (ep_len, ac_dim)
        idx = jnp.searchsorted(times, t, side="right") - 1
        action = actions[idx] # (ac_dim)
        # althoug I believe this should also work for batched 
        return self.mlp(jnp.concatenate((y, action), axis=-1))
    
class ODEAgent_jax():
    def __init__(
        self,
        env: gym.Env,
        key,
        hidden_size: int,
        num_layers: int,
        ensemble_size: int,
        train_timestep: float,
        mpc_horizon_steps: int,
        mpc_timestep: float,
        mpc_strategy: str,
        mpc_num_action_sequences: int,
        cem_num_iters: Optional[int] = None,
        cem_num_elites: Optional[int] = None,
        cem_alpha: Optional[float] = None,
        activation: str = "relu",
        output_activation: str = "identity",
        lr: float=0.001
    ):
        # super().__init__()
        self.env = env
        self.train_timestep = train_timestep
        self.mpc_horizon_steps = mpc_horizon_steps # in terms of timesteps
        self.mpc_strategy = mpc_strategy
        self.mpc_num_action_sequences = mpc_num_action_sequences
        self.cem_num_iters = cem_num_iters
        self.cem_num_elites = cem_num_elites
        self.cem_alpha = cem_alpha
        self.mpc_timestep = mpc_timestep # when evaluating

        assert mpc_strategy in (
            "random",
            "cem",
        ), f"'{mpc_strategy}' is not a valid MPC strategy"

        # ensure the environment is state-based
        assert len(env.observation_space.shape) == 1
        assert len(env.action_space.shape) == 1

        self.ob_dim = env.observation_space.shape[0]
        self.ac_dim = env.action_space.shape[0]

        self.ensemble_size = ensemble_size
        keys = jax.random.split(key, ensemble_size)
        self.ode_functions = [NeuralODE_jax(
            hidden_size=hidden_size,
            num_layers=num_layers,
            ob_dim=self.ob_dim,
            ac_dim=self.ac_dim,
            activation=activation,
            output_activation=output_activation,
            key = keys[n]
            ) for n in range(ensemble_size)]
        self.optims = [optax.adamw(lr) for _ in range(ensemble_size)]
        self.optim_states = [self.optims[n].init(eqx.filter(self.ode_functions[n].mlp, eqx.is_array)) for n in range(self.ensemble_size)]

        self.solver = Dopri5()
    
    def update(self, i: int, obs: np.ndarray, acs: np.ndarray, times: np.ndarray):
        """
        Update self.dynamics_models[i] using the given trajectory

        Args:
            i: index of the dynamics model to update
            obs: (ep_len, ob_dim)
            acs: (ep_len, ac_dim)
            times: (ep_len)
        """
        @eqx.filter_value_and_grad
        def loss_grad(ode_func):
            sol = diffeqsolve(
                diffrax.ODETerm(ode_func), 
                self.solver, 
                t0=times[0], 
                t1=times[-1], 
                dt0=self.train_timestep,
                y0 = obs[0, :],
                args={"times": times, "actions": acs},
                saveat=diffrax.SaveAt(ts=times)
            )
            assert sol.ys.shape == obs.shape
            return jnp.mean((sol.ys - obs) ** 2) # do we want a  "discount"-like trick

        @eqx.filter_jit
        def make_step(ode_func, optim, opt_state):
            loss, grad = loss_grad(ode_func)
            updates, opt_state = optim.update(grad, opt_state)
            ode_func = eqx.apply_updates(ode_func, updates)
            return loss, ode_func, opt_state
        
        ode_func, optim, opt_state = self.ode_functions[i], self.optims[i], self.optim_states[i]
        loss, ode_func, opt_state = make_step(ode_func, optim, opt_state)
        self.ode_functions[i], self.optim_states[i] = ode_func, opt_state
        return loss.item()

In [173]:
dt_sampler = ConstantSampler(dt=0.05)
env = PendulumEnv(
    dt_sampler=dt_sampler
)
mb_agent_jas = ODEAgent_jax(
    env=env,
    hidden_size=128,
    num_layers=4,
    ensemble_size=10,
    train_timestep=0.005,
    mpc_horizon_steps=100,
    mpc_timestep=0.005,
    mpc_strategy="random",
    mpc_num_action_sequences=10,
    key=key
)
replay_buffer = ReplayBufferTrajectories(seed=0)
trajs, _ = sample_n_trajectories(env, RandomPolicy(env=env), ntraj=10, max_length=200)
replay_buffer.add_rollouts(trajs)

for n in trange(1000):
    for i in range(mb_agent_jas.ensemble_size):
        traj = replay_buffer.sample_rollout()
        mb_agent_jas.update(i, traj["observations"], traj["actions"], jnp.cumsum(traj["dts"]))

100%|██████████| 10/10 [00:00<00:00, 55.33it/s]
  0%|          | 0/1000 [00:00<?, ?it/s]


TypeError: Cannot interpret value of type <class 'jax._src.custom_derivatives.custom_jvp'> as an abstract array; it does not have a dtype attribute

In [None]:
class NeuralODE(nn.Module):
    _str_to_activation = {
        "relu": nn.ReLU(),
        "tanh": nn.Tanh(),
        "leaky_relu": nn.LeakyReLU(),
        "sigmoid": nn.Sigmoid(),
        "selu": nn.SELU(),
        "softplus": nn.Softplus(),
        "identity": nn.Identity(),
    }
    def __init__(self, hidden_dims, ob_dim, ac_dim, activation="relu", output_activation='identity'):
        super().__init__()
        self.ac_dim = ac_dim
        self.ob_dim = ob_dim
        activation = self._str_to_activation[activation]
        output_activation = self._str_to_activation[output_activation]
        layers = []
        hidden_dims = [ob_dim + ac_dim] + hidden_dims
        for n in range(len(hidden_dims) - 1):
            layers.append(nn.Linear(hidden_dims[n], hidden_dims[n+1]))
            layers.append(activation)
        layers.append(nn.Linear(hidden_dims[-1], ob_dim))
        layers.append(output_activation)
        self.net = nn.Sequential(*layers)

        for m in self.net.modules():
            if isinstance(m, nn.Linear):
                nn.init.normal_(m.weight, mean=0, std=0.1)
                nn.init.constant_(m.bias, val=0)
    
    def update_action(self, actions: torch.Tensor, times: torch.Tensor):
        ep_len = actions.shape[0]
        assert actions.shape == (ep_len, self.ac_dim) and times.shape == (ep_len,)
        # times = times - times[0] # start with t=0
        # right now, do not assume t0 = 0
        self.register_buffer("times", times)
        self.register_buffer("actions", actions)

    def _get_action(self, t):
        idx = torch.searchsorted(self.times, t, right=True) - 1
        return self.actions[idx]

    def forward(self, t, y):
        ac = self._get_action(t)
        return self.net(torch.cat((y, ac), dim=-1))

    
class ODEAgent(nn.Module):
    def __init__(
        self,
        env: gym.Env,
        hidden_dims: Sequence[int],
        make_optimizer: Callable[[nn.ParameterList], torch.optim.Optimizer],
        ensemble_size: int,
        mpc_horizon_steps: int,
        mpc_timestep: float,
        mpc_strategy: str,
        mpc_num_action_sequences: int,
        cem_num_iters: Optional[int] = None,
        cem_num_elites: Optional[int] = None,
        cem_alpha: Optional[float] = None,
        activation: str = "relu",
        output_activation: str = "identity"
    ):
        super().__init__()
        self.env = env
        self.mpc_horizon_steps = mpc_horizon_steps # in terms of timesteps
        self.mpc_strategy = mpc_strategy
        self.mpc_num_action_sequences = mpc_num_action_sequences
        self.cem_num_iters = cem_num_iters
        self.cem_num_elites = cem_num_elites
        self.cem_alpha = cem_alpha
        self.mpc_timestep = mpc_timestep # when evaluating

        assert mpc_strategy in (
            "random",
            "cem",
        ), f"'{mpc_strategy}' is not a valid MPC strategy"

        # ensure the environment is state-based
        assert len(env.observation_space.shape) == 1
        assert len(env.action_space.shape) == 1

        self.ob_dim = env.observation_space.shape[0]
        self.ac_dim = env.action_space.shape[0]

        self.ensemble_size = ensemble_size
        self.ode_functions = nn.ModuleList(
            [
                NeuralODE(
                    hidden_dims,
                    self.ob_dim,
                    self.ac_dim,
                    activation,
                    output_activation
                ).to(ptu.device)
                for _ in range(ensemble_size)
            ]
        )
        self.optimizer = make_optimizer(self.ode_functions.parameters())
        self.loss_fn = nn.MSELoss()

    def update(self, i: int, obs: np.ndarray, acs: np.ndarray, times: np.ndarray):
        """
        Update self.dynamics_models[i] using the given trajectory

        Args:
            i: index of the dynamics model to update
            obs: (ep_len, ob_dim)
            acs: (ep_len, ac_dim)
            times: (ep_len)
        """
        obs = ptu.from_numpy(obs)
        acs = ptu.from_numpy(acs)
        times = ptu.from_numpy(times)
        ode_func = self.ode_functions[i]
        ode_func.update_action(acs, times)
        ode_out = odeint(ode_func, obs[0, :], times) # t0 = times[0] in torchdiffeq
        # possible problem: the ode function is only "evaluating" on times
        # I am not sure whether there is an implicit dt or dt[i] = times[i+1] - times[i]
        # I know for diffrax in jax, there is a separate dt argument passed into odeint()
        assert ode_out.shape == obs.shape
        loss = self.loss_fn(ode_out, obs)

        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()

        return ptu.to_numpy(loss)
    
    def update_statistics(self, **kwargs):
        pass

    @torch.no_grad()
    def evaluate_action_sequences(self, obs: np.ndarray, acs: np.ndarray):
        obs = ptu.from_numpy(obs) # (ob_dim)
        acs_np = acs
        acs = ptu.from_numpy(acs) # (N, steps, ac_dim)
        times = torch.linspace(0, (self.mpc_horizon_steps - 1) * self.mpc_timestep, self.mpc_horizon_steps, device=ptu.device)
        reward_arr = np.zeros((self.mpc_num_action_sequences, self.ensemble_size))
        for n in range(self.mpc_num_action_sequences):
            for i in range(self.ensemble_size):
                ode_func = self.ode_functions[i]
                ode_func.update_action(acs[n, :, :], times)
                ode_out = odeint(ode_func, obs, times) # (steps, ob_dim)
                rewards, _ = self.env.get_reward(ptu.to_numpy(ode_out), acs_np[n, :, :])
                avg_reward = np.mean(rewards)
                reward_arr[n, i] = avg_reward
        return np.mean(reward_arr, axis=1)
    # maybe I should manually implement batched Euler solver
    # to make inference faster

    @torch.no_grad()
    def get_action(self, obs: np.ndarray):
        """
        Choose the best action using model-predictive control.

        Args:
            obs: (ob_dim,)
        """
        # always start with uniformly random actions
        actions = np.random.uniform(
            self.env.action_space.low,
            self.env.action_space.high,
            size=(self.mpc_num_action_sequences, self.mpc_horizon_steps, self.ac_dim),
        )

        if self.mpc_strategy == "random":
            # evaluate each action sequence and return the best one
            rewards = self.evaluate_action_sequences(obs, actions)
            assert rewards.shape == (self.mpc_num_action_sequences,)
            best_index = np.argmax(rewards)
            return actions[best_index, 0, :]
        elif self.mpc_strategy == "cem":
            raise NotImplementedError
        else:
            raise ValueError(f"Invalid MPC strategy '{self.mpc_strategy}'")

In [101]:
dt_sampler = ConstantSampler(dt=0.05)
env = PendulumEnv(
    dt_sampler=dt_sampler
)
mb_agent = ODEAgent(
    env=env,
    hidden_dims=[128, 128, 128],
    make_optimizer=lambda param_list: torch.optim.AdamW(param_list),
    ensemble_size=10,
    mpc_horizon_steps=100,
    mpc_timestep=0.005,
    mpc_strategy="random",
    mpc_num_action_sequences=10,
)

In [102]:
ob = env.reset()

In [114]:
actions = np.random.uniform(
    mb_agent.env.action_space.low,
    mb_agent.env.action_space.high,
    size=(mb_agent.mpc_num_action_sequences, mb_agent.mpc_horizon_steps, mb_agent.ac_dim),
)

In [115]:
%lprun -f evaluate_action_sequences evaluate_action_sequences(mb_agent, ob, actions)

Timer unit: 1e-09 s

Total time: 82.8017 s
File: /tmp/ipykernel_186605/3267144233.py
Function: evaluate_action_sequences at line 1

Line #      Hits         Time  Per Hit   % Time  Line Contents
     1                                           def evaluate_action_sequences(agent, obs: np.ndarray, acs: np.ndarray):
     2         1      13825.0  13825.0      0.0      with torch.no_grad():
     3         1      24795.0  24795.0      0.0          obs = ptu.from_numpy(obs)
     4         1        116.0    116.0      0.0          acs_np = acs
     5         1      83303.0  83303.0      0.0          acs = ptu.from_numpy(acs)
     6         1      24536.0  24536.0      0.0          times = torch.linspace(0, (agent.mpc_horizon_steps - 1) * agent.mpc_timestep, agent.mpc_horizon_steps, device=ptu.device)
     7         1       6850.0   6850.0      0.0          reward_arr = np.zeros((agent.mpc_num_action_sequences, agent.ensemble_size))
     8        11       2248.0    204.4      0.0          for