Skip to content

Commit

Permalink
Update vmap example to the functional API
Browse files Browse the repository at this point in the history
  • Loading branch information
flferretti committed Apr 3, 2024
1 parent 92f2998 commit 9a0e49f
Showing 1 changed file with 71 additions and 62 deletions.
133 changes: 71 additions & 62 deletions examples/Parallel_computing.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,12 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"Now, we can create a simulator instance and load the model into it."
"JAXsim offers a simple functional API in order to interact in a memory-efficient way with the simulation. Four main elements are used to define a simulation:\n",
"\n",
"- `model`: an object that defines the dynamics of the system.\n",
"- `data`: an object that contains the state of the system.\n",
"- `integrator`: an object that defines the integration method.\n",
"- `integrator_state`: an object that contains the state of the integrator."
]
},
{
Expand All @@ -97,29 +102,44 @@
"metadata": {},
"outputs": [],
"source": [
"from jaxsim.high_level.model import VelRepr\n",
"from jaxsim.physics.algos.soft_contacts import SoftContactsParams\n",
"from jaxsim.simulation.ode_integration import IntegratorType\n",
"from jaxsim.simulation.simulator import JaxSim, SimulatorData, StepData\n",
"import jaxsim.api as js\n",
"from jaxsim import integrators\n",
"\n",
"# Simulation Step Parameters\n",
"integration_time = 3.0 # seconds\n",
"step_size = 0.001\n",
"steps_per_run = 1\n",
"dt = 0.001\n",
"integration_time = 1500\n",
"\n",
"simulator = JaxSim.build(\n",
" step_size=step_size,\n",
" steps_per_run=steps_per_run,\n",
" velocity_representation=VelRepr.Body,\n",
" integrator_type=IntegratorType.EulerSemiImplicit,\n",
" simulator_data=SimulatorData(\n",
" contact_parameters=SoftContactsParams(K=1e6, D=2e3, mu=0.5),\n",
"model = js.model.JaxSimModel.build_from_model_description(\n",
" model_description=model_sdf_string\n",
")\n",
"data = js.data.JaxSimModelData.build(model=model)\n",
"integrator = integrators.fixed_step.RungeKutta4SO3.build(\n",
" dynamics=js.ode.wrap_system_dynamics_for_integration(\n",
" model=model,\n",
" data=data,\n",
" system_dynamics=js.ode.system_dynamics,\n",
" ),\n",
")\n",
"\n",
"\n",
"# Add model to simulator\n",
"model = simulator.insert_model_from_description(model_description=model_sdf_string)"
"integrator_state = integrator.init(x0=data.state, t0=0.0, dt=dt)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"It is possible to automatically choose a good set of parameters for the terrain. By default, in JaxSim a sphere primitive has 250 collision points."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"data = data.replace(\n",
" soft_contacts_params=js.contact.estimate_good_soft_contacts_parameters(\n",
" model, number_of_active_collidable_points_steady_state=250\n",
" )\n",
")"
]
},
{
Expand All @@ -136,8 +156,9 @@
"outputs": [],
"source": [
"# Primary Calculations\n",
"envs_per_row = 4 # @slider(2, 10, 1)\n",
"\n",
"env_spacing = 0.5\n",
"envs_per_row = 3\n",
"edge_len = env_spacing * (2 * envs_per_row - 1)\n",
"\n",
"\n",
Expand All @@ -155,16 +176,15 @@
" return jnp.array(poses)\n",
"\n",
"\n",
"logging.info(f\"Simulating {envs_per_row**2} environments\")\n",
"poses = grid(edge_len, envs_per_row)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"In order to parallelize the simulation, we first need to define a function `simulate` for a single element of the batch.\n",
"\n",
"**Note:** [`step_over_horizon`](https://github.com/ami-iit/jaxsim/blob/427b1e646297495f6b33e4c0bb2273ca89bd5ae2/src/jaxsim/simulation/simulator.py#L432C1-L529C10) is useful only in open-loop simulations and where the horizon is known in advance. Please checkout [`step`](https://github.com/ami-iit/jaxsim/blob/427b1e646297495f6b33e4c0bb2273ca89bd5ae2/src/jaxsim/simulation/simulator.py#L384C10-L425) for closed-loop simulations."
"In order to parallelize the simulation, we first need to define a function `simulate` for a single element of the batch."
]
},
{
Expand All @@ -173,35 +193,27 @@
"metadata": {},
"outputs": [],
"source": [
"from jaxsim.simulation import simulator_callbacks\n",
"\n",
"\n",
"# Create a logger to store simulation data\n",
"@jax_dataclasses.pytree_dataclass\n",
"class SimulatorLogger(simulator_callbacks.PostStepCallback):\n",
" def post_step(\n",
" self, sim: JaxSim, step_data: Dict[str, StepData]\n",
" ) -> Tuple[JaxSim, jtp.PyTree]:\n",
" \"\"\"Return the StepData object of each simulated model\"\"\"\n",
" return sim, step_data\n",
"\n",
"\n",
"# Define a function to simulate a single model instance\n",
"def simulate(sim: JaxSim, pose) -> JaxSim:\n",
" model.zero()\n",
" model.reset_base_position(position=jnp.array(pose))\n",
"\n",
" with sim.editable(validate=True) as sim:\n",
" m = sim.get_model(model.name())\n",
" m.data = model.data\n",
"\n",
" sim, (cb, (_, step_data)) = simulator.step_over_horizon(\n",
" horizon_steps=integration_time // step_size,\n",
" callback_handler=SimulatorLogger(),\n",
" clear_inputs=True,\n",
" )\n",
"\n",
" return step_data"
"def simulate(\n",
" data: js.data.JaxSimModelData, integrator_state: dict, pose: jnp.array\n",
") -> tuple:\n",
"\n",
" data = data.reset_base_position(base_position=pose)\n",
" x_t_i = []\n",
"\n",
" for _ in range(integration_time):\n",
" data, integrator_state = js.model.step(\n",
" dt=dt,\n",
" model=model,\n",
" data=data,\n",
" integrator=integrator,\n",
" integrator_state=integrator_state,\n",
" joint_forces=None,\n",
" link_forces=None,\n",
" )\n",
" x_t_i.append(data.base_position())\n",
"\n",
" return x_t_i"
]
},
{
Expand All @@ -213,7 +225,7 @@
"\n",
"Note that in our case we are vectorizing over the `pose` argument of the function `simulate`, this correspond to the value assigned to the `in_axes` parameter of `jax.vmap`:\n",
"\n",
"`in_axes=(None, 0)` means that the first argument of `simulate` is not vectorized, while the second argument is vectorized over the zero-th dimension."
"`in_axes=(None, None, 0)` means that the first two arguments of `simulate` are not vectorized, while the third argument is vectorized over the zero-th dimension."
]
},
{
Expand All @@ -223,20 +235,20 @@
"outputs": [],
"source": [
"# Define a function to simulate multiple model instances\n",
"simulate_vectorized = jax.vmap(simulate, in_axes=(None, 0))\n",
"simulate_vectorized = jax.vmap(simulate, in_axes=(None, None, 0))\n",
"\n",
"# Run and time the simulation\n",
"now = time.perf_counter()\n",
"\n",
"time_history = simulate_vectorized(simulator, poses[:, 0])\n",
"x_t = simulate_vectorized(data, integrator_state, poses[:, 0])\n",
"\n",
"comp_time = time.perf_counter() - now\n",
"\n",
"logging.info(\n",
" f\"Running simulation with {envs_per_row**2} models took {comp_time} seconds.\"\n",
")\n",
"logging.info(\n",
" f\"This corresponds to an RTF (Real Time Factor) of {envs_per_row**2 *integration_time/comp_time}\"\n",
" f\"This corresponds to an RTF (Real Time Factor) of {(envs_per_row**2 *integration_time/comp_time):.2f}\"\n",
")"
]
},
Expand All @@ -253,13 +265,10 @@
"metadata": {},
"outputs": [],
"source": [
"time_history: Dict[str, StepData]\n",
"x_t = time_history[model.name()].tf_model_state\n",
"\n",
"\n",
"import matplotlib.pyplot as plt\n",
"import numpy as np\n",
"\n",
"plt.plot(time_history[model.name()].tf[0], x_t.base_position[:, :, 2].T)\n",
"plt.plot(np.arange(len(x_t)) * dt, np.array(x_t)[:, :, 2])\n",
"plt.grid(True)\n",
"plt.xlabel(\"Time [s]\")\n",
"plt.ylabel(\"Height [m]\")\n",
Expand Down Expand Up @@ -297,7 +306,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.12.1"
"version": "3.11.8"
}
},
"nbformat": 4,
Expand Down

0 comments on commit 9a0e49f

Please sign in to comment.