# Mujoco RL Renderer
This notebook is the rendering part of the mujoco RL pipeline

## Common Initial Setup

### Install packages, clone repo

In [None]:
# drive_path = '/content/drive/MyDrive/ThesisLab'

# !pip install mujoco mujoco_mjx brax mediapy --no-index \
# --find-links $drive_path/mypackages \
# --target /content/mypackages

# Install MuJoCo, MJX, and Brax
!pip install mujoco
!pip install mujoco_mjx
!pip install brax

In [None]:
# Install mediapy
!command -v ffmpeg >/dev/null || (apt update && apt install -y ffmpeg)
!pip install -q mediapy

In [None]:
# Clone the menagerie
!git clone https://github.com/google-deepmind/mujoco_menagerie.git

### Setup output directory and GPU rendering

In [None]:
media_dir = '/content/media'
!mkdir -p $media_dir

In [None]:
# Set up GPU rendering.
from google.colab import files
import distutils.util
import os
import subprocess
if subprocess.run('nvidia-smi').returncode:
  raise RuntimeError(
      'Cannot communicate with GPU. '
      'Make sure you are using a GPU Colab runtime. '
      'Go to the Runtime menu and select Choose runtime type.')

# Add an ICD config so that glvnd can pick up the Nvidia EGL driver.
# This is usually installed as part of an Nvidia driver package, but the Colab
# kernel doesn't install its driver via APT, and as a result the ICD is missing.
# (https://github.com/NVIDIA/libglvnd/blob/master/src/EGL/icd_enumeration.md)
NVIDIA_ICD_CONFIG_PATH = '/usr/share/glvnd/egl_vendor.d/10_nvidia.json'
if not os.path.exists(NVIDIA_ICD_CONFIG_PATH):
  with open(NVIDIA_ICD_CONFIG_PATH, 'w') as f:
    f.write("""{
    "file_format_version" : "1.0.0",
    "ICD" : {
        "library_path" : "libEGL_nvidia.so.0"
    }
}
""")

In [None]:
# Configure MuJoCo to use the EGL rendering backend (requires GPU)
print('Setting environment variable to use GPU rendering:')
%env MUJOCO_GL=egl

### Imports

In [None]:
# Import all necessary packages

# Supporting
import os
os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"] = "0.8"  # 0.9 causes too much lag.
import time
import itertools
import functools
from datetime import datetime
from etils import epath
from typing import Any, Dict, Sequence, Tuple, Callable, NamedTuple, Optional, Union, List
from ml_collections import config_dict
import matplotlib.pyplot as plt
import mediapy as media

# Math
import jax
import jax.numpy as jp
import numpy as np
np.set_printoptions(precision=3, suppress=True, linewidth=100)  # More legible printing from numpy.
from jax import config  # Analytical gradients work much better with double precision.
# config.update("jax_debug_nans", True)
# config.update("jax_enable_x64", True)
# config.update('jax_default_matmul_precision', 'high')
from flax.training import orbax_utils
from flax import struct
from orbax import checkpoint as ocp

# Sim
import mujoco
import mujoco.mjx as mjx

# Brax
from brax import base
from brax import envs
from brax import math
from brax.base import Base, Motion, Transform
from brax.base import State as PipelineState
from brax.envs.base import Env, PipelineEnv, State
from brax.mjx.base import State as MjxState
from brax.mjx.pipeline import _reformat_contact
from brax.training.acme import running_statistics
from brax.io import html, mjcf, model

# Algorithms
from brax.training.agents.apg import train as apg
from brax.training.agents.apg import networks as apg_networks
from brax.training.agents.ppo import train as ppo
from brax.training.agents.ppo import networks as ppo_networks

In [None]:
# Checking that mujoco installation succeeded
mujoco.MjModel.from_xml_string(xml='<mujoco/>')

### Render Unitree Go2 initial standing scene

In [None]:
menagerie_path = '/content/mujoco_menagerie'
xml_path = epath.Path(menagerie_path + '/unitree_go2/scene_mjx.xml').as_posix()

mj_model = mujoco.MjModel.from_xml_path(xml_path)

if 'renderer' not in dir():
    renderer = mujoco.Renderer(mj_model)

init_q = mj_model.keyframe('home').qpos

mj_data = mujoco.MjData(mj_model)
mj_data.qpos = init_q
mujoco.mj_forward(mj_model, mj_data)

renderer.update_scene(mj_data)
image = renderer.render()
media.write_image('/content/media/standing.png', image)
media.show_image(image)

### Rollout Rendering function

In [None]:
# Rendering Rollouts
def render_rollout(title, reset_fn, step_fn,
                   inference_fn, env,
                   n_steps=500, camera=None,
                   the_command=None,
                   seed=0):
    rng = jax.random.key(seed)
    render_every = 2
    state = reset_fn(rng)
    if the_command is not None:
      state.info['command'] = the_command
    rollout = [state.pipeline_state]

    for i in range(n_steps):
        act_rng, rng = jax.random.split(rng)
        ctrl, _ = inference_fn(state.obs, act_rng)
        state = step_fn(state, ctrl)
        if i % render_every == 0:
            rollout.append(state.pipeline_state)

    rendering = env.render(rollout, camera=camera)
    media.show_video(env.render(rollout, camera=camera),
                     fps=1.0 / (env.dt * render_every),
                     codec='gif')
    media.write_video(f'{media_dir}/{title}.gif',
                     env.render(rollout, camera=camera),
                     fps=1.0 / (env.dt * render_every),
                     codec='gif')

## Unitree Go2 APG

### Reference Kinematics

In [None]:
def cos_wave(t, step_period, scale):
    _cos_wave = -jp.cos(((2 * jp.pi) / step_period) * t)
    return _cos_wave * (scale / 2) + (scale / 2)


def dcos_wave(t, step_period, scale):
    """
    Derivative of the cos wave, for reference velocity
    """
    return ((scale * jp.pi) / step_period) * jp.sin(((2 * jp.pi) / step_period) * t)


def make_kinematic_ref(sinusoid, step_k, scale=0.3, dt=1 / 50):
    """
    Makes trotting kinematics for the 12 leg joints.
    step_k is the number of timesteps it takes to raise and lower a given foot.
    A gait cycle is 2 * step_k * dt seconds long.
    """

    _steps = jp.arange(step_k)
    step_period = step_k * dt
    t = _steps * dt

    wave = sinusoid(t, step_period, scale)
    # Commands for one step of an active front leg
    fleg_cmd_block = jp.concatenate(
        [jp.zeros((step_k, 1)),
         wave.reshape(step_k, 1),
         -2 * wave.reshape(step_k, 1)],
        axis=1
    )
    # Our standing config reverses front and hind legs
    h_leg_cmd_bloc = 1 * fleg_cmd_block

    block1 = jp.concatenate([
        jp.zeros((step_k, 3)),
        fleg_cmd_block,
        h_leg_cmd_bloc,
        jp.zeros((step_k, 3))],
        axis=1
    )

    block2 = jp.concatenate([
        fleg_cmd_block,
        jp.zeros((step_k, 3)),
        jp.zeros((step_k, 3)),
        h_leg_cmd_bloc],
        axis=1
    )
    # In one step cycle, both pairs of active legs have inactive and active phases
    step_cycle = jp.concatenate([block1, block2], axis=0)
    return step_cycle

In [None]:
poses = make_kinematic_ref(cos_wave, step_k=25)

frames = []
init_q = mj_model.keyframe('home').qpos
mj_data.qpos = init_q
default_ap = init_q[7:]

for i in range(len(poses)):
    mj_data.qpos[7:] = poses[i] + default_ap
    mujoco.mj_forward(mj_model, mj_data)
    renderer.update_scene(mj_data)
    frames.append(renderer.render())

media.show_video(frames, fps=50, codec='gif')
media.write_video(f'{media_dir}/stepping.gif', frames, fps=50, codec='gif')

### In-place Trotting Environment

In [None]:
def get_config():
    def get_default_rewards_config():
        default_config = config_dict.ConfigDict(
            dict(
                scales=config_dict.ConfigDict(
                    dict(
                        min_reference_tracking=-2.5 * 3e-3,  # to equalize the magnitude
                        reference_tracking=-1.0,
                        feet_height=-1.0
                    )
                )
            )
        )
        return default_config

    default_config = config_dict.ConfigDict(
        dict(rewards=get_default_rewards_config(), ))

    return default_config


# Math functions from (https://github.com/jiawei-ren/diffmimic)
def quaternion_to_matrix(quaternions):
    r, i, j, k = quaternions[..., 0], quaternions[..., 1], quaternions[..., 2], quaternions[..., 3]
    two_s = 2.0 / (quaternions * quaternions).sum(-1)

    o = jp.stack(
        (
            1 - two_s * (j * j + k * k),
            two_s * (i * j - k * r),
            two_s * (i * k + j * r),
            two_s * (i * j + k * r),
            1 - two_s * (i * i + k * k),
            two_s * (j * k - i * r),
            two_s * (i * k - j * r),
            two_s * (j * k + i * r),
            1 - two_s * (i * i + j * j),
        ),
        -1,
    )
    return o.reshape(quaternions.shape[:-1] + (3, 3))


