In [6]:
#@title Import packages for plotting and creating graphics
import time
import itertools
import numpy as np
from typing import Callable, NamedTuple, Optional, Union, List

# Graphics and plotting.
import mediapy as media
import matplotlib.pyplot as plt

# More legible printing from numpy.
np.set_printoptions(precision=3, suppress=True, linewidth=100)
#@title Import MuJoCo, MJX, and Brax
from datetime import datetime
from etils import epath
import functools
from IPython.display import HTML
from typing import Any, Dict, Sequence, Tuple, Union
import os
os.environ["MUJOCO_GL"] = "egl"
from ml_collections import config_dict


import jax
from jax import numpy as jp
import numpy as np
from flax.training import orbax_utils
from flax import struct
from matplotlib import pyplot as plt
import mediapy as media
from orbax import checkpoint as ocp

import mujoco
from mujoco import mjx

from brax import base
from brax import envs
from brax import math
from brax.base import Base, Motion, Transform
from brax.base import State as PipelineState
from brax.envs.base import Env, PipelineEnv, State
from brax.mjx.base import State as MjxState
from brax.training.agents.ppo import train as ppo
from brax.training.agents.ppo import networks as ppo_networks
from brax.io import html, mjcf, model


In [7]:
path = './mujoco_playground/_src/locomotion/naov6/xmls/scene.xml'

# Make model, data, and renderer
mj_model = mujoco.MjModel.from_xml_path(path)
mj_data = mujoco.MjData(mj_model)
renderer = mujoco.Renderer(mj_model)

In [8]:
mjx_model = mjx.put_model(mj_model)
mjx_data = mjx.put_data(mj_model, mj_data)

In [9]:
print(mj_data.qpos, type(mj_data.qpos))
print(mjx_data.qpos, type(mjx_data.qpos), mjx_data.qpos.devices())

[ 0.    0.    0.35  1.    0.    0.    0.    0.    0.    0.    0.    0.    0.    0.    0.    0.
  0.    0.    0.    0.    0.   -1.5   0.    0.    0.    1.57  0.    0.    0.    0.    0.    0.
  0.    0.    0.   -1.5   0.    0.    0.   -1.57  0.    0.    0.    0.    0.    0.    0.    0.
  0.  ] <class 'numpy.ndarray'>
[ 0.    0.    0.35  1.    0.    0.    0.    0.    0.    0.    0.    0.    0.    0.    0.    0.
  0.    0.    0.    0.    0.   -1.5   0.    0.    0.    1.57  0.    0.    0.    0.    0.    0.
  0.    0.    0.   -1.5   0.    0.    0.   -1.57  0.    0.    0.    0.    0.    0.    0.    0.
  0.  ] <class 'jaxlib.xla_extension.ArrayImpl'> {CudaDevice(id=0)}


In [10]:
# enable joint visualization option:
scene_option = mujoco.MjvOption()
scene_option.flags[mujoco.mjtVisFlag.mjVIS_JOINT] = True

duration = 3.8  # (seconds)
framerate = 60  # (Hz)

frames = []
mujoco.mj_resetData(mj_model, mj_data)
while mj_data.time < duration:
  mujoco.mj_step(mj_model, mj_data)
  if len(frames) < mj_data.time * framerate:
    renderer.update_scene(mj_data, scene_option=scene_option)
    pixels = renderer.render()
    frames.append(pixels)

# Simulate and display video.
media.show_video(frames, fps=framerate)

0
This browser does not support the video tag.


In [11]:

jit_step = jax.jit(mjx.step)

frames = []
mujoco.mj_resetData(mj_model, mj_data)
mjx_data = mjx.put_data(mj_model, mj_data)
while mjx_data.time < duration:
  mjx_data = jit_step(mjx_model, mjx_data)
  if len(frames) < mjx_data.time * framerate:
    mj_data = mjx.get_data(mj_model, mjx_data)
    renderer.update_scene(mj_data, scene_option=scene_option)
    pixels = renderer.render()
    frames.append(pixels)

media.show_video(frames, fps=framerate)

  return np.asarray(x, dtypes.canonicalize_dtype(x.dtype))


0
This browser does not support the video tag.


In [15]:
rng = jax.random.PRNGKey(0)
rng = jax.random.split(rng, 2048)
batch = jax.vmap(lambda rng: mjx_data.replace(qpos=jax.random.uniform(rng, (49,))))(rng)

jit_step = jax.jit(jax.vmap(mjx.step, in_axes=(None, 0)))
batch = jit_step(mjx_model, batch)

print(batch.qpos)

[[0.842 0.182 0.227 ... 0.755 0.367 0.689]
 [0.007 0.021 0.581 ... 0.972 0.711 0.468]
 [0.903 0.912 0.341 ... 0.896 0.373 0.284]
 ...
 [0.481 0.43  0.798 ... 0.69  0.19  0.967]
 [0.693 0.083 0.755 ... 0.939 0.225 0.452]
 [0.544 0.422 0.599 ... 0.268 0.355 0.206]]


In [16]:
batched_mj_data = mjx.get_data(mj_model, batch)
print([d.qpos for d in batched_mj_data])

