In [2]:
%env MUJOCO_GL=egl

import os
os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"] = ".20"  #restrict JAX to 20% of available GPU RAM; useful for loading MyoSuite and UitB policies simultaneously

env: MUJOCO_GL=egl


In [None]:
import mujoco
from mujoco import mjx

import numpy as np
from brax import envs
# from brax.training.acme.running_statistics import normalize
# from brax.training.acme import running_statistics
# from brax.io import model
# from etils import epath
import functools

import myosuite
# from myosuite.envs.myo.myouser.utils import update_target_visuals
from myosuite.train.utils.train import train_or_load_checkpoint
# from myosuite.envs.myo.mjx.train_jax_ppo_steering import get_observation_size
from myosuite.envs.myo.myouser.myouser_envs import get_observation_size
import mediapy as media

import jax
import jax.numpy as jp

from mujoco_playground import wrapper, registry
from ml_collections import ConfigDict
import json

myo_path = os.path.dirname(os.path.dirname(myosuite.__file__))
os.chdir(myo_path)

MyoSuite:> Registering Myo Envs
MyoSuite:> Registering MyoUser Envs


In [4]:
# @title Check if MuJoCo installation was successful

import distutils.util
import os
import subprocess

if subprocess.run('nvidia-smi').returncode:
  raise RuntimeError(
      'Cannot communicate with GPU. '
      'Make sure you are using a GPU Colab runtime. '
      'Go to the Runtime menu and select Choose runtime type.'
  )

# Add an ICD config so that glvnd can pick up the Nvidia EGL driver.
# This is usually installed as part of an Nvidia driver package, but the Colab
# kernel doesn't install its driver via APT, and as a result the ICD is missing.
# (https://github.com/NVIDIA/libglvnd/blob/master/src/EGL/icd_enumeration.md)
NVIDIA_ICD_CONFIG_PATH = '/usr/share/glvnd/egl_vendor.d/10_nvidia.json'
if not os.path.exists(NVIDIA_ICD_CONFIG_PATH):
  with open(NVIDIA_ICD_CONFIG_PATH, 'w') as f:
    f.write("""{
    "file_format_version" : "1.0.0",
    "ICD" : {
        "library_path" : "libEGL_nvidia.so.0"
    }
}
""")

# Configure MuJoCo to use the EGL rendering backend (requires GPU)
print('Setting environment variable to use GPU rendering:')
%env MUJOCO_GL=egl

try:
  print('Checking that the installation succeeded:')
  import mujoco

  mujoco.MjModel.from_xml_string('<mujoco/>')
except Exception as e:
  raise e from RuntimeError(
      'Something went wrong during installation. Check the shell output above '
      'for more information.\n'
      'If using a hosted Colab runtime, make sure you enable GPU acceleration '
      'by going to the Runtime menu and selecting "Choose runtime type".'
  )

print('Installation successful.')

# Tell XLA to use Triton GEMM, this improves steps/sec by ~30% on some GPUs
xla_flags = os.environ.get('XLA_FLAGS', '')
xla_flags += ' --xla_gpu_triton_gemm_any=True'
os.environ['XLA_FLAGS'] = xla_flags

Fri Aug 15 00:03:45 2025       
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 570.158.01             Driver Version: 570.158.01     CUDA Version: 12.8     |
|-----------------------------------------+------------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id          Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |           Memory-Usage | GPU-Util  Compute M. |
|                                         |                        |               MIG M. |
|   0  NVIDIA GeForce RTX 5090        On  |   00000000:01:00.0 Off |                  N/A |
| 34%   56C    P1            168W /  600W |    1937MiB /  32607MiB |     82%      Default |
|                                         |                        |                  N/A |
+-----------------------------------------+------------------------+----------------------+
                                                

In [None]:
%reload_ext tensorboard
%tensorboard --logdir ./logs/ --port 6006

In [7]:
# params[1]["params"]["hidden_3"]["kernel"].shape

Load Environment

In [8]:
# env_name = "MyoUserPointing"
env_name = "MyoUserSteering"

print(f"Current backend: {jax.default_backend()}")

Current backend: gpu


In [24]:
## VARIANT 1: load env, and load params from latest run 
# Note: load_checkpoint always loads latest checkpoint from this run

