# Walking in JaxSim

This notebook demonstrates how to use comodo with JaxSim simulator to control the walking motion of a humanoid robot.

Here's a list of acronyms used in the notebook:
- `lf`, `rf`: left foot, right foot
- `js`: JaxSim
- `tsid`: Task Space Inverse Dynamics
- `mpc`: Model Predictive Control
- `sfp`: Swing Foot Planner
- `mj`: Mujoco
- `s`: joint positions
- `ds`: joint velocities
- `τ`: joint torques
- `b`: base
- `com`: center of mass
- `dcom`: center of mass velocity

In [None]:
# ==== Imports ====
from __future__ import annotations
import xml.etree.ElementTree as ET
import jax.numpy as jnp
import numpy as np
import tempfile
import urllib.request
import time
import os
import matplotlib.pyplot as plt
import datetime
import pathlib
from pathlib import Path
import traceback

# Here we set some environment variables
# Flag to solve MUMPS hanging
os.environ["OMP_NUM_THREADS"] = "1"

from comodo.jaxsimSimulator import JaxsimSimulator, JaxsimContactModelEnum
from comodo.robotModel.robotModel import RobotModel
from comodo.robotModel.createUrdf import createUrdf
from comodo.centroidalMPC.centroidalMPC import CentroidalMPC
from comodo.centroidalMPC.mpcParameterTuning import MPCParameterTuning
from comodo.TSIDController.TSIDParameterTuning import TSIDParameterTuning
from comodo.TSIDController.TSIDController import TSIDController

In [2]:
# ==== Load the stickbot model ====

# Getting stickbot urdf file and convert it to string
urdf_robot_file = tempfile.NamedTemporaryFile(mode="w+")
url = "https://raw.githubusercontent.com/icub-tech-iit/ergocub-gazebo-simulations/master/models/stickBot/model.urdf"
urllib.request.urlretrieve(url, urdf_robot_file.name)
# Load the URDF file
tree = ET.parse(urdf_robot_file.name)
root = tree.getroot()

# Convert the XML tree to a string
robot_urdf_string_original = ET.tostring(root)

create_urdf_instance = createUrdf(
    original_urdf_path=urdf_robot_file.name, save_gazebo_plugin=False
)

joint_names = [
    "r_shoulder_pitch",  # 0
    "r_shoulder_roll",  # 1
    "r_shoulder_yaw",  # 2
    "r_elbow",  # 3
    "l_shoulder_pitch",  # 4
    "l_shoulder_roll",  # 5
    "l_shoulder_yaw",  # 6
    "l_elbow",  # 7
    "r_hip_pitch",  # 8
    "r_hip_roll",  # 9
    "r_hip_yaw",  # 10
    "r_knee",  # 11
    "r_ankle_pitch",  # 12
    "r_ankle_roll",  # 13
    "l_hip_pitch",  # 14
    "l_hip_roll",  # 15
    "l_hip_yaw",  # 16
    "l_knee",  # 17
    "l_ankle_pitch",  # 18
    "l_ankle_roll",  # 19
]

urdf_robot_string = create_urdf_instance.write_urdf_to_file()
robot_model_init = RobotModel(urdf_robot_string, "stickBot", joint_names)
robot_model_init.set_foot_corner(
    np.asarray([0.1, 0.05, 0.0]),
    np.asarray([0.1, -0.05, 0.0]),
    np.asarray([-0.1, -0.05, 0.0]),
    np.asarray([-0.1, 0.05, 0.0]),
)

In [3]:
# ==== Set simulation parameters ====

T = 3.0
js_dt = 0.001

In [None]:
# ==== Compute initial configuration ====

s_0, xyz_rpy_0, H_b_0 = robot_model_init.compute_desired_position_walking()

print(
    f"Initial configuration:\nBase position: {xyz_rpy_0[:3]}\nBase orientation: {xyz_rpy_0[3:]}\nJoint positions: {s_0}"
)

