In [1]:
import os
import yaml
import jax
from jax import numpy as jnp
from jax_cosmo.scipy.interpolate import InterpolatedUnivariateSpline
import functools
from dial_mpc.utils.io_utils import load_dataclass_from_dict
from dial_mpc.core.dial_config import DialConfig
import jax
import jax.numpy as jnp
from functools import partial
from brax.base import System
from brax.envs.base import PipelineEnv
from dataclasses import dataclass

In [2]:
@dataclass
class BaseEnvConfig:
    task_name: str = "default"
    randomize_tasks: bool = False  # Whether to randomize the task.
    # P gain, or a list of P gains for each joint.
    kp: float = 30.0
    # D gain, or a list of D gains for each joint.
    kd: float = 1.0
    debug: bool = False
    # dt of the environment step, not the underlying simulator step.
    dt: float = 0.02
    # timestep of the underlying simulator step. user is responsible for making sure it matches their model.
    timestep: float = 0.02
    backend: str = "mjx"  # backend of the environment.
    # control method for the joints, either "torque" or "position"
    leg_control: str = "torque"
    action_scale: float = 1.0  # scale of the action space.


class BaseEnv(PipelineEnv):
    def __init__(self, config: BaseEnvConfig):
        assert config.dt % config.timestep == 0, "timestep must be divisible by dt"
        self._config = config
        n_frames = int(config.dt / config.timestep)
        sys = self.make_system(config)
        super().__init__(sys, config.backend, n_frames, config.debug)

        # joint limit definitions
        self.physical_joint_range = self.sys.jnt_range[1:]
        self.joint_range = self.physical_joint_range
        self.joint_torque_range = self.sys.actuator_ctrlrange

        # number of everything
        self._nv = self.sys.nv
        self._nq = self.sys.nq

    def make_system(self, config: BaseEnvConfig) -> System:
        """
        Make the system for the environment. Called in BaseEnv.__init__.
        """
        raise NotImplementedError

    @partial(jax.jit, static_argnums=(0,))
    def act2joint(self, act: jax.Array) -> jax.Array:
        act_normalized = (
            act * self._config.action_scale + 1.0
        ) / 2.0  # normalize to [0, 1]
        joint_targets = self.joint_range[:, 0] + act_normalized * (
            self.joint_range[:, 1] - self.joint_range[:, 0]
        )  # scale to joint range
        joint_targets = jnp.clip(
            joint_targets,
            self.physical_joint_range[:, 0],
            self.physical_joint_range[:, 1],
        )
        return joint_targets

    @partial(jax.jit, static_argnums=(0,))
    def act2tau(self, act: jax.Array, pipline_state) -> jax.Array:
        joint_target = self.act2joint(act)

        q = pipline_state.qpos[7:]
        q = q[: len(joint_target)]
        qd = pipline_state.qvel[6:]
        qd = qd[: len(joint_target)]
        q_err = joint_target - q
        tau = self._config.kp * q_err - self._config.kd * qd

        tau = jnp.clip(
            tau, self.joint_torque_range[:, 0], self.joint_torque_range[:, 1]
        )
        return tau

In [3]:
from dataclasses import dataclass
from typing import Any, Union
import jax
import jax.numpy as jnp
from brax import math
import brax.base as base
from brax.base import System
from brax.envs.base import State
from brax.io import mjcf
import mujoco

from dial_mpc.utils.io_utils import get_model_path


@dataclass
class UnitreeGo2EnvConfig(BaseEnvConfig):
    kp: Union[float, jax.Array] = 30.0
    kd: Union[float, jax.Array] = 0.0
    default_vx: float = 1.0
    default_vy: float = 0.0
    default_vyaw: float = 0.0
    ramp_up_time: float = 2.0
    gait: str = "trot"


