In [28]:
import os
import subprocess
import logging

try:
    if subprocess.run('nvidia-smi').returncode:
        raise RuntimeError(
                'Cannot communicate with GPU. '
                'Make sure you are using a GPU Colab runtime. '
                'Go to the Runtime menu and select Choose runtime type.'
                )

    # Add an ICD config so that glvnd can pick up the Nvidia EGL driver.
    # This is usually installed as part of an Nvidia driver package, but the Colab
    # kernel doesn't install its driver via APT, and as a result the ICD is missing.
    # (https://github.com/NVIDIA/libglvnd/blob/master/src/EGL/icd_enumeration.md)
    NVIDIA_ICD_CONFIG_PATH = '/usr/share/glvnd/egl_vendor.d/10_nvidia.json'
    if not os.path.exists(NVIDIA_ICD_CONFIG_PATH):
        with open(NVIDIA_ICD_CONFIG_PATH, 'w') as f:
            f.write(
                    """{
                            "file_format_version" : "1.0.0",
                            "ICD" : {
                                "library_path" : "libEGL_nvidia.so.0"
                            }
                        }
                        """
                    )

    # Configure MuJoCo to use the EGL rendering backend (requires GPU)
    print('Setting environment variable to use GPU rendering:')
    %env MUJOCO_GL=egl

    # Check if jax finds the GPU
    import jax

    print(jax.devices('gpu'))
except Exception:
    logging.warning("Failed to initialize GPU. Everything will run on the cpu.")

try:
    print('Checking that the mujoco installation succeeded:')
    import mujoco

    mujoco.MjModel.from_xml_string('<mujoco/>')
except Exception as e:
    raise e from RuntimeError(
            'Something went wrong during installation. Check the shell output above '
            'for more information.\n'
            'If using a hosted Colab runtime, make sure you enable GPU acceleration '
            'by going to the Runtime menu and selecting "Choose runtime type".'
            )

print('MuJoCo installation successful.')



Checking that the mujoco installation succeeded:
MuJoCo installation successful.


In [29]:
import mediapy as media
from mujoco_utils.environment.base import MuJoCoEnvironmentConfiguration
from typing import List
import numpy as np


def post_render(
        render_output: List[np.ndarray],
        environment_configuration: MuJoCoEnvironmentConfiguration
        ) -> np.ndarray:
    if render_output is None:
        # Temporary workaround until https://github.com/google-deepmind/mujoco/issues/1379 is fixed
        return None

    num_cameras = len(environment_configuration.camera_ids)
    num_envs = len(render_output) // num_cameras

    if num_cameras > 1:
        # Horizontally stack frames of the same environment
        frames_per_env = np.array_split(render_output, num_envs)
        render_output = [np.concatenate(env_frames, axis=1) for env_frames in frames_per_env]

    # Vertically stack frames of different environments
    render_output = np.concatenate(render_output, axis=0)

    return render_output[:, :, ::-1]  # RGB to BGR


def show_video(
        images: List[np.ndarray | None],
        path: str | None = None
        ) -> str | None:
    # Temporary workaround until https://github.com/google-deepmind/mujoco/issues/1379 is fixed
    filtered_images = [image for image in images if image is not None]
    num_nones = len(images) - len(filtered_images)
    if num_nones > 0:
        logging.warning(
                f"env.render produced {num_nones} None's. Resulting video might be a bit choppy (consquence of https://github.com/google-deepmind/mujoco/issues/1379)."
                )
    if path:
        media.write_video(path=path, images=filtered_images)
    return media.show_video(images=filtered_images)

In [30]:
from typing import  Tuple
import jax.numpy as jnp
import jax.random
import numpy as np
import mujoco.mjx as mjx

from brb.brittle_star.mjcf.arena.aquarium import AquariumArenaConfiguration, MJCFAquariumArena
from brb.brittle_star.mjcf.morphology.morphology import MJCFBrittleStarMorphology
from brb.brittle_star.mjcf.morphology.specification.default import default_brittle_star_morphology_specification

morphology_spec = default_brittle_star_morphology_specification(
        num_arms=5, num_segments_per_arm=5, use_p_control=True
        )
morphology = MJCFBrittleStarMorphology(morphology_spec)
arena_config = AquariumArenaConfiguration()
arena = MJCFAquariumArena(configuration=arena_config)

arena.attach(other=morphology)

xml = arena.get_mjcf_str()