In [None]:
# ==== Define JaxSim simulator and set initial position ====

js = JaxsimSimulator(dt=js_dt, contact_model_type=JaxsimContactModelEnum.RELAXED_RIGID)
js.load_model(
    robot_model=robot_model_init,
    s=s_0,
    xyz_rpy=xyz_rpy_0,
    # Possible choices are "record", "interactive" or None (no visualization)
    visualization_mode="record",
)

s_js, ds_js, tau_js = js.get_state()
t = 0.0
H_b = js.base_transform
w_b = js.base_velocity

print(f"Contact model in use: {js._model.contact_model}")
print(f"Link names:\n{js.link_names}")
print(f"Frame names:\n{js.frame_names}")

In [None]:
# ==== Define the controller parameters  and instantiate the controller ====

# Controller Parameters
tsid_parameter = TSIDParameterTuning()
tsid_parameter.foot_tracking_task_kp_lin = 150.0
tsid_parameter.foot_tracking_task_kd_lin = 40.0
tsid_parameter.root_tracking_task_weight = np.ones(3) * 50.0

mpc_parameters = MPCParameterTuning()

# TSID Instance
tsid = TSIDController(frequency=0.01, robot_model=robot_model_init)
tsid.define_tasks(tsid_parameter)
tsid.set_state_with_base(s_js, ds_js, H_b, w_b, t)

# MPC Instance
step_length = 0.1
mpc = CentroidalMPC(robot_model=robot_model_init, step_length=step_length)
mpc.intialize_mpc(mpc_parameters=mpc_parameters, scale=0.5)

# Set desired quantities
mpc.configure(s_init=s_0, H_b_init=H_b_0)
tsid.compute_com_position()
mpc.define_test_com_traj(tsid.COM.toNumPy())

In [7]:
# Set initial robot state  and plan trajectories

tic = time.perf_counter()

js.step(dry_run=True)

step_compilation_time_s = time.perf_counter() - tic

In [8]:
# Reading the state
s_js, ds_js, tau_js = js.get_state()
H_b = js.base_transform
w_b = js.base_velocity

# MPC
mpc.set_state_with_base(s=s_js, s_dot=ds_js, H_b=H_b, w_b=w_b, t=t)
mpc.initialize_centroidal_integrator(s=s_js, s_dot=ds_js, H_b=H_b, w_b=w_b, t=t)
mpc_output = mpc.plan_trajectory()

In [9]:
# ==== Define the simulation loop ====


