-
Notifications
You must be signed in to change notification settings - Fork 9
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #65 from flferretti/example/parallel
Add JAXsim Notebook for Parallel Simulation
- Loading branch information
Showing
3 changed files
with
309 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,307 @@ | ||
{ | ||
"cells": [ | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"# `JAXsim` Showcase: Parallel Simulation of a free-falling body\n", | ||
"\n", | ||
"<a target=\"_blank\" href=\"https://colab.research.google.com/github/ami-iit/jaxsim/blob/main/examples/Parallel_computing.ipynb\">\n", | ||
" <img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/>\n", | ||
"</a>" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"First, we install the necessary packages and import them." | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"#@title Imports and setup\n", | ||
"import sys\n", | ||
"\n", | ||
"from IPython.display import HTML, clear_output, display\n", | ||
"\n", | ||
"IS_COLAB = \"google.colab\" in sys.modules\n", | ||
"\n", | ||
"# Install JAX and Gazebo\n", | ||
"if IS_COLAB:\n", | ||
" !{sys.executable} -m pip install -U -q jaxsim\n", | ||
" !apt -qq update && apt install -qq --no-install-recommends gazebo\n", | ||
" clear_output()\n", | ||
"else:\n", | ||
" # Set environment variable to avoid GPU out of memory errors\n", | ||
" %env XLA_PYTHON_CLIENT_MEM_PREALLOCATE=false\n", | ||
"\n", | ||
"import time\n", | ||
"from typing import Dict, Tuple\n", | ||
"\n", | ||
"import jax\n", | ||
"import jax.numpy as jnp\n", | ||
"import jax_dataclasses\n", | ||
"import rod\n", | ||
"from rod.builder.primitives import SphereBuilder\n", | ||
"\n", | ||
"import jaxsim.typing as jtp\n", | ||
"from jaxsim import logging\n", | ||
"\n", | ||
"logging.set_logging_level(logging.LoggingLevel.INFO)\n", | ||
"logging.info(f\"Running on {jax.devices()}\")" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"We will use a simple sphere model to simulate a free-falling body. The spheres set will be composed of 9 spheres, each with a different position. The spheres will be simulated in parallel, and the simulation will be run for 3000 steps corresponding to 3 seconds of simulation.\n", | ||
"\n", | ||
"**Note**: Parallel simulations are independent of each other, the different position is imposed only to show the parallelization visually." | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"#@title Create a sphere model\n", | ||
"model_sdf_string = rod.Sdf(\n", | ||
" version=\"1.7\",\n", | ||
" model=SphereBuilder(radius=0.10, mass=1.0, name=\"sphere\")\n", | ||
" .build_model()\n", | ||
" .add_link()\n", | ||
" .add_inertial()\n", | ||
" .add_visual()\n", | ||
" .add_collision()\n", | ||
" .build(),\n", | ||
").serialize(pretty=True)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"Now, we can create a simulator instance and load the model into it." | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"from jaxsim.high_level.model import VelRepr\n", | ||
"from jaxsim.physics.algos.soft_contacts import SoftContactsParams\n", | ||
"from jaxsim.simulation.ode_integration import IntegratorType\n", | ||
"from jaxsim.simulation.simulator import JaxSim, SimulatorData, StepData\n", | ||
"\n", | ||
"# Simulation Step Parameters\n", | ||
"integration_time = 3.0 # seconds\n", | ||
"step_size = 0.001\n", | ||
"steps_per_run = 1\n", | ||
"\n", | ||
"simulator = JaxSim.build(\n", | ||
" step_size=step_size,\n", | ||
" steps_per_run=steps_per_run,\n", | ||
" velocity_representation=VelRepr.Body,\n", | ||
" integrator_type=IntegratorType.EulerSemiImplicit,\n", | ||
" simulator_data=SimulatorData(\n", | ||
" contact_parameters=SoftContactsParams(K=1e6, D=2e3, mu=0.5),\n", | ||
" ),\n", | ||
")\n", | ||
"\n", | ||
"\n", | ||
"# Add model to simulator\n", | ||
"model = simulator.insert_model_from_description(\n", | ||
" model_description=model_sdf_string\n", | ||
")" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"Let's create a position vector for a 3x3 grid. Every sphere will be placed at a different height." | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"# Primary Calculations\n", | ||
"env_spacing = 0.5\n", | ||
"envs_per_row = 3\n", | ||
"edge_len = env_spacing * (2 * envs_per_row - 1)\n", | ||
"\n", | ||
"\n", | ||
"# Create Grid\n", | ||
"def grid(edge_len, envs_per_row):\n", | ||
" edge = jnp.linspace(-edge_len, edge_len, envs_per_row)\n", | ||
" xx, yy = jnp.meshgrid(edge, edge)\n", | ||
"\n", | ||
" poses = [\n", | ||
" [[xx[i, j], yy[i, j], 0.2 + 0.1 * (i * envs_per_row + j)], [0, 0, 0]]\n", | ||
" for i in range(xx.shape[0])\n", | ||
" for j in range(yy.shape[0])\n", | ||
" ]\n", | ||
"\n", | ||
" return jnp.array(poses)\n", | ||
"\n", | ||
"\n", | ||
"poses = grid(edge_len, envs_per_row)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"In order to parallelize the simulation, we first need to define a function `simulate` for a single element of the batch.\n", | ||
"\n", | ||
"**Note:** [`step_over_horizon`](https://github.com/ami-iit/jaxsim/blob/427b1e646297495f6b33e4c0bb2273ca89bd5ae2/src/jaxsim/simulation/simulator.py#L432C1-L529C10) is useful only in open-loop simulations and where the horizon is known in advance. Please checkout [`step`](https://github.com/ami-iit/jaxsim/blob/427b1e646297495f6b33e4c0bb2273ca89bd5ae2/src/jaxsim/simulation/simulator.py#L384C10-L425) for closed-loop simulations." | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"from jaxsim.simulation import simulator_callbacks\n", | ||
"\n", | ||
"\n", | ||
"# Create a logger to store simulation data\n", | ||
"@jax_dataclasses.pytree_dataclass\n", | ||
"class SimulatorLogger(simulator_callbacks.PostStepCallback):\n", | ||
" def post_step(\n", | ||
" self, sim: JaxSim, step_data: Dict[str, StepData]\n", | ||
" ) -> Tuple[JaxSim, jtp.PyTree]:\n", | ||
" \"\"\"Return the StepData object of each simulated model\"\"\"\n", | ||
" return sim, step_data\n", | ||
"\n", | ||
"\n", | ||
"# Define a function to simulate a single model instance\n", | ||
"def simulate(sim: JaxSim, pose) -> JaxSim:\n", | ||
" model.zero()\n", | ||
" model.reset_base_position(position=jnp.array(pose))\n", | ||
"\n", | ||
" with sim.editable(validate=True) as sim:\n", | ||
" m = sim.get_model(model.name())\n", | ||
" m.data = model.data\n", | ||
"\n", | ||
" sim, (cb, (_, step_data)) = simulator.step_over_horizon(\n", | ||
" horizon_steps=integration_time // step_size,\n", | ||
" callback_handler=SimulatorLogger(),\n", | ||
" clear_inputs=True,\n", | ||
" )\n", | ||
"\n", | ||
" return step_data" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"We will make use of `jax.vmap` to simulate multiple models in parallel. This is a very powerful feature of JAX that allows to write code that is very similar to the single-model case, but can be executed in parallel on multiple models.\n", | ||
"In order to do so, we need to first apply `jax.vmap` to the `simulate` function, and then call the resulting function with the batch of different poses as input.\n", | ||
"\n", | ||
"Note that in our case we are vectorizing over the `pose` argument of the function `simulate`, this correspond to the value assigned to the `in_axes` parameter of `jax.vmap`:\n", | ||
"\n", | ||
"`in_axes=(None, 0)` means that the first argument of `simulate` is not vectorized, while the second argument is vectorized over the zero-th dimension." | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"# Define a function to simulate multiple model instances\n", | ||
"simulate_vectorized = jax.vmap(simulate, in_axes=(None, 0))\n", | ||
"\n", | ||
"# Run and time the simulation\n", | ||
"now = time.perf_counter()\n", | ||
"\n", | ||
"time_history = simulate_vectorized(simulator, poses[:, 0])\n", | ||
"\n", | ||
"comp_time = time.perf_counter() - now\n", | ||
"\n", | ||
"logging.info(\n", | ||
" f\"Running simulation with {envs_per_row**2} models took {comp_time} seconds.\"\n", | ||
")\n", | ||
"logging.info(\n", | ||
" f\"This corresponds to an RTF (Real Time Factor) of {envs_per_row**2 *integration_time/comp_time}\"\n", | ||
")" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"Now let's extract the data from the simulation and plot it. We expect to see the height time series of each sphere starting from a different value." | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"time_history: Dict[str, StepData]\n", | ||
"x_t = time_history[model.name()].tf_model_state\n", | ||
"\n", | ||
"\n", | ||
"import matplotlib.pyplot as plt\n", | ||
"\n", | ||
"plt.plot(time_history[model.name()].tf[0], x_t.base_position[:, :, 2].T)\n", | ||
"plt.grid(True)\n", | ||
"plt.xlabel(\"Time [s]\")\n", | ||
"plt.ylabel(\"Height [m]\")\n", | ||
"plt.title(\"Trajectory of the model's base\")\n", | ||
"plt.show()" | ||
] | ||
} | ||
], | ||
"metadata": { | ||
"accelerator": "GPU", | ||
"colab": { | ||
"gpuClass": "premium", | ||
"gpuType": "V100", | ||
"private_outputs": true, | ||
"provenance": [ | ||
{ | ||
"file_id": "1QsuS7EJhdPEHxxAu9XwozvA7eb4ZnlAb", | ||
"timestamp": 1701993737024 | ||
} | ||
], | ||
"toc_visible": true | ||
}, | ||
"kernelspec": { | ||
"display_name": "Python 3 (ipykernel)", | ||
"language": "python", | ||
"name": "python3" | ||
}, | ||
"language_info": { | ||
"codemirror_mode": { | ||
"name": "ipython", | ||
"version": 3 | ||
}, | ||
"file_extension": ".py", | ||
"mimetype": "text/x-python", | ||
"name": "python", | ||
"nbconvert_exporter": "python", | ||
"pygments_lexer": "ipython3", | ||
"version": "3.12.1" | ||
} | ||
}, | ||
"nbformat": 4, | ||
"nbformat_minor": 4 | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters