In [1]:
%env MUJOCO_GL=egl

import os
os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"] = ".20"  #restrict JAX to 20% of available GPU RAM; useful for loading MyoSuite and UitB policies simultaneously

env: MUJOCO_GL=egl


In [2]:
import mujoco
from mujoco import mjx

import numpy as np
from brax import envs
# from brax.training.acme.running_statistics import normalize
# from brax.training.acme import running_statistics
# from brax.io import model
# from etils import epath
import functools

import myosuite
# from myosuite.envs.myo.myouser.utils import update_target_visuals
from myosuite.train.utils.train import train_or_load_checkpoint
from myosuite.envs.myo.myouser.evaluate import evaluate_policy
from myosuite.envs.myo.myouser.utils import render_traj
import mediapy as media

import jax
import jax.numpy as jp

from mujoco_playground import wrapper, registry
from ml_collections import ConfigDict
import json

myo_path = os.path.dirname(os.path.dirname(myosuite.__file__))
os.chdir(myo_path)

MyoSuite:> Registering Myo Envs
MyoSuite:> Registering MyoUser Envs


In [3]:
# # if module functions need to be reimported...

# import sys, importlib
# importlib.reload(sys.modules['myosuite.envs.myo.myouser.evaluate'])
# from myosuite.envs.myo.myouser.evaluate import evaluate_policy
# from myosuite.envs.myo.myouser.utils import render_traj

In [3]:
# @title Check if MuJoCo installation was successful

import distutils.util
import os
import subprocess

if subprocess.run('nvidia-smi').returncode:
  raise RuntimeError(
      'Cannot communicate with GPU. '
      'Make sure you are using a GPU Colab runtime. '
      'Go to the Runtime menu and select Choose runtime type.'
  )

# Add an ICD config so that glvnd can pick up the Nvidia EGL driver.
# This is usually installed as part of an Nvidia driver package, but the Colab
# kernel doesn't install its driver via APT, and as a result the ICD is missing.
# (https://github.com/NVIDIA/libglvnd/blob/master/src/EGL/icd_enumeration.md)
NVIDIA_ICD_CONFIG_PATH = '/usr/share/glvnd/egl_vendor.d/10_nvidia.json'
if not os.path.exists(NVIDIA_ICD_CONFIG_PATH):
  with open(NVIDIA_ICD_CONFIG_PATH, 'w') as f:
    f.write("""{
    "file_format_version" : "1.0.0",
    "ICD" : {
        "library_path" : "libEGL_nvidia.so.0"
    }
}
""")

# Configure MuJoCo to use the EGL rendering backend (requires GPU)
print('Setting environment variable to use GPU rendering:')
%env MUJOCO_GL=egl

try:
  print('Checking that the installation succeeded:')
  import mujoco

  mujoco.MjModel.from_xml_string('<mujoco/>')
except Exception as e:
  raise e from RuntimeError(
      'Something went wrong during installation. Check the shell output above '
      'for more information.\n'
      'If using a hosted Colab runtime, make sure you enable GPU acceleration '
      'by going to the Runtime menu and selecting "Choose runtime type".'
  )

print('Installation successful.')

# Tell XLA to use Triton GEMM, this improves steps/sec by ~30% on some GPUs
xla_flags = os.environ.get('XLA_FLAGS', '')
xla_flags += ' --xla_gpu_triton_gemm_any=True'
os.environ['XLA_FLAGS'] = xla_flags

Wed Aug 20 00:39:37 2025       
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 570.158.01             Driver Version: 570.158.01     CUDA Version: 12.8     |
|-----------------------------------------+------------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id          Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |           Memory-Usage | GPU-Util  Compute M. |
|                                         |                        |               MIG M. |
|   0  NVIDIA GeForce RTX 5090        On  |   00000000:01:00.0 Off |                  N/A |
|  0%   45C    P8              8W /  600W |     169MiB /  32607MiB |      0%      Default |
|                                         |                        |                  N/A |
+-----------------------------------------+------------------------+----------------------+
                                                

In [None]:
%reload_ext tensorboard
%tensorboard --logdir ./logs/ --port 6006

In [7]:
# params[1]["params"]["hidden_3"]["kernel"].shape

Load Environment