def simulate(
    T: float,
    js: JaxsimSimulator,
    tsid: TSIDController,
    mpc: CentroidalMPC,
    s_ref: list[float],
) -> dict[str, np.array]:
    # Logging
    s_js_log = []
    ds_js_log = []
    W_p_CoM_js_log = []
    W_p_lf_js_log = []
    W_p_rf_js_log = []
    W_p_CoM_mpc_log = []
    W_p_lf_sfp_log = []
    W_p_rf_sfp_log = []
    f_lf_mpc_log = []
    f_rf_mpc_log = []
    f_lf_js_log = []
    f_rf_js_log = []
    tau_tsid_log = []
    W_p_CoM_tsid_log = []
    t_log = []
    wall_time_step_log = []

    # Define number of steps
    n_step_tsid_js = int(tsid.frequency / js_dt)
    n_step_mpc_tsid = int(mpc.get_frequency_seconds() / tsid.frequency)
    print(f"{n_step_mpc_tsid=}, {n_step_tsid_js=}")
    counter = 0
    mpc_success = True
    succeded_controller = True
    contact_model_type = js.contact_model_type

    t = 0.0

    while t < T:
        try:
            print(f"==== Time: {t:.4f}s ====", flush=True, end="\r")

            # Reading robot state from simulator
            s_js, ds_js, tau_js = js.get_state()
            H_b = js.base_transform
            w_b = js.base_velocity
            t = js.simulation_time

            # Update TSID
            tsid.set_state_with_base(s=s_js, s_dot=ds_js, H_b=H_b, w_b=w_b, t=t)

            # MPC plan
            if counter == 0:
                mpc.set_state_with_base(s=s_js, s_dot=ds_js, H_b=H_b, w_b=w_b, t=t)
                mpc.update_references()
                mpc_success = mpc.plan_trajectory()
                mpc.contact_planner.advance_swing_foot_planner()
                if not (mpc_success):
                    print("MPC failed")
                    break

            # Reading new references
            com_mpc, dcom_mpc, f_lf_mpc, f_rf_mpc, ang_mom_mpc = mpc.get_references()
            lf_sfp, rf_sfp = mpc.contact_planner.get_references_swing_foot_planner()

            tsid.compute_com_position()

            # Update references TSID
            tsid.update_task_references_mpc(
                com=com_mpc,
                dcom=dcom_mpc,
                ddcom=np.zeros(3),
                left_foot_desired=lf_sfp,
                right_foot_desired=rf_sfp,
                s_desired=np.array(s_ref),
                wrenches_left=np.hstack([f_lf_mpc, np.zeros(3)]),
                wrenches_right=np.hstack([f_rf_mpc, np.zeros(3)]),
            )

            # Run control
            succeded_controller = tsid.run()

            if not (succeded_controller):
                print("Controller failed")
                break

            tau_tsid = tsid.get_torque()

            # Step the simulator
            js.set_input(tau_tsid)

            tic = time.perf_counter()
            js.step(n_step=n_step_tsid_js)
            toc = time.perf_counter() - tic

            counter = counter + 1

            if counter == n_step_mpc_tsid:
                counter = 0

            # Stop the simulation if the robot fell down
            if js._data.base_position[2] < 0.5:
                print(f"Robot fell down at t={t:.4f}s.")
                break

            # Log data
            # TODO transform mpc contact forces to wrenches to be compared with jaxsim ones
            t_log.append(t)
            tau_tsid_log.append(tau_tsid)
            s_js_log.append(s_js)
            ds_js_log.append(ds_js)
            W_p_CoM_js_log.append(js.com_position)
            W_p_lf_js, W_p_rf_js = js.feet_positions
            W_p_lf_js_log.append(W_p_lf_js)
            W_p_rf_js_log.append(W_p_rf_js)
            W_p_CoM_mpc_log.append(com_mpc)
            f_lf_mpc_log.append(f_lf_mpc)
            f_rf_mpc_log.append(f_rf_mpc)
            W_p_lf_sfp_log.append(lf_sfp.transform.translation())
            W_p_rf_sfp_log.append(rf_sfp.transform.translation())
            W_p_CoM_tsid_log.append(tsid.COM.toNumPy())
            wall_time_step_log.append(toc * 1e3)
            if contact_model_type != JaxsimContactModelEnum.VISCO_ELASTIC:
                f_lf_js, f_rf_js = js.feet_wrench
                f_lf_js_log.append(f_lf_js)
                f_rf_js_log.append(f_rf_js)

        except Exception as e:
            print(f"Exception during simulation at time{t}: {e}")
            traceback.print_exc()
            break

    logs = {
        "t": np.array(t_log),
        "s_js": np.array(s_js_log),
        "ds_js": np.array(ds_js_log),
        "tau_tsid": np.array(tau_tsid_log),
        "W_p_CoM_js": np.array(W_p_CoM_js_log),
        "W_p_lf_js": np.array(W_p_lf_js_log),
        "W_p_rf_js": np.array(W_p_rf_js_log),
        "W_p_CoM_mpc": np.array(W_p_CoM_mpc_log),
        "f_lf_mpc": np.array(f_lf_mpc_log),
        "f_rf_mpc": np.array(f_rf_mpc_log),
        "W_p_lf_sfp": np.array(W_p_lf_sfp_log),
        "W_p_rf_sfp": np.array(W_p_rf_sfp_log),
        "W_p_CoM_tsid": np.array(W_p_CoM_tsid_log),
        "wall_time_step": np.array(wall_time_step_log),
    }
    if contact_model_type != JaxsimContactModelEnum.VISCO_ELASTIC:
        logs["f_lf_js"] = np.array(f_lf_js_log)
        logs["f_rf_js"] = np.array(f_rf_js_log)

    return logs

