In [3]:
import logging
import os
import subprocess
import jax
import jax.numpy as jnp
from functools import partial

from biorobot.brittle_star.environment.directed_locomotion.shared import (
    BrittleStarDirectedLocomotionEnvironmentConfiguration,
)
import numpy as np
from moojoco.environment.base import MuJoCoEnvironmentConfiguration
from biorobot.brittle_star.environment.directed_locomotion.dual import (
    BrittleStarDirectedLocomotionEnvironment,
)
from typing import List
import mediapy as media
from biorobot.brittle_star.mjcf.morphology.morphology import MJCFBrittleStarMorphology
from biorobot.brittle_star.mjcf.morphology.specification.default import (
    default_brittle_star_morphology_specification,
)
from biorobot.brittle_star.mjcf.arena.aquarium import (
    AquariumArenaConfiguration,
    MJCFAquariumArena,
)


import environment as env

DEBUG = True

#################### Initialization code ####################
try:
    if subprocess.run("nvidia-smi").returncode:
        raise RuntimeError("Cannot communicate with GPU.")

    # Add an ICD config so that glvnd can pick up the Nvidia EGL driver.
    # This is usually installed as part of an Nvidia driver package, but the Colab
    # kernel doesn't install its driver via APT, and as a result the ICD is missing.
    # (https://github.com/NVIDIA/libglvnd/blob/master/src/EGL/icd_enumeration.md)
    NVIDIA_ICD_CONFIG_PATH = "/usr/share/glvnd/egl_vendor.d/10_nvidia.json"
    if not os.path.exists(NVIDIA_ICD_CONFIG_PATH):
        with open(NVIDIA_ICD_CONFIG_PATH, "w") as f:
            f.write(
                """{
                            "file_format_version" : "1.0.0",
                            "ICD" : {
                                "library_path" : "libEGL_nvidia.so.0"
                            }
                        }
                        """
            )

    # Configure MuJoCo to use the EGL rendering backend (requires GPU)
    print("Setting environment variable to use GPU rendering:")

    # xla_flags = os.environ.get('XLA_FLAGS', '')
    # xla_flags += ' --xla_gpu_triton_gemm_any=True'
    # os.environ['XLA_FLAGS'] = xla_flags

    # Check if jax finds the GPU
    import jax

    print(jax.devices("gpu"))
except Exception:
    logging.warning("Failed to initialize GPU. Everything will run on the cpu.")

try:
    print("Checking that the mujoco installation succeeded:")
    import mujoco

    mujoco.MjModel.from_xml_string("<mujoco/>")
except Exception as e:
    raise e from RuntimeError(
        "Something went wrong during installation. Check the shell output above "
        "for more information.\n"
        "If using a hosted Colab runtime, make sure you enable GPU acceleration "
        'by going to the Runtime menu and selecting "Choose runtime type".'
    )

print("MuJoCo installation successful.")



NVIDIA-SMI has failed because it couldn't communicate with the NVIDIA driver. Make sure that the latest NVIDIA driver is installed and running.

Checking that the mujoco installation succeeded:
MuJoCo installation successful.
NVIDIA-SMI has failed because it couldn't communicate with the NVIDIA driver. Make sure that the latest NVIDIA driver is installed and running.

Checking that the mujoco installation succeeded:
MuJoCo installation successful.


In [4]:
#################### Creating the experimental enviriment ####################
morphology_specification = default_brittle_star_morphology_specification(
    num_arms=2, num_segments_per_arm=[1,1], use_p_control=True, use_torque_control=False
)
arena_configuration = AquariumArenaConfiguration(
    size=(1.5, 1.5),
    sand_ground_color=False,
    attach_target=True,
    wall_height=1.5,
    wall_thickness=0.1,
)
environment_configuration = BrittleStarDirectedLocomotionEnvironmentConfiguration(
    target_distance=1.2,
    joint_randomization_noise_scale=0.0,
    render_mode="rgb_array",
    simulation_time=20,
    num_physics_steps_per_control_step=10,
    time_scale=2,
    camera_ids=[0, 1],
    render_size=(480, 640),
)

experimental_env = env.create_environment(
    morphology_specification, arena_configuration, environment_configuration, "MJX"
)

rng = jax.random.PRNGKey(seed=0)

# Fix the target location
env_fixed_target_reset_fn = jax.jit(
    partial(experimental_env.reset, target_position=(-1.25, 0.75, 0.0))
)
env_step_fn = jax.jit(experimental_env.step)

In [5]:
#################### Some Printing functions ####################
def post_render(
    render_output: List[np.ndarray],
    environment_configuration: MuJoCoEnvironmentConfiguration,
) -> np.ndarray:
    if render_output is None:
        # Temporary workaround until https://github.com/google-deepmind/mujoco/issues/1379 is fixed
        return None

    num_cameras = len(environment_configuration.camera_ids)
    num_envs = len(render_output) // num_cameras

    if num_cameras > 1:
        # Horizontally stack frames of the same environment
        frames_per_env = np.array_split(render_output, num_envs)
        render_output = [
            np.concatenate(env_frames, axis=1) for env_frames in frames_per_env
        ]

    # Vertically stack frames of different environments
    render_output = np.concatenate(render_output, axis=0)

    return np.array(render_output)[:, :, ::-1]  # RGB to BGR