# checkpoint_path = "logs/MyoUserPointing-20250718-200508/checkpoints"
# checkpoint_path = "logs/MyoUserPointing-20250721-221417/checkpoints"
# checkpoint_path = "logs/MyoUserPointing-20250723-150450-MyoUserBase/checkpoints"
# checkpoint_path = "logs/MyoUserSteering-20250726-224307-FixedRewardv2/checkpoints"
checkpoint_path = "logs/MyoUserSteering-20250814-233508-Testv002e/checkpoints"

# checkpoint_path = get_latest_run_path("logs/")

# env_cfg = registry.get_default_config(env_name)
with open(os.path.join(checkpoint_path, "config.json"), "r") as f:
    env_cfg = ConfigDict(json.load(f))
env, make_inference_fn, params = train_or_load_checkpoint(env_name, env_cfg, ppo_params=env_cfg.ppo_config, eval_mode=True, checkpoint_path=checkpoint_path)

Restoring from: /scratch/fjf33/myouser/logs/MyoUserSteering-20250814-233508-Testv002e/checkpoints/15001600
Checkpoint path: /scratch/fjf33/myouser/logs/MyoUserSteering-20250814-233508-Testv002e/checkpoints
No vision, so adding ['screen_pos', 'start_line', 'end_line', 'top_line', 'bottom_line', 'completed_phase_0_arr', 'target'] to obs_keys
Obs keys: ['qpos', 'qvel', 'qacc', 'fingertip', 'act', 'screen_pos', 'start_line', 'end_line', 'top_line', 'bottom_line', 'completed_phase_0_arr', 'target']
No vision, so adding ['screen_pos', 'start_line', 'end_line', 'top_line', 'bottom_line', 'completed_phase_0_arr', 'target'] to obs_keys
Obs keys: ['qpos', 'qvel', 'qacc', 'fingertip', 'act', 'screen_pos', 'start_line', 'end_line', 'top_line', 'bottom_line', 'completed_phase_0_arr', 'target']


In [6]:
# ## VARIANT 2: load params from pickle file

# env_cfg = registry.get_default_config(env_name)
# env = registry.load(env_name, config=env_cfg)

# #param_path = os.path.join(os.path.dirname(checkpoint_path), 'playground_params.pickle')
# param_path = "/scratch/fjf33/myosuite_ankit/myosuite-mjx-policies/myouserbase_params"

# normalize_observations = True
# ppo_network = custom_network_factory(
#     obs_shape=get_observation_size(),  # Example observation shape
#     action_size=26,  # Example action size,
#     policy_hidden_layer_sizes=(256, 256),  #(50, 50, 50),
#     value_hidden_layer_sizes=(256, 256),  #(50, 50, 50),
#     preprocess_observations_fn=running_statistics.normalize if normalize_observations else lambda x, y: x  # Identity function for preprocessing
# )

# params = model.load_params(param_path)
# def deterministic_policy(input_data):
#     logits = ppo_network.policy_network.apply(*params[:2], input_data)
#     brax_result = ppo_network.parametric_action_distribution.mode(logits)
#     return brax_result
# params

### Internal Testing

In [None]:
from brax.training import acting

eval_env = wrapper.wrap_for_brax_training(env, episode_length=200)

rng = jax.random.PRNGKey(123)
rng, eval_key = jax.random.split(rng)

# key = jax.random.PRNGKey(123)
# global_key, local_key = jax.random.split(key)
# del key
# local_key = jax.random.fold_in(local_key, jax.process_index())
# local_key, key_env, eval_key = jax.random.split(local_key, 3)

evaluator = acting.Evaluator(
    eval_env,
    functools.partial(make_inference_fn, deterministic=True),
    num_eval_envs=1,
    episode_length=200,
    action_repeat=1,
    key=eval_key,
)

metrics = evaluator.run_evaluation(params, {})

In [8]:
metrics

