In [1]:
import os
import subprocess
import logging

try:
    if subprocess.run('nvidia-smi').returncode:
        raise RuntimeError(
                'Cannot communicate with GPU. '
                'Make sure you are using a GPU Colab runtime. '
                'Go to the Runtime menu and select Choose runtime type.'
                )

    # Add an ICD config so that glvnd can pick up the Nvidia EGL driver.
    # This is usually installed as part of an Nvidia driver package, but the Colab
    # kernel doesn't install its driver via APT, and as a result the ICD is missing.
    # (https://github.com/NVIDIA/libglvnd/blob/master/src/EGL/icd_enumeration.md)
    NVIDIA_ICD_CONFIG_PATH = '/usr/share/glvnd/egl_vendor.d/10_nvidia.json'
    if not os.path.exists(NVIDIA_ICD_CONFIG_PATH):
        with open(NVIDIA_ICD_CONFIG_PATH, 'w') as f:
            f.write(
                    """{
                            "file_format_version" : "1.0.0",
                            "ICD" : {
                                "library_path" : "libEGL_nvidia.so.0"
                            }
                        }
                        """
                    )

    # Configure MuJoCo to use the EGL rendering backend (requires GPU)
    print('Setting environment variable to use GPU rendering:')
    %env MUJOCO_GL=egl

    # Check if jax finds the GPU
    import jax

    print(jax.devices('gpu'))
except Exception:
    logging.warning("Failed to initialize GPU. Everything will run on the cpu.")

try:
    print('Checking that the mujoco installation succeeded:')
    import mujoco

    mujoco.MjModel.from_xml_string('<mujoco/>')
except Exception as e:
    raise e from RuntimeError(
            'Something went wrong during installation. Check the shell output above '
            'for more information.\n'
            'If using a hosted Colab runtime, make sure you enable GPU acceleration '
            'by going to the Runtime menu and selecting "Choose runtime type".'
            )

print('MuJoCo installation successful.')

Fri Feb  2 13:53:41 2024       
+---------------------------------------------------------------------------------------+
| NVIDIA-SMI 535.154.05             Driver Version: 535.154.05   CUDA Version: 12.2     |
|-----------------------------------------+----------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |         Memory-Usage | GPU-Util  Compute M. |
|                                         |                      |               MIG M. |
|   0  NVIDIA A100-SXM4-80GB          On  | 00000000:C1:00.0 Off |                    0 |
| N/A   33C    P0              62W / 500W |      0MiB / 81920MiB |      0%      Default |
|                                         |                      |             Disabled |
+-----------------------------------------+----------------------+----------------------+
                                                                    

In [2]:
import mediapy as media
from mujoco_utils.environment.base import MuJoCoEnvironmentConfiguration
from typing import List
import numpy as np


def post_render(
        render_output: List[np.ndarray],
        environment_configuration: MuJoCoEnvironmentConfiguration
        ) -> np.ndarray:
    if render_output is None:
        # Temporary workaround until https://github.com/google-deepmind/mujoco/issues/1379 is fixed
        return None

    num_cameras = len(environment_configuration.camera_ids)
    num_envs = len(render_output) // num_cameras

    if num_cameras > 1:
        # Horizontally stack frames of the same environment
        frames_per_env = np.array_split(render_output, num_envs)
        render_output = [np.concatenate(env_frames, axis=1) for env_frames in frames_per_env]

    # Vertically stack frames of different environments
    render_output = np.concatenate(render_output, axis=0)

    return render_output[:, :, ::-1]  # RGB to BGR


def show_video(
        images: List[np.ndarray | None],
        path: str | None = None
        ) -> str | None:
    # Temporary workaround until https://github.com/google-deepmind/mujoco/issues/1379 is fixed
    filtered_images = [image for image in images if image is not None]
    num_nones = len(images) - len(filtered_images)
    if num_nones > 0:
        logging.warning(
                f"env.render produced {num_nones} None's. Resulting video might be a bit choppy (consquence of https://github.com/google-deepmind/mujoco/issues/1379)."
                )
    if path:
        media.write_video(path=path, images=filtered_images)
    return media.show_video(images=filtered_images)

In [11]:
from typing import Callable, Tuple

import chex
import jax.numpy as jnp
import jax.random
import numpy as np
from mujoco_utils.environment import mjx_spaces

from brb.brittle_star.environment.light_escape.dual import BrittleStarLightEscapeEnvironment
from brb.brittle_star.environment.light_escape.shared import BrittleStarLightEscapeEnvironmentConfiguration
from brb.brittle_star.mjcf.arena.aquarium import AquariumArenaConfiguration, MJCFAquariumArena
from brb.brittle_star.mjcf.morphology.morphology import MJCFBrittleStarMorphology
from brb.brittle_star.mjcf.morphology.specification.default import default_brittle_star_morphology_specification


def create_env(
        backend: str,
        render_mode: str
        ) -> BrittleStarLightEscapeEnvironment:
    morphology_spec = default_brittle_star_morphology_specification(
            num_arms=5, num_segments_per_arm=3, use_p_control=True
            )
    morphology = MJCFBrittleStarMorphology(morphology_spec)
    arena_config = AquariumArenaConfiguration(sand_ground_color=True)
    arena = MJCFAquariumArena(configuration=arena_config)

    env_config = BrittleStarLightEscapeEnvironmentConfiguration(
            render_mode=render_mode,
            light_perlin_noise_scale=0,
            num_physics_steps_per_control_step=10,
            simulation_time=5,
            time_scale=1,
            camera_ids=[0, 1]
            )
    env_config.solver_iterations = 5
    env_config.solver_ls_iterations = 5
    env = BrittleStarLightEscapeEnvironment.from_morphology_and_arena(
            morphology=morphology, arena=arena, configuration=env_config, backend=backend
            )
    return env



BACKEND = "MJX"
RENDER_MODE = "rgb_array"

env = create_env(backend=BACKEND, render_mode=RENDER_MODE)


env_rng, action_rng = jax.random.split(jax.random.PRNGKey(0), 2)

step_fn = jax.jit(env.step)
reset_fn = jax.jit(env.reset)

In [12]:
def create_mjx_open_loop_controller(
        single_action_space: mjx_spaces.Box,
        ) -> Callable[[float], jnp.ndarray]:
    def open_loop_controller(
            t: float
            ) -> jnp.ndarray:
        actions = jnp.ones(single_action_space.shape)
        actions = actions.at[jnp.arange(0, len(actions), 2)].set(jnp.cos(5 * t))
        actions = actions.at[jnp.arange(1, len(actions), 2)].set(jnp.sin(5 * t))
        actions = actions.at[jnp.arange(len(actions) // 2, len(actions), 2)].set(
                actions[jnp.arange(len(actions) // 2, len(actions), 2)] * -1
                )
        return actions

    open_loop_controller = jax.jit(open_loop_controller)

    return open_loop_controller
controller = create_mjx_open_loop_controller(env.action_space)

In [13]:

state = reset_fn(env_rng)
frames = []
while not (state.terminated | state.truncated):
    t = state.info["time"]
    action = controller(t) * 0

    state = step_fn(state=state, action=action)
    frames.append(post_render(env.render(state=state), environment_configuration=env.environment_configuration))

show_video(frames)

0
This browser does not support the video tag.


In [None]:
env.close()