In [1]:
#@title Setup Runtime

# install jax with cuda enabled
!pip install "jax[cuda12]==0.5.1"

# Install the necessary prerequisite libraries
!sudo apt install -y libx11-dev libxrandr-dev libxinerama-dev libxcursor-dev libxi-dev mesa-common-dev

# get madrona mjx and its subpackages
!mkdir modules
!git clone https://github.com/shacklettbp/madrona_mjx.git modules/madrona_mjx
!git -C modules/madrona_mjx submodule update --init --recursive

# prepare the build directory
!mkdir modules/madrona_mjx/build

# Build Madrona MJX
!mkdir modules/madrona_mjx/build
!cd modules/madrona_mjx/build && cmake -DLOAD_VULKAN=OFF .. && make -j 8

# Install Madrona MJX
!pip install -e modules/madrona_mjx

# Clone and install our Quadruped RL package
!git clone https://github.com/alexeiplatzer/unitree-go2-mjx-rl.git
!pip install -e unitree-go2-mjx-rl

Collecting jax==0.5.1 (from jax[cuda12]==0.5.1)
  Downloading jax-0.5.1-py3-none-any.whl.metadata (22 kB)
Collecting jaxlib<=0.5.1,>=0.5.1 (from jax==0.5.1->jax[cuda12]==0.5.1)
  Downloading jaxlib-0.5.1-cp312-cp312-manylinux2014_x86_64.whl.metadata (978 bytes)
Collecting jax-cuda12-plugin<=0.5.1,>=0.5.1 (from jax-cuda12-plugin[with_cuda]<=0.5.1,>=0.5.1; extra == "cuda12"->jax[cuda12]==0.5.1)
  Downloading jax_cuda12_plugin-0.5.1-cp312-cp312-manylinux2014_x86_64.whl.metadata (1.2 kB)
Collecting jax-cuda12-pjrt==0.5.1 (from jax-cuda12-plugin<=0.5.1,>=0.5.1->jax-cuda12-plugin[with_cuda]<=0.5.1,>=0.5.1; extra == "cuda12"->jax[cuda12]==0.5.1)
  Downloading jax_cuda12_pjrt-0.5.1-py3-none-manylinux2014_x86_64.whl.metadata (348 bytes)