def matrix_to_rotation_6d(matrix):
    batch_dim = matrix.shape[:-2]
    return matrix[..., :2, :].reshape(batch_dim + (6,))


def quaternion_to_rotation_6d(quaternion):
    return matrix_to_rotation_6d(quaternion_to_matrix(quaternion))


class TrotGo2(PipelineEnv):

    def __init__(
            self,
            termination_height: float = 0.25,
            **kwargs,
    ):
        step_k = kwargs.pop('step_k', 25)

        physics_steps_per_control_step = 10
        kwargs['n_frames'] = kwargs.get(
            'n_frames', physics_steps_per_control_step)

        mj_model = mujoco.MjModel.from_xml_path(xml_path)
        kp = 230
        mj_model.actuator_gainprm[:, 0] = kp
        mj_model.actuator_biasprm[:, 1] = -kp

        sys = mjcf.load_model(mj_model)

        super().__init__(sys=sys, **kwargs)

        self.termination_height = termination_height

        self._init_q = mj_model.keyframe('home').qpos

        self.err_threshold = 0.4  # diffmimic; value from paper.

        self._default_ap_pose = mj_model.keyframe('home').qpos[7:]
        self.reward_config = get_config()

        self.action_loc = self._default_ap_pose
        self.action_scale = jp.array([0.2, 0.8, 0.8] * 4)

        self.feet_inds = jp.array([21, 28, 35, 42])  # LF, RF, LH, RH

        #### Imitation reference
        kinematic_ref_qpos = make_kinematic_ref(
            cos_wave, step_k, scale=0.3, dt=self.dt)
        kinematic_ref_qvel = make_kinematic_ref(
            dcos_wave, step_k, scale=0.3, dt=self.dt)

        self.l_cycle = jp.array(kinematic_ref_qpos.shape[0])

        # Expand to entire state space.

        kinematic_ref_qpos += self._default_ap_pose
        ref_qs = np.tile(self._init_q.reshape(1, 19), (self.l_cycle, 1))
        ref_qs[:, 7:] = kinematic_ref_qpos
        self.kinematic_ref_qpos = jp.array(ref_qs)

        ref_qvels = np.zeros((self.l_cycle, 18))
        ref_qvels[:, 6:] = kinematic_ref_qvel
        self.kinematic_ref_qvel = jp.array(ref_qvels)

        # Can decrease jit time and training wall-clock time significantly.
        self.pipeline_step = jax.checkpoint(self.pipeline_step,
                                            policy=jax.checkpoint_policies.dots_with_no_batch_dims_saveable)

    def reset(self, rng: jax.Array) -> State:
        # Deterministic init

        qpos = jp.array(self._init_q)
        qvel = jp.zeros(18)

        data = self.pipeline_init(qpos, qvel)

        # Position onto ground
        pen = jp.min(data.contact.dist)
        qpos = qpos.at[2].set(qpos[2] - pen)
        data = self.pipeline_init(qpos, qvel)

        state_info = {
            'rng': rng,
            'steps': 0.0,
            'reward_tuple': {
                'reference_tracking': 0.0,
                'min_reference_tracking': 0.0,
                'feet_height': 0.0
            },
            'last_action': jp.zeros(12),  # from MJX tutorial.
            'kinematic_ref': jp.zeros(19),
        }

        x, xd = data.x, data.xd
        obs = self._get_obs(data.qpos, x, xd, state_info)
        reward, done = jp.zeros(2)
        metrics = {}
        for k in state_info['reward_tuple']:
            metrics[k] = state_info['reward_tuple'][k]
        state = State(data, obs, reward, done, metrics, state_info)
        return jax.lax.stop_gradient(state)

    def step(self, state: State, action: jax.Array) -> State:
        action = jp.clip(action, -1, 1)  # Raw action

        action = self.action_loc + (action * self.action_scale)

        data = self.pipeline_step(state.pipeline_state, action)

        ref_qpos = self.kinematic_ref_qpos[jp.array(state.info['steps'] % self.l_cycle, int)]
        ref_qvel = self.kinematic_ref_qvel[jp.array(state.info['steps'] % self.l_cycle, int)]

        # Calculate maximal coordinates
        ref_data = data.replace(qpos=ref_qpos, qvel=ref_qvel)
        ref_data = mjx.forward(self.sys, ref_data)
        ref_x, ref_xd = ref_data.x, ref_data.xd

        state.info['kinematic_ref'] = ref_qpos

        # observation data
        x, xd = data.x, data.xd
        obs = self._get_obs(data.qpos, x, xd, state.info)

        # Terminate if flipped over or fallen down.
        done = 0.0
        done = jp.where(x.pos[0, 2] < self.termination_height, 1.0, done)
        up = jp.array([0.0, 0.0, 1.0])
        done = jp.where(jp.dot(math.rotate(up, x.rot[0]), up) < 0, 1.0, done)

        # reward
        reward_tuple = {
            'reference_tracking': (
                    self._reward_reference_tracking(x, xd, ref_x, ref_xd)
                    * self.reward_config.rewards.scales.reference_tracking
            ),
            'min_reference_tracking': (
                    self._reward_min_reference_tracking(ref_qpos, ref_qvel, state)
                    * self.reward_config.rewards.scales.min_reference_tracking
            ),
            'feet_height': (
                    self._reward_feet_height(data.geom_xpos[self.feet_inds][:, 2]
                                             , ref_data.geom_xpos[self.feet_inds][:, 2])
                    * self.reward_config.rewards.scales.feet_height
            )
        }

        reward = sum(reward_tuple.values())

        # state management
        state.info['reward_tuple'] = reward_tuple
        state.info['last_action'] = action  # used for observation.

        for k in state.info['reward_tuple'].keys():
            state.metrics[k] = state.info['reward_tuple'][k]

        state = state.replace(
            pipeline_state=data, obs=obs, reward=reward,
            done=done)

        #### Reset state to reference if it gets too far
        error = (((x.pos - ref_x.pos) ** 2).sum(-1) ** 0.5).mean()
        to_reference = jp.where(error > self.err_threshold, 1.0, 0.0)

        to_reference = jp.array(to_reference, dtype=int)  # keeps output types same as input.
        ref_data = self.mjx_to_brax(ref_data)

        data = jax.tree_util.tree_map(lambda x, y:
                                      jp.array((1 - to_reference) * x + to_reference * y, x.dtype), data, ref_data)

        x, xd = data.x, data.xd  # Data may have changed.
        obs = self._get_obs(data.qpos, x, xd, state.info)

        return state.replace(pipeline_state=data, obs=obs)

    def _get_obs(self, qpos: jax.Array, x: Transform, xd: Motion,
                 state_info: Dict[str, Any]) -> jax.Array:

        inv_base_orientation = math.quat_inv(x.rot[0])
        local_rpyrate = math.rotate(xd.ang[0], inv_base_orientation)

        obs_list = []
        # yaw rate
        obs_list.append(jp.array([local_rpyrate[2]]) * 0.25)
        # projected gravity
        obs_list.append(
            math.rotate(jp.array([0.0, 0.0, -1.0]), inv_base_orientation))
        # motor angles
        angles = qpos[7:19]
        obs_list.append(angles - self._default_ap_pose)
        # last action
        obs_list.append(state_info['last_action'])
        # kinematic reference
        kin_ref = self.kinematic_ref_qpos[jp.array(state_info['steps'] % self.l_cycle, int)]
        obs_list.append(kin_ref[7:])  # First 7 indicies are fixed

        obs = jp.clip(jp.concatenate(obs_list), -100.0, 100.0)

        return obs

    def mjx_to_brax(self, data):
        """
        Apply the brax wrapper on the core MJX data structure.
        """
        q, qd = data.qpos, data.qvel
        x = Transform(pos=data.xpos[1:], rot=data.xquat[1:])
        cvel = Motion(vel=data.cvel[1:, 3:], ang=data.cvel[1:, :3])
        offset = data.xpos[1:, :] - data.subtree_com[self.sys.body_rootid[1:]]
        offset = Transform.create(pos=offset)
        xd = offset.vmap().do(cvel)
        data = _reformat_contact(self.sys, data)
        return data.replace(q=q, qd=qd, x=x, xd=xd)

    # ------------ reward functions----------------
    def _reward_reference_tracking(self, x, xd, ref_x, ref_xd):
        """
        Rewards based on inertial-frame body positions.
        Notably, we use a high-dimension representation of orientation.
        """

        f = lambda x, y: ((x - y) ** 2).sum(-1).mean()

        _mse_pos = f(x.pos, ref_x.pos)
        _mse_rot = f(quaternion_to_rotation_6d(x.rot),
                     quaternion_to_rotation_6d(ref_x.rot))
        _mse_vel = f(xd.vel, ref_xd.vel)
        _mse_ang = f(xd.ang, ref_xd.ang)

        # Tuned to be about the same size.
        return _mse_pos \
            + 0.1 * _mse_rot \
            + 0.01 * _mse_vel \
            + 0.001 * _mse_ang

    def _reward_min_reference_tracking(self, ref_qpos, ref_qvel, state):
        """
        Using minimal coordinates. Improves accuracy of joint angle tracking.
        """
        pos = jp.concatenate([
            state.pipeline_state.qpos[:3],
            state.pipeline_state.qpos[7:]])
        pos_targ = jp.concatenate([
            ref_qpos[:3],
            ref_qpos[7:]])
        pos_err = jp.linalg.norm(pos_targ - pos)
        vel_err = jp.linalg.norm(state.pipeline_state.qvel - ref_qvel)

        return pos_err + vel_err

    def _reward_feet_height(self, feet_pos, feet_pos_ref):
        return jp.sum(jp.abs(feet_pos - feet_pos_ref))  # try to drive it to 0 using the l1 norm.


