In [None]:
#@title Install pre-requisites
!pip install mujoco
!pip install mujoco_mjx
!pip install brax

In [1]:
# @title Check if MuJoCo installation was successful

import distutils.util
import os
import subprocess

if subprocess.run('nvidia-smi').returncode:
  raise RuntimeError(
      'Cannot communicate with GPU. '
      'Make sure you are using a GPU Colab runtime. '
      'Go to the Runtime menu and select Choose runtime type.'
  )

# Add an ICD config so that glvnd can pick up the Nvidia EGL driver.
# This is usually installed as part of an Nvidia driver package, but the Colab
# kernel doesn't install its driver via APT, and as a result the ICD is missing.
# (https://github.com/NVIDIA/libglvnd/blob/master/src/EGL/icd_enumeration.md)
NVIDIA_ICD_CONFIG_PATH = '/usr/share/glvnd/egl_vendor.d/10_nvidia.json'
if not os.path.exists(NVIDIA_ICD_CONFIG_PATH):
  with open(NVIDIA_ICD_CONFIG_PATH, 'w') as f:
    f.write("""{
    "file_format_version" : "1.0.0",
    "ICD" : {
        "library_path" : "libEGL_nvidia.so.0"
    }
}
""")

# Configure MuJoCo to use the EGL rendering backend (requires GPU)
print('Setting environment variable to use GPU rendering:')
%env MUJOCO_GL=egl

try:
  print('Checking that the installation succeeded:')
  import mujoco

  mujoco.MjModel.from_xml_string('<mujoco/>')
except Exception as e:
  raise e from RuntimeError(
      'Something went wrong during installation. Check the shell output above '
      'for more information.\n'
      'If using a hosted Colab runtime, make sure you enable GPU acceleration '
      'by going to the Runtime menu and selecting "Choose runtime type".'
  )

print('Installation successful.')

# Tell XLA to use Triton GEMM, this improves steps/sec by ~30% on some GPUs
xla_flags = os.environ.get('XLA_FLAGS', '')
xla_flags += ' --xla_gpu_triton_gemm_any=True'
os.environ['XLA_FLAGS'] = xla_flags

Thu Sep 11 04:57:47 2025       
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 570.153.02             Driver Version: 570.153.02     CUDA Version: 12.8     |
|-----------------------------------------+------------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id          Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |           Memory-Usage | GPU-Util  Compute M. |
|                                         |                        |               MIG M. |
|   0  NVIDIA GeForce GTX 1650        Off |   00000000:01:00.0 Off |                  N/A |
| N/A   55C    P8              4W /   50W |       9MiB /   4096MiB |      0%      Default |
|                                         |                        |                  N/A |
+-----------------------------------------+------------------------+----------------------+
                                                

In [2]:
# @title Import packages for plotting and creating graphics
import json
import itertools
import time
from typing import Callable, List, NamedTuple, Optional, Union
import numpy as np

# Graphics and plotting.
print("Installing mediapy:")
!command -v ffmpeg >/dev/null || (apt update && apt install -y ffmpeg)
!pip install -q mediapy
import mediapy as media
import matplotlib.pyplot as plt

# More legible printing from numpy.
np.set_printoptions(precision=3, suppress=True, linewidth=100)

Installing mediapy:


In [3]:
# @title Import MuJoCo, MJX, and Brax
from datetime import datetime
import functools
import os
from typing import Any, Dict, Sequence, Tuple, Union
from brax import base
from brax import envs
from brax import math
from brax.base import Base, Motion, Transform
from brax.base import State as PipelineState
from brax.envs.base import Env, PipelineEnv, State
from brax.io import html, mjcf, model
from brax.mjx.base import State as MjxState
from brax.training.agents.ppo import networks as ppo_networks
from brax.training.agents.ppo import train as ppo
from brax.training.agents.sac import networks as sac_networks
from brax.training.agents.sac import train as sac
from etils import epath
from flax import struct
from flax.training import orbax_utils
from IPython.display import HTML, clear_output
import jax
from jax import numpy as jp
from matplotlib import pyplot as plt
import mediapy as media
from ml_collections import config_dict
import mujoco
from mujoco import mjx
import numpy as np
from orbax import checkpoint as ocp

In [7]:
%ls