Downloading jax-0.5.1-py3-none-any.whl (2.4 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m2.4/2.4 MB[0m [31m28.8 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading jaxlib-0.5.1-cp312-cp312-manylinux2014_x86_64.whl (105.2 MB)
[2K   [90m━━━━━━

In [None]:
#@title Refresh the package
repo_path = "./unitree-go2-mjx-rl"
!git -C {repo_path} pull

Already up to date.


In [1]:
#@title Setup Session

# Configure logging
import logging
logging.basicConfig(level=logging.INFO, force=True)
logging.info("Logging switched on.")

import os
# Ensure that Madrona gets the chance to pre-allocate memory before Jax
os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"] = "false"

# On your second reading, load the compiled rendering backend to save time!
use_madrona_cache = False #@param {"type":"boolean"}
if use_madrona_cache:
    os.environ["MADRONA_MWGPU_KERNEL_CACHE"] = "modules/madrona_mjx/build/cache"

# Add an ICD config so that glvnd can pick up the Nvidia EGL driver.
# This is usually installed as part of an Nvidia driver package, but the Colab
# kernel doesn't install its driver via APT, and as a result the ICD is missing.
# (https://github.com/NVIDIA/libglvnd/blob/master/src/EGL/icd_enumeration.md)
NVIDIA_ICD_CONFIG_PATH = '/usr/share/glvnd/egl_vendor.d/10_nvidia.json'
if not os.path.exists(NVIDIA_ICD_CONFIG_PATH):
    with open(NVIDIA_ICD_CONFIG_PATH, 'w') as f:
        f.write("""{
        "file_format_version" : "1.0.0",
        "ICD" : {
            "library_path" : "libEGL_nvidia.so.0"
        }
    }
    """)

# Configure MuJoCo to use the EGL rendering backend (requires GPU)
print('Setting environment variable to use GPU rendering:')
%env MUJOCO_GL=egl

# Tell XLA to use Triton GEMM, this improves steps/sec by ~30% on some GPUs
xla_flags = os.environ.get('XLA_FLAGS', '')
xla_flags += ' --xla_gpu_triton_gemm_any=True'
os.environ['XLA_FLAGS'] = xla_flags

# More legible printing from numpy.
import numpy as np
np.set_printoptions(precision=3, suppress=True, linewidth=100)

# Prepare paths
from pathlib import Path
repo_path = Path("unitree-go2-mjx-rl")
robot_name = "unitree_go2"
scenes_path = repo_path / "resources" / robot_name
results_path = Path("results")
results_path.mkdir(parents=True, exist_ok=True)

# Relevant imports
import functools
import jax
from jax import numpy as jnp

# Prepare robot config
from quadruped_mjx_rl.robots import predefined_robot_configs
robot_config = predefined_robot_configs[robot_name]()

# Prepare env config
from quadruped_mjx_rl.environments import QuadrupedVisionBaseEnvConfig
CameraConfig = QuadrupedVisionBaseEnvConfig.ObservationConfig.CameraInputConfig
env_config = QuadrupedVisionBaseEnvConfig(
    observation_noise=QuadrupedVisionBaseEnvConfig.ObservationConfig(
        camera_inputs=[
            CameraConfig("frontal_ego", True, True, True),
            CameraConfig("terrain", True, True, True),
        ]
    )
)

# Prepare scene file
init_scene_path = scenes_path / "scene_mjx_empty_arena.xml"

# Prepare the terrain
from quadruped_mjx_rl.terrain_gen import make_plain_tiled_terrain
env_model = make_plain_tiled_terrain(init_scene_path)

# Render the environments from different cameras in mujoco
from quadruped_mjx_rl.environments.rendering import (
    render_model,
    show_image,
    large_overview_camera,
)
render_cam = functools.partial(
    render_model, env_model=env_model, initial_keyframe=robot_config.initial_keyframe
)
image_overview = render_cam(camera=large_overview_camera())
image_tracking = render_cam(camera="track")
image_terrain = render_cam(camera="privileged")
image_egocentric = render_cam(camera="ego_frontal")
# show_image(image_overview)
# show_image(image_tracking)
# show_image(image_terrain)
# show_image(image_egocentric)

# Prepare vision config
from quadruped_mjx_rl.robotic_vision import VisionConfig
num_envs=64 #@param {"type":"integer"}
enabled_cameras=[1, 2] #@param
enabled_geom_groups=[0, 1, 2] #@param
render_width=64 #@param {"type":"integer"}
render_height=64 #@param {"type":"integer"}
vision_config = VisionConfig(
    render_batch_size=num_envs,
    enabled_cameras=enabled_cameras,
    enabled_geom_groups=enabled_geom_groups,
    render_width=render_width,
    render_height=render_height,
)

# Make the env factory
from quadruped_mjx_rl.environments import get_env_factory
from quadruped_mjx_rl.robotic_vision import get_renderer
renderer_maker = functools.partial(get_renderer, vision_config=vision_config, debug=False)
env_factory = get_env_factory(
    robot_config=robot_config,
    environment_config=env_config,
    env_model=env_model,
    customize_model=True,
    vision_config=vision_config,
    renderer_maker=renderer_maker,
)

# Execute one environment step to initialize mjx
from mujoco import mjx
mjx_model = mjx.put_model(env_model)
mjx_data = mjx.make_data(mjx_model)
mjx_data = mjx.forward(mjx_model, mjx_data)

# Create the environment
print("Setup finished, initializing the environment...")
env = env_factory()

INFO:root:Logging switched on.


Setting environment variable to use GPU rendering:
env: MUJOCO_GL=egl


INFO:OpenGL.acceleratesupport:No OpenGL_accelerate module loaded: No module named 'OpenGL_accelerate'
INFO:2025-11-12 12:12:32,935:jax._src.xla_bridge:924: Unable to initialize backend 'rocm': module 'jaxlib.xla_extension' has no attribute 'GpuAllocatorConfig'
INFO:jax._src.xla_bridge:Unable to initialize backend 'rocm': module 'jaxlib.xla_extension' has no attribute 'GpuAllocatorConfig'
INFO:2025-11-12 12:12:32,939:jax._src.xla_bridge:924: Unable to initialize backend 'tpu': INTERNAL: Failed to open libtpu.so: libtpu.so: cannot open shared object file: No such file or directory
INFO:jax._src.xla_bridge:Unable to initialize backend 'tpu': INTERNAL: Failed to open libtpu.so: libtpu.so: cannot open shared object file: No such file or directory
INFO:root:Using JAX default device: cuda:0.
INFO:root:MJX Warp is disabled via MJX_WARP_ENABLED=false.
INFO:root:Using JAX default device: cuda:0.
INFO:root:MJX Warp is disabled via MJX_WARP_ENABLED=false.
INFO:root:Using JAX default device: cuda:0

Setup finished, initializing the environment...


In [2]:
import jax
from jax import numpy as jnp
import mujoco

from quadruped_mjx_rl.environments.physics_pipeline import (
    EnvModel,
    PipelineModel,
)


def collect_tile_ids(env_model: EnvModel, tile_body_prefix: str = "tile_") -> jax.Array:
    """Collects the ids of geoms that belong to tile bodies.

    Args:
        env_model: Compiled environment model containing the tiled terrain.
        tile_body_prefix: Prefix that identifies bodies representing tiles.

    Returns:
        An array with the ids of all geoms that belong to bodies with the provided prefix.
    """

    geom_ids: list[int] = []
    for body_id in range(env_model.nbody):
        name = mujoco.mj_id2name(env_model, mujoco.mjtObj.mjOBJ_BODY, body_id)
        if name is None or not name.startswith(tile_body_prefix):
            continue
        first_geom = env_model.body_geomadr[body_id]
        geom_ids.append(first_geom)

    if not geom_ids:
        raise ValueError(
            "Unable to locate any tile geoms. Ensure that the terrain builder "
            f"creates bodies with the prefix '{tile_body_prefix}'."
        )

    return jnp.array(geom_ids)


def randomize_tiles(
    pipeline_model: PipelineModel,
    rng: jax.Array,
    num_worlds: int,
    env_model: EnvModel,
    tile_body_prefix: str = "tile_",
):
    """Randomizes the mjx.Model. Assumes the ground is a square grid of square tiles."""
    tile_geom_ids = collect_tile_ids(env_model, tile_body_prefix)
    print(tile_geom_ids)
    num_variants = 2
    rgbas = jnp.array([[1.0, 0.0, 0.0, 1.0], [0.0, 1.0, 0.0, 1.0]])
    friction_min = 0.6
    friction_max = 1.4
    solref_min = 0.002
    solref_max = 0.05

    @jax.vmap
    def rand(rng):
        key_tiles, key_friction, key_solref = jax.random.split(rng, 3)

        colour_friction = jax.random.uniform(
            key_friction,
            shape=(num_variants,),
            minval=friction_min,
            maxval=friction_max,
        )
        colour_solref = jax.random.uniform(
            key_solref,
            shape=(num_variants,),
            minval=solref_min,
            maxval=solref_max,
        )
        tile_colour_indices = jax.random.randint(
            key_tiles,
            shape=(tile_geom_ids.shape[0],),
            minval=0,
            maxval=num_variants,
        )

        chosen_colors = rgbas[tile_colour_indices]
        chosen_frictions = colour_friction[tile_colour_indices]
        chosen_solrefs = colour_solref[tile_colour_indices]

        geom_rgba = pipeline_model.model.geom_rgba.at[tile_geom_ids].set(jnp.array([1.0, 0.0, 0.0, 1.0]))
        geom_friction = pipeline_model.model.geom_friction.at[tile_geom_ids, 0].set(
            chosen_frictions
        )
        geom_solref = pipeline_model.model.geom_solref.at[tile_geom_ids, 0].set(chosen_solrefs)

        return geom_rgba, geom_friction, geom_solref

    rgba, friction, solref = rand(rng)

    in_axes = jax.tree.map(lambda x: None, pipeline_model)
    in_axes = in_axes.replace(
        model=in_axes.model.tree_replace(
            {
                "geom_matid": 0,
                "geom_rgba": 0,
                "geom_friction": 0,
                "geom_solref": 0,
            }
        )
    )

    pipeline_model = pipeline_model.replace(
        model=pipeline_model.model.tree_replace(
            {
                "geom_matid": jnp.repeat(
                    jnp.expand_dims(
                        jnp.repeat(-1, pipeline_model.model.geom_matid.shape[0], 0), 0
                    ),
                    num_worlds,
                    axis=0,
                ),
                "geom_rgba": rgba,
                "geom_friction": friction,
                "geom_solref": solref,
            }
        )
    )

    return pipeline_model, in_axes


In [3]:
# Wrap the environment
from quadruped_mjx_rl.environments.wrappers import wrap_for_training
#from quadruped_mjx_rl.domain_randomization.randomized_tiles import randomize_tiles
rng_key = jax.random.PRNGKey(0)
domain_rand_key, reset_key = jax.random.split(rng_key, 2)
env_keys = jax.random.split(domain_rand_key, num_envs)
env = wrap_for_training(
    env=env,
    vision=True,
    num_vision_envs=num_envs,
    randomization_fn=functools.partial(randomize_tiles, num_worlds=num_envs, rng=env_keys, env_model=env_model)
)
jit_reset = jax.jit(env.reset)
jit_step = jax.jit(env.step)

[37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52]


In [5]:
print(env._sys_v.model.geom_rgba[0][39])

[1. 0. 0. 1.]


In [6]:
# Execute one step
state = jit_reset(jax.random.split(reset_key, num_envs))
state = jit_step(state, jnp.zeros((num_envs, env.action_size)))

INFO:root:Using JAX default device: cuda:0.
INFO:root:MJX Warp is disabled via MJX_WARP_ENABLED=false.


In [10]:
state.obs["pixels/frontal_ego/rgb"][0]

Array([[[0.   , 0.   , 0.   ],
        [0.   , 0.   , 0.   ],
        [0.   , 0.   , 0.   ],
        ...,
        [0.   , 0.   , 0.   ],
        [0.   , 0.   , 0.   ],
        [0.   , 0.   , 0.   ]],

       [[0.   , 0.   , 0.   ],
        [0.   , 0.   , 0.   ],
        [0.   , 0.   , 0.   ],
        ...,
        [0.   , 0.   , 0.   ],
        [0.   , 0.   , 0.   ],
        [0.   , 0.   , 0.   ]],

       [[0.   , 0.   , 0.   ],
        [0.   , 0.   , 0.   ],
        [0.   , 0.   , 0.   ],
        ...,
        [0.   , 0.   , 0.   ],
        [0.   , 0.   , 0.   ],
        [0.   , 0.   , 0.   ]],

       ...,

       [[1.   , 1.   , 0.863],
        [1.   , 1.   , 0.863],
        [1.   , 1.   , 0.863],
        ...,
        [1.   , 1.   , 0.863],
        [1.   , 1.   , 0.863],
        [1.   , 1.   , 0.863]],

       [[1.   , 1.   , 0.863],
        [1.   , 1.   , 0.863],
        [1.   , 1.   , 0.863],
        ...,
        [1.   , 1.   , 0.863],
        [1.   , 1.   , 0.863],
        [1.   ,

In [7]:
frontal_view = state.obs["pixels/frontal_ego/rgb"]
print(frontal_view.shape)
terrain_view = state.obs["pixels/terrain/rgb"]
print(terrain_view.shape)

(64, 64, 64, 3)
(64, 64, 64, 3)


In [8]:
def tile(img, d):
    assert img.shape[0] == d*d
    img = img.reshape((d,d)+img.shape[1:])
    return np.concat(np.concat(img, axis=1), axis=1)  # replace with 2 for multi-camera tensors!

# image = tile(rgb_tensor[:16], 4)
# image.shape
frontal_view_image = tile(frontal_view[:16], 4)
print(frontal_view_image.shape)
terrain_view_image = tile(terrain_view[:16], 4)
print(terrain_view_image.shape)

(256, 256, 3)
(256, 256, 3)


In [9]:
import mediapy as media
media.show_image(frontal_view_image, width=512)

In [11]:
media.show_image(terrain_view_image, width=512)