envs.register_environment('trotting_go2', TrotGo2)

### Policy Rollout

In [None]:
# Each foot contacts the ground twice/sec.
env = envs.get_environment("trotting_go2", step_k=13)

In [None]:
# Reconstruct the trotting inference function
make_networks_factory = functools.partial(
    apg_networks.make_apg_networks,
    hidden_layer_sizes=(256, 128)
)

nets = make_networks_factory(observation_size=1, # Observation_size argument doesn't matter since it's only used for param init.
                             action_size=12,
                             preprocess_observations_fn=running_statistics.normalize)

make_inference_fn = apg_networks.make_inference_fn(nets)

In [None]:
model_path = '/content/trotting_2hz_policy'
params = model.load_params(model_path)

# Inference function for in-place trotting
baseline_inference_fn = make_inference_fn(params)

In [None]:
demo_env = envs.training.EpisodeWrapper(env,
                                        episode_length=1000,
                                        action_repeat=1)

render_rollout(
  "go2_trotting",
  jax.jit(demo_env.reset),
  jax.jit(demo_env.step),
  jax.jit(baseline_inference_fn),
  demo_env,
  n_steps=200,
)

### Forward Motion Environment

In [None]:
def axis_angle_to_quaternion(v: jp.ndarray, theta:jp.float_):
    """
    axis angle representation: rotation of theta around v.
    """
    return jp.concatenate([jp.cos(0.5*theta).reshape(1), jp.sin(0.5*theta)*v.reshape(3)])

def get_config():
  """Returns reward config for go2 quadruped environment."""

  def get_default_rewards_config():
    default_config = config_dict.ConfigDict(
        dict(
            scales=config_dict.ConfigDict(
                dict(
                    tracking_lin_vel = 1.0,
                    orientation = -1.0, # non-flat base
                    height = 0.5,
                    lin_vel_z=-1.0, # prevents the suicide policy
                    torque = -0.01,
                    feet_pos = -1, # Bad action hard-coding.
                    feet_height = -1, # prevents it from just standing still
                    joint_velocity = -0.001
                    )
            ),
        )
    )
    return default_config

  default_config = config_dict.ConfigDict(
      dict(rewards=get_default_rewards_config(),))

  return default_config