In [None]:
# ==== Run the simulation ====


now = time.perf_counter()

logs = simulate(T=T, js=js, tsid=tsid, mpc=mpc, s_ref=s_0)

wall_time = time.perf_counter() - now
avg_iter_time_ms = (wall_time / (T / js_dt)) * 1000

print(
    f"\nSimulation done.\nRunning simulation took {wall_time:.2f}s for {T:.3f}s simulated time."
)

In [None]:
# Extract logged variables
t = logs["t"]
s_js = logs["s_js"]
ds_js = logs["ds_js"]
tau_tsid = logs["tau_tsid"]
W_p_CoM_js = logs["W_p_CoM_js"]
W_p_lf_js = logs["W_p_lf_js"]
W_p_rf_js = logs["W_p_rf_js"]
W_p_CoM_mpc = logs["W_p_CoM_mpc"]
f_lf_mpc = logs["f_lf_mpc"]
f_rf_mpc = logs["f_rf_mpc"]
W_p_lf_sfp = logs["W_p_lf_sfp"]
W_p_rf_sfp = logs["W_p_rf_sfp"]
W_p_CoM_tsid = logs["W_p_CoM_tsid"]
wall_time_step = logs["wall_time_step"]
f_lf_js = logs["f_lf_js"]
f_rf_js = logs["f_rf_js"]

print(t.shape)

In [None]:
# Compute simulator step runtime statistics and RTF

min_step_time = np.min(wall_time_step)
max_step_time = np.max(wall_time_step)
avg_step_time = np.mean(wall_time_step)
std_step_time = np.std(wall_time_step)
total_step_time = np.sum(wall_time_step)
rtf = (T * 1e3) / total_step_time * 100

print("===========================================")
print(f"Step compilation time: {step_compilation_time_s:.2f} s")
print(f"Min step time: {min_step_time:.2f} ms")
print(f"Max step time: {max_step_time:.2f} ms")
print(f"Average step time: {avg_step_time:.2f} ms")
print(f"Std deviation step time: {std_step_time:.2f} ms")
print(f"RTF: {rtf:.1f}%")
print("===========================================")

In [None]:
# ==== Plot the results ====


n_sim_steps = s_js.shape[0]
s_0_plot = np.full_like(a=s_js, fill_value=s_0)

