Skip to content

Commit

Permalink
Merge pull request #65 from flferretti/example/parallel
Browse files Browse the repository at this point in the history
Add JAXsim Notebook for Parallel Simulation
  • Loading branch information
flferretti committed Jan 15, 2024
2 parents 427b1e6 + 145573c commit d329544
Show file tree
Hide file tree
Showing 3 changed files with 309 additions and 0 deletions.
307 changes: 307 additions & 0 deletions examples/Parallel_computing.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,307 @@
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# `JAXsim` Showcase: Parallel Simulation of a free-falling body\n",
"\n",
"<a target=\"_blank\" href=\"https://colab.research.google.com/github/ami-iit/jaxsim/blob/main/examples/Parallel_computing.ipynb\">\n",
" <img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/>\n",
"</a>"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"First, we install the necessary packages and import them."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"#@title Imports and setup\n",
"import sys\n",
"\n",
"from IPython.display import HTML, clear_output, display\n",
"\n",
"IS_COLAB = \"google.colab\" in sys.modules\n",
"\n",
"# Install JAX and Gazebo\n",
"if IS_COLAB:\n",
" !{sys.executable} -m pip install -U -q jaxsim\n",
" !apt -qq update && apt install -qq --no-install-recommends gazebo\n",
" clear_output()\n",
"else:\n",
" # Set environment variable to avoid GPU out of memory errors\n",
" %env XLA_PYTHON_CLIENT_MEM_PREALLOCATE=false\n",
"\n",
"import time\n",
"from typing import Dict, Tuple\n",
"\n",
"import jax\n",
"import jax.numpy as jnp\n",
"import jax_dataclasses\n",
"import rod\n",
"from rod.builder.primitives import SphereBuilder\n",
"\n",
"import jaxsim.typing as jtp\n",
"from jaxsim import logging\n",
"\n",
"logging.set_logging_level(logging.LoggingLevel.INFO)\n",
"logging.info(f\"Running on {jax.devices()}\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"We will use a simple sphere model to simulate a free-falling body. The spheres set will be composed of 9 spheres, each with a different position. The spheres will be simulated in parallel, and the simulation will be run for 3000 steps corresponding to 3 seconds of simulation.\n",
"\n",
"**Note**: Parallel simulations are independent of each other, the different position is imposed only to show the parallelization visually."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"#@title Create a sphere model\n",
"model_sdf_string = rod.Sdf(\n",
" version=\"1.7\",\n",
" model=SphereBuilder(radius=0.10, mass=1.0, name=\"sphere\")\n",
" .build_model()\n",
" .add_link()\n",
" .add_inertial()\n",
" .add_visual()\n",
" .add_collision()\n",
" .build(),\n",
").serialize(pretty=True)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Now, we can create a simulator instance and load the model into it."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from jaxsim.high_level.model import VelRepr\n",
"from jaxsim.physics.algos.soft_contacts import SoftContactsParams\n",
"from jaxsim.simulation.ode_integration import IntegratorType\n",
"from jaxsim.simulation.simulator import JaxSim, SimulatorData, StepData\n",
"\n",
"# Simulation Step Parameters\n",
"integration_time = 3.0 # seconds\n",
"step_size = 0.001\n",
"steps_per_run = 1\n",
"\n",
"simulator = JaxSim.build(\n",
" step_size=step_size,\n",
" steps_per_run=steps_per_run,\n",
" velocity_representation=VelRepr.Body,\n",
" integrator_type=IntegratorType.EulerSemiImplicit,\n",
" simulator_data=SimulatorData(\n",
" contact_parameters=SoftContactsParams(K=1e6, D=2e3, mu=0.5),\n",
" ),\n",
")\n",
"\n",
"\n",
"# Add model to simulator\n",
"model = simulator.insert_model_from_description(\n",
" model_description=model_sdf_string\n",
")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Let's create a position vector for a 3x3 grid. Every sphere will be placed at a different height."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Primary Calculations\n",
"env_spacing = 0.5\n",
"envs_per_row = 3\n",
"edge_len = env_spacing * (2 * envs_per_row - 1)\n",
"\n",
"\n",
"# Create Grid\n",
"def grid(edge_len, envs_per_row):\n",
" edge = jnp.linspace(-edge_len, edge_len, envs_per_row)\n",
" xx, yy = jnp.meshgrid(edge, edge)\n",
"\n",
" poses = [\n",
" [[xx[i, j], yy[i, j], 0.2 + 0.1 * (i * envs_per_row + j)], [0, 0, 0]]\n",
" for i in range(xx.shape[0])\n",
" for j in range(yy.shape[0])\n",
" ]\n",
"\n",
" return jnp.array(poses)\n",
"\n",
"\n",
"poses = grid(edge_len, envs_per_row)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"In order to parallelize the simulation, we first need to define a function `simulate` for a single element of the batch.\n",
"\n",
"**Note:** [`step_over_horizon`](https://github.com/ami-iit/jaxsim/blob/427b1e646297495f6b33e4c0bb2273ca89bd5ae2/src/jaxsim/simulation/simulator.py#L432C1-L529C10) is useful only in open-loop simulations and where the horizon is known in advance. Please checkout [`step`](https://github.com/ami-iit/jaxsim/blob/427b1e646297495f6b33e4c0bb2273ca89bd5ae2/src/jaxsim/simulation/simulator.py#L384C10-L425) for closed-loop simulations."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from jaxsim.simulation import simulator_callbacks\n",
"\n",
"\n",
"# Create a logger to store simulation data\n",
"@jax_dataclasses.pytree_dataclass\n",
"class SimulatorLogger(simulator_callbacks.PostStepCallback):\n",
" def post_step(\n",
" self, sim: JaxSim, step_data: Dict[str, StepData]\n",
" ) -> Tuple[JaxSim, jtp.PyTree]:\n",
" \"\"\"Return the StepData object of each simulated model\"\"\"\n",
" return sim, step_data\n",
"\n",
"\n",
"# Define a function to simulate a single model instance\n",
"def simulate(sim: JaxSim, pose) -> JaxSim:\n",
" model.zero()\n",
" model.reset_base_position(position=jnp.array(pose))\n",
"\n",
" with sim.editable(validate=True) as sim:\n",
" m = sim.get_model(model.name())\n",
" m.data = model.data\n",
"\n",
" sim, (cb, (_, step_data)) = simulator.step_over_horizon(\n",
" horizon_steps=integration_time // step_size,\n",
" callback_handler=SimulatorLogger(),\n",
" clear_inputs=True,\n",
" )\n",
"\n",
" return step_data"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"We will make use of `jax.vmap` to simulate multiple models in parallel. This is a very powerful feature of JAX that allows to write code that is very similar to the single-model case, but can be executed in parallel on multiple models.\n",
"In order to do so, we need to first apply `jax.vmap` to the `simulate` function, and then call the resulting function with the batch of different poses as input.\n",
"\n",
"Note that in our case we are vectorizing over the `pose` argument of the function `simulate`, this correspond to the value assigned to the `in_axes` parameter of `jax.vmap`:\n",
"\n",
"`in_axes=(None, 0)` means that the first argument of `simulate` is not vectorized, while the second argument is vectorized over the zero-th dimension."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Define a function to simulate multiple model instances\n",
"simulate_vectorized = jax.vmap(simulate, in_axes=(None, 0))\n",
"\n",
"# Run and time the simulation\n",
"now = time.perf_counter()\n",
"\n",
"time_history = simulate_vectorized(simulator, poses[:, 0])\n",
"\n",
"comp_time = time.perf_counter() - now\n",
"\n",
"logging.info(\n",
" f\"Running simulation with {envs_per_row**2} models took {comp_time} seconds.\"\n",
")\n",
"logging.info(\n",
" f\"This corresponds to an RTF (Real Time Factor) of {envs_per_row**2 *integration_time/comp_time}\"\n",
")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Now let's extract the data from the simulation and plot it. We expect to see the height time series of each sphere starting from a different value."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"time_history: Dict[str, StepData]\n",
"x_t = time_history[model.name()].tf_model_state\n",
"\n",
"\n",
"import matplotlib.pyplot as plt\n",
"\n",
"plt.plot(time_history[model.name()].tf[0], x_t.base_position[:, :, 2].T)\n",
"plt.grid(True)\n",
"plt.xlabel(\"Time [s]\")\n",
"plt.ylabel(\"Height [m]\")\n",
"plt.title(\"Trajectory of the model's base\")\n",
"plt.show()"
]
}
],
"metadata": {
"accelerator": "GPU",
"colab": {
"gpuClass": "premium",
"gpuType": "V100",
"private_outputs": true,
"provenance": [
{
"file_id": "1QsuS7EJhdPEHxxAu9XwozvA7eb4ZnlAb",
"timestamp": 1701993737024
}
],
"toc_visible": true
},
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.12.1"
}
},
"nbformat": 4,
"nbformat_minor": 4
}
1 change: 1 addition & 0 deletions examples/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ This folder includes a Jupyter Notebook demonstrating practical usage of JAXsim
### Examples

- [PD_controller](https://colab.research.google.com/github/ami-iit/jaxsim/blob/main/examples/PD_controller.ipynb) - A simple example demonstrating the use of JAXsim to simulate a PD controller with gravity compensation for a 2-DOF cartpole.
- [Parallel_computing](https://colab.research.google.com/github/ami-iit/jaxsim/blob/main/examples/Parallel_computing.ipynb) - An example demonstrating how to simulate vectorized models in parallel using JAXsim.

> [!TIP]
> Stay tuned for more examples!
Expand Down
1 change: 1 addition & 0 deletions examples/pixi.toml
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ unix = true

[tasks]
PD_controller = {cmd = "jupyter notebook PD_controller.ipynb", depends_on = ["install"]}
Parallel_computing = {cmd = "jupyter notebook Parallel_computing.ipynb", depends_on = ["install"]}
install = "python -m pip install git+https://github.com/ami-iit/jaxsim.git"

[dependencies]
Expand Down

0 comments on commit d329544

Please sign in to comment.