class FwdTrotGo2(PipelineEnv):

  def __init__(
      self,
      termination_height: float=0.25,
      **kwargs,
  ):

    self.target_vel = kwargs.pop('target_vel', 0.75)
    step_k = kwargs.pop('step_k', 25)
    self.baseline_inference_fn = kwargs.pop("baseline_inference_fn")
    physics_steps_per_control_step = 10
    kwargs['n_frames'] = kwargs.get(
        'n_frames', physics_steps_per_control_step)
    self.termination_height = termination_height

    mj_model = mujoco.MjModel.from_xml_path(xml_path)
    kp = 230
    mj_model.actuator_gainprm[:, 0] = kp
    mj_model.actuator_biasprm[:, 1] = -kp
    self._init_q = mj_model.keyframe('home').qpos
    self._default_ap_pose = mj_model.keyframe('home').qpos[7:]
    self.reward_config = get_config()

    self.action_loc = self._default_ap_pose
    self.action_scale = jp.array([0.2, 0.8, 0.8] * 4)

    self.target_h = self._init_q[2]

    sys = mjcf.load_model(mj_model)
    super().__init__(sys=sys, **kwargs)

    """
    Kinematic references are used for gait scheduling.
    """

    kinematic_ref_qpos = make_kinematic_ref(
      cos_wave, step_k, scale=0.3, dt=self.dt)
    self.l_cycle = jp.array(kinematic_ref_qpos.shape[0])
    self.kinematic_ref_qpos = jp.array(kinematic_ref_qpos + self._default_ap_pose)

    """
    Foot tracking
    """
    gait_k = step_k * 2
    self.gait_period = gait_k * self.dt

    self.step_k = step_k
    self.feet_inds = jp.array([21,28,35,42]) # LF, RF, LH, RH
    self.hip_inds = self.feet_inds - 6

    self.pipeline_step = jax.checkpoint(self.pipeline_step,
      policy=jax.checkpoint_policies.dots_with_no_batch_dims_saveable)

  def reset(self, rng: jax.Array) -> State:
    rng, key_xyz, key_ang, key_ax, key_q, key_qd = jax.random.split(rng, 6)

    qpos = jp.array(self._init_q)
    qvel = jp.zeros(18)

    #### Add Randomness ####

    r_xyz = 0.2 * (jax.random.uniform(key_xyz, (3,))-0.5)
    r_angle = (jp.pi/12) * (jax.random.uniform(key_ang, (1,)) - 0.5) # 15 deg range
    r_axis = (jax.random.uniform(key_ax, (3,)) - 0.5)
    r_axis = r_axis / jp.linalg.norm(r_axis)
    r_quat = axis_angle_to_quaternion(r_axis, r_angle)

    r_joint_q = 0.2 * (jax.random.uniform(key_q, (12,)) - 0.5)
    r_joint_qd = 0.1 * (jax.random.uniform(key_qd, (12,)) - 0.5)

    qpos = qpos.at[0:3].set(qpos[0:3] + r_xyz)
    qpos = qpos.at[3:7].set(r_quat)
    qpos = qpos.at[7:19].set(qpos[7:19] + r_joint_q)
    qvel = qvel.at[6:18].set(qvel[6:18] + r_joint_qd)

    data = self.pipeline_init(qpos, qvel)

    # Ensure you're not sunken into the ground nor above it.
    pen = jp.min(data.contact.dist)
    qpos = qpos.at[2].set(qpos[2] - pen)
    data = self.pipeline_init(qpos, qvel)

    state_info = {
        'rng': rng,
        'steps': 0.0,
        'reward_tuple': {
            'tracking_lin_vel': 0.0,
            'orientation': 0.0,
            'height': 0.0,
            'lin_vel_z': 0.0,
            'torque': 0.0,
            'joint_velocity': 0.0,
            'feet_pos': 0.0,
            'feet_height': 0.0
        },
        'last_action': jp.zeros(12), # from MJX tutorial.
        'baseline_action': jp.zeros(12),
        'xy0': jp.zeros((4, 2)),
        'k0': 0.0,
        'xy*': jp.zeros((4, 2))
    }

    x, xd = data.x, data.xd
    _obs = self._get_obs(data.qpos, x, xd, state_info) # inner obs; to trotter

    action_key, key = jax.random.split(state_info['rng'])
    state_info['rng'] = key
    next_action, _ = self.baseline_inference_fn(_obs, action_key)

    obs = jp.concatenate([_obs, next_action])

    reward, done = jp.zeros(2)
    metrics = {}
    for k in state_info['reward_tuple']:
      metrics[k] = state_info['reward_tuple'][k]
    state = State(data, obs, reward, done, metrics, state_info)
    return jax.lax.stop_gradient(state)

  def step(self, state: State, action: jax.Array) -> State:

    action = jp.clip(action, -1, 1)

    cur_base = state.obs[-12:]
    action += cur_base
    state.info['baseline_action'] = cur_base

    action = self.action_loc + (action * self.action_scale)

    data = self.pipeline_step(state.pipeline_state, action)

    # observation data
    x, xd = data.x, data.xd
    obs = self._get_obs(data.qpos, x, xd, state.info)

    # Terminate if flipped over or fallen down.
    done = 0.0
    done = jp.where(x.pos[0, 2] < self.termination_height, 1.0, done)
    up = jp.array([0.0, 0.0, 1.0])
    done = jp.where(jp.dot(math.rotate(up, x.rot[0]), up) < 0, 1.0, done)

    #### Foot Position Reference Updating ####

    # Detect the start of a new step
    s = state.info['steps']
    step_num = s // (self.step_k)
    even_step = step_num % 2 == 0
    new_step = (s % self.step_k) == 0
    new_even_step = jp.logical_and(new_step, even_step)
    new_odd_step = jp.logical_and(new_step, jp.logical_not(even_step))

    # Apply railbert heuristic to calculate target foot position, after step
    hip_xy = data.geom_xpos[self.hip_inds][:,:2] # 4 x 2
    v_body = data.qvel[0:2]
    step_period = self.gait_period/2
    raibert_xy = hip_xy + (step_period/2) * v_body

    # Update.
    cur_tars = state.info['xy*']
    i_RFLH = jp.array([1, 2])
    i_LFRH = jp.array([0, 3])
    feet_xy = data.geom_xpos[self.feet_inds][:,:2]

    # With the trotting gait, we will move one pair of opposite legs,
    # and keep the other pair fixed in place.
    case_c1 = raibert_xy.at[i_LFRH].set(feet_xy[i_LFRH])
    case_c2 = raibert_xy.at[i_RFLH].set(feet_xy[i_RFLH])
    xy_tars = jp.where(new_even_step, case_c1, cur_tars)
    xy_tars = jp.where(new_odd_step, case_c2, xy_tars)
    state.info['xy*'] = xy_tars

    # Save timestep and location at start of step.
    state.info['k0'] = jp.where(new_step,
                                state.info['steps'],
                                state.info['k0'])
    state.info['xy0'] = jp.where(new_step,
                                 feet_xy,
                                 state.info['xy0'])

    # reward
    reward_tuple = {
        'tracking_lin_vel': (
            self._reward_tracking_lin_vel(jp.array([self.target_vel, 0, 0]), x, xd)
            * self.reward_config.rewards.scales.tracking_lin_vel
        ),
        'orientation': (
          self._reward_orientation(x)
          * self.reward_config.rewards.scales.orientation
        ),
        'lin_vel_z': (
            self._reward_lin_vel_z(xd)
            * self.reward_config.rewards.scales.lin_vel_z
        ),
        'height': (
          self._reward_height(data.qpos)
          * self.reward_config.rewards.scales.height
        ),
        'torque': (
          self._reward_action(data.qfrc_actuator)
          * self.reward_config.rewards.scales.torque
        ),
        'joint_velocity': (
          self._reward_joint_velocity(data.qvel)
          * self.reward_config.rewards.scales.joint_velocity
        ),
        'feet_pos': (
          self._reward_feet_pos(data, state)
          * self.reward_config.rewards.scales.feet_pos
        ),
        'feet_height': (
          self._reward_feet_height(data, state.info)
          * self.reward_config.rewards.scales.feet_height
        )
    }

    reward = sum(reward_tuple.values())

    # state management
    state.info['reward_tuple'] = reward_tuple
    state.info['last_action'] = action

    for k in state.info['reward_tuple'].keys():
      state.metrics[k] = state.info['reward_tuple'][k]

    # next action
    action_key, key = jax.random.split(state.info['rng'])
    state.info['rng'] = key
    next_action, _ = self.baseline_inference_fn(obs, action_key)
    obs = jp.concatenate([obs, next_action])

    state = state.replace(
        pipeline_state=data, obs=obs, reward=reward,
        done=done)
    return state

  def _get_obs(self, qpos: jax.Array, x: Transform, xd: Motion,
               state_info: Dict[str, Any]) -> jax.Array:

    inv_base_orientation = math.quat_inv(x.rot[0])
    local_rpyrate = math.rotate(xd.ang[0], inv_base_orientation)

    obs_list = []
    # yaw rate
    obs_list.append(jp.array([local_rpyrate[2]]) * 0.25)
    # projected gravity
    obs_list.append(
        math.rotate(jp.array([0.0, 0.0, -1.0]), inv_base_orientation))
    # motor angles
    angles = qpos[7:19]
    obs_list.append(angles - self._default_ap_pose)
    # last action
    obs_list.append(state_info['last_action'])
    # gait schedule
    kin_ref = self.kinematic_ref_qpos[jp.array(state_info['steps']%self.l_cycle, int)]
    obs_list.append(kin_ref)

    obs = jp.clip(jp.concatenate(obs_list), -100.0, 100.0)

    return obs

  # ------------ reward functions----------------
  def _reward_tracking_lin_vel(
      self, commands: jax.Array, x: Transform, xd: Motion) -> jax.Array:
    # Tracking of linear velocity commands (xy axes)
    local_vel = math.rotate(xd.vel[0], math.quat_inv(x.rot[0]))
    lin_vel_error = jp.sum(jp.square(commands[:2] - local_vel[:2]))
    lin_vel_reward = jp.exp(-lin_vel_error)
    return lin_vel_reward
  def _reward_orientation(self, x: Transform) -> jax.Array:
    # Penalize non flat base orientation
    up = jp.array([0.0, 0.0, 1.0])
    rot_up = math.rotate(up, x.rot[0])
    return jp.sum(jp.square(rot_up[:2]))
  def _reward_lin_vel_z(self, xd: Motion) -> jax.Array:
    # Penalize z axis base linear velocity
    return jp.clip(jp.square(xd.vel[0, 2]), 0, 10)
  def _reward_joint_velocity(self, qvel):
      return jp.clip(jp.sqrt(jp.sum(jp.square(qvel[6:]))), 0, 100)
  def _reward_height(self, qpos) -> jax.Array:
    return jp.exp(-jp.abs(qpos[2] - self.target_h)) # Not going to be > 1 meter tall.
  def _reward_action(self, action) -> jax.Array:
    return jp.sqrt(jp.sum(jp.square(action)))
  def _reward_feet_pos(self, data, state):
    dt = (state.info['steps'] - state.info['k0']) * self.dt # scalar
    step_period = self.gait_period / 2
    xyt = state.info['xy0'] + (state.info['xy*'] - state.info['xy0']) * (dt/step_period)

    feet_pos = data.geom_xpos[self.feet_inds][:, :2]

    rews = jp.sum(jp.square(feet_pos - xyt), axis=1)
    rews = jp.clip(rews, 0, 10)
    return jp.sum(rews)
  def _reward_feet_height(self, data, state_info):
    """
    Feet height tracks rectified sine waves
    """
    h_tar = 0.1
    t = state_info['steps'] * self.dt
    offset = self.gait_period/2
    ref1 = jp.sin((2*jp.pi/self.gait_period)*t) # RF and LH feet
    ref2 = jp.sin((2*jp.pi/self.gait_period)*(t - offset)) # LF and RH

    ref1, ref2 = ref1 * h_tar, ref2 * h_tar
    h_tars = jp.array([ref2, ref1, ref1, ref2])
    h_tars = h_tars.clip(min=0, max=None) + 0.02 # offset height of feet.

    feet_height = data.geom_xpos[self.feet_inds][:,2]
    errs = jp.clip(jp.square(feet_height - h_tars), 0, 10)
    return jp.sum(errs)

envs.register_environment('go2', FwdTrotGo2)

### Forward Motion Rollout

In [None]:
# Construct Environment
env_kwargs = dict(target_vel=0.75, step_k=13,
                  baseline_inference_fn=baseline_inference_fn)

env = envs.get_environment("go2", **env_kwargs)

In [None]:
# Reconstruct the locomotion inference function
make_networks_factory = functools.partial(
    apg_networks.make_apg_networks,
    hidden_layer_sizes=(128, 64)
)

nets = make_networks_factory(observation_size=1, # Observation_size argument doesn't matter since it's only used for param init.
                             action_size=12,
                             preprocess_observations_fn=running_statistics.normalize)

make_inference_fn = apg_networks.make_inference_fn(nets)

In [None]:
model_path = '/content/running_2hz_policy'
params = model.load_params(model_path)

# Inference function for in-place trotting
locomotion_inference_fn = make_inference_fn(params)

In [None]:
demo_env = envs.training.EpisodeWrapper(env,
                                        episode_length=1000,
                                        action_repeat=1)

render_rollout(
  "forward_trot",
  jax.jit(demo_env.reset),
  jax.jit(demo_env.step),
  jax.jit(locomotion_inference_fn),
  demo_env,
  n_steps=200,
  camera="track"
)

## Unitree Go2 PPO

In [None]:
#@title Unitree Go2 Quadruped Env
menagerie_path = '/content/mujoco_menagerie'
GO2_ROOT_PATH = epath.Path(menagerie_path + '/unitree_go2')