def show_video(
    images: List[np.ndarray | None], sim_time: float, path: str | None = None
) -> str | None:
    if path:
        media.write_video(path=path, images=[img for img in images if img is not None])
    return media.show_video(
        images=[img for img in images if img is not None], fps=len(images) // sim_time
    )

In [4]:
rng, sub_rng = jax.random.split(rng, 2)
env_state = env_fixed_target_reset_fn(sub_rng)
print("Observation space:")
print(experimental_env.observation_space)
print()
print("Action space:")
print(experimental_env.action_space)
print()
print("Info:")
print(env_state.info)
print()
# Driver issue ==> visualisation code doesn't work
media.show_image(
    post_render(
        experimental_env.render(env_state),
        environment_configuration=experimental_env.environment_configuration,
    )
)

Observation space:
Dict('joint_position': Box([-1.047 -0.785 -1.047 -0.785 -1.047 -0.785 -1.047 -0.785 -1.047 -0.785 -1.047 -0.785 -1.047 -0.785
 -1.047 -0.785 -1.047 -0.785 -1.047 -0.785 -1.047 -0.785 -1.047 -0.785 -1.047 -0.785 -1.047 -0.785
 -1.047 -0.785], [1.047 0.785 1.047 0.785 1.047 0.785 1.047 0.785 1.047 0.785 1.047 0.785 1.047 0.785 1.047 0.785
 1.047 0.785 1.047 0.785 1.047 0.785 1.047 0.785 1.047 0.785 1.047 0.785 1.047 0.785], (30,), <class 'jax.numpy.float32'>), 'joint_velocity': Box(-inf, inf, (30,), <class 'jax.numpy.float32'>), 'joint_actuator_force': Box(-inf, inf, (30,), <class 'jax.numpy.float32'>), 'actuator_force': Box([-5.    -5.    -4.167 -4.167 -3.333 -3.333 -5.    -5.    -4.167 -4.167 -3.333 -3.333 -5.    -5.
 -4.167 -4.167 -3.333 -3.333 -5.    -5.    -4.167 -4.167 -3.333 -3.333 -5.    -5.    -4.167 -4.167
 -3.333 -3.333], [5.    5.    4.167 4.167 3.333 3.333 5.    5.    4.167 4.167 3.333 3.333 5.    5.    4.167 4.167
 3.333 3.333 5.    5.    4.167 4.167 3.33

MESA-LOADER: failed to open radeonsi: /usr/lib/dri/radeonsi_dri.so: cannot open shared object file: No such file or directory (search paths /usr/lib/x86_64-linux-gnu/dri:\$${ORIGIN}/dri:/usr/lib/dri, suffix _dri)
failed to load driver: radeonsi
MESA-LOADER: failed to open radeonsi: /usr/lib/dri/radeonsi_dri.so: cannot open shared object file: No such file or directory (search paths /usr/lib/x86_64-linux-gnu/dri:\$${ORIGIN}/dri:/usr/lib/dri, suffix _dri)
failed to load driver: radeonsi
MESA-LOADER: failed to open swrast: /usr/lib/dri/swrast_dri.so: cannot open shared object file: No such file or directory (search paths /usr/lib/x86_64-linux-gnu/dri:\$${ORIGIN}/dri:/usr/lib/dri, suffix _dri)
/home/thibaud/anaconda3/envs/biorobot/lib/python3.11/site-packages/glfw/__init__.py:917: GLFWError: (65543) b'GLX: Failed to create context: BadValue (integer parameter out of range for operation)'


FatalError: gladLoadGL error

In [6]:
from moojoco.environment.mjx_env import MJXEnvState

ARENA_SIZE = jnp.array(arena_configuration.size)
NUM_CELLS_PER_AXIS = 6
NUM_POSITION_STATES = NUM_CELLS_PER_AXIS ** 2
TOTAL_NUM_STATES = NUM_POSITION_STATES


def position_to_state_index(
        position: jnp.ndarray
        ) -> jnp.ndarray:
    # position is in [-arena size, arena size], first convert it to [0, 2 * arena size]
    shifted_position = position + ARENA_SIZE
    # Then convert it to [0, 1]
    normalized_position = shifted_position / (2 * ARENA_SIZE + 0.001)

    x, y = (NUM_CELLS_PER_AXIS * normalized_position).astype(jnp.int32)
    return x + NUM_CELLS_PER_AXIS * y


@jax.jit
def state_indexer(
        env_state: MJXEnvState
        ) -> int:
    robot_position = env_state.observations["disk_position"][:2]
    position_index = position_to_state_index(position=robot_position)
    return position_index

In [None]:
from tensorneat.pipeline import Pipeline
from tensorneat.algorithm.neat import NEAT
from tensorneat.genome import DefaultGenome, BiasNode
from tensorneat.problem.rl import BraxEnv
from tensorneat.problem import BaseProblem
from tensorneat.common import ACT, AGG
from tensorneat.common import State

inputs = [
    
]

# Attempt to write as custom non-jittable function
genome = DefaultGenome(
    num_inputs=1
)
