# `JAXsim` Showcase: Parallel Simulation of a free-falling body

First, we install the necessary packages and import them.

In [None]:
# @title Imports and setup
import sys
import os
import pathlib

# Deactivate GPU to avoid out of memory errors
os.environ["CUDA_VISIBLE_DEVICES"] = ""

from IPython.display import HTML, clear_output, display

IS_COLAB = "google.colab" in sys.modules

# Install JAX and Gazebo
if IS_COLAB:
    !{sys.executable} -m pip install -qU jaxsim
    !apt install -qq lsb-release wget gnupg
    !wget https://packages.osrfoundation.org/gazebo.gpg -O /usr/share/keyrings/pkgs-osrf-archive-keyring.gpg
    !echo "deb [arch=$(dpkg --print-architecture) signed-by=/usr/share/keyrings/pkgs-osrf-archive-keyring.gpg] http://packages.osrfoundation.org/gazebo/ubuntu-stable $(lsb_release -cs) main" | sudo tee /etc/apt/sources.list.d/gazebo-stable.list > /dev/null
    !apt -qq update
    !apt install -qq --no-install-recommends libsdformat13 gz-tools2

# Set environment variable to avoid GPU out of memory errors
%env XLA_PYTHON_CLIENT_MEM_PREALLOCATE=false

import time
from typing import Dict, Tuple

import jax
import jax.numpy as jnp
import jax_dataclasses
import rod
from rod.builder.primitives import SphereBuilder, BoxBuilder

import jaxsim.typing as jtp
from jaxsim import logging
from jaxsim.api.common import VelRepr

from jaxsim.mujoco import (
    MujocoVideoRecorder,
    MujocoModelHelper,
    RodModelToMjcf,
    SdfToMjcf,
    UrdfToMjcf,
)

logging.set_logging_level(logging.LoggingLevel.INFO)
logging.info(f"Running on {jax.devices()}")

We will use a simple sphere model to simulate a free-falling body. The spheres set will be composed of 9 spheres, each with a different position. The spheres will be simulated in parallel, and the simulation will be run for 3000 steps corresponding to 3 seconds of simulation.

**Note**: Parallel simulations are independent of each other, the different position is imposed only to show the parallelization visually.

In [None]:
# @title Create a sphere model
model_sdf_string = rod.Sdf(
    version="1.7",
    # model=BoxBuilder(x=0.30, y=0.30, z=0.30, mass=1.0, name="box")
    model=SphereBuilder(radius=0.15, mass=1.0, name="sphere")
    .build_model()
    .add_link()
    .add_inertial()
    .add_visual()
    .add_collision()
    .build(),
).serialize(pretty=True)
# import urllib

# url = "https://raw.githubusercontent.com/icub-tech-iit/ergocub-gazebo-simulations/master/models/stickBot/model.urdf"

# model_sdf_string = urllib.request.urlopen(url).read().decode()
# # model_sdf_string = pathlib.Path("/home/flferretti/git/element_rl-for-codesign/assets/model/hopper.sdf")

JAXsim offers a simple functional API in order to interact in a memory-efficient way with the simulation. Four main elements are used to define a simulation:

- `model`: an object that defines the dynamics of the system.
- `data`: an object that contains the state of the system.
- `integrator`: an object that defines the integration method.
- `integrator_state`: an object that contains the state of the integrator.

In [None]:
import jaxsim.api as js
from jaxsim import integrators
import jaxsim

dt = 0.001
integration_time = 1500

model = js.model.JaxSimModel.build_from_model_description(
    model_description=model_sdf_string,
    contact_model=js.rigid_contacts.RigidContacts(),
    is_urdf=True,
)