def get_config():
  """Returns reward config for barkour quadruped environment."""

  def get_default_rewards_config():
    default_config = config_dict.ConfigDict(
        dict(
            # The coefficients for all reward terms used for training. All
            # physical quantities are in SI units, if no otherwise specified,
            # i.e. joint positions are in rad, positions are measured in meters,
            # torques in Nm, and time in seconds, and forces in Newtons.
            scales=config_dict.ConfigDict(
                dict(
                    # Tracking rewards are computed using exp(-delta^2/sigma)
                    # sigma can be a hyperparameters to tune.
                    # Track the base x-y velocity (no z-velocity tracking.)
                    tracking_lin_vel=1.5,
                    # Track the angular velocity along z-axis, i.e. yaw rate.
                    tracking_ang_vel=0.8,
                    # Below are regularization terms, we roughly divide the
                    # terms to base state regularizations, joint
                    # regularizations, and other behavior regularizations.
                    # Penalize the base velocity in z direction, L2 penalty.
                    lin_vel_z=-2.0,
                    # Penalize the base roll and pitch rate. L2 penalty.
                    ang_vel_xy=-0.05,
                    # Penalize non-zero roll and pitch angles. L2 penalty.
                    orientation=-5.0,
                    # L2 regularization of joint torques, |tau|^2.
                    torques=-0.0002,
                    # Penalize the change in the action and encourage smooth
                    # actions. L2 regularization |action - last_action|^2
                    action_rate=-0.01,
                    # Encourage long swing steps.  However, it does not
                    # encourage high clearances.
                    feet_air_time=0.2,
                    # Encourage no motion at zero command, L2 regularization
                    # |q - q_default|^2.
                    stand_still=-0.5,
                    # Early termination penalty.
                    termination=-1.0,
                    # Penalizing foot slipping on the ground.
                    foot_slip=-0.1,
                )
            ),
            # Tracking reward = exp(-error^2/sigma).
            tracking_sigma=0.25,
        )
    )
    return default_config

  default_config = config_dict.ConfigDict(
      dict(
          rewards=get_default_rewards_config(),
      )
  )

  return default_config


class Go2JoystickEnv(PipelineEnv):
  """Environment for training the go2 quadruped joystick policy in MJX."""

  def __init__(
      self,
      obs_noise: float = 0.05,
      action_scale: float = 0.3,
      kick_vel: float = 0.05,
      scene_file: str = 'scene_mjx.xml',
      **kwargs,
  ):
    path = GO2_ROOT_PATH / scene_file
    sys = mjcf.load(path.as_posix())
    self._dt = 0.02  # this environment is 50 fps
    sys = sys.tree_replace({'opt.timestep': 0.004})

    # override menagerie params for smoother policy
    sys = sys.replace(
        dof_damping=sys.dof_damping.at[6:].set(0.5239),
        actuator_gainprm=sys.actuator_gainprm.at[:, 0].set(35.0),
        actuator_biasprm=sys.actuator_biasprm.at[:, 1].set(-35.0),
    )

    n_frames = kwargs.pop('n_frames', int(self._dt / sys.opt.timestep))
    super().__init__(sys, backend='mjx', n_frames=n_frames)

    self.reward_config = get_config()
    # set custom from kwargs
    for k, v in kwargs.items():
      if k.endswith('_scale'):
        self.reward_config.rewards.scales[k[:-6]] = v

    self._torso_idx = mujoco.mj_name2id(
        sys.mj_model, mujoco.mjtObj.mjOBJ_BODY.value, 'base'
    )
    self._action_scale = action_scale
    self._obs_noise = obs_noise
    self._kick_vel = kick_vel
    self._init_q = jp.array(sys.mj_model.keyframe('home').qpos)
    self._default_pose = sys.mj_model.keyframe('home').qpos[7:]
    self.lowers = jp.array([-0.7, -1.0, -2.3] * 4)
    self.uppers = jp.array([0.52, 2.1, -1] * 4)
    feet_site = [
        'FL_foot',
        'RL_foot',
        'FR_foot',
        'RR_foot',
    ]
    feet_site_id = [
        mujoco.mj_name2id(sys.mj_model, mujoco.mjtObj.mjOBJ_SITE.value, f)
        for f in feet_site
    ]
    assert not any(id_ == -1 for id_ in feet_site_id), 'Site not found.'
    self._feet_site_id = np.array(feet_site_id)
    lower_leg_body = [
        'FL_calf',
        'RL_calf',
        'FR_calf',
        'RR_calf',
    ]
    lower_leg_body_id = [
        mujoco.mj_name2id(sys.mj_model, mujoco.mjtObj.mjOBJ_BODY.value, l)
        for l in lower_leg_body
    ]
    assert not any(id_ == -1 for id_ in lower_leg_body_id), 'Body not found.'
    self._lower_leg_body_id = np.array(lower_leg_body_id)
    self._foot_radius = 0.0175
    self._nv = sys.nv

  def sample_command(self, rng: jax.Array) -> jax.Array:
    lin_vel_x = [-0.6, 1.5]  # min max [m/s]
    lin_vel_y = [-0.8, 0.8]  # min max [m/s]
    ang_vel_yaw = [-0.7, 0.7]  # min max [rad/s]

    _, key1, key2, key3 = jax.random.split(rng, 4)
    lin_vel_x = jax.random.uniform(
        key1, (1,), minval=lin_vel_x[0], maxval=lin_vel_x[1]
    )
    lin_vel_y = jax.random.uniform(
        key2, (1,), minval=lin_vel_y[0], maxval=lin_vel_y[1]
    )
    ang_vel_yaw = jax.random.uniform(
        key3, (1,), minval=ang_vel_yaw[0], maxval=ang_vel_yaw[1]
    )
    new_cmd = jp.array([lin_vel_x[0], lin_vel_y[0], ang_vel_yaw[0]])
    return new_cmd

  def reset(self, rng: jax.Array) -> State:  # pytype: disable=signature-mismatch
    rng, key = jax.random.split(rng)

    pipeline_state = self.pipeline_init(self._init_q, jp.zeros(self._nv))

    state_info = {
        'rng': rng,
        'last_act': jp.zeros(12),
        'last_vel': jp.zeros(12),
        'command': self.sample_command(key),
        'last_contact': jp.zeros(4, dtype=bool),
        'feet_air_time': jp.zeros(4),
        'rewards': {k: 0.0 for k in self.reward_config.rewards.scales.keys()},
        'kick': jp.array([0.0, 0.0]),
        'step': 0,
    }

    obs_history = jp.zeros(15 * 31)  # store 15 steps of history
    obs = self._get_obs(pipeline_state, state_info, obs_history)
    reward, done = jp.zeros(2)
    metrics = {'total_dist': 0.0}
    for k in state_info['rewards']:
      metrics[k] = state_info['rewards'][k]
    state = State(pipeline_state, obs, reward, done, metrics, state_info)  # pytype: disable=wrong-arg-types
    return state

  def step(self, state: State, action: jax.Array) -> State:  # pytype: disable=signature-mismatch
    rng, cmd_rng, kick_noise_2 = jax.random.split(state.info['rng'], 3)

    # kick
    push_interval = 10
    kick_theta = jax.random.uniform(kick_noise_2, maxval=2 * jp.pi)
    kick = jp.array([jp.cos(kick_theta), jp.sin(kick_theta)])
    kick *= jp.mod(state.info['step'], push_interval) == 0
    qvel = state.pipeline_state.qvel  # pytype: disable=attribute-error
    qvel = qvel.at[:2].set(kick * self._kick_vel + qvel[:2])
    state = state.tree_replace({'pipeline_state.qvel': qvel})

    # physics step
    motor_targets = self._default_pose + action * self._action_scale
    motor_targets = jp.clip(motor_targets, self.lowers, self.uppers)
    pipeline_state = self.pipeline_step(state.pipeline_state, motor_targets)
    x, xd = pipeline_state.x, pipeline_state.xd

    # observation data
    obs = self._get_obs(pipeline_state, state.info, state.obs)
    joint_angles = pipeline_state.q[7:]
    joint_vel = pipeline_state.qd[6:]

    # foot contact data based on z-position
    foot_pos = pipeline_state.site_xpos[self._feet_site_id]  # pytype: disable=attribute-error
    foot_contact_z = foot_pos[:, 2] - self._foot_radius
    contact = foot_contact_z < 1e-3  # a mm or less off the floor
    contact_filt_mm = contact | state.info['last_contact']
    contact_filt_cm = (foot_contact_z < 3e-2) | state.info['last_contact']
    first_contact = (state.info['feet_air_time'] > 0) * contact_filt_mm
    state.info['feet_air_time'] += self.dt

    # done if joint limits are reached or robot is falling
    up = jp.array([0.0, 0.0, 1.0])
    done = jp.dot(math.rotate(up, x.rot[self._torso_idx - 1]), up) < 0
    done |= jp.any(joint_angles < self.lowers)
    done |= jp.any(joint_angles > self.uppers)
    done |= pipeline_state.x.pos[self._torso_idx - 1, 2] < 0.18

    # reward
    rewards = {
        'tracking_lin_vel': (
            self._reward_tracking_lin_vel(state.info['command'], x, xd)
        ),
        'tracking_ang_vel': (
            self._reward_tracking_ang_vel(state.info['command'], x, xd)
        ),
        'lin_vel_z': self._reward_lin_vel_z(xd),
        'ang_vel_xy': self._reward_ang_vel_xy(xd),
        'orientation': self._reward_orientation(x),
        'torques': self._reward_torques(pipeline_state.qfrc_actuator),  # pytype: disable=attribute-error
        'action_rate': self._reward_action_rate(action, state.info['last_act']),
        'stand_still': self._reward_stand_still(
            state.info['command'], joint_angles,
        ),
        'feet_air_time': self._reward_feet_air_time(
            state.info['feet_air_time'],
            first_contact,
            state.info['command'],
        ),
        'foot_slip': self._reward_foot_slip(pipeline_state, contact_filt_cm),
        'termination': self._reward_termination(done, state.info['step']),
    }
    rewards = {
        k: v * self.reward_config.rewards.scales[k] for k, v in rewards.items()
    }
    reward = jp.clip(sum(rewards.values()) * self.dt, 0.0, 10000.0)

    # state management
    state.info['kick'] = kick
    state.info['last_act'] = action
    state.info['last_vel'] = joint_vel
    state.info['feet_air_time'] *= ~contact_filt_mm
    state.info['last_contact'] = contact
    state.info['rewards'] = rewards
    state.info['step'] += 1
    state.info['rng'] = rng

    # sample new command if more than 500 timesteps achieved
    state.info['command'] = jp.where(
        state.info['step'] > 500,
        self.sample_command(cmd_rng),
        state.info['command'],
    )
    # reset the step counter when done
    state.info['step'] = jp.where(
        done | (state.info['step'] > 500), 0, state.info['step']
    )

    # log total displacement as a proxy metric
    state.metrics['total_dist'] = math.normalize(x.pos[self._torso_idx - 1])[1]
    state.metrics.update(state.info['rewards'])

    done = jp.float32(done)
    state = state.replace(
        pipeline_state=pipeline_state, obs=obs, reward=reward, done=done
    )
    return state

  def _get_obs(
      self,
      pipeline_state: base.State,
      state_info: dict[str, Any],
      obs_history: jax.Array,
  ) -> jax.Array:
    inv_torso_rot = math.quat_inv(pipeline_state.x.rot[0])
    local_rpyrate = math.rotate(pipeline_state.xd.ang[0], inv_torso_rot)

    obs = jp.concatenate([
        jp.array([local_rpyrate[2]]) * 0.25,                 # yaw rate
        math.rotate(jp.array([0, 0, -1]), inv_torso_rot),    # projected gravity
        state_info['command'] * jp.array([2.0, 2.0, 0.25]),  # command
        pipeline_state.q[7:] - self._default_pose,           # motor angles
        state_info['last_act'],                              # last action
    ])

    # clip, noise
    obs = jp.clip(obs, -100.0, 100.0) + self._obs_noise * jax.random.uniform(
        state_info['rng'], obs.shape, minval=-1, maxval=1
    )
    # stack observations through time
    obs = jp.roll(obs_history, obs.size).at[:obs.size].set(obs)

    return obs

  # ------------ reward functions----------------
  def _reward_lin_vel_z(self, xd: Motion) -> jax.Array:
    # Penalize z axis base linear velocity
    return jp.square(xd.vel[0, 2])

  def _reward_ang_vel_xy(self, xd: Motion) -> jax.Array:
    # Penalize xy axes base angular velocity
    return jp.sum(jp.square(xd.ang[0, :2]))

  def _reward_orientation(self, x: Transform) -> jax.Array:
    # Penalize non flat base orientation
    up = jp.array([0.0, 0.0, 1.0])
    rot_up = math.rotate(up, x.rot[0])
    return jp.sum(jp.square(rot_up[:2]))

  def _reward_torques(self, torques: jax.Array) -> jax.Array:
    # Penalize torques
    return jp.sqrt(jp.sum(jp.square(torques))) + jp.sum(jp.abs(torques))

  def _reward_action_rate(
      self, act: jax.Array, last_act: jax.Array
  ) -> jax.Array:
    # Penalize changes in actions
    return jp.sum(jp.square(act - last_act))

  def _reward_tracking_lin_vel(
      self, commands: jax.Array, x: Transform, xd: Motion
  ) -> jax.Array:
    # Tracking of linear velocity commands (xy axes)
    local_vel = math.rotate(xd.vel[0], math.quat_inv(x.rot[0]))
    lin_vel_error = jp.sum(jp.square(commands[:2] - local_vel[:2]))
    lin_vel_reward = jp.exp(
        -lin_vel_error / self.reward_config.rewards.tracking_sigma
    )
    return lin_vel_reward

  def _reward_tracking_ang_vel(
      self, commands: jax.Array, x: Transform, xd: Motion
  ) -> jax.Array:
    # Tracking of angular velocity commands (yaw)
    base_ang_vel = math.rotate(xd.ang[0], math.quat_inv(x.rot[0]))
    ang_vel_error = jp.square(commands[2] - base_ang_vel[2])
    return jp.exp(-ang_vel_error / self.reward_config.rewards.tracking_sigma)

  def _reward_feet_air_time(
      self, air_time: jax.Array, first_contact: jax.Array, commands: jax.Array
  ) -> jax.Array:
    # Reward air time.
    rew_air_time = jp.sum((air_time - 0.1) * first_contact)
    rew_air_time *= (
        math.normalize(commands[:2])[1] > 0.05
    )  # no reward for zero command
    return rew_air_time

  def _reward_stand_still(
      self,
      commands: jax.Array,
      joint_angles: jax.Array,
  ) -> jax.Array:
    # Penalize motion at zero commands
    return jp.sum(jp.abs(joint_angles - self._default_pose)) * (
        math.normalize(commands[:2])[1] < 0.1
    )

  def _reward_foot_slip(
      self, pipeline_state: base.State, contact_filt: jax.Array
  ) -> jax.Array:
    # get velocities at feet which are offset from lower legs
    # pytype: disable=attribute-error
    pos = pipeline_state.site_xpos[self._feet_site_id]  # feet position
    feet_offset = pos - pipeline_state.xpos[self._lower_leg_body_id]
    # pytype: enable=attribute-error
    offset = base.Transform.create(pos=feet_offset)
    foot_indices = self._lower_leg_body_id - 1  # we got rid of the world body
    foot_vel = offset.vmap().do(pipeline_state.xd.take(foot_indices)).vel

    # Penalize large feet velocity for feet that are in contact with the ground.
    return jp.sum(jp.square(foot_vel[:, :2]) * contact_filt.reshape((-1, 1)))

  def _reward_termination(self, done: jax.Array, step: jax.Array) -> jax.Array:
    return done & (step < 500)

  def render(
      self, trajectory: List[base.State], camera: str | None = None,
      width: int = 240, height: int = 320,
  ) -> Sequence[np.ndarray]:
    camera = camera or 'track'
    return super().render(trajectory, camera=camera, width=width, height=height)

