![MuJoCo banner](https://raw.githubusercontent.com/google-deepmind/mujoco/main/banner.png)

# <h1><center>Toy Example Tutorial</center></h1>

This notebook provides an introductory tutorial for [**the Bio-inspired Robotics Benchmark (BRB)**](https://github.com/Co-Evolve/brb), a collection of bio-inspired robotics environments implemented in MuJoCo and MuJoCo XLA (MJX). Specifically, this tutorial covers the usage of the [toy example](https://github.com/Co-Evolve/brb/tree/new-framework/brb/toy_example) environment.

**A Colab runtime with GPU acceleration is required.** If you're using a CPU-only runtime, you can switch using the menu "Runtime > Change runtime type".










# Installation

In [2]:
# Check if this points towards your conda environment
import sys
print(sys.executable)

/data/gent/432/vsc43242/conda/envs/toy-example/bin/python


In [3]:
try:
    import brb
except :
    !{sys.executable} -m pip install git+https://github.com/Co-Evolve/brb@new-framework

In [4]:
#@title Check if MuJoCo installation was successful
import distutils.util
import os
import subprocess
if subprocess.run('nvidia-smi').returncode:
  raise RuntimeError(
      'Cannot communicate with GPU. '
      'Make sure you are using a GPU Colab runtime. '
      'Go to the Runtime menu and select Choose runtime type.')

# Add an ICD config so that glvnd can pick up the Nvidia EGL driver.
# This is usually installed as part of an Nvidia driver package, but the Colab
# kernel doesn't install its driver via APT, and as a result the ICD is missing.
# (https://github.com/NVIDIA/libglvnd/blob/master/src/EGL/icd_enumeration.md)
NVIDIA_ICD_CONFIG_PATH = '/usr/share/glvnd/egl_vendor.d/10_nvidia.json'
if not os.path.exists(NVIDIA_ICD_CONFIG_PATH):
  with open(NVIDIA_ICD_CONFIG_PATH, 'w') as f:
    f.write("""{
    "file_format_version" : "1.0.0",
    "ICD" : {
        "library_path" : "libEGL_nvidia.so.0"
    }
}
""")

# Configure MuJoCo to use the EGL rendering backend (requires GPU)
print('Setting environment variable to use GPU rendering:')
%env MUJOCO_GL=egl

try:
  print('Checking that the installation succeeded:')
  import mujoco
  mujoco.MjModel.from_xml_string('<mujoco/>')
except Exception as e:
  raise e from RuntimeError(
      'Something went wrong during installation. Check the shell output above '
      'for more information.\n'
      'If using a hosted Colab runtime, make sure you enable GPU acceleration '
      'by going to the Runtime menu and selecting "Choose runtime type".')

print('Installation successful.')

Tue Jan  9 23:19:25 2024       
+---------------------------------------------------------------------------------------+
| NVIDIA-SMI 535.129.03             Driver Version: 535.129.03   CUDA Version: 12.2     |
|-----------------------------------------+----------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |         Memory-Usage | GPU-Util  Compute M. |
|                                         |                      |               MIG M. |
|   0  NVIDIA A2                      On  | 00000000:3B:00.0 Off |                    0 |
|  0%   44C    P8               8W /  60W |      4MiB / 15356MiB |      0%      Default |
|                                         |                      |                  N/A |
+-----------------------------------------+----------------------+----------------------+
                                                                    

In [15]:
#@title Import packages for plotting and creating graphics
import time
import itertools
import numpy as np
from typing import Callable, NamedTuple, Optional, Union, List

# Graphics and plotting.
print('Installing mediapy:')
ffmpeg_v = !command -v ffmpeg
assert "command not found" not in ffmpeg_v,  f"FFmpeg is needed for visualizations. Please restart the interactive HPC session but include 'ml load FFmpeg/6.0-GCCcore-12.3.0' in the 'Custom code' text box."
!{sys.executable} -m pip install -q mediapy
import mediapy as media
import matplotlib.pyplot as plt

# More legible printing from numpy.
np.set_printoptions(precision=3, suppress=True, linewidth=100)

Installing mediapy:


# Toy Example Environment


## Imports

In [16]:
from brb.toy_example.arena.arena import MJCFPlaneWithTargetArena, PlaneWithTargetArenaConfiguration
from brb.toy_example.mjc.env import ToyExampleEnvironmentConfiguration, ToyExampleMJCEnvironment
from brb.toy_example.mjx.env import ToyExampleMJXEnvironment
from brb.toy_example.morphology.morphology import MJCFToyExampleMorphology
from brb.toy_example.morphology.specification.specification import ToyExampleMorphologySpecification
from brb.toy_example.morphology.specification.default import default_toy_example_morphology_specification
from mujoco_utils.environment.mjx_env import MJXGymEnvWrapper
import gymnasium
import jax
import jax.numpy as jnp
import time

## Creating the environment

* Todo: Show Image and explain environment goal

### Morphology

In [17]:
def create_morphology(
    morphology_specification: ToyExampleMorphologySpecification
    ) -> MJCFToyExampleMorphology:
    morphology = MJCFToyExampleMorphology(
        specification=morphology_specification
        )
    return morphology
morphology_specification = default_toy_example_morphology_specification(
  num_arms=4,
  num_segments_per_arm=2
  )

### Arena

In [18]:
def create_arena(
  arena_configuration: PlaneWithTargetArenaConfiguration
  ) -> MJCFPlaneWithTargetArena:
    arena = MJCFPlaneWithTargetArena(
      configuration=arena_configuration
      )
    return arena
arena_configuration = PlaneWithTargetArenaConfiguration(
    size=(10, 10)
)


### Environment
* Todo: Explain differences between MJC and MJX.
* Todo: refer to gymnasium api: https://gymnasium.farama.org/
  * explain reset, step, render, close


In [19]:
environment_configuration = ToyExampleEnvironmentConfiguration(
    target_distance=3.0,
    randomization_noise_scale=0.0,
    num_physics_steps_per_control_step=5,
    simulation_time=5.0,
    camera_ids=[0, 1],
    render_mode="rgb_array"
)

##### MJC

In [20]:
def create_mjc_environment(
    morphology_specification: ToyExampleMorphologySpecification,
    arena_configuration: PlaneWithTargetArenaConfiguration,
    environment_configuration: ToyExampleEnvironmentConfiguration,
    num_envs: int = 1
    ) -> Union[ToyExampleMJCEnvironment, gymnasium.vector.AsyncVectorEnv]:
    def _create_env() -> ToyExampleMJCEnvironment:
        morphology = create_morphology(
            morphology_specification=morphology_specification
        )
        arena = create_arena(
            arena_configuration=arena_configuration
        )
        env = ToyExampleMJCEnvironment(
                morphology=morphology,
                arena=arena,
                configuration=environment_configuration
                )
        return env


    if num_envs == 1:
        env = _create_env()
    else:
        env = gymnasium.vector.AsyncVectorEnv(env_fns=[_create_env for _ in range(num_envs)])

    return env

##### MJX

In [21]:
def create_mjx_environment(
    morphology_specification: ToyExampleMorphologySpecification,
    arena_configuration: PlaneWithTargetArenaConfiguration,
    environment_configuration: ToyExampleEnvironmentConfiguration,
    num_envs: int = 1
    ) -> ToyExampleMJXEnvironment:
    morphology = create_morphology(
        morphology_specification=morphology_specification
    )
    arena = create_arena(
        arena_configuration=arena_configuration
    )
    env = ToyExampleMJXEnvironment(
            morphology=morphology,
            arena=arena,
            configuration=environment_configuration
            )
    env = MJXGymEnvWrapper(env=env, num_envs=num_envs)
    return env

In [22]:
mjc_env = create_mjc_environment(
    morphology_specification=morphology_specification,
    arena_configuration=arena_configuration,
    environment_configuration=environment_configuration
)
mjx_env = create_mjx_environment(
    morphology_specification=morphology_specification,
    arena_configuration=arena_configuration,
    environment_configuration=environment_configuration
)

#### Observation space

In [23]:
print(mjc_env.observation_space)
print()
print(mjx_env.observation_space)
# todo: show vectorized observation space

Dict('in_plane_joint_position': Box(-0.34906584, 0.34906584, (8,), float32), 'in_plane_joint_velocity': Box(-inf, inf, (8,), float32), 'out_of_plane_joint_position': Box(-0.34906584, 0.34906584, (8,), float32), 'out_of_plane_joint_velocity': Box(-inf, inf, (8,), float32), 'segment_ground_contact': Box(0.0, inf, (8,), float32), 'torso_angular_velocity': Box(-inf, inf, (3,), float32), 'torso_linear_velocity': Box(-inf, inf, (3,), float32), 'torso_rotation': Box(-3.1415927, 3.1415927, (3,), float32), 'unit_xy_direction_to_target': Box(-1.0, 1.0, (2,), float32), 'xy_distance_to_target': Box(0.0, inf, (1,), float32))

Dict('in_plane_joint_position': Box(-0.34906584, 0.34906584, (8,), float32), 'in_plane_joint_velocity': Box(-inf, inf, (8,), float32), 'out_of_plane_joint_position': Box(-0.34906584, 0.34906584, (8,), float32), 'out_of_plane_joint_velocity': Box(-inf, inf, (8,), float32), 'segment_ground_contact': Box(0.0, 1.0, (8,), float32), 'torso_angular_velocity': Box(-inf, inf, (3,), flo

#### Action space

In [24]:
print(mjc_env.action_space)
print()
print(mjx_env.action_space)
# todo: show vectorized action space

Box(-0.34906584, 0.34906584, (16,), float32)

Box(-0.34906584, 0.34906584, (16,), float32)


Avoid memory leaks and always close your environment after using!

In [25]:
mjc_env.close()
mjx_env.close()

# Rollouts


* todo: explain what a rollout is
* todo: explain that we will use MJX with a lot of sub environments -> fast data collection in randomized environments / stochastic controllers
!

## MJC

### Defining a simple open-loop controller

In [26]:
def create_mjc_open_loop_controller(
        single_action_space: gymnasium.spaces.Box,
        num_envs: int
        ) -> Callable[[float], np.ndarray]:
    def open_loop_controller(
            t: float
            ) -> np.ndarray:
        actions = np.ones(single_action_space.shape)
        actions[::2] = np.cos(5 * t)
        actions[1::2] = np.sin(5 * t)
        actions[-actions.shape[0] // 2::2] *= -1
        return actions

    if num_envs > 1:
        batched_open_loop_controller = lambda \
                t: np.stack([open_loop_controller(tt) for tt in t])
        return batched_open_loop_controller

    return open_loop_controller

### Doing a rollout

In [51]:
def do_mjc_rollout(
    env: Union[ToyExampleMJCEnvironment, gymnasium.vector.AsyncVectorEnv],
    return_frames: bool,
    fps: int = 30
    ) -> List[np.ndarray]:

    try:
        num_envs = env.num_envs
        single_action_space = env.single_action_space
        render_fn = lambda: env.call(name="render")
    except AttributeError:
        num_envs = 1
        single_action_space = env.action_space
        render_fn = env.render
        

    controller = create_mjc_open_loop_controller(
        single_action_space=single_action_space, num_envs=num_envs
    )

    done = False
    steps = 0
    fps = 30
    frames = []
    
    obs, info = env.reset()
    
    start_time = time.time()
    while not done:
        ts = info["time"]
        actions = controller(ts)

        obs, reward, terminated, truncated, info = env.step(
            actions=actions
            )

        done = np.any(terminated | truncated)
        if return_frames:
          if steps % round((1 / fps) / environment_configuration.control_timestep) == 0:
              frame = render_fn()
              if num_envs == 1:
                  frame = [frame]
                  
              if len(environment_configuration.camera_ids) > 1:
                  # If we have multiple camera's, stack their images horizontally
                  frame = [np.concatenate(env_frames, axis=1) for env_frames in frame]

              # If we have multiple environments, stack their frames vertically
              frame = np.concatenate(frame, axis=0)

              frame = frame[:, :, ::-1] # rgb to bgr
              frames.append(frame)

        steps += 1

    stop_time = time.time()

    total_steps = steps * num_envs * \
      environment_configuration.num_physics_steps_per_control_step
    total_time = stop_time - start_time

    print(f"Did {total_steps} env steps in {round(total_time, 2)}s")
    print(f"\tSPS: {round(total_steps / total_time, 2)}")

    return frames

In [52]:
mjc_env = create_mjc_environment(
    morphology_specification=morphology_specification,
    arena_configuration=arena_configuration,
    environment_configuration=environment_configuration,
)
do_mjc_rollout(
    env=mjc_env,
    return_frames=False
    )
mjc_env.close()

Did 2505 env steps in 0.23s
	SPS: 10761.83


In [53]:
# todo: explain why its possible that we don't see a SPS increase here -> due to number of cores / threads available / OMP settings
mjc_env = create_mjc_environment(
    morphology_specification=morphology_specification,
    arena_configuration=arena_configuration,
    environment_configuration=environment_configuration,
    num_envs=5
)
do_mjc_rollout(
    env=mjc_env,
    return_frames=False
    )
mjc_env.close()

Did 12525 env steps in 1.19s
	SPS: 10491.66


### Visualizing a rollout

In [54]:
mjc_env = create_mjc_environment(
    morphology_specification=morphology_specification,
    arena_configuration=arena_configuration,
    environment_configuration=environment_configuration,
    num_envs=1
)
frames = do_mjc_rollout(
    env=mjc_env,
    return_frames=True
    )
mjc_env.close()

Did 2505 env steps in 0.66s
	SPS: 3784.42


In [55]:
media.show_video(frames, fps=30)

0
This browser does not support the video tag.


## MJX

### Defining a simple open-loop controller

In [56]:
def create_mjx_open_loop_controller(
        single_action_space: gymnasium.spaces.Box,
        num_envs: int
        ) -> Callable[[float], jnp.ndarray]:
    def open_loop_controller(
            t: float
            ) -> jnp.ndarray:
        actions = jnp.ones(single_action_space.shape)
        actions = actions.at[jnp.arange(0, len(actions), 2)].set(jnp.cos(5 * t))
        actions = actions.at[jnp.arange(1, len(actions), 2)].set(jnp.sin(5 * t))
        actions = actions.at[jnp.arange(len(actions) // 2, len(actions), 2)].set(
                actions[jnp.arange(len(actions) // 2, len(actions), 2)] * -1
                )
        return actions

    if num_envs > 1:
        open_loop_controller = jax.vmap(open_loop_controller)

    open_loop_controller = jax.jit(open_loop_controller)

    return open_loop_controller

### Doing a rollout

In [57]:
def do_rollout(
    env: ToyExampleMJXEnvironment,
    return_frames: bool,
    fps: int = 30
    ) -> List[np.ndarray]:
    controller = create_mjx_open_loop_controller(
        single_action_space=env.single_action_space,
        num_envs=env.number_of_environments
        )

    done = False
    steps = 0
    fps = 30
    frames = []
    start_time = None

    obs, info = env.reset()
    while not done:
        t = info["time"]
        actions = controller(t)
        obs, reward, terminated, truncated, info = env.step(actions=actions)
        done = jnp.any((terminated | truncated))

        if start_time is None:
            # Do one step out of the loop to not include jitting in timing
            start_time = time.time()

        if return_frames:
          if steps % round((1 / fps) / environment_configuration.control_timestep) == 0:
              frame = env.render()

              if env.number_of_environments > 1:
                  # Only save frame of first environment
                  frame = frame[0]

              if len(environment_configuration.camera_ids) > 1:
                  # If we have multiple camera's, stack their images horizontally
                  frame = np.concatenate(frame, axis=1)

              frame = frame[:, :, ::-1] # rgb to bgr
              frames.append(frame)

        steps += 1
    stop_time = time.time()

    total_steps = steps * env.number_of_environments * \
      environment_configuration.num_physics_steps_per_control_step
    total_time = stop_time - start_time

    print(f"Did {total_steps} env steps in {total_time:2f}s")
    print(f"\tSPS: {(total_steps / total_time):2f}")

    return frames

* Explain why this takes a while -> jitting -> won't be an issue when optimizing the controller
* Should see a decent increase in SPS
* GPU based sim env goes nicely with gpu based rl / neuroevolution / controller -> no data transfer
* Current HPC cluster is meant for implementation / debugging -> we will get even higher SPS values when using better clusters (joltik, accelgor)

In [59]:
mjx_env = create_mjx_environment(
    morphology_specification=morphology_specification,
    arena_configuration=arena_configuration,
    environment_configuration=environment_configuration,
    num_envs=1280 # A2 GPU has 1280 cuda cores
)
do_rollout(
    env=mjx_env,
    return_frames=False
    )
mjx_env.close()

Did 3200000 env steps in 96.473444s
	SPS: 33169.749889


### Visualizing a rollout

In [60]:
mjx_env = create_mjx_environment(
    morphology_specification=morphology_specification,
    arena_configuration=arena_configuration,
    environment_configuration=environment_configuration,
    num_envs=1 # Rendering too many environments can cause RAM overflows
)
frames = do_rollout(
    env=mjx_env,
    return_frames=True
    )
mjx_env.close()

Did 2500 env steps in 41.647275s
	SPS: 60.027937


In [61]:
media.show_video(frames, fps=30)

0
This browser does not support the video tag.


# Conclusions


# Next steps