In [25]:
import os
import yaml
import jax
from jax import numpy as jnp
from jax_cosmo.scipy.interpolate import InterpolatedUnivariateSpline
import functools
from dial_mpc.utils.io_utils import load_dataclass_from_dict
from dial_mpc.core.dial_config import DialConfig
from dial_mpc.envs.unitree_go2_env import UnitreeGo2EnvConfig, UnitreeGo2Env

In [26]:
# Tell XLA to use Triton GEMM, this improves steps/sec by ~30% on some GPUs
xla_flags = os.environ.get("XLA_FLAGS", "")
xla_flags += " --xla_gpu_triton_gemm_any=True"
os.environ["XLA_FLAGS"] = xla_flags


In [27]:
def rollout_us(step_env, state, us):
    def step(state, u):
        state = step_env(state, u)
        return state, (state.reward, state.pipeline_state)

    _, (rews, pipline_states) = jax.lax.scan(step, state, us)
    return rews, pipline_states

class XLAEngine:
    def __init__(self, 
                 env,
                 ctrl_dt=0.02,
                 Hsample=15,
                 Hnode=5,
                 ):
        self.env = env
        self.nu = env.action_size

        # node to u
        self.ctrl_dt = ctrl_dt
        self.step_us = jnp.linspace(0, self.ctrl_dt * Hsample, Hsample)
        self.step_nodes = jnp.linspace(0, self.ctrl_dt * Hsample, Hnode)
        self.node_dt = self.ctrl_dt * (Hsample) / (Hnode)

        # setup function
        self.rollout_us = jax.jit(functools.partial(rollout_us, self.env.step))
        self.rollout_us_vmap = jax.jit(jax.vmap(self.rollout_us, in_axes=(None, 0)))
        self.node2u_vmap = jax.jit(
            jax.vmap(self.node2u, in_axes=(1), out_axes=(1))
        )  # process (horizon, node)
        self.u2node_vmap = jax.jit(jax.vmap(self.u2node, in_axes=(1), out_axes=(1)))
        self.node2u_vvmap = jax.jit(
            jax.vmap(self.node2u_vmap, in_axes=(0))
        )  # process (batch, horizon, node)
        self.u2node_vvmap = jax.jit(jax.vmap(self.u2node_vmap, in_axes=(0)))

    @functools.partial(jax.jit, static_argnums=(0,))
    def node2u(self, nodes):
        spline = InterpolatedUnivariateSpline(self.step_nodes, nodes, k=2)
        us = spline(self.step_us)
        return us

    @functools.partial(jax.jit, static_argnums=(0,))
    def u2node(self, us):
        spline = InterpolatedUnivariateSpline(self.step_us, us, k=2)
        nodes = spline(self.step_nodes)
        return nodes

    @functools.partial(jax.jit, static_argnums=(0,))
    def rollout(self, state, Y0s):
        # convert Y0s to us
        us = self.node2u_vvmap(Y0s)
        # esitimate mu_0tm1
        rewss, pipeline_statess = self.rollout_us_vmap(state, us)
        qss = pipeline_statess.q
        qdss = pipeline_statess.qd
        xss = pipeline_statess.x.pos
        info = {
            "qss": qss,
            "qdss": qdss,
            "xss": xss,
            "rewss": rewss,
        }
        return info


In [28]:

example_cfg_path = '/home/go2-laptop/dial-mpc/dial_mpc/examples/unitree_go2_trot.yaml'
config_dict = yaml.safe_load(open(example_cfg_path))
dial_config = load_dataclass_from_dict(DialConfig, config_dict)

In [30]:
config_dict

{'seed': 0,
 'output_dir': 'unitree_go2_trot',
 'n_steps': 400,
 'env_name': 'unitree_go2_walk',
 'Nsample': 2048,
 'Hsample': 16,
 'Hnode': 4,
 'Ndiffuse': 2,
 'Ndiffuse_init': 10,
 'temp_sample': 0.05,
 'horizon_diffuse_factor': 0.9,
 'traj_diffuse_factor': 0.5,
 'update_method': 'mppi',
 'dt': 0.02,
 'timestep': 0.02,
 'leg_control': 'torque',
 'action_scale': 1.0,
 'default_vx': 0.8,
 'default_vy': 0.0,
 'default_vyaw': 0.0,
 'ramp_up_time': 1.0,
 'gait': 'trot'}

In [31]:
rng = jax.random.PRNGKey(seed=dial_config.seed)
env_config = load_dataclass_from_dict(UnitreeGo2EnvConfig, config_dict, convert_list_to_array=True)
env = UnitreeGo2Env(env_config)
reset_env = jax.jit(env.reset)
step_env = jax.jit(env.step)

In [32]:
mbdpi = XLAEngine(env)
rng, rng_reset = jax.random.split(rng)
state_init = reset_env(rng_reset)

In [33]:
Ys = jnp.zeros((2048,5,12))
info = mbdpi.rollout(state_init, Ys)

In [34]:
info['qss'][0,:,:3]

Array([[-3.8305271e-04, -1.0669973e-07,  2.6867348e-01],
       [-1.1256181e-03, -3.4669438e-07,  2.6583236e-01],
       [-2.7895288e-03, -7.7797552e-07,  2.6300976e-01],
       [-5.4255854e-03, -1.0492379e-06,  2.6037541e-01],
       [-8.9999186e-03, -1.0691543e-06,  2.5798476e-01],
       [-1.3402898e-02, -9.0620722e-07,  2.5578323e-01],
       [-1.8490788e-02, -6.8422798e-07,  2.5366411e-01],
       [-2.4124030e-02, -4.9753572e-07,  2.5152180e-01],
       [-3.0190257e-02, -3.9491371e-07,  2.4927591e-01],
       [-3.6611184e-02, -3.8397195e-07,  2.4687603e-01],
       [-4.3337241e-02, -4.4588114e-07,  2.4429187e-01],
       [-5.0336055e-02, -5.5281714e-07,  2.4150051e-01],
       [-5.7580158e-02, -6.7653775e-07,  2.3847647e-01],
       [-6.5037802e-02, -7.9567832e-07,  2.3518856e-01],
       [-7.2668836e-02, -8.9602412e-07,  2.3160297e-01]], dtype=float32)

In [35]:
state_init.pipeline_state.qpos[:3]

Array([0.  , 0.  , 0.27], dtype=float32)