envs.register_environment('joystick_go2', Go2JoystickEnv)

### Rollout

In [None]:
env_name = 'joystick_go2'
env = envs.get_environment(env_name)

In [None]:
# Reconstruct the locomotion inference function
make_networks_factory = functools.partial(
    ppo_networks.make_ppo_networks,
    policy_hidden_layer_sizes=(128, 128, 128, 128)
)

nets = make_networks_factory(observation_size=1, # Observation_size argument doesn't matter since it's only used for param init.
                             action_size=12,
                             preprocess_observations_fn=running_statistics.normalize)

make_inference_fn = apg_networks.make_inference_fn(nets)

In [None]:
model_path = '/content/mjx_brax_go2_policy'
params = model.load_params(model_path)

# Inference function for in-place trotting
ppo_inference_fn = make_inference_fn(params)

In [None]:
# @markdown Commands **only used for Barkour Env**:
x_vel = 2.0  #@param {type: "number"}
y_vel = 0.0  #@param {type: "number"}
ang_vel = 0.0  #@param {type: "number"}

the_command = jp.array([x_vel, y_vel, ang_vel])

In [None]:
demo_env = envs.training.EpisodeWrapper(env,
                                        episode_length=1000,
                                        action_repeat=1)

render_rollout(
  "ppo_run",
  jax.jit(demo_env.reset),
  jax.jit(demo_env.step),
  jax.jit(ppo_inference_fn),
  demo_env,
  n_steps=500,
  the_command=the_command,
  camera="track"
)

## Barkour vb PPO

### Environment

In [None]:
#@title Barkour vb Quadruped Env
menagerie_path = '/content/mujoco_menagerie'
BARKOUR_ROOT_PATH = epath.Path(menagerie_path + '/google_barkour_vb')

