# `JAXsim` Showcase: Parallel Simulation of a free-falling body

<a target="_blank" href="https://colab.research.google.com/github/ami-iit/jaxsim/blob/main/examples/Parallel_computing.ipynb">
  <img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/>
</a>

First, we install the necessary packages and import them.

In [None]:
# @title Imports and setup
import sys

from IPython.display import HTML, clear_output, display

IS_COLAB = "google.colab" in sys.modules

# Install JAX and Gazebo
if IS_COLAB:
    !{sys.executable} -m pip install -U -q jaxsim
    !apt -qq update && apt install -qq --no-install-recommends gazebo
    clear_output()
else:
    # Set environment variable to avoid GPU out of memory errors
    %env XLA_PYTHON_CLIENT_MEM_PREALLOCATE=false

import time
from typing import Dict, Tuple

import jax
import jax.numpy as jnp
import jax_dataclasses
import rod
from rod.builder.primitives import SphereBuilder

import jaxsim.typing as jtp
from jaxsim import logging

logging.set_logging_level(logging.LoggingLevel.INFO)
logging.info(f"Running on {jax.devices()}")

We will use a simple sphere model to simulate a free-falling body. The spheres set will be composed of 9 spheres, each with a different position. The spheres will be simulated in parallel, and the simulation will be run for 3000 steps corresponding to 3 seconds of simulation.

**Note**: Parallel simulations are independent of each other, the different position is imposed only to show the parallelization visually.

In [None]:
# @title Create a sphere model
model_sdf_string = rod.Sdf(
    version="1.7",
    model=SphereBuilder(radius=0.10, mass=1.0, name="sphere")
    .build_model()
    .add_link()
    .add_inertial()
    .add_visual()
    .add_collision()
    .build(),
).serialize(pretty=True)

Now, we can create a simulator instance and load the model into it.

In [None]:
from jaxsim.high_level.model import VelRepr
from jaxsim.physics.algos.soft_contacts import SoftContactsParams
from jaxsim.simulation.ode_integration import IntegratorType
from jaxsim.simulation.simulator import JaxSim, SimulatorData, StepData

# Simulation Step Parameters
integration_time = 3.0  # seconds
step_size = 0.001
steps_per_run = 1

simulator = JaxSim.build(
    step_size=step_size,
    steps_per_run=steps_per_run,
    velocity_representation=VelRepr.Body,
    integrator_type=IntegratorType.EulerSemiImplicit,
    simulator_data=SimulatorData(
        contact_parameters=SoftContactsParams(K=1e6, D=2e3, mu=0.5),
    ),
)


# Add model to simulator
model = simulator.insert_model_from_description(model_description=model_sdf_string)

Let's create a position vector for a 3x3 grid. Every sphere will be placed at a different height.

In [None]:
# Primary Calculations
env_spacing = 0.5
envs_per_row = 3
edge_len = env_spacing * (2 * envs_per_row - 1)


# Create Grid
def grid(edge_len, envs_per_row):
    edge = jnp.linspace(-edge_len, edge_len, envs_per_row)
    xx, yy = jnp.meshgrid(edge, edge)

    poses = [
        [[xx[i, j], yy[i, j], 0.2 + 0.1 * (i * envs_per_row + j)], [0, 0, 0]]
        for i in range(xx.shape[0])
        for j in range(yy.shape[0])
    ]

    return jnp.array(poses)


poses = grid(edge_len, envs_per_row)

In order to parallelize the simulation, we first need to define a function `simulate` for a single element of the batch.

**Note:** [`step_over_horizon`](https://github.com/ami-iit/jaxsim/blob/427b1e646297495f6b33e4c0bb2273ca89bd5ae2/src/jaxsim/simulation/simulator.py#L432C1-L529C10) is useful only in open-loop simulations and where the horizon is known in advance. Please checkout [`step`](https://github.com/ami-iit/jaxsim/blob/427b1e646297495f6b33e4c0bb2273ca89bd5ae2/src/jaxsim/simulation/simulator.py#L384C10-L425) for closed-loop simulations.

In [None]:
from jaxsim.simulation import simulator_callbacks


# Create a logger to store simulation data
@jax_dataclasses.pytree_dataclass
class SimulatorLogger(simulator_callbacks.PostStepCallback):
    def post_step(
        self, sim: JaxSim, step_data: Dict[str, StepData]
    ) -> Tuple[JaxSim, jtp.PyTree]:
        """Return the StepData object of each simulated model"""
        return sim, step_data


# Define a function to simulate a single model instance
def simulate(sim: JaxSim, pose) -> JaxSim:
    model.zero()
    model.reset_base_position(position=jnp.array(pose))

    with sim.editable(validate=True) as sim:
        m = sim.get_model(model.name())
        m.data = model.data

    sim, (cb, (_, step_data)) = simulator.step_over_horizon(
        horizon_steps=integration_time // step_size,
        callback_handler=SimulatorLogger(),
        clear_inputs=True,
    )

    return step_data

We will make use of `jax.vmap` to simulate multiple models in parallel. This is a very powerful feature of JAX that allows to write code that is very similar to the single-model case, but can be executed in parallel on multiple models.
In order to do so, we need to first apply `jax.vmap` to the `simulate` function, and then call the resulting function with the batch of different poses as input.

Note that in our case we are vectorizing over the `pose` argument of the function `simulate`, this correspond to the value assigned to the `in_axes` parameter of `jax.vmap`:

`in_axes=(None, 0)` means that the first argument of `simulate` is not vectorized, while the second argument is vectorized over the zero-th dimension.

In [None]:
# Define a function to simulate multiple model instances
simulate_vectorized = jax.vmap(simulate, in_axes=(None, 0))

# Run and time the simulation
now = time.perf_counter()

time_history = simulate_vectorized(simulator, poses[:, 0])

comp_time = time.perf_counter() - now

logging.info(
    f"Running simulation with {envs_per_row**2} models took {comp_time} seconds."
)
logging.info(
    f"This corresponds to an RTF (Real Time Factor) of {envs_per_row**2 *integration_time/comp_time}"
)

Now let's extract the data from the simulation and plot it. We expect to see the height time series of each sphere starting from a different value.

In [None]:
time_history: Dict[str, StepData]
x_t = time_history[model.name()].tf_model_state


import matplotlib.pyplot as plt

plt.plot(time_history[model.name()].tf[0], x_t.base_position[:, :, 2].T)
plt.grid(True)
plt.xlabel("Time [s]")
plt.ylabel("Height [m]")
plt.title("Trajectory of the model's base")
plt.show()