Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add JAXsim Notebook for Parallel Simulation #65

Merged
merged 7 commits into from
Jan 15, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
307 changes: 307 additions & 0 deletions examples/Parallel_computing.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,307 @@
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# `JAXsim` Showcase: Parallel Simulation of a free-falling body\n",
"\n",
"<a target=\"_blank\" href=\"https://colab.research.google.com/github/ami-iit/jaxsim/blob/main/examples/Parallel_computing.ipynb\">\n",
" <img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/>\n",
"</a>"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"First, we install the necessary packages and import them."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"#@title Imports and setup\n",
"import sys\n",
"\n",
"from IPython.display import HTML, clear_output, display\n",
"\n",
"IS_COLAB = \"google.colab\" in sys.modules\n",
"\n",
"# Install JAX and Gazebo\n",
"if IS_COLAB:\n",
" !{sys.executable} -m pip install -U -q jaxsim\n",
" !apt -qq update && apt install -qq --no-install-recommends gazebo\n",
" clear_output()\n",
"else:\n",
" # Set environment variable to avoid GPU out of memory errors\n",
" %env XLA_PYTHON_CLIENT_MEM_PREALLOCATE=false\n",
"\n",
"import time\n",
"from typing import Dict, Tuple\n",
"\n",
"import jax\n",
"import jax.numpy as jnp\n",
"import jax_dataclasses\n",
"import rod\n",
"from rod.builder.primitives import SphereBuilder\n",
"\n",
"import jaxsim.typing as jtp\n",
"from jaxsim import logging\n",
"\n",
"logging.set_logging_level(logging.LoggingLevel.INFO)\n",
"logging.info(f\"Running on {jax.devices()}\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"We will use a simple sphere model to simulate a free-falling body. The spheres set will be composed of 9 spheres, each with a different position. The spheres will be simulated in parallel, and the simulation will be run for 3000 steps corresponding to 3 seconds of simulation.\n",
"\n",
"**Note**: Parallel simulations are independent of each other, the different position is imposed only to show the parallelization visually."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"#@title Create a sphere model\n",
"model_sdf_string = rod.Sdf(\n",
" version=\"1.7\",\n",
" model=SphereBuilder(radius=0.10, mass=1.0, name=\"sphere\")\n",
" .build_model()\n",
" .add_link()\n",
" .add_inertial()\n",
" .add_visual()\n",
" .add_collision()\n",
" .build(),\n",
").serialize(pretty=True)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Now, we can create a simulator instance and load the model into it."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from jaxsim.high_level.model import VelRepr\n",
"from jaxsim.physics.algos.soft_contacts import SoftContactsParams\n",
"from jaxsim.simulation.ode_integration import IntegratorType\n",
"from jaxsim.simulation.simulator import JaxSim, SimulatorData, StepData\n",
"\n",
"# Simulation Step Parameters\n",
"integration_time = 3.0 # seconds\n",
"step_size = 0.001\n",
"steps_per_run = 1\n",
"\n",
"simulator = JaxSim.build(\n",
" step_size=step_size,\n",
" steps_per_run=steps_per_run,\n",
" velocity_representation=VelRepr.Body,\n",
" integrator_type=IntegratorType.EulerSemiImplicit,\n",
" simulator_data=SimulatorData(\n",
" contact_parameters=SoftContactsParams(K=1e6, D=2e3, mu=0.5),\n",
" ),\n",
")\n",
"\n",
"\n",
"# Add model to simulator\n",
"model = simulator.insert_model_from_description(\n",
" model_description=model_sdf_string\n",
")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Let's create a position vector for a 3x3 grid. Every sphere will be placed at a different height."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Primary Calculations\n",
"env_spacing = 0.5\n",
"envs_per_row = 3\n",
"edge_len = env_spacing * (2 * envs_per_row - 1)\n",
"\n",
"\n",
"# Create Grid\n",
"def grid(edge_len, envs_per_row):\n",
" edge = jnp.linspace(-edge_len, edge_len, envs_per_row)\n",
" xx, yy = jnp.meshgrid(edge, edge)\n",
"\n",
" poses = [\n",
" [[xx[i, j], yy[i, j], 0.2 + 0.1 * (i * envs_per_row + j)], [0, 0, 0]]\n",
" for i in range(xx.shape[0])\n",
" for j in range(yy.shape[0])\n",
" ]\n",
"\n",
" return jnp.array(poses)\n",
"\n",
"\n",
"poses = grid(edge_len, envs_per_row)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"In order to parallelize the simulation, we first need to define a function `simulate` for a single element of the batch.\n",
"\n",
"**Note:** [`step_over_horizon`](https://github.com/ami-iit/jaxsim/blob/427b1e646297495f6b33e4c0bb2273ca89bd5ae2/src/jaxsim/simulation/simulator.py#L432C1-L529C10) is useful only in open-loop simulations and where the horizon is known in advance. Please checkout [`step`](https://github.com/ami-iit/jaxsim/blob/427b1e646297495f6b33e4c0bb2273ca89bd5ae2/src/jaxsim/simulation/simulator.py#L384C10-L425) for closed-loop simulations."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from jaxsim.simulation import simulator_callbacks\n",
"\n",
"\n",
"# Create a logger to store simulation data\n",
"@jax_dataclasses.pytree_dataclass\n",
"class SimulatorLogger(simulator_callbacks.PostStepCallback):\n",
" def post_step(\n",
" self, sim: JaxSim, step_data: Dict[str, StepData]\n",
" ) -> Tuple[JaxSim, jtp.PyTree]:\n",
" \"\"\"Return the StepData object of each simulated model\"\"\"\n",
" return sim, step_data\n",
"\n",
"\n",
"# Define a function to simulate a single model instance\n",
"def simulate(sim: JaxSim, pose) -> JaxSim:\n",
" model.zero()\n",
" model.reset_base_position(position=jnp.array(pose))\n",
"\n",
" with sim.editable(validate=True) as sim:\n",
" m = sim.get_model(model.name())\n",
" m.data = model.data\n",
"\n",
" sim, (cb, (_, step_data)) = simulator.step_over_horizon(\n",
" horizon_steps=integration_time // step_size,\n",
" callback_handler=SimulatorLogger(),\n",
" clear_inputs=True,\n",
" )\n",
"\n",
" return step_data"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"We will make use of `jax.vmap` to simulate multiple models in parallel. This is a very powerful feature of JAX that allows to write code that is very similar to the single-model case, but can be executed in parallel on multiple models.\n",
"In order to do so, we need to first apply `jax.vmap` to the `simulate` function, and then call the resulting function with the batch of different poses as input.\n",
"\n",
"Note that in our case we are vectorizing over the `pose` argument of the function `simulate`, this correspond to the value assigned to the `in_axes` parameter of `jax.vmap`:\n",
"\n",
"`in_axes=(None, 0)` means that the first argument of `simulate` is not vectorized, while the second argument is vectorized over the zero-th dimension."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Define a function to simulate multiple model instances\n",
"simulate_vectorized = jax.vmap(simulate, in_axes=(None, 0))\n",
"\n",
"# Run and time the simulation\n",
"now = time.perf_counter()\n",
"\n",
"time_history = simulate_vectorized(simulator, poses[:, 0])\n",
"\n",
"comp_time = time.perf_counter() - now\n",
"\n",
"logging.info(\n",
" f\"Running simulation with {envs_per_row**2} models took {comp_time} seconds.\"\n",
")\n",
"logging.info(\n",
" f\"This corresponds to an RTF (Real Time Factor) of {envs_per_row**2 *integration_time/comp_time}\"\n",
")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Now let's extract the data from the simulation and plot it. We expect to see the height time series of each sphere starting from a different value."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"time_history: Dict[str, StepData]\n",
"x_t = time_history[model.name()].tf_model_state\n",
"\n",
"\n",
"import matplotlib.pyplot as plt\n",
"\n",
"plt.plot(time_history[model.name()].tf[0], x_t.base_position[:, :, 2].T)\n",
"plt.grid(True)\n",
"plt.xlabel(\"Time [s]\")\n",
"plt.ylabel(\"Height [m]\")\n",
"plt.title(\"Trajectory of the model's base\")\n",
"plt.show()"
]
}
],
"metadata": {
"accelerator": "GPU",
"colab": {
"gpuClass": "premium",
"gpuType": "V100",
"private_outputs": true,
"provenance": [
{
"file_id": "1QsuS7EJhdPEHxxAu9XwozvA7eb4ZnlAb",
"timestamp": 1701993737024
}
],
"toc_visible": true
},
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.12.1"
}
},
"nbformat": 4,
"nbformat_minor": 4
}
1 change: 1 addition & 0 deletions examples/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ This folder includes a Jupyter Notebook demonstrating practical usage of JAXsim
### Examples

- [PD_controller](https://colab.research.google.com/github/ami-iit/jaxsim/blob/main/examples/PD_controller.ipynb) - A simple example demonstrating the use of JAXsim to simulate a PD controller with gravity compensation for a 2-DOF cartpole.
- [Parallel_computing](https://colab.research.google.com/github/ami-iit/jaxsim/blob/main/examples/Parallel_computing.ipynb) - An example demonstrating how to simulate vectorized models in parallel using JAXsim.

> [!TIP]
> Stay tuned for more examples!
Expand Down
1 change: 1 addition & 0 deletions examples/pixi.toml
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ unix = true

[tasks]
PD_controller = {cmd = "jupyter notebook PD_controller.ipynb", depends_on = ["install"]}
Parallel_computing = {cmd = "jupyter notebook Parallel_computing.ipynb", depends_on = ["install"]}
install = "python -m pip install git+https://github.com/ami-iit/jaxsim.git"

[dependencies]
Expand Down
Loading