def get_config():
  """Returns reward config for barkour quadruped environment."""

  def get_default_rewards_config():
    default_config = config_dict.ConfigDict(
        dict(
            # The coefficients for all reward terms used for training. All
            # physical quantities are in SI units, if no otherwise specified,
            # i.e. joint positions are in rad, positions are measured in meters,
            # torques in Nm, and time in seconds, and forces in Newtons.
            scales=config_dict.ConfigDict(
                dict(
                    # Tracking rewards are computed using exp(-delta^2/sigma)
                    # sigma can be a hyperparameters to tune.
                    # Track the base x-y velocity (no z-velocity tracking.)
                    tracking_lin_vel=1.5,
                    # Track the angular velocity along z-axis, i.e. yaw rate.
                    tracking_ang_vel=0.8,
                    # Below are regularization terms, we roughly divide the
                    # terms to base state regularizations, joint
                    # regularizations, and other behavior regularizations.
                    # Penalize the base velocity in z direction, L2 penalty.
                    lin_vel_z=-2.0,
                    # Penalize the base roll and pitch rate. L2 penalty.
                    ang_vel_xy=-0.05,
                    # Penalize non-zero roll and pitch angles. L2 penalty.
                    orientation=-5.0,
                    # L2 regularization of joint torques, |tau|^2.
                    torques=-0.0002,
                    # Penalize the change in the action and encourage smooth
                    # actions. L2 regularization |action - last_action|^2
                    action_rate=-0.01,
                    # Encourage long swing steps.  However, it does not
                    # encourage high clearances.
                    feet_air_time=0.2,
                    # Encourage no motion at zero command, L2 regularization
                    # |q - q_default|^2.
                    stand_still=-0.5,
                    # Early termination penalty.
                    termination=-1.0,
                    # Penalizing foot slipping on the ground.
                    foot_slip=-0.1,
                )
            ),
            # Tracking reward = exp(-error^2/sigma).
            tracking_sigma=0.25,
        )
    )
    return default_config

  default_config = config_dict.ConfigDict(
      dict(
          rewards=get_default_rewards_config(),
      )
  )

  return default_config