# Joint tracking
fig, axs = plt.subplots(
    nrows=int(np.ceil(len(joint_names) / 2)), ncols=2, sharex=True, figsize=(12, 16)
)
for idx, name in enumerate(joint_names):
    ax = axs[idx // 2, idx % 2]
    ax.title.set_text(name)
    ax.plot(t, s_js[:, idx] * 180 / np.pi, label="Simulated")
    ax.plot(
        t,
        s_0_plot[:, idx] * 180 / np.pi,
        linestyle="--",
        label="Reference",
    )
    ax.grid()
    ax.set_ylabel("[deg]")
    ax.legend()
plt.suptitle("Joint tracking")
plt.show()

# Joint tracking error
fig, axs = plt.subplots(
    nrows=int(np.ceil(len(joint_names) / 2)), ncols=2, sharex=True, figsize=(12, 16)
)
for idx, name in enumerate(joint_names):
    ax = axs[idx // 2, idx % 2]
    ax.title.set_text(name)
    ax.plot(t, (s_js[:, idx] - s_0_plot[:, idx]) * 180 / np.pi)
    ax.grid()
    ax.set_ylabel("[deg]")
plt.suptitle("Joint tracking error (reference - simulated)")
plt.tight_layout()
plt.show()

# Feet height
fig, axs = plt.subplots(nrows=1, ncols=2, sharey=True)
ax = axs[0]
ax.title.set_text("Left foot sole height")
ax.plot(t, W_p_lf_js[:, 2], label="Simulated")
ax.plot(t, W_p_lf_sfp[:, 2], label="Swing Foot Planner reference")
ax.legend()
ax.grid()
ax.set_ylabel("Height [m]")
ax = axs[1]
ax.title.set_text("Right foot sole height")
ax.plot(t, W_p_rf_js[:, 2], label="Simulated")
ax.plot(t, W_p_rf_sfp[:, 2], label="Swing Foot Planner reference")
ax.legend()
ax.grid()
plt.show()

# COM tracking
fig = plt.figure()
ax1, ax2, ax3 = fig.subplots(nrows=3, ncols=1, sharex=True)
ax1.title.set_text("Center of mass: x component")
ax1.plot(t, W_p_CoM_js[:, 0], label="Simulated")
ax1.plot(t, W_p_CoM_mpc[:, 0], linestyle="--", label="MPC References")
ax2.title.set_text("Center of mass: y component")
ax2.plot(t, W_p_CoM_js[:, 1], label="Simulated")
ax2.plot(t, W_p_CoM_mpc[:, 1], linestyle="--", label="MPC References")
ax3.title.set_text("Center of mass: z component")
ax3.plot(t, W_p_CoM_js[:, 2], label="Simulated")
ax3.plot(t, W_p_CoM_mpc[:, 2], linestyle="--", label="MPC References")
ax1.legend()
ax2.legend()
ax3.legend()
ax1.grid()
ax2.grid()
ax3.grid()
plt.xlabel("Time [s]")
plt.show(block=False)

# Torques
fig, axs = plt.subplots(
    nrows=int(np.ceil(len(joint_names) / 2)), ncols=2, sharex=True, figsize=(12, 12)
)
for idx, name in enumerate(joint_names):
    ax = axs[idx // 2, idx % 2]
    ax.title.set_text(name)
    ax.plot(t, tau_tsid[:, idx], label="TSID References")
    ax.legend()
    ax.grid()
    ax.set_ylabel("[Nm]")
plt.suptitle("Joint torques")
plt.tight_layout()
plt.show()

# Contact forces
fig, axs = plt.subplots(nrows=3, ncols=2, figsize=(12, 12))
titles = [
    "Left foot force - x component",
    "Right foot force - x component",
    "Left foot force - y component",
    "Right foot force - y component",
    "Left foot force - z component",
    "Right foot force - z component",
]
for row in range(3):
    for col in range(2):
        idx = row * 2 + col
        # force_js = f_lf_js[:, col] if col == 0 else f_rf_js[:, col]
        force_mpc = f_lf_mpc[:, col] if col == 0 else f_rf_mpc[:, col]
        # axs[row, col].plot(t, force_js, label="Simulated")
        axs[row, col].plot(t, force_mpc, linestyle="--", label="MPC References")
        axs[row, col].set_title(titles[idx])
        axs[row, col].set_ylabel("Force [N]")
        axs[row, col].set_xlabel("Time [s]")
        axs[row, col].grid()
        axs[row, col].legend()

plt.suptitle("Feet contact forces")
plt.tight_layout()
plt.show()

In [None]:
# ==== Generate video ====
# Create results folder if not existing
def get_repo_root(
    current_path: Path = Path(os.path.abspath("jaxsim_walking.ipynb")).parent,
) -> Path:
    current_path = current_path.resolve()

    for parent in current_path.parents:
        if (parent / ".git").exists():
            return parent

    raise RuntimeError("No .git directory found, not a Git repository.")


def create_output_dir(directory: Path):
    # Create the directory if it doesn't exist
    directory.mkdir(parents=True, exist_ok=True)


# Usage
repo_root = get_repo_root()

# Define the results directory
results_dir = repo_root / "examples" / "results"

# Create the results directory if it doesn't exist
create_output_dir(results_dir)
now = datetime.datetime.now()
current_time = now.strftime("%Y-%m-%d_%H-%M-%S")
filepath = results_dir / pathlib.Path(current_time + "simulation_comodo.mp4")
js._save_video(filepath)
print(f"Video saved at: {filepath}")

## Perform batch of simulations on GPU

In [None]:
import jax

BATCH_SIZE = 2
DOFS = 20
# Number of steps of simulation done for each TSID control reference
N_SIM_STEPS_JS_TSID = 5
# Total number of simulation steps
N_SIM_TIME_STEPS = t.shape[0]

torques = jnp.array(logs["tau_tsid"])
batched_torques = jnp.repeat(torques[None, ...], BATCH_SIZE, axis=0).swapaxes(0, 1)

print(f"Running on {jax.devices()}")

In [None]:
np.all(batched_torques[:, 0] == torques)

In [None]:
s_0, xyz_rpy_0, H_b_0 = robot_model_init.compute_desired_position_walking()

In [None]:
import jaxsim.api as js
import jaxlie
import jaxsim
import os

os.environ["JAXSIM_COLLISION_USE_BOTTOM_ONLY"] = "1"

model = js.model.JaxSimModel.build_from_model_description(
    model_description=robot_model_init.urdf_string,
    time_step=js_dt,
    considered_joints=joint_names,
    contact_params=jaxsim.rbda.contacts.RelaxedRigidContactsParams.build(mu=0.001),
)

get_joint_map = lambda from_, to: np.array(list(map(from_.index, to)))
_to_js = get_joint_map(from_=robot_model_init.joint_name_list, to=model.joint_names())
_to_user = get_joint_map(from_=model.joint_names(), to=robot_model_init.joint_name_list)

data = js.data.JaxSimModelData.build(
    model=model, base_position=xyz_rpy_0[:3], joint_positions=s_0[_to_js]
)

batched_data = jax.vmap(
    lambda xyz_rpy, s: js.data.JaxSimModelData.build(
        model=model,
        base_position=xyz_rpy[0:3] + 0.5,
        joint_positions=s,
        base_quaternion=jaxlie.SO3.from_rpy_radians(*xyz_rpy[3:6]).wxyz,
    ),
    in_axes=(0, 0),
)(
    jnp.repeat(xyz_rpy_0[None, ...], BATCH_SIZE, axis=0),
    jnp.repeat(s_0[None, ...], BATCH_SIZE, axis=0)[:, _to_js],
)

mapped_batched_torques = batched_torques[:, :, _to_js]

# jit_step = jax.jit(jax.vmap(js.model.step, in_axes=(None,)))


@jax.vmap
def step(data_0, torques):
    def step_js_loop(data, _) -> tuple[js.data.JaxSimModelData, None]:
        new_data = js.model.step(data=data, model=model, joint_force_references=torques)

        return new_data, None

    data_tf, _ = jax.lax.scan(
        step_js_loop,
        data_0,
        length=N_SIM_STEPS_JS_TSID,
    )
    return data_tf


jit_step = jax.jit(step)

In [None]:
xyz_rpy_0

In [19]:
# Compile the simulation loop
_ = jit_step(batched_data, mapped_batched_torques[0, :])

In [None]:
from jaxsim.mujoco import MujocoVideoRecorder, ModelToMjcf, MujocoModelHelper

mjcf_string, assets = ModelToMjcf.convert(model.built_from)

_mj_model_helper = MujocoModelHelper.build_from_xml(
    mjcf_description=mjcf_string, assets=assets
)

_recorder = MujocoVideoRecorder(
    model=_mj_model_helper.model,
    data=_mj_model_helper.data,
    fps=30,
    width=320 * 4,
    height=240 * 4,
)


_mj_model_helper.set_base_position(
    position=np.array(batched_data.base_position[0]),
)
_mj_model_helper.set_base_orientation(
    orientation=np.array(batched_data.base_quaternion[0]),
)
_mj_model_helper.set_joint_positions(
    positions=np.array(batched_data.joint_positions[0]),
    joint_names=model.joint_names(),
)


# import mediapy as media

# media.show_image(_recorder.render_frame(), width=640, height=480)

In [None]:
# Perform the batched simulation on GPU

import time

# batched_torques = batched_torques.reshape(-1, BATCH_SIZE, DOFS)
# assert batched_torques.shape[0] == N_TIME_STEPS

batched_base_positions = np.zeros((N_SIM_TIME_STEPS, BATCH_SIZE, 3))
batched_base_quaternions = np.zeros((N_SIM_TIME_STEPS, BATCH_SIZE, 4))
batched_joint_positions = np.zeros((N_SIM_TIME_STEPS, BATCH_SIZE, DOFS))

now = time.perf_counter()
# start = 0
for idx, tau in enumerate(jnp.zeros_like(mapped_batched_torques)):
    print(f"{idx + 1} / {mapped_batched_torques.shape[0]}", end="\r")
    batched_data = jit_step(
        batched_data,
        tau,
    )
    batched_base_positions[idx] = batched_data.base_position
    batched_base_quaternions[idx] = batched_data.base_quaternion
    batched_joint_positions[idx] = batched_data.joint_positions
    # end = start + N_SIM_STEPS_JS_TSID
    # batched_base_positions[start:end] = jnp.swapaxes(
    #     batched_data_hist.base_position, 0, 1
    # )
    # batched_base_quaternions[start:end] = jnp.swapaxes(
    #     batched_data_hist.base_quaternion, 0, 1
    # )
    # batched_joint_positions[start:end] = jnp.swapaxes(
    #     batched_data_hist.joint_positions, 0, 1
    # )
    # start = end

wall_time = time.perf_counter() - now

# def step(data_0, model, n_step, torques)

In [None]:
batched_base_positions[0, 0, 1]

In [None]:
rtf = T / wall_time * BATCH_SIZE * 100
print(f"RTF: {rtf:.1f}%")

In [None]:
import pickle

# Save batched_data to a pickle file
batched_data_log = {
    "base_positions": batched_base_positions,
    "base_quaternions": batched_base_quaternions,
    "joint_positions": batched_joint_positions,
}
with open("batched_data.pkl", "wb") as f:
    pickle.dump(batched_data_log, f)

print("batched_data has been saved to batched_data.pkl")

In [None]:
batched_joint_positions.shape

In [None]:
batched_base_quaternions.shape

In [27]:
del _recorder

In [None]:
# Check the results from one of the batched simulations

_recorder = MujocoVideoRecorder(
    model=_mj_model_helper.model,
    data=_mj_model_helper.data,
    fps=30,
    width=320 * 4,
    height=240 * 4,
)
current_time = datetime.datetime.now().strftime("%Y-%m-%d_%H-%M-%S")

try:
    for base_pos, base_quat, joint_pos in zip(
        batched_base_positions[:200, 0],
        batched_base_quaternions[:200, 0],
        batched_joint_positions[:200, 0],
    ):
        _mj_model_helper.set_base_position(position=base_pos)
        # print(base_quat, jnp.linalg.norm(base_quat))
        _mj_model_helper.set_base_orientation(orientation=base_quat)
        _mj_model_helper.set_joint_positions(
            positions=joint_pos, joint_names=model.joint_names()
        )

        _recorder.record_frame()
finally:
    _recorder.write_video(results_dir / pathlib.Path(current_time + "batched.mp4"))

In [None]:
batched_base_positions[:, 0, 2]