class UnitreeGo2Env(BaseEnv):
    def __init__(self, config: UnitreeGo2EnvConfig):
        super().__init__(config)

        self._foot_radius = 0.0175
        self._torso_idx = mujoco.mj_name2id(
            self.sys.mj_model, mujoco.mjtObj.mjOBJ_BODY.value, "base"
        )

        self._init_q = jnp.array(self.sys.mj_model.keyframe("home").qpos)
        self._default_pose = self.sys.mj_model.keyframe("home").qpos[7:]

        self.joint_range = jnp.array(
            [
                [-0.5, 0.5],
                [0.4, 1.4],
                [-2.3, -0.85],
                [-0.5, 0.5],
                [0.4, 1.4],
                [-2.3, -0.85],
                [-0.5, 0.5],
                [0.4, 1.4],
                [-2.3, -1.3],
                [-0.5, 0.5],
                [0.4, 1.4],
                [-2.3, -1.3],
            ]
        )
        feet_site = [
            "FL_foot",
            "FR_foot",
            "RL_foot",
            "RR_foot",
        ]
        feet_site_id = [
            mujoco.mj_name2id(self.sys.mj_model, mujoco.mjtObj.mjOBJ_SITE.value, f)
            for f in feet_site
        ]
        assert not any(id_ == -1 for id_ in feet_site_id), "Site not found."
        self._feet_site_id = jnp.array(feet_site_id)

    def make_system(self, config: UnitreeGo2EnvConfig) -> System:
        model_path = get_model_path("unitree_go2", "mjx_scene_force.xml")
        sys = mjcf.load(model_path)
        sys = sys.tree_replace({"opt.timestep": config.timestep})
        return sys

    def reset(self) -> State:  # pytype: disable=signature-mismatch
        pipeline_state = self.pipeline_init(self._init_q, jnp.zeros(self._nv))

        state_info = {
            "pos_tar": jnp.array([0.282, 0.0, 0.3]),
            "vel_tar": jnp.array([0.0, 0.0, 0.0]),
            "ang_vel_tar": jnp.array([0.0, 0.0, 0.0]),
            "yaw_tar": 0.0,
            "step": 0,
            "z_feet": jnp.zeros(4),
            "z_feet_tar": jnp.zeros(4),
            "randomize_target": self._config.randomize_tasks,
            "last_contact": jnp.zeros(4, dtype=jnp.bool),
            "feet_air_time": jnp.zeros(4),
        }

        obs = self._get_obs(pipeline_state, state_info)
        reward, done = jnp.zeros(2)
        metrics = {}
        state = State(pipeline_state, obs, reward, done, metrics, state_info)
        return state

    def step(self, state: State, action: jax.Array) -> State:
        # physics step
        joint_targets = self.act2joint(action)
        if self._config.leg_control == "position":
            ctrl = joint_targets
        elif self._config.leg_control == "torque":
            ctrl = self.act2tau(action, state.pipeline_state)
        pipeline_state = self.pipeline_step(state.pipeline_state, ctrl)
        x, xd = pipeline_state.x, pipeline_state.xd

        # observation data
        obs = self._get_obs(pipeline_state, state.info)
        up = jnp.array([0.0, 0.0, 1.0])
        joint_angles = pipeline_state.q[7:]
        done = jnp.dot(math.rotate(up, x.rot[self._torso_idx - 1]), up) < 0
        done |= jnp.any(joint_angles < self.joint_range[:, 0])
        done |= jnp.any(joint_angles > self.joint_range[:, 1])
        done |= pipeline_state.x.pos[self._torso_idx - 1, 2] < 0.18
        done = done.astype(jnp.float32)
        state = state.replace(
            pipeline_state=pipeline_state, obs=obs, done=done
        )
        return state

    def _get_obs(
        self,
        pipeline_state: base.State,
        state_info: dict[str, Any],
    ) -> jax.Array:
        obs = jnp.concatenate(
            [
                state_info["vel_tar"],
                state_info["ang_vel_tar"],
                pipeline_state.ctrl,
                pipeline_state.qpos,
                pipeline_state.qvel,
            ]
        )
        return obs


In [4]:
# Tell XLA to use Triton GEMM, this improves steps/sec by ~30% on some GPUs
xla_flags = os.environ.get("XLA_FLAGS", "")
xla_flags += " --xla_gpu_triton_gemm_any=True"
os.environ["XLA_FLAGS"] = xla_flags


In [5]:
def rollout_us(step_env, state, us):
    def step(state, u):
        state = step_env(state, u)
        return state, (state.reward, state.pipeline_state)

    _, (rews, pipline_states) = jax.lax.scan(step, state, us)
    return rews, pipline_states