[array([0.842, 0.182, 0.227, 0.112, 0.178, 0.671, 0.711, 0.152, 0.947, 0.029, 0.099, 0.552, 0.124,
       0.594, 0.958, 0.693, 0.721, 0.318, 0.82 , 0.641, 0.263, 0.192, 0.776, 0.856, 0.81 , 0.31 ,
       0.813, 0.95 , 0.006, 0.671, 0.965, 0.12 , 0.278, 0.984, 0.789, 0.422, 0.041, 0.206, 0.009,
       0.941, 0.92 , 0.637, 0.525, 0.874, 0.416, 0.367, 0.755, 0.367, 0.689]), array([0.007, 0.021, 0.581, 0.788, 0.486, 0.26 , 0.273, 0.617, 0.095, 0.981, 0.425, 0.829, 0.097,
       0.227, 0.354, 0.671, 0.117, 0.472, 0.169, 0.025, 0.134, 0.374, 0.908, 0.792, 0.122, 0.293,
       0.79 , 0.951, 0.341, 0.521, 0.258, 0.039, 0.416, 0.822, 0.097, 0.237, 0.008, 0.047, 0.319,
       0.044, 0.023, 0.049, 0.392, 0.236, 0.653, 0.173, 0.972, 0.711, 0.468]), array([0.903, 0.912, 0.341, 0.197, 0.582, 0.456, 0.644, 0.226, 0.59 , 0.068, 0.18 , 0.957, 0.169,
       0.333, 0.792, 0.883, 0.456, 0.183, 0.916, 0.552, 0.277, 0.322, 0.285, 0.609, 0.43 , 0.951,
       0.238, 0.979, 0.833, 0.043, 0.155, 0.374, 0.177, 0

In [None]:
#@title Humanoid Env

HUMANOID_ROOT_PATH = epath.Path(epath.resource_path('mujoco')) / 'mjx/test_data/humanoid'

class Humanoid(PipelineEnv):

  def __init__(
      self,
      forward_reward_weight=1.25,
      ctrl_cost_weight=0.1,
      healthy_reward=5.0,
      terminate_when_unhealthy=True,
      healthy_z_range=(1.0, 2.0),
      reset_noise_scale=1e-2,
      exclude_current_positions_from_observation=True,
      **kwargs,
  ):
#
    mj_model = mujoco.MjModel.from_xml_path(
        (HUMANOID_ROOT_PATH / 'humanoid.xml').as_posix())
    mj_model.opt.solver = mujoco.mjtSolver.mjSOL_CG
    mj_model.opt.iterations = 6
    mj_model.opt.ls_iterations = 6

    sys = mjcf.load_model(mj_model)

    physics_steps_per_control_step = 5
    kwargs['n_frames'] = kwargs.get(
        'n_frames', physics_steps_per_control_step)
    kwargs['backend'] = 'mjx'

    super().__init__(sys, **kwargs)

    self._forward_reward_weight = forward_reward_weight
    self._ctrl_cost_weight = ctrl_cost_weight
    self._healthy_reward = healthy_reward
    self._terminate_when_unhealthy = terminate_when_unhealthy
    self._healthy_z_range = healthy_z_range
    self._reset_noise_scale = reset_noise_scale
    self._exclude_current_positions_from_observation = (
        exclude_current_positions_from_observation
    )

  def reset(self, rng: jp.ndarray) -> State:
    """Resets the environment to an initial state."""
    rng, rng1, rng2 = jax.random.split(rng, 3)

    low, hi = -self._reset_noise_scale, self._reset_noise_scale
    qpos = self.sys.qpos0 + jax.random.uniform(
        rng1, (self.sys.nq,), minval=low, maxval=hi
    )
    qvel = jax.random.uniform(
        rng2, (self.sys.nv,), minval=low, maxval=hi
    )

    data = self.pipeline_init(qpos, qvel)

    obs = self._get_obs(data, jp.zeros(self.sys.nu))
    reward, done, zero = jp.zeros(3)
    metrics = {
        'forward_reward': zero,
        'reward_linvel': zero,
        'reward_quadctrl': zero,
        'reward_alive': zero,
        'x_position': zero,
        'y_position': zero,
        'distance_from_origin': zero,
        'x_velocity': zero,
        'y_velocity': zero,
    }
    return State(data, obs, reward, done, metrics)

  def step(self, state: State, action: jp.ndarray) -> State:
    """Runs one timestep of the environment's dynamics."""
    data0 = state.pipeline_state
    data = self.pipeline_step(data0, action)

    com_before = data0.subtree_com[1]
    com_after = data.subtree_com[1]
    velocity = (com_after - com_before) / self.dt
    forward_reward = self._forward_reward_weight * velocity[0]

    min_z, max_z = self._healthy_z_range
    is_healthy = jp.where(data.q[2] < min_z, 0.0, 1.0)
    is_healthy = jp.where(data.q[2] > max_z, 0.0, is_healthy)
    if self._terminate_when_unhealthy:
      healthy_reward = self._healthy_reward
    else:
      healthy_reward = self._healthy_reward * is_healthy

    ctrl_cost = self._ctrl_cost_weight * jp.sum(jp.square(action))

    obs = self._get_obs(data, action)
    reward = forward_reward + healthy_reward - ctrl_cost
    done = 1.0 - is_healthy if self._terminate_when_unhealthy else 0.0
    state.metrics.update(
        forward_reward=forward_reward,
        reward_linvel=forward_reward,
        reward_quadctrl=-ctrl_cost,
        reward_alive=healthy_reward,
        x_position=com_after[0],
        y_position=com_after[1],
        distance_from_origin=jp.linalg.norm(com_after),
        x_velocity=velocity[0],
        y_velocity=velocity[1],
    )

    return state.replace(
        pipeline_state=data, obs=obs, reward=reward, done=done
    )

  def _get_obs(
      self, data: mjx.Data, action: jp.ndarray
  ) -> jp.ndarray:
    """Observes humanoid body position, velocities, and angles."""
    position = data.qpos
    if self._exclude_current_positions_from_observation:
      position = position[2:]

    # external_contact_forces are excluded
    return jp.concatenate([
        position,
        data.qvel,
        data.cinert[1:].ravel(),
        data.cvel[1:].ravel(),
        data.qfrc_actuator,
    ])


envs.register_environment('humanoid', Humanoid)