In [4]:
# env_name = "MyoUserPointing"
env_name = "MyoUserSteering"

print(f"Current backend: {jax.default_backend()}")

Current backend: gpu


In [None]:
# ## LOADING DEPRECATED CHECKPOINTS (NOT SAFE!!!):

# checkpoint_path = "logs/MyoUserSteering-20250816-020256-Testv002hNoCollision80and50fixedoffset0.5Collision/checkpoints"

# # env_cfg = registry.get_default_config(env_name)
# with open(os.path.join(checkpoint_path, "config.json"), "r") as f:
#     env_cfg = ConfigDict(json.load(f))

# with env_cfg.unlocked():
#     env_cfg.vision.enabled = False
    
#     env_cfg.task_config.min_width: float = 0.3
#     env_cfg.task_config.min_height: float = 0.1
#     env_cfg.task_config.bottom: float = -0.3
#     env_cfg.task_config.top: float = 0.3
#     env_cfg.task_config.left: float = 0.3
#     env_cfg.task_config.right: float = -0.3

# _env = registry.load(env_name, config=env_cfg)
# rng_init = jax.random.PRNGKey(0)
# init_state = _env.reset(rng_init)

No vision, so adding ['screen_pos', 'start_line', 'end_line', 'top_line', 'bottom_line', 'completed_phase_0_arr', 'target'] to obs_keys
Obs keys: ['qpos', 'qvel', 'qacc', 'fingertip', 'act', 'screen_pos', 'start_line', 'end_line', 'top_line', 'bottom_line', 'completed_phase_0_arr', 'target']


<mujoco._structs.MjModel at 0x7f1e9d08a630>

In [5]:
## VARIANT 1: load env, and load params from latest run 
# Note: load_checkpoint always loads latest checkpoint from this run

# checkpoint_path = "logs/MyoUserPointing-20250718-200508/checkpoints"
# checkpoint_path = "logs/MyoUserPointing-20250721-221417/checkpoints"
# checkpoint_path = "logs/MyoUserPointing-20250723-150450-MyoUserBase/checkpoints"
# checkpoint_path = "logs/MyoUserSteering-20250726-224307-FixedRewardv2/checkpoints"
# checkpoint_path = "logs/MyoUserSteering-20250815-230528-Testv002hNoCollision80and50fixedoffset0.5/checkpoints"
# checkpoint_path = "logs/MyoUserSteering-20250815-232641-Testv002hNoCollision80and50fixedoffset0.5Collision/checkpoints"
# checkpoint_path = "logs/MyoUserSteering-20250816-020256-Testv002hNoCollision80and50fixedoffset0.5Collision/checkpoints"
# checkpoint_path = get_latest_run_path("logs/")
checkpoint_path = "logs/MyoUserSteering-20250819-123318/checkpoints"

# env_cfg = registry.get_default_config(env_name)
with open(os.path.join(checkpoint_path, "config.json"), "r") as f:
    env_cfg = ConfigDict(json.load(f))
env, make_inference_fn, params = train_or_load_checkpoint(env_name, env_cfg, eval_mode=True, checkpoint_path=checkpoint_path)
jit_inference_fn = jax.jit(make_inference_fn(params, deterministic=True))
jit_reset = jax.jit(env.reset)
jit_step = jax.jit(env.step)

Restoring from: /scratch/fjf33/myouser/logs/MyoUserSteering-20250819-123318/checkpoints/15206400
Checkpoint path: /scratch/fjf33/myouser/logs/MyoUserSteering-20250819-123318/checkpoints
No vision, so adding ['screen_pos', 'start_line', 'end_line', 'top_line', 'bottom_line', 'completed_phase_0_arr', 'target'] to obs_keys
Obs keys: ['qpos', 'qvel', 'qacc', 'fingertip', 'act', 'screen_pos', 'start_line', 'end_line', 'top_line', 'bottom_line', 'completed_phase_0_arr', 'target']
No vision, so adding ['screen_pos', 'start_line', 'end_line', 'top_line', 'bottom_line', 'completed_phase_0_arr', 'target'] to obs_keys
Obs keys: ['qpos', 'qvel', 'qacc', 'fingertip', 'act', 'screen_pos', 'start_line', 'end_line', 'top_line', 'bottom_line', 'completed_phase_0_arr', 'target']