model = js.model.reduce(
    model=model,
    considered_joints=tuple(
        [
            j
            for j in model.joint_names()
            if "camera" not in j
            and "neck" not in j
            and "wrist" not in j
            and "thumb" not in j
            and "index" not in j
            and "middle" not in j
            and "ring" not in j
            and "pinkie" not in j
            and "elbow" not in j
            and "shoulder" not in j
            and "hip" not in j
            and "knee" not in j
            and "lidar" not in j
            and "torso" not in j
        ]
    ),
)
model = js.model.reduce(model=model, considered_joints=tuple())

data = js.data.JaxSimModelData.build(
    model=model, velocity_representation=VelRepr.Inertial
)
integrator = integrators.fixed_step.RungeKutta4SO3.build(
    dynamics=js.ode.wrap_system_dynamics_for_integration(
        model=model,
        data=data,
        system_dynamics=js.ode.system_dynamics,
    ),
)
# with jax.disable_jit():
integrator_state = integrator.init(x0=data.state, t0=0.0, dt=dt)

In [None]:
mcjf_string, assets = UrdfToMjcf.convert(urdf=model_sdf_string)
mj_helper = MujocoModelHelper.build_from_xml(
    mjcf_description=mcjf_string, assets=assets
)
recorder = MujocoVideoRecorder(
    model=mj_helper.model, data=mj_helper.data, fps=int(1 / dt), width=640, height=480
)

It is possible to automatically choose a good set of parameters for the terrain. 

By default, in JaxSim a sphere primitive has 250 collision points. This can be modified by setting the `JAXSIM_COLLISION_SPHERE_POINTS` environment variable.

Given that at its steady-state the sphere will act on two or three points, we can estimate the ground parameters by explicitly setting the number of active points to these values.

In [None]:
# data = data.replace(
#     soft_contacts_params=js.contact.estimate_good_soft_contacts_parameters(
#         model, number_of_active_collidable_points_steady_state=3
#     )
# )

Let's create a position vector for a 3x3 grid. Every sphere will be placed at a different height.

In [None]:
# Primary Calculations
envs_per_row = 1  # @slider(2, 10, 1)
initial_height = 0.7
env_spacing = 0.5
edge_len = env_spacing * (2 * envs_per_row - 1)


# Create Grid
def grid(edge_len, envs_per_row):
    edge = jnp.linspace(-edge_len, edge_len, envs_per_row)
    xx, yy = jnp.meshgrid(edge, edge)

    poses = [
        [[xx[i, j], yy[i, j], initial_height], [0, 0, 0]]
        for i in range(xx.shape[0])
        for j in range(yy.shape[0])
    ]

    return jnp.array(poses)


logging.info(f"Simulating {envs_per_row**2} environments")
poses = grid(edge_len, envs_per_row)

In order to parallelize the simulation, we first need to define a function `simulate` for a single element of the batch.