{'eval/walltime': 41.33294916152954,
 'eval/episode_reach_dist': Array(0.022375, dtype=float32),
 'eval/episode_reward': Array(6.5148783, dtype=float32),
 'eval/episode_success_rate': Array(1., dtype=float32),
 'eval/episode_target_area_dynamic_width_scale': Array(0., dtype=float32),
 'eval/episode_reach_dist_std': Array(0., dtype=float32),
 'eval/episode_reward_std': Array(0., dtype=float32),
 'eval/episode_success_rate_std': Array(0., dtype=float32),
 'eval/episode_target_area_dynamic_width_scale_std': Array(0., dtype=float32),
 'eval/avg_episode_length': Array(26., dtype=float32),
 'eval/std_episode_length': Array(0., dtype=float32),
 'eval/epoch_eval_time': 41.33294916152954,
 'eval/sps': 4.838754651123446}

In [9]:
evaluator._key, unroll_key = jax.random.split(evaluator._key)
eval_state = evaluator._generate_eval_unroll(params, unroll_key)
eval_state.info["eval_metrics"].episode_metrics

{'reach_dist': Array([0.02351627], dtype=float32),
 'reward': Array([6.330341], dtype=float32),
 'success_rate': Array([1.], dtype=float32),
 'target_area_dynamic_width_scale': Array([0.], dtype=float32)}

In [55]:
from brax.training.acting import generate_unroll
_eval_env = envs.training.EvalWrapper(eval_env)
def _generate_unroll(params, key, num_eval_envs=1):
    reset_keys = jax.random.split(key, num_eval_envs)
    eval_first_state = _eval_env.reset(reset_keys)
    return generate_unroll(
    _eval_env, eval_first_state,
    make_inference_fn(params, deterministic=True),
    key=key,
    unroll_length=200)
generate_unroll_jit = jax.jit(_generate_unroll)

In [56]:
# _eval_env = envs.training.EvalWrapper(eval_env)
evaluator._key, unroll_key = jax.random.split(evaluator._key)
# reset_keys = jax.random.split(unroll_key, 1)
# eval_first_state = _eval_env.reset(reset_keys)
# eval_final_state, eval_data = generate_unroll(
#     _eval_env, eval_first_state,
#     make_inference_fn(params, deterministic=True),
#     key=unroll_key,
#     unroll_length=200)
eval_final_state, eval_data = generate_unroll_jit(
    params, key=unroll_key)

In [57]:
eval_data.reward.sum(), eval_final_state.info["eval_metrics"].episode_metrics

(Array(-10.9054, dtype=float32),
 {'reach_dist': Array([0.02216392], dtype=float32),
  'reward': Array([6.478419], dtype=float32),
  'success_rate': Array([1.], dtype=float32),
  'target_area_dynamic_width_scale': Array([0.], dtype=float32)})

### Visualize Rollouts

In [25]:
eval_env = env
# eval_env = wrapper.wrap_for_brax_training(env, episode_length=200)
# eval_env = envs.training.EvalWrapper(_eval_env)

# # Hide model target sphere for visualization of policies trained without vision
# eval_env.unwrapped.mj_model.geom("target_sphere").rgba = np.array([0., 1., 0., 1])  #np.zeros(4, dtype=np.float32)

In [26]:
jit_reset = jax.jit(eval_env.reset)
jit_step = jax.jit(eval_env.step)

# For VARIANT 1:
jit_inference_fn = jax.jit(make_inference_fn(params, deterministic=True))

In [46]:
eval_key = jax.random.PRNGKey(123)
# reset_keys = jax.random.split(eval_key, 1)  #only required if wrapper.wrap_for_brax_training is used
eval_key, reset_keys = jax.random.split(eval_key)
rollout = []
# modify_scene_fns = []
n_episodes = 1

for _ in range(n_episodes):
  state = jit_reset(reset_keys)
  rollout.append(state)
  # modify_scene_fns.append(functools.partial(update_target_visuals, target_pos=state.info["target_pos"].flatten(), target_size=state.info["target_radius"].flatten()))
  for i in range(env_cfg.ppo_config.episode_length):
    eval_key, key = jax.random.split(eval_key)
    ctrl, _ = jit_inference_fn(state.obs, key)  #VARIANT 1
    # ctrl = deterministic_policy(state.obs)  #VARIANT 2
    # ctrl = jax.random.uniform(act_rng, shape=eval_env._na)  #BASELINE: random control
    state = jit_step(state, ctrl)
    # touch_detected = any([({eval_env.mj_model.geom(con_geom[0]).name, eval_env.mj_model.geom(con_geom[1]).name} == {"fingertip_contact", "screen"}) and (con_dist < 0) for con_geom, con_dist in zip(state.data._impl.contact.geom, state.data._impl.contact.dist)])
    # if touch_detected:
    #   print(f"Step {i}, touch detected. {state.data._impl.contact.dist}")    
    # print(f"Step {i}, ee_pos: {state.obs[:, 15:18]}")
    # print(f"Target {i}, target_pos: {state.obs[:, -4:-1]}")
    # print(f"Step {i}, steps_since_last_hit: {state.info['steps_since_last_hit']}")
    rollout.append(state)
    # modify_scene_fns.append(functools.partial(update_target_visuals, target_pos=state.info["target_pos"].flatten(), target_size=state.info["target_radius"].flatten()))
    if state.done.all():
      break