In [7]:
## Replay stored video
rollout_stored = np.load(os.path.join(checkpoint_path, "../traj.pickle"), allow_pickle=True)
render_traj(rollout_stored, eval_env=env)  #, camera=None)

# frames = env.render(rollout_stored) #, camera="fixed-eye") #, modify_scene_fns=mod_fns)
# rewards = [s.reward for s in rollout_stored]
# print(sum(rewards))
# media.show_video(frames, fps=1.0 / env.dt)

100%|██████████| 1822/1822 [00:02<00:00, 765.85it/s]


0
This browser does not support the video tag.


In [6]:
# ## VARIANT 2: load params from pickle file

# env_cfg = registry.get_default_config(env_name)
# env = registry.load(env_name, config=env_cfg)

# #param_path = os.path.join(os.path.dirname(checkpoint_path), 'playground_params.pickle')
# param_path = "/scratch/fjf33/myosuite_ankit/myosuite-mjx-policies/myouserbase_params"

# normalize_observations = True
# ppo_network = custom_network_factory(
#     obs_shape=get_observation_size(),  # Example observation shape
#     action_size=26,  # Example action size,
#     policy_hidden_layer_sizes=(256, 256),  #(50, 50, 50),
#     value_hidden_layer_sizes=(256, 256),  #(50, 50, 50),
#     preprocess_observations_fn=running_statistics.normalize if normalize_observations else lambda x, y: x  # Identity function for preprocessing
# )

# params = model.load_params(param_path)
# def deterministic_policy(input_data):
#     logits = ppo_network.policy_network.apply(*params[:2], input_data)
#     brax_result = ppo_network.parametric_action_distribution.mode(logits)
#     return brax_result
# params

### Internal Testing

In [None]:
from brax.training import acting

eval_env = wrapper.wrap_for_brax_training(env, episode_length=200)

rng = jax.random.PRNGKey(123)
rng, eval_key = jax.random.split(rng)

# key = jax.random.PRNGKey(123)
# global_key, local_key = jax.random.split(key)
# del key
# local_key = jax.random.fold_in(local_key, jax.process_index())
# local_key, key_env, eval_key = jax.random.split(local_key, 3)

evaluator = acting.Evaluator(
    eval_env,
    functools.partial(make_inference_fn, deterministic=True),
    num_eval_envs=1,
    episode_length=200,
    action_repeat=1,
    key=eval_key,
)

metrics = evaluator.run_evaluation(params, {})

In [9]:
evaluator._key, unroll_key = jax.random.split(evaluator._key)
eval_state = evaluator._generate_eval_unroll(params, unroll_key)
eval_state.info["eval_metrics"].episode_metrics

{'reach_dist': Array([0.02351627], dtype=float32),
 'reward': Array([6.330341], dtype=float32),
 'success_rate': Array([1.], dtype=float32),
 'target_area_dynamic_width_scale': Array([0.], dtype=float32)}

In [55]:
from brax.training.acting import generate_unroll
_eval_env = envs.training.EvalWrapper(eval_env)
def _generate_unroll(params, key, num_eval_envs=1):
    reset_keys = jax.random.split(key, num_eval_envs)
    eval_first_state = _eval_env.reset(reset_keys)
    return generate_unroll(
    _eval_env, eval_first_state,
    make_inference_fn(params, deterministic=True),
    key=key,
    unroll_length=200)
generate_unroll_jit = jax.jit(_generate_unroll)

In [56]:
# _eval_env = envs.training.EvalWrapper(eval_env)
evaluator._key, unroll_key = jax.random.split(evaluator._key)
# reset_keys = jax.random.split(unroll_key, 1)
# eval_first_state = _eval_env.reset(reset_keys)
# eval_final_state, eval_data = generate_unroll(
#     _eval_env, eval_first_state,
#     make_inference_fn(params, deterministic=True),
#     key=unroll_key,
#     unroll_length=200)
eval_final_state, eval_data = generate_unroll_jit(
    params, key=unroll_key)

In [57]:
eval_data.reward.sum(), eval_final_state.info["eval_metrics"].episode_metrics

