# `JAXsim` Showcase: Parallel Simulation of a free-falling body

<a target="_blank" href="https://colab.research.google.com/github/flferretti/jaxsim/blob/example/parallel/examples/Parallel_computing.ipynb">
  <img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/>
</a>

First, we install the necessary packages and import them.

In [None]:
# @title Imports and setup
from IPython.display import clear_output, HTML, display
import sys

IS_COLAB = "google.colab" in sys.modules

# Install JAX and Gazebo
if IS_COLAB:
    !{sys.executable} -m pip install -U -q jaxsim
    !apt -qq update && apt install -qq --no-install-recommends gazebo
    clear_output()

import jax
import jax.numpy as jnp
import rod
from rod.builder.primitives import SphereBuilder
from jaxsim import logging

logging.set_logging_level(logging.LoggingLevel.INFO)
logging.info(f"Running on {jax.devices()}")

We will use a simple cartpole model for this example. The cartpole model is a 2D model with a cart that can move horizontally and a pole that can rotate around the cart. The state of the cartpole is given by the position of the cart, the angle of the pole, the velocity of the cart, and the angular velocity of the pole. The control input is the horizontal force applied to the cart.

In [None]:
# @title Create a sphere model
model_sdf_string = rod.Sdf(
    version="1.7",
    model=SphereBuilder(radius=0.10, mass=1.0, name="sphere")
    .build_model()
    .add_link()
    .add_inertial()
    .add_visual()
    .add_collision()
    .build(),
).serialize(pretty=True)

JAXsim offers a simple high-level API in order to extract quantities needed in most robotic applications. 

In [None]:
from jaxsim.high_level.model import Model

model = Model.build_from_model_description(
    model_description=model_sdf_string, is_urdf=True
)

Now, we can create a simulator instance and load the model into it.

In [None]:
from jaxsim.simulation.ode_integration import IntegratorType
from jaxsim.simulation.simulator import JaxSim, SimulatorData, StepData
from jaxsim.high_level.model import VelRepr
from jaxsim.physics.algos.soft_contacts import SoftContactsParams

# Simulation Step Parameters
integration_time = 3.0  # seconds
step_size = 0.001
steps_per_run = 1

simulator = JaxSim.build(
    step_size=step_size,
    steps_per_run=steps_per_run,
    velocity_representation=VelRepr.Body,
    integrator_type=IntegratorType.EulerSemiImplicit,
    simulator_data=SimulatorData(
        contact_parameters=SoftContactsParams(K=1e6, D=2e3, mu=0.5),
    ),
).mutable(validate=False)


# Add model to simulator

model = simulator.insert_model_from_description(model_description=model_sdf_string).mutable(validate=True)

Let's create a position vector for a 8x8 grid of sphere positions.

In [None]:
# Primary Calculations
radius = 0.1
envs_per_row = 8
num_envs = envs_per_row**2
edge_len = env_spacing * envs_per_row + env_spacing * (envs_per_row - 1)


# Create Grid
def grid(num_envs, edge_len, envs_per_row):
    poses = []
    x = 0
    y = 0

    for env in range(num_envs):
        x = jnp.linspace(-edge_len, edge_len, envs_per_row)
        y = jnp.linspace(-edge_len, edge_len, envs_per_row)
        xx, yy = jnp.meshgrid(x, y)

        poses = [
            [[xx[i, j], yy[i, j], 1], [0, 0, 0]]
            for i in range(xx.shape[0])
            for j in range(yy.shape[0])
        ]

    return jnp.array(poses)


poses = grid(num_envs, edge_len, envs_per_row)
model.reset_joint_positions(positions=random_positions)

The visualization is done using mujoco package, to be able to render easily the animations also on Google Colab. If you are not interested in the animation, execute but do not try to understand deeply this cell.

In [None]:
# @title Set up MuJoCo renderer
!{sys.executable} -m pip install -U -q mujoco
!{sys.executable} -m pip install -q mediapy

