In [None]:
#@title Setup Runtime

# install jax with cuda enabled
!pip install "jax[cuda12]==0.5.1"

# Install the necessary prerequisite libraries
!sudo apt install -y libx11-dev libxrandr-dev libxinerama-dev libxcursor-dev libxi-dev mesa-common-dev

# get madrona mjx and its subpackages
!mkdir modules
!git clone https://github.com/shacklettbp/madrona_mjx.git modules/madrona_mjx
!git -C modules/madrona_mjx submodule update --init --recursive

# prepare the build directory
!mkdir modules/madrona_mjx/build

# Build Madrona MJX
!mkdir modules/madrona_mjx/build
!cd modules/madrona_mjx/build && cmake -DLOAD_VULKAN=OFF .. && make -j 8

# Install Madrona MJX
!pip install -e modules/madrona_mjx

# Clone and install our Quadruped RL package
!git clone https://github.com/alexeiplatzer/unitree-go2-mjx-rl.git
!pip install -e unitree-go2-mjx-rl

In [None]:
#@title Refresh the package
repo_path = "./unitree-go2-mjx-rl"
!git -C {repo_path} pull

In [None]:
#@title Setup Session

# Configure logging
import logging
logging.basicConfig(level=logging.INFO, force=True)
logging.info("Logging switched on.")

import os
# Ensure that Madrona gets the chance to pre-allocate memory before Jax
os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"] = "false"

# On your second reading, load the compiled rendering backend to save time!
use_madrona_cache = False #@param {"type":"boolean"}
if use_madrona_cache:
    os.environ["MADRONA_MWGPU_KERNEL_CACHE"] = "modules/madrona_mjx/build/cache"

# Add an ICD config so that glvnd can pick up the Nvidia EGL driver.
# This is usually installed as part of an Nvidia driver package, but the Colab
# kernel doesn't install its driver via APT, and as a result the ICD is missing.
# (https://github.com/NVIDIA/libglvnd/blob/master/src/EGL/icd_enumeration.md)
NVIDIA_ICD_CONFIG_PATH = '/usr/share/glvnd/egl_vendor.d/10_nvidia.json'
if not os.path.exists(NVIDIA_ICD_CONFIG_PATH):
    with open(NVIDIA_ICD_CONFIG_PATH, 'w') as f:
        f.write("""{
        "file_format_version" : "1.0.0",
        "ICD" : {
            "library_path" : "libEGL_nvidia.so.0"
        }
    }
    """)

# Configure MuJoCo to use the EGL rendering backend (requires GPU)
print('Setting environment variable to use GPU rendering:')
%env MUJOCO_GL=egl

# Tell XLA to use Triton GEMM, this improves steps/sec by ~30% on some GPUs
xla_flags = os.environ.get('XLA_FLAGS', '')
xla_flags += ' --xla_gpu_triton_gemm_any=True'
os.environ['XLA_FLAGS'] = xla_flags

# More legible printing from numpy.
import numpy as np
np.set_printoptions(precision=3, suppress=True, linewidth=100)

# Prepare paths
from pathlib import Path
repo_path = Path("unitree-go2-mjx-rl")
robot_name = "unitree_go2"
scenes_path = repo_path / "resources" / robot_name
results_path = Path("results")
results_path.mkdir(parents=True, exist_ok=True)

# Prepare robot config
from quadruped_mjx_rl.robots import predefined_robot_configs
robot_config = predefined_robot_configs[robot_name]()

# Prepare env config
from quadruped_mjx_rl.envrionments import QuadrupedVisionBaseEnvConfig
CameraConfig = QuadrupedVisionBaseEnvConfig.ObservationConfig.CameraInputConfig
env_config = QuadrupedVisionBaseEnvConfig(
    observation_noise=QuadrupedVisionBaseEnvConfig.ObservationConfig(
        [
            CameraConfig("frontal_ego", True, True, True),
            CameraConfig("terrain", True, True, True),
        ]
    )
)

# Prepare scene file
init_scene_path = scenes_path / "scene_mjx_empty_arena.xml"

# Prepare the terrain
from quadruped_mjx_rl.terrain_gen import make_plain_tiled_terrain
env_model = make_plain_tiled_terrain(init_scene_path)

# Render the environments from different cameras in mujoco
import functools
from quadruped_mjx_rl.environments.rendering import (
    render_model,
    show_image,
    large_overview_camera,
)
render_cam = functools.partial(
    render_model, env_model=env_model, initial_keyframe=robot_config.initial_keyframe
)
image_overview = render_cam(camera=large_overview_camera())
image_tracking = render_cam(camera="track")
image_terrain = render_cam(camera="privileged")
image_egocentric = render_cam(camera="ego_frontal")
show_image(image_overview)
show_image(image_tracking)
show_image(image_terrain)
show_image(image_egocentric)

# Prepare vision config
from quadruped_mjx_rl.robotic_vision import VisionConfig
num_envs=64 #@param {"type":"integer"}
enabled_cameras=[1, 2] #@param
enabled_geom_groups=[0, 1, 2] #@param
render_width=64 #@param {"type":"integer"}
render_height=64 #@param {"type":"integer"}
vision_config = VisionConfig(
    render_batch_size=num_envs,
    enabled_cameras=enabled_cameras,
    enabled_geom_groups=enabled_geom_groups,
    render_width=render_width,
    render_height=render_height,
)

# Make the env factory
from quadruped_mjx_rl.environments import get_env_factory
from quadruped_mjx_rl.robotic_vision import get_renderer
import functools
from jax import numpy as jnp

renderer_maker = functools.partial(get_renderer, vision_config=vision_config, debug=False)
env_factory = get_env_factory(
    robot_config=robot_config,
    environment_config=env_config,
    env_model=env_model,
    customize_model=True,
    vision_config=vision_config,
    renderer_maker=renderer_maker,
)
#TODO add support for non vision envs with a metafactory
logging.info(training_config.batch_size)

from mujoco import mjx
mjx_model = mjx.put_model(env_model)
mjx_data = mjx.make_data(mjx_model)
mjx_data = mjx.forward(mjx_model, mjx_data)

# Create the environment
print("Setup finished, initializing the environment...")
env = env_factory()