class BarkourEnv(PipelineEnv):
  """Environment for training the barkour quadruped joystick policy in MJX."""

  def __init__(
      self,
      obs_noise: float = 0.05,
      action_scale: float = 0.3,
      kick_vel: float = 0.05,
      scene_file: str = 'scene_mjx.xml',
      **kwargs,
  ):
    path = BARKOUR_ROOT_PATH / scene_file
    sys = mjcf.load(path.as_posix())
    self._dt = 0.02  # this environment is 50 fps
    sys = sys.tree_replace({'opt.timestep': 0.004})

    # override menagerie params for smoother policy
    sys = sys.replace(
        dof_damping=sys.dof_damping.at[6:].set(0.5239),
        actuator_gainprm=sys.actuator_gainprm.at[:, 0].set(35.0),
        actuator_biasprm=sys.actuator_biasprm.at[:, 1].set(-35.0),
    )

    n_frames = kwargs.pop('n_frames', int(self._dt / sys.opt.timestep))
    super().__init__(sys, backend='mjx', n_frames=n_frames)

    self.reward_config = get_config()
    # set custom from kwargs
    for k, v in kwargs.items():
      if k.endswith('_scale'):
        self.reward_config.rewards.scales[k[:-6]] = v

    self._torso_idx = mujoco.mj_name2id(
        sys.mj_model, mujoco.mjtObj.mjOBJ_BODY.value, 'torso'
    )
    self._action_scale = action_scale
    self._obs_noise = obs_noise
    self._kick_vel = kick_vel
    self._init_q = jp.array(sys.mj_model.keyframe('home').qpos)
    self._default_pose = sys.mj_model.keyframe('home').qpos[7:]
    self.lowers = jp.array([-0.7, -1.0, 0.05] * 4)
    self.uppers = jp.array([0.52, 2.1, 2.1] * 4)
    feet_site = [
        'foot_front_left',
        'foot_hind_left',
        'foot_front_right',
        'foot_hind_right',
    ]
    feet_site_id = [
        mujoco.mj_name2id(sys.mj_model, mujoco.mjtObj.mjOBJ_SITE.value, f)
        for f in feet_site
    ]
    assert not any(id_ == -1 for id_ in feet_site_id), 'Site not found.'
    self._feet_site_id = np.array(feet_site_id)
    lower_leg_body = [
        'lower_leg_front_left',
        'lower_leg_hind_left',
        'lower_leg_front_right',
        'lower_leg_hind_right',
    ]
    lower_leg_body_id = [
        mujoco.mj_name2id(sys.mj_model, mujoco.mjtObj.mjOBJ_BODY.value, l)
        for l in lower_leg_body
    ]
    assert not any(id_ == -1 for id_ in lower_leg_body_id), 'Body not found.'
    self._lower_leg_body_id = np.array(lower_leg_body_id)
    self._foot_radius = 0.0175
    self._nv = sys.nv

  def sample_command(self, rng: jax.Array) -> jax.Array:
    lin_vel_x = [-0.6, 1.5]  # min max [m/s]
    lin_vel_y = [-0.8, 0.8]  # min max [m/s]
    ang_vel_yaw = [-0.7, 0.7]  # min max [rad/s]

    _, key1, key2, key3 = jax.random.split(rng, 4)
    lin_vel_x = jax.random.uniform(
        key1, (1,), minval=lin_vel_x[0], maxval=lin_vel_x[1]
    )
    lin_vel_y = jax.random.uniform(
        key2, (1,), minval=lin_vel_y[0], maxval=lin_vel_y[1]
    )
    ang_vel_yaw = jax.random.uniform(
        key3, (1,), minval=ang_vel_yaw[0], maxval=ang_vel_yaw[1]
    )
    new_cmd = jp.array([lin_vel_x[0], lin_vel_y[0], ang_vel_yaw[0]])
    return new_cmd

  def reset(self, rng: jax.Array) -> State:  # pytype: disable=signature-mismatch
    rng, key = jax.random.split(rng)

    pipeline_state = self.pipeline_init(self._init_q, jp.zeros(self._nv))

    state_info = {
        'rng': rng,
        'last_act': jp.zeros(12),
        'last_vel': jp.zeros(12),
        'command': self.sample_command(key),
        'last_contact': jp.zeros(4, dtype=bool),
        'feet_air_time': jp.zeros(4),
        'rewards': {k: 0.0 for k in self.reward_config.rewards.scales.keys()},
        'kick': jp.array([0.0, 0.0]),
        'step': 0,
    }

    obs_history = jp.zeros(15 * 31)  # store 15 steps of history
    obs = self._get_obs(pipeline_state, state_info, obs_history)
    reward, done = jp.zeros(2)
    metrics = {'total_dist': 0.0}
    for k in state_info['rewards']:
      metrics[k] = state_info['rewards'][k]
    state = State(pipeline_state, obs, reward, done, metrics, state_info)  # pytype: disable=wrong-arg-types
    return state

  def step(self, state: State, action: jax.Array) -> State:  # pytype: disable=signature-mismatch
    rng, cmd_rng, kick_noise_2 = jax.random.split(state.info['rng'], 3)

    # kick
    push_interval = 10
    kick_theta = jax.random.uniform(kick_noise_2, maxval=2 * jp.pi)
    kick = jp.array([jp.cos(kick_theta), jp.sin(kick_theta)])
    kick *= jp.mod(state.info['step'], push_interval) == 0
    qvel = state.pipeline_state.qvel  # pytype: disable=attribute-error
    qvel = qvel.at[:2].set(kick * self._kick_vel + qvel[:2])
    state = state.tree_replace({'pipeline_state.qvel': qvel})

    # physics step
    motor_targets = self._default_pose + action * self._action_scale
    motor_targets = jp.clip(motor_targets, self.lowers, self.uppers)
    pipeline_state = self.pipeline_step(state.pipeline_state, motor_targets)
    x, xd = pipeline_state.x, pipeline_state.xd

    # observation data
    obs = self._get_obs(pipeline_state, state.info, state.obs)
    joint_angles = pipeline_state.q[7:]
    joint_vel = pipeline_state.qd[6:]

    # foot contact data based on z-position
    foot_pos = pipeline_state.site_xpos[self._feet_site_id]  # pytype: disable=attribute-error
    foot_contact_z = foot_pos[:, 2] - self._foot_radius
    contact = foot_contact_z < 1e-3  # a mm or less off the floor
    contact_filt_mm = contact | state.info['last_contact']
    contact_filt_cm = (foot_contact_z < 3e-2) | state.info['last_contact']
    first_contact = (state.info['feet_air_time'] > 0) * contact_filt_mm
    state.info['feet_air_time'] += self.dt

    # done if joint limits are reached or robot is falling
    up = jp.array([0.0, 0.0, 1.0])
    done = jp.dot(math.rotate(up, x.rot[self._torso_idx - 1]), up) < 0
    done |= jp.any(joint_angles < self.lowers)
    done |= jp.any(joint_angles > self.uppers)
    done |= pipeline_state.x.pos[self._torso_idx - 1, 2] < 0.18

    # reward
    rewards = {
        'tracking_lin_vel': (
            self._reward_tracking_lin_vel(state.info['command'], x, xd)
        ),
        'tracking_ang_vel': (
            self._reward_tracking_ang_vel(state.info['command'], x, xd)
        ),
        'lin_vel_z': self._reward_lin_vel_z(xd),
        'ang_vel_xy': self._reward_ang_vel_xy(xd),
        'orientation': self._reward_orientation(x),
        'torques': self._reward_torques(pipeline_state.qfrc_actuator),  # pytype: disable=attribute-error
        'action_rate': self._reward_action_rate(action, state.info['last_act']),
        'stand_still': self._reward_stand_still(
            state.info['command'], joint_angles,
        ),
        'feet_air_time': self._reward_feet_air_time(
            state.info['feet_air_time'],
            first_contact,
            state.info['command'],
        ),
        'foot_slip': self._reward_foot_slip(pipeline_state, contact_filt_cm),
        'termination': self._reward_termination(done, state.info['step']),
    }
    rewards = {
        k: v * self.reward_config.rewards.scales[k] for k, v in rewards.items()
    }
    reward = jp.clip(sum(rewards.values()) * self.dt, 0.0, 10000.0)

    # state management
    state.info['kick'] = kick
    state.info['last_act'] = action
    state.info['last_vel'] = joint_vel
    state.info['feet_air_time'] *= ~contact_filt_mm
    state.info['last_contact'] = contact
    state.info['rewards'] = rewards
    state.info['step'] += 1
    state.info['rng'] = rng

    # sample new command if more than 500 timesteps achieved
    state.info['command'] = jp.where(
        state.info['step'] > 500,
        self.sample_command(cmd_rng),
        state.info['command'],
    )
    # reset the step counter when done
    state.info['step'] = jp.where(
        done | (state.info['step'] > 500), 0, state.info['step']
    )

    # log total displacement as a proxy metric
    state.metrics['total_dist'] = math.normalize(x.pos[self._torso_idx - 1])[1]
    state.metrics.update(state.info['rewards'])

    done = jp.float32(done)
    state = state.replace(
        pipeline_state=pipeline_state, obs=obs, reward=reward, done=done
    )
    return state

  def _get_obs(
      self,
      pipeline_state: base.State,
      state_info: dict[str, Any],
      obs_history: jax.Array,
  ) -> jax.Array:
    inv_torso_rot = math.quat_inv(pipeline_state.x.rot[0])
    local_rpyrate = math.rotate(pipeline_state.xd.ang[0], inv_torso_rot)

    obs = jp.concatenate([
        jp.array([local_rpyrate[2]]) * 0.25,                 # yaw rate
        math.rotate(jp.array([0, 0, -1]), inv_torso_rot),    # projected gravity
        state_info['command'] * jp.array([2.0, 2.0, 0.25]),  # command
        pipeline_state.q[7:] - self._default_pose,           # motor angles
        state_info['last_act'],                              # last action
    ])

    # clip, noise
    obs = jp.clip(obs, -100.0, 100.0) + self._obs_noise * jax.random.uniform(
        state_info['rng'], obs.shape, minval=-1, maxval=1
    )
    # stack observations through time
    obs = jp.roll(obs_history, obs.size).at[:obs.size].set(obs)

    return obs

  # ------------ reward functions----------------
  def _reward_lin_vel_z(self, xd: Motion) -> jax.Array:
    # Penalize z axis base linear velocity
    return jp.square(xd.vel[0, 2])

  def _reward_ang_vel_xy(self, xd: Motion) -> jax.Array:
    # Penalize xy axes base angular velocity
    return jp.sum(jp.square(xd.ang[0, :2]))

  def _reward_orientation(self, x: Transform) -> jax.Array:
    # Penalize non flat base orientation
    up = jp.array([0.0, 0.0, 1.0])
    rot_up = math.rotate(up, x.rot[0])
    return jp.sum(jp.square(rot_up[:2]))

  def _reward_torques(self, torques: jax.Array) -> jax.Array:
    # Penalize torques
    return jp.sqrt(jp.sum(jp.square(torques))) + jp.sum(jp.abs(torques))

  def _reward_action_rate(
      self, act: jax.Array, last_act: jax.Array
  ) -> jax.Array:
    # Penalize changes in actions
    return jp.sum(jp.square(act - last_act))

  def _reward_tracking_lin_vel(
      self, commands: jax.Array, x: Transform, xd: Motion
  ) -> jax.Array:
    # Tracking of linear velocity commands (xy axes)
    local_vel = math.rotate(xd.vel[0], math.quat_inv(x.rot[0]))
    lin_vel_error = jp.sum(jp.square(commands[:2] - local_vel[:2]))
    lin_vel_reward = jp.exp(
        -lin_vel_error / self.reward_config.rewards.tracking_sigma
    )
    return lin_vel_reward

  def _reward_tracking_ang_vel(
      self, commands: jax.Array, x: Transform, xd: Motion
  ) -> jax.Array:
    # Tracking of angular velocity commands (yaw)
    base_ang_vel = math.rotate(xd.ang[0], math.quat_inv(x.rot[0]))
    ang_vel_error = jp.square(commands[2] - base_ang_vel[2])
    return jp.exp(-ang_vel_error / self.reward_config.rewards.tracking_sigma)

  def _reward_feet_air_time(
      self, air_time: jax.Array, first_contact: jax.Array, commands: jax.Array
  ) -> jax.Array:
    # Reward air time.
    rew_air_time = jp.sum((air_time - 0.1) * first_contact)
    rew_air_time *= (
        math.normalize(commands[:2])[1] > 0.05
    )  # no reward for zero command
    return rew_air_time

  def _reward_stand_still(
      self,
      commands: jax.Array,
      joint_angles: jax.Array,
  ) -> jax.Array:
    # Penalize motion at zero commands
    return jp.sum(jp.abs(joint_angles - self._default_pose)) * (
        math.normalize(commands[:2])[1] < 0.1
    )

  def _reward_foot_slip(
      self, pipeline_state: base.State, contact_filt: jax.Array
  ) -> jax.Array:
    # get velocities at feet which are offset from lower legs
    # pytype: disable=attribute-error
    pos = pipeline_state.site_xpos[self._feet_site_id]  # feet position
    feet_offset = pos - pipeline_state.xpos[self._lower_leg_body_id]
    # pytype: enable=attribute-error
    offset = base.Transform.create(pos=feet_offset)
    foot_indices = self._lower_leg_body_id - 1  # we got rid of the world body
    foot_vel = offset.vmap().do(pipeline_state.xd.take(foot_indices)).vel

    # Penalize large feet velocity for feet that are in contact with the ground.
    return jp.sum(jp.square(foot_vel[:, :2]) * contact_filt.reshape((-1, 1)))

  def _reward_termination(self, done: jax.Array, step: jax.Array) -> jax.Array:
    return done & (step < 500)

  def render(
      self, trajectory: List[base.State], camera: str | None = None,
      width: int = 240, height: int = 320,
  ) -> Sequence[np.ndarray]:
    camera = camera or 'track'
    return super().render(trajectory, camera=camera, width=width, height=height)

envs.register_environment('barkour', BarkourEnv)

### Rollout

In [None]:
env_name = 'barkour'
env = envs.get_environment(env_name)

In [None]:
# Reconstruct the locomotion inference function
make_networks_factory = functools.partial(
    ppo_networks.make_ppo_networks,
    policy_hidden_layer_sizes=(128, 128, 128, 128)
)

nets = make_networks_factory(observation_size=1, # Observation_size argument doesn't matter since it's only used for param init.
                             action_size=12,
                             preprocess_observations_fn=running_statistics.normalize)

make_inference_fn = apg_networks.make_inference_fn(nets)

In [None]:
model_path = '/content/mjx_brax_barkour_policy'
params = model.load_params(model_path)

# Inference function for in-place trotting
ppo_inference_fn = make_inference_fn(params)

In [None]:
# @markdown Commands **only used for Barkour Env**:
x_vel = 3.0  #@param {type: "number"}
y_vel = 0.0  #@param {type: "number"}
ang_vel = 0.0  #@param {type: "number"}

the_command = jp.array([x_vel, y_vel, ang_vel])

In [None]:
demo_env = envs.training.EpisodeWrapper(env,
                                        episode_length=1000,
                                        action_repeat=1)

render_rollout(
  "barkour_ppo_trot",
  jax.jit(demo_env.reset),
  jax.jit(demo_env.step),
  jax.jit(ppo_inference_fn),
  demo_env,
  n_steps=1000,
  the_command=the_command,
  camera="track"
)