import mediapy as media
import tempfile
import xml.etree.ElementTree as ET
import numpy as np

import distutils.util
import os
import subprocess

if IS_COLAB:
    if subprocess.run("ffmpeg -version", shell=True).returncode:
        !command -v ffmpeg >/dev/null || (apt update && apt install -y ffmpeg)
        clear_output()

    if subprocess.run("nvidia-smi").returncode:
        raise RuntimeError(
            "Cannot communicate with GPU. "
            "Make sure you are using a GPU Colab runtime. "
            "Go to the Runtime menu and select Choose runtime type."
        )

    # Add an ICD config so that glvnd can pick up the Nvidia EGL driver.
    # This is usually installed as part of an Nvidia driver package, but the Colab
    # kernel doesn't install its driver via APT, and as a result the ICD is missing.
    # (https://github.com/NVIDIA/libglvnd/blob/master/src/EGL/icd_enumeration.md)
    NVIDIA_ICD_CONFIG_PATH = "/usr/share/glvnd/egl_vendor.d/10_nvidia.json"
    if not os.path.exists(NVIDIA_ICD_CONFIG_PATH):
        with open(NVIDIA_ICD_CONFIG_PATH, "w") as f:
            f.write(
                """{
      "file_format_version" : "1.0.0",
      "ICD" : {
         "library_path" : "libEGL_nvidia.so.0"
     }
  }
  """
            )

%env MUJOCO_GL=egl

try:
    import mujoco
except Exception as e:
    raise e from RuntimeError(
        "Something went wrong during installation. Check the shell output above "
        "for more information.\n"
        "If using a hosted Colab runtime, make sure you enable GPU acceleration "
        'by going to the Runtime menu and selecting "Choose runtime type".'
    )


def load_mujoco_model_with_camera(xml_string, camera_pos, camera_xyaxes):
    def to_mjcf_string(list_to_str):
        return " ".join(map(str, list_to_str))

    mj_model_raw = mujoco.MjModel.from_xml_string(model_urdf_string)
    path_temp_xml = tempfile.NamedTemporaryFile(mode="w+")
    mujoco.mj_saveLastXML(path_temp_xml.name, mj_model_raw)
    # Add camera in mujoco model
    tree = ET.parse(path_temp_xml)
    for elem in tree.getroot().iter("worldbody"):
        worldbody_elem = elem
    camera_elem = ET.Element("camera")
    # Set attributes
    camera_elem.set("name", "side")
    camera_elem.set("pos", to_mjcf_string(camera_pos))
    camera_elem.set("xyaxes", to_mjcf_string(camera_xyaxes))
    camera_elem.set("mode", "fixed")
    worldbody_elem.append(camera_elem)

    # Save new model
    mujoco_xml_with_camera = ET.tostring(tree.getroot(), encoding="unicode")
    mj_model = mujoco.MjModel.from_xml_string(mujoco_xml_with_camera)
    return mj_model


def from_jaxsim_to_mujoco_pos(jaxsim_jointpos, mjmodel, jaxsimmodel):
    mujocoqposaddr2jaxindex = {}
    for jaxjnt in jaxsimmodel.joints():
        jntname = jaxjnt.name()
        mujocoqposaddr2jaxindex[mjmodel.joint(jntname).qposadr[0]] = jaxjnt.index() - 1

    mujoco_jointpos = jaxsim_jointpos
    for i in range(0, len(mujoco_jointpos)):
        mujoco_jointpos[i] = jaxsim_jointpos[mujocoqposaddr2jaxindex[i]]

    return mujoco_jointpos


# To get a good camera location, you can use "Copy camera" functionality in MuJoCo GUI
mj_model = load_mujoco_model_with_camera(
    model_urdf_string,
    [3.954, 3.533, 2.343],
    [-0.594, 0.804, -0.000, -0.163, -0.120, 0.979],
)
renderer = mujoco.Renderer(mj_model, height=480, width=640)


