# Visualize mujoco runtime model loading .onnx neural net

This jupyter notebook does two small things:

1. Load the simple environment set up in the mjx tutorial (within the mujoco repository), and show it can render
2. Load an .onnx file resulting from converting the output of the mjx tutorial, and load it to show the humanoid used for training run forward

   This second step is exactly the same than what the script *eval.py* does, and the conversion is done in the script *convert_to_onnx.py*


In [None]:
!pip install mujoco
!pip install mujoco_mjx
!pip install brax

In [None]:
#@title Check if MuJoCo installation was successful

#from google.colab import files

import distutils.util
import os
import subprocess
if subprocess.run('nvidia-smi').returncode:
  raise RuntimeError(
      'Cannot communicate with GPU. '
      'Make sure you are using a GPU Colab runtime. '
      'Go to the Runtime menu and select Choose runtime type.')

# Add an ICD config so that glvnd can pick up the Nvidia EGL driver.
# This is usually installed as part of an Nvidia driver package, but the Colab
# kernel doesn't install its driver via APT, and as a result the ICD is missing.
# (https://github.com/NVIDIA/libglvnd/blob/master/src/EGL/icd_enumeration.md)
NVIDIA_ICD_CONFIG_PATH = '/usr/share/glvnd/egl_vendor.d/10_nvidia.json'
if not os.path.exists(NVIDIA_ICD_CONFIG_PATH):
  with open(NVIDIA_ICD_CONFIG_PATH, 'w') as f:
    f.write("""{
    "file_format_version" : "1.0.0",
    "ICD" : {
        "library_path" : "libEGL_nvidia.so.0"
    }
}
""")

# Tell XLA to use Triton GEMM, this improves steps/sec by ~30% on some GPUs
xla_flags = os.environ.get('XLA_FLAGS', '')
xla_flags += ' --xla_gpu_triton_gemm_any=True'
os.environ['XLA_FLAGS'] = xla_flags

# Configure MuJoCo to use the EGL rendering backend (requires GPU)
print('Setting environment variable to use GPU rendering:')
%env MUJOCO_GL=egl

try:
  print('Checking that the installation succeeded:')
  import mujoco
  mujoco.MjModel.from_xml_string('<mujoco/>')
except Exception as e:
  raise e from RuntimeError(
      'Something went wrong during installation. Check the shell output above '
      'for more information.\n'
      'If using a hosted Colab runtime, make sure you enable GPU acceleration '
      'by going to the Runtime menu and selecting "Choose runtime type".')

print('Installation successful.')

In [None]:
#@title Import packages for plotting and creating graphics
import time
import itertools
import numpy as np
from typing import Callable, NamedTuple, Optional, Union, List

# Graphics and plotting.
print('Installing mediapy:')
!command -v ffmpeg >/dev/null || (apt update && apt install -y ffmpeg)
!pip install -q mediapy
import mediapy as media
import matplotlib.pyplot as plt

# More legible printing from numpy.
np.set_printoptions(precision=3, suppress=True, linewidth=100)

In [None]:
#@title Import MuJoCo, MJX, and Brax


from datetime import datetime
import functools
from IPython.display import HTML
import jax
from jax import numpy as jp
import numpy as np
from typing import Any, Dict, Sequence, Tuple, Union

from brax import base
from brax import envs
from brax import math
from brax.base import Base, Motion, Transform
from brax.envs.base import Env, PipelineEnv, State
from brax.mjx.base import State as MjxState
from brax.training.agents.ppo import train as ppo
from brax.training.agents.ppo import networks as ppo_networks
from brax.io import html, mjcf, model

from etils import epath
from flax import struct
from matplotlib import pyplot as plt
import mediapy as media
from ml_collections import config_dict
import mujoco
from mujoco import mjx


In [None]:
xml = """
<mujoco>
  <worldbody>
    <light name="top" pos="0 0 1"/>
    <body name="box_and_sphere" euler="0 0 -30">
      <joint name="swing" type="hinge" axis="1 -1 0" pos="-.2 -.2 -.2"/>
      <geom name="red_box" type="box" size=".2 .2 .2" rgba="1 0 0 1"/>
      <geom name="green_sphere" pos=".2 .2 .2" size=".1" rgba="0 1 0 1"/>
    </body>
  </worldbody>
</mujoco>
"""

# Make model, data, and renderer
mj_model = mujoco.MjModel.from_xml_string(xml)
mj_data = mujoco.MjData(mj_model)
renderer = mujoco.Renderer(mj_model)

In [None]:
mjx_model = mjx.put_model(mj_model)
mjx_data = mjx.put_data(mj_model, mj_data)

In [None]:
# enable joint visualization option:
scene_option = mujoco.MjvOption()
scene_option.flags[mujoco.mjtVisFlag.mjVIS_JOINT] = True

duration = 3.8  # (seconds)
framerate = 60  # (Hz)

frames = []
mujoco.mj_resetData(mj_model, mj_data)
while mj_data.time < duration:
  mujoco.mj_step(mj_model, mj_data)
  if len(frames) < mj_data.time * framerate:
    renderer.update_scene(mj_data, scene_option=scene_option)
    pixels = renderer.render()
    frames.append(pixels)

# Simulate and display video.
media.show_video(frames, fps=framerate)

# Eval

Once we have checked that we can load a mujoco model and render it we can render an environment where we load the result of the onnnx conversion.

In [None]:
import onnxruntime

print("Importing...")
import numpy as np
from brax.training.acme import running_statistics

np.set_printoptions(precision=3, suppress=True, linewidth=100)

import jax
from jax import numpy as jp
import numpy as np

from brax import envs

from brax.training.agents.ppo import networks as ppo_networks
from brax.io import model
import mujoco
from mujoco import mjx
from video import show_video
import jax.numpy as jnp
from humanoid import Humanoid
from brax.io import html, mjcf, model

print("register env...")
envs.register_environment('humanoid', Humanoid)

# instantiate the environment
env_name = 'humanoid'
env = envs.get_environment(env_name)

eval_env = envs.get_environment(env_name)

rng = jax.random.PRNGKey(0)
n_steps = 1000
render_every = 2

# Load the ONNX model
onnx_model_path = 'feedforward_model.onnx'
onnx_session = onnxruntime.InferenceSession(onnx_model_path)
# Run inference
input_name = onnx_session.get_inputs()[0].name
output_names = [output.name for output in onnx_session.get_outputs()]


"""# MJX Policy in MuJoCo

We can also perform the physics step using the original MuJoCo python bindings to show that the policy trained in MJX works in MuJoCo.
"""



mj_model = eval_env.sys.mj_model
mj_data = mujoco.MjData(mj_model)
mujoco.mj_resetData(mj_model, mj_data)


renderer = mujoco.Renderer(mj_model)
ctrl = jp.zeros(mj_model.nu)
eval_env._n_frames = 1
images = []
for i in range(n_steps):
  obs = eval_env._get_obs(mjx.put_data(mj_model, mj_data), ctrl)
  onnx_result = onnx_session.run(output_names, {input_name: np.array([obs])})
  onnx_ctrl = onnx_result[1][0]

  mj_data.ctrl = onnx_ctrl
  for _ in range(eval_env._n_frames):
    mujoco.mj_step(mj_model, mj_data)  # Physics step using MuJoCo mj_step.

  if i % render_every == 0:
    renderer.update_scene(mj_data, camera='side')
    images.append(renderer.render())


In [None]:

media.show_video(images, fps=framerate)