In [1]:
%env MUJOCO_GL=egl

import os
os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"] = ".20"  #restrict JAX to 20% of available GPU RAM; useful for loading MyoSuite and UitB policies simultaneously

env: MUJOCO_GL=egl


In [2]:
import mujoco
from mujoco import mjx

import numpy as np
from brax import envs
from brax.training.acme.running_statistics import normalize
# from brax.training.agents.ppo import networks as ppo_networks
from brax.training.acme import running_statistics
from brax.io import model
from etils import epath
import functools

from myosuite.envs.myo.myouser.utils import custom_network_factory, get_observation_size, update_target_visuals
from myosuite.envs.myo.mjx.utils import get_latest_run_path, load_checkpoint
import mediapy as media

import jax
import jax.numpy as jp

from mujoco_playground import wrapper, registry

from playground_myoUser_pointing import PlaygroundArmPointing, default_config
registry.manipulation.register_environment("MyoUserPointing", PlaygroundArmPointing, default_config)

MyoSuite:> Registering Myo Envs
MyoSuite:> Registering MyoUser Envs


In [3]:
# @title Check if MuJoCo installation was successful

import distutils.util
import os
import subprocess

if subprocess.run('nvidia-smi').returncode:
  raise RuntimeError(
      'Cannot communicate with GPU. '
      'Make sure you are using a GPU Colab runtime. '
      'Go to the Runtime menu and select Choose runtime type.'
  )

# Add an ICD config so that glvnd can pick up the Nvidia EGL driver.
# This is usually installed as part of an Nvidia driver package, but the Colab
# kernel doesn't install its driver via APT, and as a result the ICD is missing.
# (https://github.com/NVIDIA/libglvnd/blob/master/src/EGL/icd_enumeration.md)
NVIDIA_ICD_CONFIG_PATH = '/usr/share/glvnd/egl_vendor.d/10_nvidia.json'
if not os.path.exists(NVIDIA_ICD_CONFIG_PATH):
  with open(NVIDIA_ICD_CONFIG_PATH, 'w') as f:
    f.write("""{
    "file_format_version" : "1.0.0",
    "ICD" : {
        "library_path" : "libEGL_nvidia.so.0"
    }
}
""")

# Configure MuJoCo to use the EGL rendering backend (requires GPU)
print('Setting environment variable to use GPU rendering:')
%env MUJOCO_GL=egl

try:
  print('Checking that the installation succeeded:')
  import mujoco

  mujoco.MjModel.from_xml_string('<mujoco/>')
except Exception as e:
  raise e from RuntimeError(
      'Something went wrong during installation. Check the shell output above '
      'for more information.\n'
      'If using a hosted Colab runtime, make sure you enable GPU acceleration '
      'by going to the Runtime menu and selecting "Choose runtime type".'
  )

print('Installation successful.')

# Tell XLA to use Triton GEMM, this improves steps/sec by ~30% on some GPUs
xla_flags = os.environ.get('XLA_FLAGS', '')
xla_flags += ' --xla_gpu_triton_gemm_any=True'
os.environ['XLA_FLAGS'] = xla_flags

Mon Jul 21 22:08:42 2025       
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 570.158.01             Driver Version: 570.158.01     CUDA Version: 12.8     |
|-----------------------------------------+------------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id          Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |           Memory-Usage | GPU-Util  Compute M. |
|                                         |                        |               MIG M. |
|   0  NVIDIA GeForce RTX 5090        On  |   00000000:01:00.0 Off |                  N/A |
|  0%   41C    P1             24W /  600W |   12081MiB /  32607MiB |      0%      Default |
|                                         |                        |                  N/A |
+-----------------------------------------+------------------------+----------------------+
                                                

In [3]:
%tensorboard --logdir ./logs/ --port 6006

UsageError: Line magic function `%tensorboard` not found.


In [7]:
# params[1]["params"]["hidden_3"]["kernel"].shape

Load Environment

In [4]:
env_name = "MyoUserPointing"

print(f"Current backend: {jax.default_backend()}")

Current backend: gpu


In [None]:
## VARIANT 1: load env, and load params from latest run 
# Note: load_checkpoint always loads latest checkpoint from this run