(Array(-10.9054, dtype=float32),
 {'reach_dist': Array([0.02216392], dtype=float32),
  'reward': Array([6.478419], dtype=float32),
  'success_rate': Array([1.], dtype=float32),
  'target_area_dynamic_width_scale': Array([0.], dtype=float32)})

### Visualize Rollouts

In [7]:
rng_init = jax.random.PRNGKey(0)
init_state = env.reset(rng_init)

In [33]:
init_state.data.qpos

from mujoco_playground._src import mjx_env
reset_qpos, reset_qvel, reset_act = env._reset_zero(rng_init)
data = mjx_env.init(env.mjx_model, qpos=reset_qpos, qvel=reset_qvel, act=reset_act)
reset_qpos, reset_qvel, reset_act = env._reset_range_uniform(rng_init, data)
jnt_range = env._mj_model.jnt_range[env._independent_joints]
nqi = len(env._independent_qpos)
rng, rng1, rng2, rng3 = jax.random.split(rng_init, 4)
qpos = jax.random.uniform(
            rng1, shape=(nqi,), minval=jnt_range[:, 0], maxval=jnt_range[:, 1]
        )
qvel = jax.random.uniform(
    rng2,
    shape=(nqi,),
    minval=jp.ones((nqi,)) * -0.05,
    maxval=jp.ones((nqi,)) * 0.05,
)
reset_qpos = jp.zeros((env._mj_model.nq,))
reset_qvel = jp.zeros((env._mj_model.nv,))
reset_act = jax.random.uniform(
    rng3,
    shape=env._na,
    minval=jp.zeros((env._na,)),
    maxval=jp.ones((env._na,)),
)

# Set qpos and qvel
reset_qpos = reset_qpos.at[env._independent_qpos].set(qpos)
# reset_qpos[self._dependent_qpos] = 0
reset_qvel = reset_qvel.at[env._independent_dofs].set(qvel)

_joint_constraints = env.mjx_model.eq_type == 2
_active_eq_constraints = data.eq_active == 1

eq_dep, eq_indep, poly_coefs = jp.array(env.mjx_model.eq_obj1id), \
    jp.array(env.mjx_model.eq_obj2id), \
    jp.array(env.mjx_model.eq_data[:, 4::-1])

reset_qpos_new = jp.where(jp.array([jp.any((eq_dep == i) & _joint_constraints & _active_eq_constraints) for i in range(env.mjx_model.njnt)]), 
                            jp.array([jp.polyval(poly_coefs[jp.argwhere(eq_dep == i, size=1).flatten(), :].flatten(), reset_qpos[eq_indep[jp.argwhere(eq_dep == i, size=1).flatten()]]) for i in range(env.mjx_model.njnt)]).flatten(),
                            reset_qpos)

reset_qpos, data.qpos, jnt_range, reset_qpos_new

(Array([ 0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
         0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
         0.        , -1.5427936 ,  0.06563155,  0.        , -0.4542042 ,
         0.820985  , -0.8701047 ], dtype=float32),
 Array([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],      dtype=float32),
 array([[-1.5708  ,  2.26893 ],
        [ 0.      ,  3.14159 ],
        [-1.57    ,  0.349066],
        [ 0.      ,  2.26893 ],
        [-1.5708  ,  1.5708  ]]),
 Array([ 0.        , -0.0158829 ,  0.00672723, -0.00672723,  0.0158829 ,
        -0.00321596,  0.02599023,  0.01168242, -0.01168242, -0.02599023,
         0.00321596, -1.5427936 ,  0.06563155,  1.5427936 , -0.4542042 ,
         0.820985  , -0.8701047 ], dtype=float32))

In [6]:
## Variant A (reload policy from checkpoint):
# rollout = evaluate_policy(checkpoint_path=checkpoint_path, env_name=env_name,
#                           seed=123, n_episodes=2)

## Variant B (directly pass preloaded policy):
rollout = evaluate_policy(eval_env=env, jit_inference_fn=jit_inference_fn, jit_reset=jit_reset, jit_step=jit_step,
                          seed=123, n_episodes=10)

In [7]:
render_traj(rollout, eval_env=env)

100%|██████████| 4010/4010 [00:09<00:00, 428.22it/s]


0
This browser does not support the video tag.


In [10]:
[(r.obs["proprioception"][15:18], r.obs["proprioception"][47:50]) for r in rollout]