def get_image(camera, mujocojointpos) -> np.ndarray:
    """Renders the environment state."""
    # Copy joint data in mjdata state
    d = mujoco.MjData(mj_model)
    d.qpos = mujocojointpos

    # Forward kinematics
    mujoco.mj_forward(mj_model, d)

    # use the mjData object to update the renderer
    renderer.update_scene(d, camera=camera)
    return renderer.render()

In order to parallelize the simulation, we need to define a function for a single element of the batch. This function will be called in parallel, vectorizing the simulation over the chosen dimension or parameter.

In [None]:
# Create a logger to store simulation data
@jax_dataclasses.pytree_dataclass
class SimulatorLogger(simulator_callbacks.PostStepCallback):
    def post_step(
        self, sim: JaxSim, step_data: Dict[str, StepData]
    ) -> Tuple[JaxSim, jtp.PyTree]:
        """Return the StepData object of each simulated model"""
        return sim, step_data


# Define a function to simulate a single model instance
def simulate(sim: JaxSim, pose) -> JaxSim:
    model.zero()
    model.reset_base_position(position=jnp.array(pose))

    with sim.editable(validate=True) as sim:
        m = sim.get_model(model.name())
        m.data = model.data

    sim, ((_, cb), step_data) = simulator.step_over_horizon(
        horizon_steps=integration_time // step_size,
        callback_handler=SimulatorLogger(),
        clear_inputs=True,
    )

    return step_data


for _ in range(300):
    sim_images.append(
        get_image(
            "side",
            from_jaxsim_to_mujoco_pos(
                np.array(model.joint_positions()), mj_model, model
            ),
        )
    )
    model.integrate(
        t0=0.0, tf=timestep, integrator_type=IntegratorType.EulerSemiImplicit
    )

media.show_video(sim_images, fps=1 / timestep)

We will make use of `jax.vmap` to simulate multiple models in parallel. This is a very powerful feature of JAX that allows us to write code that is very similar to the single-model case, but can be executed in parallel on multiple models.

Note that in our case we are vectorizing over the `pose` argument of the function `simulate`, this correspond to the value assigned to the `in_axes` parameter of `jax.vmap`:

`in_axes=(None, 0)` means that the first argument of `simulate` is not vectorized, while the second argument is vectorized over the zero-th dimension.

In [None]:
# Define a function to simulate multiple model instances
simulate_vectorized = jax.vmap(simulate, in_axes=(None, 0))

# Run and time the simulation
now = time.perf_counter()

time_history = simulate_vectorized(simulator, poses[:, 0])

logging.info(f"Running simulation with {num_models} models")

Now let's extract the data from the simulation and plot it.

In [None]:
time_history: Dict[str, StepData]
x_t = time_history[model.name()].tf_model_state


import matplotlib.pyplot as plt

plt.plot(
    time_history[model.name()].tf[0], x_t.base_position[obj_id], label=["x", "y", "z"]
)
plt.grid(True)
plt.legend()
plt.xlabel("Time [s]")
plt.ylabel("Position [m]")
plt.title("Trajectory of the model's base")
plt.show()

In [None]:
sim_images = []
timestep = 0.01

for _ in range(300):
    sim_images.append(
        get_image(
            "side",
            from_jaxsim_to_mujoco_pos(
                np.array(model.joint_positions()), mj_model, model
            ),
        )
    )
    model.set_joint_generalized_force_targets(
        forces=pd_controller(
            q=model.joint_positions(),
            q_d=jnp.array([0.0, 0.0]),
            q_dot=model.joint_velocities(),
            q_dot_d=jnp.array([0.0, 0.0]),
        )
    )
    model.integrate(
        t0=0.0, tf=timestep, integrator_type=IntegratorType.EulerSemiImplicit
    )

media.show_video(sim_images, fps=1 / timestep)