# checkpoint_path = "logs/MyoUserPointing-20250718-200508/checkpoints"
checkpoint_path = "logs/MyoUserPointing-20250721-221417/checkpoints"

# checkpoint_path = get_latest_run_path("logs/")

env_cfg = registry.get_default_config(env_name)
env, make_inference_fn, params = load_checkpoint(env_name, checkpoint_path)

Restoring from: /scratch/fjf33/myouser/myosuite/envs/myo/mjx/logs/MyoUserPointing-20250721-215656/checkpoints/1003520


In [6]:
checkpoint_path = "logs/MyoUserPointing-20250718-200508/checkpoints"

from myosuite.train.myouser.custom_ppo import checkpoint
ckpt_path = epath.Path(checkpoint_path).resolve()
if ckpt_path.is_dir():
    latest_ckpts = list(ckpt_path.glob("*"))
    latest_ckpts = [ckpt for ckpt in latest_ckpts if ckpt.is_dir()]
    latest_ckpts.sort(key=lambda x: int(x.name))
    latest_ckpt = latest_ckpts[-1]
    restore_checkpoint_path = latest_ckpt
ckpt = checkpoint.load(os.path.join(restore_checkpoint_path))
ckpt[2]["params"]["hidden_0"].keys()

dict_keys(['bias', 'kernel'])

In [7]:
checkpoint_path = "logs/MyoUserPointing-20250721-194636/checkpoints"

from myosuite.train.myouser.custom_ppo import checkpoint
ckpt_path = epath.Path(checkpoint_path).resolve()
if ckpt_path.is_dir():
    latest_ckpts = list(ckpt_path.glob("*"))
    latest_ckpts = [ckpt for ckpt in latest_ckpts if ckpt.is_dir()]
    latest_ckpts.sort(key=lambda x: int(x.name))
    latest_ckpt = latest_ckpts[-1]
    restore_checkpoint_path = latest_ckpt
ckpt2 = checkpoint.load(os.path.join(restore_checkpoint_path))
ckpt2[1]["params"]["value_network"]["layers"]["0"].keys()


dict_keys(['bias', 'kernel'])

In [None]:
checkpoint_path = "logs/MyoUserPointing-20250721-221417/checkpoints"

from myosuite.train.myouser.custom_ppo import checkpoint
ckpt_path = epath.Path(checkpoint_path).resolve()
if ckpt_path.is_dir():
    latest_ckpts = list(ckpt_path.glob("*"))
    latest_ckpts = [ckpt for ckpt in latest_ckpts if ckpt.is_dir()]
    latest_ckpts.sort(key=lambda x: int(x.name))
    latest_ckpt = latest_ckpts[-1]
    restore_checkpoint_path = latest_ckpt
ckpt2 = checkpoint.load(os.path.join(restore_checkpoint_path))
# ckpt2[1]["params"]["value_network"]["layers"]["0"].keys()
ckpt2[2]["params"]["hidden_0"].keys()


IndexError: list index out of range

In [None]:
## TODO: delete this code

# param_path = "logs/MyoUserPointing-20250721-221417/playground_params.pickle"

# env_cfg = registry.get_default_config(env_name)
# env = registry.load(env_name, config=env_cfg)

# normalize_observations = True

# from myosuite.train.myouser.custom_ppo import networks_vision_unified as networks
# from myosuite.envs.myo.mjx.train_jax_ppo import get_observation_size
# network_factory = functools.partial(networks.custom_network_factory, vision=False, get_observation_size=functools.partial(get_observation_size, vision=False))
# ppo_network = network_factory(
#     get_observation_size(), env.action_size, preprocess_observations_fn=normalize
# )
# params = model.load_params(param_path)
# def deterministic_policy(input_data):
#     print(ppo_network.policy_network)
#     logits = ppo_network.policy_network.apply(*params[:2], input_data)
#     brax_result = ppo_network.parametric_action_distribution.mode(logits)
#     return brax_result
# # params

In [None]:
# ## VARIANT 2: load params from pickle file

# env_cfg = registry.get_default_config(env_name)
# env = registry.load(env_name, config=env_cfg)