In [31]:
mj_model = mujoco.MjModel.from_xml_string(xml=xml)
mj_model.vis.global_.offheight = 480
mj_model.vis.global_.offwidth = 640
mj_data = mujoco.MjData(mj_model)

renderer = mujoco.Renderer(
    model=mj_model,
    height=480,
    width=640
    )

In [32]:
def reset_env(
        ) -> Tuple[mjx.Model, mjx.Data]:
    mjx_model = mjx.put_model(mj_model)
    mjx_data = mjx.put_data(mj_model, mj_data)
    
    # Set the initial brittle star position in mjx model
    #   While this could be done through the mjx data and with the joints, I do require this functionality for other environments (in which I need to randomly position objects with no joints)
    disk_body_id = mj_model.body("BrittleStarMorphology/central_disk").id
    morphology_pos = jnp.array([0.0, 0.0, 0.11])
    mjx_model = mjx_model.replace(body_pos=mjx_model.body_pos.at[disk_body_id].set(morphology_pos))
    
    mjx_data = mjx.forward(mjx_model, mjx_data)
    return mjx_model, mjx_data

@jax.jit
def step_env(
        mjx_data: mjx.Data,
        ctrl: jnp.ndarray
        ) -> mjx.Data:
    def _simulation_step(
            _data,
            _
            ) -> Tuple[mjx.Data, None]:
        _data = _data.replace(ctrl=ctrl)
        return mjx.step(mjx_model, _data), None

    mjx_data, _ = jax.lax.scan(
            _simulation_step, mjx_data, (), 10
            )

    return mjx_data

In [33]:
import copy


def mjx_get_model(
        mj_model: mujoco.MjModel,
        mjx_model: mjx.Model,
        n_mj_models: int = 1
        ) -> List[mujoco.MjModel]:
    """
    Transfer mjx.Model to mujoco.MjModel
    """
    mj_models = [copy.deepcopy(mj_model) for _ in range(n_mj_models)]

    offloaded_mjx_model = jax.device_get(mjx_model)
    for key, v in vars(offloaded_mjx_model).items():
        try:
            for i, model in enumerate(mj_models):
                previous_value = getattr(model, key)
                if isinstance(previous_value, np.ndarray):
                    if previous_value.shape != v.shape:
                        actual_value = v[i]
                    else:
                        actual_value = v
                    previous_value[:] = actual_value
                else:
                    setattr(model, key, v)
        except AttributeError:
            pass
        except ValueError:
            pass
    return mj_models

def render_env(mjx_model: mjx.Model, mjx_data: mjx.Data) -> np.ndarray:
    # We need to update mj_model and mj_data
    # (mj_model requires an update as well because we updated mjx.Model!)
    model = mjx_get_model(mj_model=mj_model, mjx_model=mjx_model, n_mj_models=1) 
    data = mjx.get_data(m=mj_model, d=mjx_data)
    mujoco.mj_forward(m=model, d=data) 
    renderer.update_scene(data=mj_data, camera=1)
    return renderer.render()

In [34]:
def open_loop_controller(
        t: float
        ) -> jnp.ndarray:
    actions = jnp.ones(mj_model.nu)
    actions = actions.at[jnp.arange(0, len(actions), 2)].set(jnp.cos(5 * t))
    actions = actions.at[jnp.arange(1, len(actions), 2)].set(jnp.sin(5 * t))
    actions = actions.at[jnp.arange(len(actions) // 2, len(actions), 2)].set(
            actions[jnp.arange(len(actions) // 2, len(actions), 2)] * -1
            )
    return actions

In [None]:
mjx_model, mjx_data = reset_env()

frames = []
while mjx_data.time < 10:
    action = open_loop_controller(mjx_data.time)
    mjx_data = step_env(mjx_data, action)
    frame = render_env(mjx_model=mjx_model, mjx_data=mjx_data)
    frames.append(frame)
    print(mjx_data.time) 

Exception ignored in: <function Renderer.__del__ at 0x1123bede0>
Traceback (most recent call last):
  File "/Users/driesmarzougui/miniforge3/envs/SEL3-2024/lib/python3.11/site-packages/mujoco/renderer.py", line 327, in __del__
    self.close()
  File "/Users/driesmarzougui/miniforge3/envs/SEL3-2024/lib/python3.11/site-packages/mujoco/renderer.py", line 315, in close
    if self._gl_context:
       ^^^^^^^^^^^^^^^^
AttributeError: 'Renderer' object has no attribute '_gl_context'


In [None]:
import mediapy as media
media.show_video(frames)