class XLAEngine:
    def __init__(self, 
                 env,
                 ctrl_dt=0.02,
                 Hsample=15,
                 Hnode=5,
                 ):
        self.env = env
        self.nu = env.action_size

        # node to u
        self.ctrl_dt = ctrl_dt
        self.step_us = jnp.linspace(0, self.ctrl_dt * Hsample, Hsample)
        self.step_nodes = jnp.linspace(0, self.ctrl_dt * Hsample, Hnode)
        self.node_dt = self.ctrl_dt * (Hsample) / (Hnode)

        # setup function
        self.rollout_us = jax.jit(functools.partial(rollout_us, self.env.step))
        self.rollout_us_vmap = jax.jit(jax.vmap(self.rollout_us, in_axes=(None, 0)))
        self.node2u_vmap = jax.jit(
            jax.vmap(self.node2u, in_axes=(1), out_axes=(1))
        )  # process (horizon, node)
        self.u2node_vmap = jax.jit(jax.vmap(self.u2node, in_axes=(1), out_axes=(1)))
        self.node2u_vvmap = jax.jit(
            jax.vmap(self.node2u_vmap, in_axes=(0))
        )  # process (batch, horizon, node)
        self.u2node_vvmap = jax.jit(jax.vmap(self.u2node_vmap, in_axes=(0)))

    @functools.partial(jax.jit, static_argnums=(0,))
    def node2u(self, nodes):
        spline = InterpolatedUnivariateSpline(self.step_nodes, nodes, k=2)
        us = spline(self.step_us)
        return us

    @functools.partial(jax.jit, static_argnums=(0,))
    def u2node(self, us):
        spline = InterpolatedUnivariateSpline(self.step_us, us, k=2)
        nodes = spline(self.step_nodes)
        return nodes

    @functools.partial(jax.jit, static_argnums=(0,))
    def rollout(self, state, Y0s):
        # convert Y0s to us
        us = self.node2u_vvmap(Y0s)
        # esitimate mu_0tm1
        rewss, pipeline_statess = self.rollout_us_vmap(state, us)
        qss = pipeline_statess.q
        qdss = pipeline_statess.qd
        xss = pipeline_statess.x.pos
        info = {
            "qss": qss,
            "qdss": qdss,
            "xss": xss,
            "rewss": rewss,
        }
        return info


In [6]:

example_cfg_path = '/home/go2-laptop/dial-mpc/dial_mpc/examples/unitree_go2_trot.yaml'
config_dict = yaml.safe_load(open(example_cfg_path))
dial_config = load_dataclass_from_dict(DialConfig, config_dict)

In [7]:
config_dict

{'seed': 0,
 'output_dir': 'unitree_go2_trot',
 'n_steps': 400,
 'env_name': 'unitree_go2_walk',
 'Nsample': 2048,
 'Hsample': 16,
 'Hnode': 4,
 'Ndiffuse': 2,
 'Ndiffuse_init': 10,
 'temp_sample': 0.05,
 'horizon_diffuse_factor': 0.9,
 'traj_diffuse_factor': 0.5,
 'update_method': 'mppi',
 'dt': 0.02,
 'timestep': 0.02,
 'leg_control': 'torque',
 'action_scale': 1.0,
 'default_vx': 0.8,
 'default_vy': 0.0,
 'default_vyaw': 0.0,
 'ramp_up_time': 1.0,
 'gait': 'trot'}

In [8]:
rng = jax.random.PRNGKey(seed=dial_config.seed)
env_config = load_dataclass_from_dict(UnitreeGo2EnvConfig, config_dict, convert_list_to_array=True)
env = UnitreeGo2Env(env_config)
reset_env = jax.jit(env.reset)
step_env = jax.jit(env.step)

In [9]:
mbdpi = XLAEngine(env)
state_init = reset_env()

In [10]:
Ys = jnp.zeros((2048,5,12))
info = mbdpi.rollout(state_init, Ys)

In [11]:
info['qss'][0,:,:3]

Array([[-3.8305271e-04, -1.0669973e-07,  2.6867348e-01],
       [-1.1256181e-03, -3.4669438e-07,  2.6583236e-01],
       [-2.7895288e-03, -7.7797552e-07,  2.6300976e-01],
       [-5.4255854e-03, -1.0492379e-06,  2.6037541e-01],
       [-8.9999186e-03, -1.0691543e-06,  2.5798476e-01],
       [-1.3402898e-02, -9.0620722e-07,  2.5578323e-01],
       [-1.8490788e-02, -6.8422798e-07,  2.5366411e-01],
       [-2.4124030e-02, -4.9753572e-07,  2.5152180e-01],
       [-3.0190257e-02, -3.9491371e-07,  2.4927591e-01],
       [-3.6611184e-02, -3.8397195e-07,  2.4687603e-01],
       [-4.3337241e-02, -4.4588114e-07,  2.4429187e-01],
       [-5.0336055e-02, -5.5281714e-07,  2.4150051e-01],
       [-5.7580158e-02, -6.7653775e-07,  2.3847647e-01],
       [-6.5037802e-02, -7.9567832e-07,  2.3518856e-01],
       [-7.2668836e-02, -8.9602412e-07,  2.3160297e-01]], dtype=float32)

In [35]:
state_init.pipeline_state.qpos[:3]

Array([0.  , 0.  , 0.27], dtype=float32)