# #param_path = os.path.join(os.path.dirname(checkpoint_path), 'playground_params.pickle')
# param_path = "/scratch/fjf33/myosuite_ankit/myosuite-mjx-policies/myouserbase_params"

# normalize_observations = True
# ppo_network = custom_network_factory(
#     obs_shape=get_observation_size(),  # Example observation shape
#     action_size=26,  # Example action size,
#     policy_hidden_layer_sizes=(256, 256),  #(50, 50, 50),
#     value_hidden_layer_sizes=(256, 256),  #(50, 50, 50),
#     preprocess_observations_fn=running_statistics.normalize if normalize_observations else lambda x, y: x  # Identity function for preprocessing
# )

# params = model.load_params(param_path)
# def deterministic_policy(input_data):
#     logits = ppo_network.policy_network.apply(*params[:2], input_data)
#     brax_result = ppo_network.parametric_action_distribution.mode(logits)
#     return brax_result
# params

(RunningStatisticsState(mean=Array([-0.2615254 , -0.54561627,  0.61962676,  0.83070046, -0.34872645,
        -0.03588274, -0.05668309,  0.03538842,  0.06029191, -0.03726191,
        -0.7860558 , -0.8074278 ,  0.49745247, -0.7142116 ,  0.31403428,
         0.2688805 , -0.17935699,  0.9940682 , -0.9701431 , -0.6329077 ,
         0.19914536, -0.20567195, -0.20970385, -0.07764046, -0.502418  ,
        -0.09197175, -0.8806276 , -0.9722189 , -0.8965455 ,  0.0283843 ,
        -0.5821786 , -0.12825386, -0.45874986, -0.32036504, -0.9201705 ,
        -0.8188699 , -0.9456196 , -0.9243711 ,  0.24019895,  0.17352438,
        -0.49581355,  0.30509523, -0.5208441 , -0.9418352 ,  0.27023974,
        -0.16672991,  1.0077051 ,  0.05      ], dtype=float32), std=Array([3.5880351e-01, 2.0395048e-01, 2.9089022e-01, 2.0171501e-01,
        1.4141162e-01, 1.0120587e+00, 1.0257189e+00, 1.3978747e+00,
        1.1899329e+00, 9.2838669e-01, 5.2053886e+01, 1.6014919e+01,
        7.9966331e+01, 4.5478992e+01, 1.2146

### Internal Testing

In [None]:
from brax.training import acting

eval_env = wrapper.wrap_for_brax_training(env, episode_length=200)

rng = jax.random.PRNGKey(123)
rng, eval_key = jax.random.split(rng)

# key = jax.random.PRNGKey(123)
# global_key, local_key = jax.random.split(key)
# del key
# local_key = jax.random.fold_in(local_key, jax.process_index())
# local_key, key_env, eval_key = jax.random.split(local_key, 3)

evaluator = acting.Evaluator(
    eval_env,
    functools.partial(make_inference_fn, deterministic=True),
    num_eval_envs=1,
    episode_length=200,
    action_repeat=1,
    key=eval_key,
)

metrics = evaluator.run_evaluation(params, {})

In [8]:
metrics

{'eval/walltime': 41.33294916152954,
 'eval/episode_reach_dist': Array(0.022375, dtype=float32),
 'eval/episode_reward': Array(6.5148783, dtype=float32),
 'eval/episode_success_rate': Array(1., dtype=float32),
 'eval/episode_target_area_dynamic_width_scale': Array(0., dtype=float32),
 'eval/episode_reach_dist_std': Array(0., dtype=float32),
 'eval/episode_reward_std': Array(0., dtype=float32),
 'eval/episode_success_rate_std': Array(0., dtype=float32),
 'eval/episode_target_area_dynamic_width_scale_std': Array(0., dtype=float32),
 'eval/avg_episode_length': Array(26., dtype=float32),
 'eval/std_episode_length': Array(0., dtype=float32),
 'eval/epoch_eval_time': 41.33294916152954,
 'eval/sps': 4.838754651123446}

In [9]:
evaluator._key, unroll_key = jax.random.split(evaluator._key)
eval_state = evaluator._generate_eval_unroll(params, unroll_key)
eval_state.info["eval_metrics"].episode_metrics

{'reach_dist': Array([0.02351627], dtype=float32),
 'reward': Array([6.330341], dtype=float32),
 'success_rate': Array([1.], dtype=float32),
 'target_area_dynamic_width_scale': Array([0.], dtype=float32)}

In [55]:
from brax.training.acting import generate_unroll
_eval_env = envs.training.EvalWrapper(eval_env)
def _generate_unroll(params, key, num_eval_envs=1):
    reset_keys = jax.random.split(key, num_eval_envs)
    eval_first_state = _eval_env.reset(reset_keys)
    return generate_unroll(
    _eval_env, eval_first_state,
    make_inference_fn(params, deterministic=True),
    key=key,
    unroll_length=200)
generate_unroll_jit = jax.jit(_generate_unroll)

In [56]:
# _eval_env = envs.training.EvalWrapper(eval_env)
evaluator._key, unroll_key = jax.random.split(evaluator._key)
# reset_keys = jax.random.split(unroll_key, 1)
# eval_first_state = _eval_env.reset(reset_keys)
# eval_final_state, eval_data = generate_unroll(
#     _eval_env, eval_first_state,
#     make_inference_fn(params, deterministic=True),
#     key=unroll_key,
#     unroll_length=200)
eval_final_state, eval_data = generate_unroll_jit(
    params, key=unroll_key)

In [57]:
eval_data.reward.sum(), eval_final_state.info["eval_metrics"].episode_metrics

(Array(-10.9054, dtype=float32),
 {'reach_dist': Array([0.02216392], dtype=float32),
  'reward': Array([6.478419], dtype=float32),
  'success_rate': Array([1.], dtype=float32),
  'target_area_dynamic_width_scale': Array([0.], dtype=float32)})

### Visualize Rollouts

In [34]:
eval_env = env
# eval_env = wrapper.wrap_for_brax_training(env, episode_length=200)
# eval_env = envs.training.EvalWrapper(_eval_env)

# Hide model target sphere for visualization of policies trained without vision
eval_env.unwrapped.mj_model.geom("target_sphere").rgba = np.zeros(4, dtype=np.float32)

In [35]:
jit_reset = jax.jit(eval_env.reset)
jit_step = jax.jit(eval_env.step)

# # For VARIANT 1:
# jit_inference_fn = jax.jit(make_inference_fn(params, deterministic=True))

In [36]:
eval_key = jax.random.PRNGKey(1)
# reset_keys = jax.random.split(eval_key, 1)  #only required if wrapper.wrap_for_brax_training is used
eval_key, reset_keys = jax.random.split(eval_key)
rollout = []
modify_scene_fns = []
n_episodes = 1

for _ in range(n_episodes):
  state = jit_reset(reset_keys)
  rollout.append(state)
  modify_scene_fns.append(functools.partial(update_target_visuals, target_pos=state.info["target_pos"].flatten(), target_size=state.info["target_radius"].flatten()))
  for i in range(env_cfg.ppo_config.episode_length):
    eval_key, key = jax.random.split(eval_key)
    # ctrl, _ = jit_inference_fn(state.obs, key)  #VARIANT 1
    ctrl = deterministic_policy(state.obs)  #VARIANT 2
    state = jit_step(state, ctrl)
    # print(f"Step {i}, ee_pos: {state.obs[:, 15:18]}")
    # print(f"Target {i}, target_pos: {state.obs[:, -4:-1]}")
    # print(f"Step {i}, steps_since_last_hit: {state.info['steps_since_last_hit']}")
    rollout.append(state)
    modify_scene_fns.append(functools.partial(update_target_visuals, target_pos=state.info["target_pos"].flatten(), target_size=state.info["target_radius"].flatten()))
    if state.done.all():
      break

render_every = 1
traj = rollout[::render_every]
mod_fns = modify_scene_fns[::render_every]
frames = env.render(traj, modify_scene_fns=mod_fns)
rewards = [s.reward for s in rollout]
media.show_video(frames, fps=1.0 / env.dt / render_every)

TypeError: argument of type 'RunningStatisticsState' is not iterable

In [None]:
# [i.info["target_pos"] for i in rollout]  #[_i.obs[:, -4:-1] for _i in rollout]

In [30]:
np.sum(rewards)

np.float32(-17.954895)