[(Array([ 0.03493699, -0.198526  ,  0.28504306], dtype=float32),
  Array([ 0.5       , -0.05966789,  0.95312196], dtype=float32)),
 (Array([ 0.04321903, -0.20694444,  0.2865013 ], dtype=float32),
  Array([ 0.5       , -0.05966789,  0.95312196], dtype=float32)),
 (Array([ 0.09450396, -0.22791187,  0.298051  ], dtype=float32),
  Array([ 0.5       , -0.05966789,  0.95312196], dtype=float32)),
 (Array([ 0.19809219, -0.24103165,  0.34398323], dtype=float32),
  Array([ 0.5       , -0.05966789,  0.95312196], dtype=float32)),
 (Array([ 0.33882892, -0.21474227,  0.48161364], dtype=float32),
  Array([ 0.5       , -0.05966789,  0.95312196], dtype=float32)),
 (Array([ 0.40011418, -0.13481002,  0.71028936], dtype=float32),
  Array([ 0.5       , -0.05966789,  0.95312196], dtype=float32)),
 (Array([ 0.3625064 , -0.07929441,  0.8808954 ], dtype=float32),
  Array([ 0.5       , -0.05966789,  0.95312196], dtype=float32)),
 (Array([ 0.367423  , -0.06623739,  0.921308  ], dtype=float32),
  Array([ 0.5     

In [11]:
env.mj_model.site_pos[env.end_line_id, :], env.mj_model.site_pos[env.screen_id, :], env.mj_model.body("screen").pos

(array([-0.00999999, -0.13367838,  0.10312194]),
 array([0., 0., 0.]),
 array([ 0.5 , -0.35,  0.85]))

In [8]:
rollout_stored[0].info["end_line"], rollout_stored[0].info["screen_pos"] , rollout_stored[0].data.site_xpos[env.end_line_id], rollout_stored[0].data.site_xpos[env.screen_id]

(Array([ 0.6       , -0.48367837,  0.95312196], dtype=float32),
 Array([ 0.6 , -0.35,  0.85], dtype=float32),
 Array([ 0.59000003, -0.7       ,  0.875     ], dtype=float32),
 Array([ 0.6 , -0.35,  0.85], dtype=float32))

In [9]:
rollout_stored[0].info["end_line"], rollout_stored[0].info["screen_pos"] , rollout_stored[0].data.site_xpos[env.end_line_id], rollout_stored[0].data.site_xpos[env.screen_id]

(Array([ 0.6       , -0.48367837,  0.95312196], dtype=float32),
 Array([ 0.6 , -0.35,  0.85], dtype=float32),
 Array([ 0.59000003, -0.7       ,  0.875     ], dtype=float32),
 Array([ 0.6 , -0.35,  0.85], dtype=float32))

In [13]:
env.mj_model.site_pos[env.end_line_id, :], env.mj_model.site_pos[env.screen_id, :], env.mj_model.body("screen").pos

(array([ 0.        , -0.13367838,  0.10312194]),
 array([0., 0., 0.]),
 array([ 0.6 , -0.35,  0.85]))

In [26]:
init_state.info["end_line"] - init_state.info["screen_pos"] , init_state.data.site_xpos[env.end_line_id] - init_state.data.site_xpos[env.screen_id]

(Array([ 0.        , -0.13367838,  0.10312194], dtype=float32),
 Array([-0.00999999, -0.35      ,  0.02499998], dtype=float32))

In [8]:
rollout[0].info["end_line"] - rollout[0].info["screen_pos"] , rollout[0].data.site_xpos[env.end_line_id] - rollout[0].data.site_xpos[env.screen_id]

(Array([ 0.        , -0.13367838,  0.10312194], dtype=float32),
 Array([-0.00999999, -0.35      ,  0.02499998], dtype=float32))

In [8]:
env.mj_model.site_pos[env.end_line_id], init_state.data.site_xpos[env.end_line_id], env.mj_model.site_pos[env.screen_id], init_state.data.site_xpos[env.screen_id]

(array([-1.01327896e-06, -5.03678381e-01,  1.03121936e-01]),
 Array([ 0.59000003, -0.7       ,  0.875     ], dtype=float32),
 array([0., 0., 0.]),
 Array([ 0.6 , -0.35,  0.85], dtype=float32))

In [11]:
env.mj_model.site_pos[env.end_line_id], init_state.data.site_xpos[env.end_line_id], env.mj_model.site_pos[env.screen_id], init_state.data.site_xpos[env.screen_id]

(array([ 0.        , -0.13367838,  0.10312194]),
 Array([ 0.59000003, -0.7       ,  0.875     ], dtype=float32),
 array([0., 0., 0.]),
 Array([ 0.6 , -0.35,  0.85], dtype=float32))

In [12]:
render_traj(rollout, eval_env=env)

  0%|          | 0/48 [00:00<?, ?it/s]

100%|██████████| 48/48 [00:00<00:00, 688.32it/s]


0
This browser does not support the video tag.


In [22]:
[(r.data.site_xpos[env.fingertip_id], r.info["end_line"]) for r in rollout]
([(r.reward, r.metrics["phase_1_x_dist"], r.metrics["con_1_crossed_line_y"]) for r in rollout[:80]])#, [r.done for r in rollout[:80]]
# rollout[0].metrics.keys()

[(Array(0., dtype=float32),
  Array(0., dtype=float32, weak_type=True),
  Array(False, dtype=bool)),
 (Array(-4.745425, dtype=float32),
  Array(0.45678097, dtype=float32),
  Array(False, dtype=bool)),
 (Array(-4.712586, dtype=float32),
  Array(0.40549603, dtype=float32),
  Array(False, dtype=bool)),
 (Array(-4.627638, dtype=float32),
  Array(0.3019078, dtype=float32),
  Array(False, dtype=bool)),
 (Array(-4.4458766, dtype=float32),
  Array(0.16117108, dtype=float32),
  Array(False, dtype=bool)),
 (Array(-4.1971245, dtype=float32),
  Array(0.09988582, dtype=float32),
  Array(False, dtype=bool)),
 (Array(-4.0805554, dtype=float32),
  Array(0.13749361, dtype=float32),
  Array(False, dtype=bool)),
 (Array(-4.060509, dtype=float32),
  Array(0.132577, dtype=float32),
  Array(False, dtype=bool)),
 (Array(-4.011795, dtype=float32),
  Array(0.07837212, dtype=float32),
  Array(False, dtype=bool)),
 (Array(-3.9762263, dtype=float32),
  Array(0.02909961, dtype=float32),
  Array(False, dtype=bool))

In [23]:
[(r.metrics['distance_phase_0'], r.metrics['distance_phase_1'], r.metrics['dist']) for r in rollout]
[(r.obs['proprioception'][15:18], r.obs['proprioception'][50:53], r.metrics['con_1_touching_screen']) for r in rollout]


[(Array([ 0.03493699, -0.198526  ,  0.28504306], dtype=float32),
  Array([ 0.5       , -0.48367837,  0.95312196], dtype=float32),
  Array(False, dtype=bool)),
 (Array([ 0.04321903, -0.20694444,  0.2865013 ], dtype=float32),
  Array([ 0.5       , -0.48367837,  0.95312196], dtype=float32),
  Array(False, dtype=bool)),
 (Array([ 0.09450396, -0.22791187,  0.298051  ], dtype=float32),
  Array([ 0.5       , -0.48367837,  0.95312196], dtype=float32),
  Array(False, dtype=bool)),
 (Array([ 0.19809219, -0.24103165,  0.34398323], dtype=float32),
  Array([ 0.5       , -0.48367837,  0.95312196], dtype=float32),
  Array(False, dtype=bool)),
 (Array([ 0.33882892, -0.21474227,  0.48161364], dtype=float32),
  Array([ 0.5       , -0.48367837,  0.95312196], dtype=float32),
  Array(False, dtype=bool)),
 (Array([ 0.40011418, -0.13481002,  0.71028936], dtype=float32),
  Array([ 0.5       , -0.48367837,  0.95312196], dtype=float32),
  Array(False, dtype=bool)),
 (Array([ 0.3625064 , -0.07929441,  0.8808954 

In [49]:
touch_detected = any([{eval_env.mj_model.geom(con[0]).name, eval_env.mj_model.geom(con[1]).name} == {"fingertip_contact", "screen"} for con in state.data._impl.contact.geom])
touches = [{eval_env.mj_model.geom(con[0]).name, eval_env.mj_model.geom(con[1]).name} for con in state.data._impl.contact.geom]

touch_detected, touches

(True, [{'fingertip_contact', 'screen'}])

In [51]:
state.data._impl.contact.dist

Array([1.], dtype=float32)

In [41]:
eval_env.mj_model.geom

<bound method PyCapsule.geom of <mujoco._structs.MjModel object at 0x7749d83c61b0>>

In [26]:
# eval_key, reset_keys = jax.random.split(eval_key)
# state = jit_reset(reset_keys)
# eval_key, key = jax.random.split(eval_key)
# ctrl, _ = jit_inference_fn(state.obs, key)  #VARIANT 1
state = jit_step(state, ctrl)
state.metrics["dist_combined"]
(np.exp(-state.metrics["dist_combined"]*10.) - 1.)/10., (np.exp(-1.5*10.) - 1.)/10.

(np.float32(-0.09992059), np.float64(-0.09999996940976795))

In [33]:
np.exp(-0.7*2), (np.exp(-1.5*2))

(np.float64(0.2465969639416065), np.float64(0.049787068367863944))

In [17]:
np.where([r.info["phase_0_done"] for r in rollout])[0][0]*env.unwrapped._ctrl_dt
[(r.info["phase_0_done"], r.reward) for r in rollout]

[(Array(False, dtype=bool), Array(0., dtype=float32)),
 (Array(False, dtype=bool), Array(-0.09999998, dtype=float32)),
 (Array(False, dtype=bool), Array(-0.09999997, dtype=float32)),
 (Array(False, dtype=bool), Array(-0.09999994, dtype=float32)),
 (Array(False, dtype=bool), Array(-0.09999982, dtype=float32)),
 (Array(False, dtype=bool), Array(-0.09999924, dtype=float32)),
 (Array(False, dtype=bool), Array(-0.09999751, dtype=float32)),
 (Array(False, dtype=bool), Array(-0.09999336, dtype=float32)),
 (Array(False, dtype=bool), Array(-0.09998853, dtype=float32)),
 (Array(False, dtype=bool), Array(-0.09998377, dtype=float32)),
 (Array(False, dtype=bool), Array(-0.0999779, dtype=float32)),
 (Array(False, dtype=bool), Array(-0.09997098, dtype=float32)),
 (Array(False, dtype=bool), Array(-0.0999607, dtype=float32)),
 (Array(False, dtype=bool), Array(-0.09995576, dtype=float32)),
 (Array(False, dtype=bool), Array(-0.0999537, dtype=float32)),
 (Array(False, dtype=bool), Array(-0.0999264, dtype=

In [22]:
[i.info["target_pos"] for i in rollout], [_i.obs["proprioception"][15:18] for _i in rollout]

([Array([ 0.26876748, -0.24839334,  0.79923654], dtype=float32),
  Array([ 0.26876748, -0.24839334,  0.79923654], dtype=float32),
  Array([ 0.26876748, -0.24839334,  0.79923654], dtype=float32),
  Array([ 0.26876748, -0.24839334,  0.79923654], dtype=float32),
  Array([ 0.26876748, -0.24839334,  0.79923654], dtype=float32),
  Array([ 0.26876748, -0.24839334,  0.79923654], dtype=float32),
  Array([ 0.26876748, -0.24839334,  0.79923654], dtype=float32),
  Array([ 0.26876748, -0.24839334,  0.79923654], dtype=float32),
  Array([ 0.26876748, -0.24839334,  0.79923654], dtype=float32),
  Array([ 0.26876748, -0.24839334,  0.79923654], dtype=float32),
  Array([ 0.26876748, -0.24839334,  0.79923654], dtype=float32),
  Array([ 0.26876748, -0.24839334,  0.79923654], dtype=float32),
  Array([ 0.26876748, -0.24839334,  0.79923654], dtype=float32)],
 [Array([ 0.03493699, -0.198526  ,  0.28504306], dtype=float32),
  Array([ 0.05467008, -0.20556903,  0.2879694 ], dtype=float32),
  Array([ 0.12839839, -0

In [11]:
np.sum(rewards)

np.float32(7.423525)