render_every = 1
traj = rollout[::render_every]
# mod_fns = modify_scene_fns[::render_every]
frames = eval_env.render(traj, camera="fixed-eye") #, modify_scene_fns=mod_fns)
rewards = [s.reward for s in rollout]
media.show_video(frames, fps=1.0 / eval_env.dt / render_every / 5)

100%|██████████| 24/24 [00:00<00:00, 856.38it/s]


0
This browser does not support the video tag.


In [47]:
[(r.data.site_xpos[eval_env.fingertip_id], r.info["start_line"]) for r in rollout]

[(Array([ 0.03493699, -0.198526  ,  0.28504306], dtype=float32),
  Array([-0.053655  ,  0.14310211,  1.0893918 ], dtype=float32)),
 (Array([ 0.04069119, -0.20599592,  0.28613693], dtype=float32),
  Array([-0.053655  ,  0.14310211,  1.0893918 ], dtype=float32)),
 (Array([ 0.08228442, -0.2298862 ,  0.29534462], dtype=float32),
  Array([-0.053655  ,  0.14310211,  1.0893918 ], dtype=float32)),
 (Array([ 0.17071193, -0.25748423,  0.33155507], dtype=float32),
  Array([-0.053655  ,  0.14310211,  1.0893918 ], dtype=float32)),
 (Array([ 0.30330464, -0.27855498,  0.44532046], dtype=float32),
  Array([-0.053655  ,  0.14310211,  1.0893918 ], dtype=float32)),
 (Array([ 0.40173364, -0.2501375 ,  0.6785252 ], dtype=float32),
  Array([-0.053655  ,  0.14310211,  1.0893918 ], dtype=float32)),
 (Array([ 0.36892766, -0.19578382,  0.9501913 ], dtype=float32),
  Array([-0.053655  ,  0.14310211,  1.0893918 ], dtype=float32)),
 (Array([ 0.26211452, -0.15788034,  1.1548858 ], dtype=float32),
  Array([-0.053655

In [34]:
[(r.metrics['distance_phase_0'], r.metrics['distance_phase_1'], r.metrics['dist']) for r in rollout]

[(Array(0., dtype=float32, weak_type=True),
  Array(0., dtype=float32, weak_type=True),
  Array(0., dtype=float32, weak_type=True)),
 (Array(1.3049129, dtype=float32),
  Array(0., dtype=float32),
  Array(1.3049129, dtype=float32)),
 (Array(1.3117665, dtype=float32),
  Array(0., dtype=float32),
  Array(1.3117665, dtype=float32)),
 (Array(1.3100841, dtype=float32),
  Array(0., dtype=float32),
  Array(1.3100841, dtype=float32)),
 (Array(1.2725639, dtype=float32),
  Array(0., dtype=float32),
  Array(1.2725639, dtype=float32)),
 (Array(1.1525897, dtype=float32),
  Array(0., dtype=float32),
  Array(1.1525897, dtype=float32)),
 (Array(0.9832926, dtype=float32),
  Array(0., dtype=float32),
  Array(0.9832926, dtype=float32)),
 (Array(0.8651339, dtype=float32),
  Array(0., dtype=float32),
  Array(0.8651339, dtype=float32)),
 (Array(0.78354716, dtype=float32),
  Array(0., dtype=float32),
  Array(0.78354716, dtype=float32)),
 (Array(0.69642377, dtype=float32),
  Array(0., dtype=float32),
  Array(0

In [49]:
touch_detected = any([{eval_env.mj_model.geom(con[0]).name, eval_env.mj_model.geom(con[1]).name} == {"fingertip_contact", "screen"} for con in state.data._impl.contact.geom])
touches = [{eval_env.mj_model.geom(con[0]).name, eval_env.mj_model.geom(con[1]).name} for con in state.data._impl.contact.geom]

touch_detected, touches

(True, [{'fingertip_contact', 'screen'}])

In [51]:
state.data._impl.contact.dist

Array([1.], dtype=float32)

In [41]:
eval_env.mj_model.geom

<bound method PyCapsule.geom of <mujoco._structs.MjModel object at 0x7749d83c61b0>>

In [26]:
# eval_key, reset_keys = jax.random.split(eval_key)
# state = jit_reset(reset_keys)
# eval_key, key = jax.random.split(eval_key)
# ctrl, _ = jit_inference_fn(state.obs, key)  #VARIANT 1
state = jit_step(state, ctrl)
state.metrics["dist_combined"]
(np.exp(-state.metrics["dist_combined"]*10.) - 1.)/10., (np.exp(-1.5*10.) - 1.)/10.

(np.float32(-0.09992059), np.float64(-0.09999996940976795))

In [33]:
np.exp(-0.7*2), (np.exp(-1.5*2))

(np.float64(0.2465969639416065), np.float64(0.049787068367863944))

In [17]:
np.where([r.info["phase_0_done"] for r in rollout])[0][0]*env.unwrapped._ctrl_dt
[(r.info["phase_0_done"], r.reward) for r in rollout]

[(Array(False, dtype=bool), Array(0., dtype=float32)),
 (Array(False, dtype=bool), Array(-0.09999998, dtype=float32)),
 (Array(False, dtype=bool), Array(-0.09999997, dtype=float32)),
 (Array(False, dtype=bool), Array(-0.09999994, dtype=float32)),
 (Array(False, dtype=bool), Array(-0.09999982, dtype=float32)),
 (Array(False, dtype=bool), Array(-0.09999924, dtype=float32)),
 (Array(False, dtype=bool), Array(-0.09999751, dtype=float32)),
 (Array(False, dtype=bool), Array(-0.09999336, dtype=float32)),
 (Array(False, dtype=bool), Array(-0.09998853, dtype=float32)),
 (Array(False, dtype=bool), Array(-0.09998377, dtype=float32)),
 (Array(False, dtype=bool), Array(-0.0999779, dtype=float32)),
 (Array(False, dtype=bool), Array(-0.09997098, dtype=float32)),
 (Array(False, dtype=bool), Array(-0.0999607, dtype=float32)),
 (Array(False, dtype=bool), Array(-0.09995576, dtype=float32)),
 (Array(False, dtype=bool), Array(-0.0999537, dtype=float32)),
 (Array(False, dtype=bool), Array(-0.0999264, dtype=

In [22]:
[i.info["target_pos"] for i in rollout], [_i.obs["proprioception"][15:18] for _i in rollout]

([Array([ 0.26876748, -0.24839334,  0.79923654], dtype=float32),
  Array([ 0.26876748, -0.24839334,  0.79923654], dtype=float32),
  Array([ 0.26876748, -0.24839334,  0.79923654], dtype=float32),
  Array([ 0.26876748, -0.24839334,  0.79923654], dtype=float32),
  Array([ 0.26876748, -0.24839334,  0.79923654], dtype=float32),
  Array([ 0.26876748, -0.24839334,  0.79923654], dtype=float32),
  Array([ 0.26876748, -0.24839334,  0.79923654], dtype=float32),
  Array([ 0.26876748, -0.24839334,  0.79923654], dtype=float32),
  Array([ 0.26876748, -0.24839334,  0.79923654], dtype=float32),
  Array([ 0.26876748, -0.24839334,  0.79923654], dtype=float32),
  Array([ 0.26876748, -0.24839334,  0.79923654], dtype=float32),
  Array([ 0.26876748, -0.24839334,  0.79923654], dtype=float32),
  Array([ 0.26876748, -0.24839334,  0.79923654], dtype=float32)],
 [Array([ 0.03493699, -0.198526  ,  0.28504306], dtype=float32),
  Array([ 0.05467008, -0.20556903,  0.2879694 ], dtype=float32),
  Array([ 0.12839839, -0

In [11]:
np.sum(rewards)

np.float32(7.423525)