In [8]:
!git clone https://github.com/Maker-Rat/mujoco_playground_unitree.git 

Cloning into 'mujoco_playground_unitree'...
remote: Enumerating objects: 1122, done.[K
remote: Counting objects: 100% (437/437), done.[K
remote: Compressing objects: 100% (140/140), done.[K
remote: Total 1122 (delta 358), reused 313 (delta 297), pack-reused 685 (from 1)[K
Receiving objects: 100% (1122/1122), 23.57 MiB | 19.94 MiB/s, done.
Resolving deltas: 100% (726/726), done.


In [None]:
%cd mujoco_playground_unitree
!pip install -e ".[all]"

In [4]:
#@title Import The Playground
from mujoco_playground import wrapper
from mujoco_playground import registry

In [48]:
env_name = 'SpiderbotStairsClimbing'
env = registry.load(env_name)
env_cfg = registry.get_default_config(env_name)

from mujoco_playground.config import locomotion_params
ppo_params = locomotion_params.brax_ppo_config(env_name)
ppo_params.entropy_cost = 0.025
ppo_params.learning_rate = 0.00015
ppo_params.num_timesteps = int(1)
ppo_params.batch_size = 1
ppo_params.episode_length = 10
ppo_params.num_minibatches = 2
ppo_params.unroll_length = 2
ppo_params.num_envs = 1


ppo_params

action_repeat: 1
batch_size: 1
discounting: 0.97
entropy_cost: 0.025
episode_length: 10
learning_rate: 0.00015
max_grad_norm: 1.0
network_factory:
  policy_hidden_layer_sizes: &id001 !!python/tuple
  - 512
  - 256
  - 128
  policy_obs_key: state
  value_hidden_layer_sizes: *id001
  value_obs_key: privileged_state
normalize_observations: true
num_envs: 1
num_evals: 10
num_minibatches: 2
num_resets_per_eval: 10
num_timesteps: 1
num_updates_per_batch: 4
reward_scaling: 1.0
unroll_length: 2

In [49]:
x_data, y_data, y_dataerr = [], [], []
times = [datetime.now()]


def progress(num_steps, metrics):
  clear_output(wait=True)

  times.append(datetime.now())
  x_data.append(num_steps)
  y_data.append(metrics["eval/episode_reward"])
  y_dataerr.append(metrics["eval/episode_reward_std"])

  plt.xlim([0, ppo_params["num_timesteps"] * 1.25])
  plt.xlabel("# environment steps")
  plt.ylabel("reward per episode")
  plt.title(f"y={y_data[-1]:.3f}")
  plt.errorbar(x_data, y_data, yerr=y_dataerr, color="blue")

  display(plt.gcf())

randomizer = registry.get_domain_randomizer(env_name)
ppo_training_params = dict(ppo_params)
network_factory = ppo_networks.make_ppo_networks
if "network_factory" in ppo_params:
  del ppo_training_params["network_factory"]
  network_factory = functools.partial(
      ppo_networks.make_ppo_networks,
      **ppo_params.network_factory
  )

train_fn = functools.partial(
    ppo.train, **dict(ppo_training_params),
    network_factory=network_factory,
    randomization_fn=randomizer,
    progress_fn=progress
)

Env 'SpiderbotStairsClimbing' does not have a domain randomizer in the locomotion registry.


In [51]:
make_inference_fn, params, metrics = train_fn(
    environment=env,
    eval_env=registry.load(env_name, config=env_cfg),
    wrap_env_fn=wrapper.wrap_for_brax_training,
)
print(f"time to jit: {times[1] - times[0]}")
print(f"time to train: {times[-1] - times[1]}")

Exception ignored in: <bound method IPythonKernel._clean_thread_parent_frames of <ipykernel.ipkernel.IPythonKernel object at 0x73e1e9749890>>
Traceback (most recent call last):
  File "/home/ritwik/anaconda3/envs/playground/lib/python3.11/site-packages/ipykernel/ipkernel.py", line 781, in _clean_thread_parent_frames
    def _clean_thread_parent_frames(

KeyboardInterrupt: 


KeyboardInterrupt: 

In [53]:
# Enable perturbation in the eval env.
env_cfg = registry.get_default_config(env_name)
env_cfg.pert_config.enable = True
env_cfg.pert_config.velocity_kick = [3.0, 6.0]
env_cfg.pert_config.kick_wait_times = [5.0, 15.0]
env_cfg.command_config.a = [1.5, 0.8, 2*jp.pi]
eval_env = registry.load(env_name, config=env_cfg)
velocity_kick_range = [0.0, 0.0]  # Disable velocity kick.
kick_duration_range = [0.05, 0.2]

jit_reset = jax.jit(eval_env.reset)
jit_step = jax.jit(eval_env.step)
jit_inference_fn = jax.jit(make_inference_fn(params, deterministic=True))

In [54]:
#@title Rollout and Render
from mujoco_playground._src.gait import draw_joystick_command

x_vel = -0.2  #@param {type: "number"}
y_vel = -0.2  #@param {type: "number"}
yaw_vel = 0.0  #@param {type: "number"}


def sample_pert(rng):
  rng, key1, key2 = jax.random.split(rng, 3)
  pert_mag = jax.random.uniform(
      key1, minval=velocity_kick_range[0], maxval=velocity_kick_range[1]
  )
  duration_seconds = jax.random.uniform(
      key2, minval=kick_duration_range[0], maxval=kick_duration_range[1]
  )
  duration_steps = jp.round(duration_seconds / eval_env.dt).astype(jp.int32)
  state.info["pert_mag"] = pert_mag
  state.info["pert_duration"] = duration_steps
  state.info["pert_duration_seconds"] = duration_seconds
  return rng


rng = jax.random.PRNGKey(0)
rollout = []
modify_scene_fns = []

swing_peak = []
rewards = []
linvel = []
angvel = []
track = []
foot_vel = []
rews = []
contact = []
command = jp.array([x_vel, y_vel, yaw_vel])

state = jit_reset(rng)
if state.info["steps_since_last_pert"] < state.info["steps_until_next_pert"]:
  rng = sample_pert(rng)
state.info["command"] = command
for i in range(env_cfg.episode_length//10):
  if state.info["steps_since_last_pert"] < state.info["steps_until_next_pert"]:
    rng = sample_pert(rng)
  act_rng, rng = jax.random.split(rng)
#   ctrl, _ = jit_inference_fn(state.obs, act_rng)
  ctrl = jp.zeros(env.action_size)  
    
  state = jit_step(state, ctrl)
  state.info["command"] = command
  rews.append(
      {k: v for k, v in state.metrics.items() if k.startswith("reward/")}
  )
  rollout.append(state)
  swing_peak.append(state.info["swing_peak"])
  rewards.append(
      {k[7:]: v for k, v in state.metrics.items() if k.startswith("reward/")}
  )
  linvel.append(env.get_global_linvel(state.data))
  angvel.append(env.get_gyro(state.data))
  track.append(
      env._reward_tracking_lin_vel(
          state.info["command"], env.get_local_linvel(state.data)
      )
  )

  feet_vel = state.data.sensordata[env._foot_linvel_sensor_adr]
  vel_xy = feet_vel[..., :2]
  vel_norm = jp.sqrt(jp.linalg.norm(vel_xy, axis=-1))
  foot_vel.append(vel_norm)

  contact.append(state.info["last_contact"])

  xyz = np.array(state.data.xpos[env._torso_body_id])
  xyz += np.array([0, 0, 0.2])
  x_axis = state.data.xmat[env._torso_body_id, 0]
  yaw = -np.arctan2(x_axis[1], x_axis[0])
  modify_scene_fns.append(
      functools.partial(
          draw_joystick_command,
          cmd=state.info["command"],
          xyz=xyz,
          theta=yaw,
          scl=abs(state.info["command"][0])
          / env_cfg.command_config.a[0],
      )
  )


render_every = 2
fps = 1.0 / eval_env.dt / render_every
traj = rollout[::render_every]
mod_fns = modify_scene_fns[::render_every]

scene_option = mujoco.MjvOption()
scene_option.geomgroup[2] = True
scene_option.geomgroup[3] = False
scene_option.flags[mujoco.mjtVisFlag.mjVIS_CONTACTPOINT] = True
scene_option.flags[mujoco.mjtVisFlag.mjVIS_TRANSPARENT] = False
scene_option.flags[mujoco.mjtVisFlag.mjVIS_PERTFORCE] = True

# Create camera object to specify our custom camera settings
camera = mujoco.MjvCamera()
camera.type = mujoco.mjtCamera.mjCAMERA_FREE
camera.lookat = [0.25, -0.25, 0.25]  # Point camera at robot's height
camera.distance = 3.0  # Set camera distance
camera.azimuth = 45  # Set camera angle around z-axis
camera.elevation = -20  # Set camera angle above ground

frames = eval_env.render(
    traj,
    scene_option=scene_option,
    camera=camera,  # Add custom camera settings
    width=640,
    height=480,
    modify_scene_fns=mod_fns,
)
media.show_video(frames, fps=fps, loop=False)

100%|██████████| 50/50 [00:00<00:00, 178.50it/s]


0
This browser does not support the video tag.


In [28]:
import numpy as np
from brax.io import model

# Load parameters
path = "hexapod_PPO.npz"
params = model.load_params(path)

# Extract only the policy parameters (assumed to be in Element 1)
policy_params = params[1]  # Second element of the tuple

# Function to recursively convert JAX tensors to NumPy
def jax_to_numpy(params):
    if isinstance(params, dict):
        return {k: jax_to_numpy(v) for k, v in params.items()}
    elif isinstance(params, (list, tuple)):
        return [jax_to_numpy(v) for v in params]
    else:
        return np.array(params)  # Convert JAX tensor to NumPy

numpy_params = jax_to_numpy(policy_params)

# Save the converted NumPy parameters
np.savez("hexapod_PPO_numpy.npz", **numpy_params)
print("Converted policy parameters saved as hexapod_PPO_numpy.npz")

Converted policy parameters saved as hexapod_PPO_numpy.npz


In [None]:
#@title Rollout and Render
from mujoco_playground._src.gait import draw_joystick_command
    
x_vel = -0.2  #@param {type: "number"}
y_vel = -0.2  #@param {type: "number"}
yaw_vel = 0.0  #@param {type: "number"}

def sample_pert(rng):
  rng, key1, key2 = jax.random.split(rng, 3)
  pert_mag = jax.random.uniform(
      key1, minval=velocity_kick_range[0], maxval=velocity_kick_range[1]
  )
  duration_seconds = jax.random.uniform(
      key2, minval=kick_duration_range[0], maxval=kick_duration_range[1]
  )
  duration_steps = jp.round(duration_seconds / eval_env.dt).astype(jp.int32)
  state.info["pert_mag"] = pert_mag
  state.info["pert_duration"] = duration_steps
  state.info["pert_duration_seconds"] = duration_seconds
  return rng


rng = jax.random.PRNGKey(0)
rollout = []
modify_scene_fns = []

swing_peak = []
rewards = []
linvel = []
angvel = []
track = []
foot_vel = []
rews = []
contact = []
command = jp.array([x_vel, y_vel, yaw_vel])

state = jit_reset(rng)
if state.info["steps_since_last_pert"] < state.info["steps_until_next_pert"]:
  rng = sample_pert(rng)
state.info["command"] = command
for i in range(env_cfg.episode_length):
  if state.info["steps_since_last_pert"] < state.info["steps_until_next_pert"]:
    rng = sample_pert(rng)

  
  act_rng, rng = jax.random.split(rng)
  ctrl, _ = jit_inference_fn(state.obs, act_rng)
  if i % 50 == 0:  # Print every 10 steps
  #     # print(f"Step {i}:")
      print(f"State: {state.obs['state']}")
  #     # print(np.array(state.obs['state']))
      print(f"Action: {ctrl}")
      # print(f"Action Numpy: {action_np}")
    
  # ctrl = jp.zeros((12,))
  state = jit_step(state, ctrl)
  state.info["command"] = command
  rews.append(
      {k: v for k, v in state.metrics.items() if k.startswith("reward/")}
  )
  rollout.append(state)
  swing_peak.append(state.info["swing_peak"])
  rewards.append(
      {k[7:]: v for k, v in state.metrics.items() if k.startswith("reward/")}
  )
  linvel.append(env.get_global_linvel(state.data))
  angvel.append(env.get_gyro(state.data))
  track.append(
      env._reward_tracking_lin_vel(
          state.info["command"], env.get_local_linvel(state.data)
      )
  )

  feet_vel = state.data.sensordata[env._foot_linvel_sensor_adr]
  vel_xy = feet_vel[..., :2]
  vel_norm = jp.sqrt(jp.linalg.norm(vel_xy, axis=-1))
  foot_vel.append(vel_norm)

  contact.append(state.info["last_contact"])

  xyz = np.array(state.data.xpos[env._torso_body_id])
  xyz += np.array([0, 0, 0.2])
  x_axis = state.data.xmat[env._torso_body_id, 0]
  yaw = -np.arctan2(x_axis[1], x_axis[0])
  modify_scene_fns.append(
      functools.partial(
          draw_joystick_command,
          cmd=state.info["command"],
          xyz=xyz,
          theta=yaw,
          scl=abs(state.info["command"][0])
          / env_cfg.command_config.a[0],
      )
  )


render_every = 2
fps = 1.0 / eval_env.dt / render_every
traj = rollout[::render_every]
mod_fns = modify_scene_fns[::render_every]

scene_option = mujoco.MjvOption()
scene_option.geomgroup[2] = True
scene_option.geomgroup[3] = False
scene_option.flags[mujoco.mjtVisFlag.mjVIS_CONTACTPOINT] = True
scene_option.flags[mujoco.mjtVisFlag.mjVIS_TRANSPARENT] = False
scene_option.flags[mujoco.mjtVisFlag.mjVIS_PERTFORCE] = True

# Create camera object to specify our custom camera settings
camera = mujoco.MjvCamera()
camera.type = mujoco.mjtCamera.mjCAMERA_FREE
camera.lookat = [0, 0, 0.5]  # Point camera at robot's height
camera.distance = 2.0  # Set camera distance
camera.azimuth = 45  # Set camera angle around z-axis
camera.elevation = -20  # Set camera angle above ground

frames = eval_env.render(
    traj,
    scene_option=scene_option,
    camera=camera,  # Add custom camera settings
    width=640,
    height=480,
    modify_scene_fns=mod_fns,
)
media.show_video(frames, fps=fps, loop=False)

  and should_run_async(code)


Parameter keys: No keys (not a dict)
[789897092 923717985]
Input shape: (36,)
Final network output shape: (36,)
Mean shape: (18,), first values: [-0.084  0.34  -0.266  0.112 -0.145]
Final action shape: (18,)
State: [-0.084  0.34  -0.266  0.112 -0.145 -0.194  0.15   0.04  -0.985  0.011  0.009  0.015 -0.022 -0.028
  0.012  0.016 -0.014 -0.013  0.011  0.012 -0.029  0.     0.     0.     0.     0.     0.     0.
  0.     0.     0.     0.     0.    -0.495 -0.346  0.062]
Action: [ 0.086 -0.424 -0.11  -0.407 -0.146 -0.682 -0.485 -0.206 -0.122 -0.054 -0.609 -0.482]
[3111037091  652627534]
Input shape: (36,)
Final network output shape: (36,)
Mean shape: (18,), first values: [-0.285  0.22  -0.041  0.077 -2.569]
Final action shape: (18,)
[3389846981 1335233766]
Input shape: (36,)
Final network output shape: (36,)
Mean shape: (18,), first values: [-0.226  0.202 -0.204  0.004 -2.657]
Final action shape: (18,)
[2277789199 3854972894]
Input shape: (36,)
Final network output shape: (36,)
Mean shape: (18

KeyboardInterrupt: 

In [60]:
params

NpzFile 'hexapod_numpy_model.npz' with keys: params

In [59]:
from brax import envs
from brax.io import model
import jax
import jax.numpy as jnp
import flax

path = "hexapod_PPO.npz"
model.save_params(path, params)

print(f"Model saved to {path}")

  and should_run_async(code)


TypeError: cannot pickle '_io.BufferedReader' object

In [21]:
new_env_name = "SpiderbotJoystickFlatTerrain"
new_env = registry.load(new_env_name)


from mujoco_playground.config import locomotion_params
ppo_params = locomotion_params.brax_ppo_config(new_env_name)

ppo_training_params = dict(ppo_params)
network_factory = ppo_networks.make_ppo_networks
if "network_factory" in ppo_params:
  del ppo_training_params["network_factory"]
  network_factory = functools.partial(
      ppo_networks.make_ppo_networks,
      **ppo_params.network_factory
  )

ppo_mod = network_factory()

inference_fn_new = ppo_networks.make_inference_fn(ppo_mod)
params = model.load_params("hexapod_PPO.npz")

jit_inference_fn_new = jax.jit(inference_fn_new(params, deterministic=True))

TypeError: make_ppo_networks() missing 2 required positional arguments: 'observation_size' and 'action_size'

In [None]:
#@title Plot each foot in a 2x2 grid.

swing_peak = jp.array(swing_peak)
names = ["foot_1", "foot_2", "foot_3", "foot_4"]
colors = ["r", "g", "b", "y"]
fig, axs = plt.subplots(2, 2)
for i, ax in enumerate(axs.flat):
  ax.plot(swing_peak[:, i], color=colors[i])
  ax.set_ylim([0, env_cfg.reward_config.max_foot_height * 1.25])
  ax.axhline(env_cfg.reward_config.max_foot_height, color="k", linestyle="--")
  ax.set_title(names[i])
  ax.set_xlabel("time")
  ax.set_ylabel("height")
plt.tight_layout()
plt.show()

linvel_x = jp.array(linvel)[:, 0]
linvel_y = jp.array(linvel)[:, 1]
angvel_yaw = jp.array(angvel)[:, 2]

# Plot whether velocity is within the command range.
linvel_x = jp.convolve(linvel_x, jp.ones(10) / 10, mode="same")
linvel_y = jp.convolve(linvel_y, jp.ones(10) / 10, mode="same")
angvel_yaw = jp.convolve(angvel_yaw, jp.ones(10) / 10, mode="same")

fig, axes = plt.subplots(3, 1, figsize=(10, 10))
axes[0].plot(linvel_x)
axes[1].plot(linvel_y)
axes[2].plot(angvel_yaw)

axes[0].set_ylim(
    -env_cfg.command_config.a[0], env_cfg.command_config.a[0]
)
axes[1].set_ylim(
    -env_cfg.command_config.a[1], env_cfg.command_config.a[1]
)
axes[2].set_ylim(
    -env_cfg.command_config.a[2], env_cfg.command_config.a[2]
)

for i, ax in enumerate(axes):
  ax.axhline(state.info["command"][i], color="red", linestyle="--")

labels = ["dx", "dy", "dyaw"]
for i, ax in enumerate(axes):
  ax.set_ylabel(labels[i])

In [None]:
#@title Slowly increase linvel commands

rng = jax.random.PRNGKey(0)
rollout = []
modify_scene_fns = []
swing_peak = []
linvel = []
angvel = []

x = -0.25
command = jp.array([x, 0, 0])
state = jit_reset(rng)
for i in range(1_400):
  # Increase the forward velocity by 0.25 m/s every 200 steps.
  if i % 20 == 0:
    x += 0.5
    print(f"Setting x to {x}")
    command = jp.array([x, 0, 0])
  state.info["command"] = command
  if state.info["steps_since_last_pert"] < state.info["steps_until_next_pert"]:
    rng = sample_pert(rng)
  act_rng, rng = jax.random.split(rng)
  ctrl, _ = jit_inference_fn(state.obs, act_rng)
  state = jit_step(state, ctrl)
  rollout.append(state)
  swing_peak.append(state.info["swing_peak"])
  linvel.append(env.get_global_linvel(state.data))
  angvel.append(env.get_gyro(state.data))
  xyz = np.array(state.data.xpos[env._torso_body_id])
  xyz += np.array([0, 0, 0.2])
  x_axis = state.data.xmat[env._torso_body_id, 0]
  yaw = -np.arctan2(x_axis[1], x_axis[0])
  modify_scene_fns.append(
      functools.partial(
          draw_joystick_command,
          cmd=command,
          xyz=xyz,
          theta=yaw,
          scl=abs(command[0]) / env_cfg.command_config.a[0],
      )
  )


# Plot each foot in a 2x2 grid.
swing_peak = jp.array(swing_peak)
names = ["FR", "FL", "RR", "RL"]
colors = ["r", "g", "b", "y"]
fig, axs = plt.subplots(2, 2)
for i, ax in enumerate(axs.flat):
  ax.plot(swing_peak[:, i], color=colors[i])
  ax.set_ylim([0, env_cfg.reward_config.max_foot_height * 1.25])
  ax.axhline(env_cfg.reward_config.max_foot_height, color="k", linestyle="--")
  ax.set_title(names[i])
  ax.set_xlabel("time")
  ax.set_ylabel("height")
plt.tight_layout()
plt.show()

linvel_x = jp.array(linvel)[:, 0]
linvel_y = jp.array(linvel)[:, 1]
angvel_yaw = jp.array(angvel)[:, 2]

# Plot whether velocity is within the command range.
linvel_x = jp.convolve(linvel_x, jp.ones(10) / 10, mode="same")
linvel_y = jp.convolve(linvel_y, jp.ones(10) / 10, mode="same")
angvel_yaw = jp.convolve(angvel_yaw, jp.ones(10) / 10, mode="same")

fig, axes = plt.subplots(3, 1, figsize=(10, 10))
axes[0].plot(linvel_x)
axes[1].plot(linvel_y)
axes[2].plot(angvel_yaw)

axes[0].set_ylim(
    -env_cfg.command_config.a[0], env_cfg.command_config.a[0]
)
axes[1].set_ylim(
    -env_cfg.command_config.a[1], env_cfg.command_config.a[1]
)
axes[2].set_ylim(
    -env_cfg.command_config.a[2], env_cfg.command_config.a[2]
)

for i, ax in enumerate(axes):
  ax.axhline(state.info["command"][i], color="red", linestyle="--")

labels = ["dx", "dy", "dyaw"]
for i, ax in enumerate(axes):
  ax.set_ylabel(labels[i])


render_every = 2
fps = 1.0 / eval_env.dt / render_every
print(f"fps: {fps}")

traj = rollout[::render_every]
mod_fns = modify_scene_fns[::render_every]
assert len(traj) == len(mod_fns)

scene_option = mujoco.MjvOption()
scene_option.geomgroup[2] = True
scene_option.geomgroup[3] = False
scene_option.flags[mujoco.mjtVisFlag.mjVIS_CONTACTPOINT] = True
scene_option.flags[mujoco.mjtVisFlag.mjVIS_PERTFORCE] = True

frames = eval_env.render(
    traj,
    height=480,
    width=640,
    modify_scene_fns=mod_fns,
    scene_option=scene_option,
)
media.show_video(frames, fps=fps, loop=False)

In [48]:
import jax
import numpy as np
from brax.io import model

def convert_jax_to_numpy_model(jax_model_path, numpy_model_path):
    """
    Converts a JAX model to a NumPy-compatible format.
    
    Args:
        jax_model_path: Path to the JAX model file
        numpy_model_path: Path to save the NumPy model file
    """
    print(f"Loading JAX model from {jax_model_path}")
    
    # Load the JAX model parameters
    params = model.load_params(jax_model_path)
    
    # Extract the policy parameters
    if 'inference_fn' in params:
        policy_params = params['inference_fn']
    else:
        policy_params = params
    
    # Convert JAX parameters to NumPy
    numpy_params = jax.tree_util.tree_map(lambda x: np.array(x), policy_params)
    
    # Save as a NumPy file
    np.savez(numpy_model_path, params=numpy_params)
    
    print(f"NumPy model saved to {numpy_model_path}")
    return numpy_params

if __name__ == "__main__":
    # Convert the model
    convert_jax_to_numpy_model("hexapod_PPO.npz", "hexapod_numpy_model.npz")
    
    # Print model structure for debugging
    loaded_data = np.load("hexapod_numpy_model.npz", allow_pickle=True)
    # print(loaded_data["params"])
    params = loaded_data
    print(params)
    # 
    print("\nModel structure:")
    for key in params:
        if isinstance(params[key], dict):
            print(f"{key}:")
            for subkey, value in params[key].items():
                print(f"  {subkey}: shape={value.shape}, dtype={value.dtype}")
        else:
            print(f"{key}: shape={params[key].shape}, dtype={params[key].dtype}")
    
    print("\nConversion complete! You can now use the NumPy model without JAX.")

Loading JAX model from hexapod_PPO.npz
NumPy model saved to hexapod_numpy_model.npz
NpzFile 'hexapod_numpy_model.npz' with keys: params

Model structure:
params: shape=(3,), dtype=object

Conversion complete! You can now use the NumPy model without JAX.


  and should_run_async(code)