In [None]:
# Define a function to simulate a single model instance
def simulate(
    data: js.data.JaxSimModelData, integrator_state: dict, pose: jnp.array
) -> tuple:

    data = data.reset_base_position(base_position=pose)
    x_t_i = []
    forces = []

    S = jnp.block([jnp.zeros(shape=(model.dofs(), 6)), jnp.eye(model.dofs())]).T
    τ = jnp.zeros(model.dofs())

    # l_foot = model.link_names().index("l_ankle_2")
    # r_foot = model.link_names().index("r_ankle_2")

    for _ in range(integration_time):
        F = []

        h = js.model.free_floating_bias_forces(model=model, data=data)

        M = js.model.free_floating_mass_matrix(model=model, data=data)

        J̇ν = js.model.link_bias_accelerations(model=model, data=data)

        M_inv = jnp.linalg.inv(M)

        # idxs = (0,) # (l_foot, r_foot)
        # O_JL = jax.vmap(
        #     lambda body: js.link.jacobian(
        #         model=model,
        #         data=data,
        #         link_index=body,
        #         # output_vel_repr=VelRepr.Inertial,
        #     )
        #     )(jnp.array(idxs))
        O_JL = js.link.jacobian(
            model=model,
            data=data,
            link_index=0,
            output_vel_repr=VelRepr.Mixed,
        )

        # O_JL = O_JL.reshape(6 * len(idxs), 10)

        # W_H_L = js.link.transform(model=model, data=data, link_index=body)
        # W_X_L = jaxsim.math.Adjoint.from_transform(W_H_L).T
        # F = -jnp.linalg.inv(O_JL @ M_inv @ O_JL.T) @ (
        #         J̇ν[l_foot:r_foot+1].ravel() + O_JL @ M_inv @ (S @ τ - h)
        #     )
        F = -jnp.linalg.inv(O_JL.squeeze() @ M_inv @ O_JL.squeeze().T) @ (
            J̇ν[0] + O_JL.squeeze() @ M_inv @ (S @ τ - h)
        )

        # F = F.reshape(-1, 6)

        # link_forces = jnp.zeros((model.number_of_links(), 6)).at[l_foot:r_foot+1].set(jnp.array(F))
        link_forces = jnp.zeros((model.number_of_links(), 6)).at[0].set(jnp.array(F))

        data, integrator_state = js.model.step(
            dt=dt,
            model=model,
            data=data,
            integrator=integrator,
            integrator_state=integrator_state,
            joint_forces=None,
            link_forces=link_forces,
        )

        x_t_i.append(data.base_position())
        forces.append(F)

    return x_t_i, forces

We will make use of `jax.vmap` to simulate multiple models in parallel. This is a very powerful feature of JAX that allows to write code that is very similar to the single-model case, but can be executed in parallel on multiple models.
In order to do so, we need to first apply `jax.vmap` to the `simulate` function, and then call the resulting function with the batch of different poses as input.

Note that in our case we are vectorizing over the `pose` argument of the function `simulate`, this correspond to the value assigned to the `in_axes` parameter of `jax.vmap`:

`in_axes=(None, None, 0)` means that the first two arguments of `simulate` are not vectorized, while the third argument is vectorized over the zero-th dimension.

In [None]:
# Define a function to simulate multiple model instances
# simulate_vectorized = jax.vmap(simulate, in_axes=(None, None, 0))

# Run and time the simulation
now = time.perf_counter()

# x_t = simulate_vectorized(data, integrator_state, poses[:, 0]).
x_t, forces = simulate(data, integrator_state, poses[:, 0])

comp_time = time.perf_counter() - now

logging.info(
    f"Running simulation with {envs_per_row**2} models took {comp_time} seconds."
)
logging.info(
    f"This corresponds to an RTF (Real Time Factor) of {(envs_per_row**2 *integration_time/comp_time):.2f}"
)

In [None]:
from pathlib import Path

for pose in x_t:
    mj_helper.set_base_position(pose)
    recorder.record_frame()

import datetime

import mediapy as media

media.show_video(recorder.frames, fps=1 / dt)

recorder.write_video(path=Path.cwd() / Path(f"video_{datetime.datetime.now()}.mp4"))

Now let's extract the data from the simulation and plot it. We expect to see the height time series of each sphere starting from a different value.

In [None]:
import matplotlib.pyplot as plt
import numpy as np

plt.plot(np.arange(len(x_t[:])) * dt, np.array(x_t)[:, 2])
plt.grid(True)
plt.xlabel("Time [s]")
plt.ylabel("Height [m]")
plt.title("Trajectory of the model's base")
plt.show()

In [None]:
forces = np.array([force for force in forces])

In [None]:
forces.shape

In [None]:
import matplotlib.pyplot as plt
import numpy as np

plt.plot(
    np.arange(len(forces[:600])) * dt,
    forces[:600],
    label=["X", "Y", "Z", "Rx", "Ry", "Rz"],
)
plt.grid(True)
plt.xlabel("Time [s]")
plt.ylabel("Force [N]")
plt.title("Contact forces")
plt.legend()
plt.show()