From 0d597728e29c26fb47a7c1eca2ed570e500389ed Mon Sep 17 00:00:00 2001 From: diegoferigo Date: Fri, 20 Oct 2023 14:03:32 +0200 Subject: [PATCH 01/11] Make all fields of simulator.JaxSim static, excluding SimulatorData Fixes trace leaks when jitting vectorized simulations --- src/jaxsim/simulation/simulator.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/jaxsim/simulation/simulator.py b/src/jaxsim/simulation/simulator.py index 6b7e9b77c..fa96c3650 100644 --- a/src/jaxsim/simulation/simulator.py +++ b/src/jaxsim/simulation/simulator.py @@ -63,7 +63,7 @@ class JaxSim(Vmappable): """The JaxSim simulator.""" # Step size stored in ns in order to prevent floats approximation - step_size_ns: jtp.Int = dataclasses.field( + step_size_ns: Static[jtp.Int] = dataclasses.field( default_factory=lambda: jnp.array(1_000_000, dtype=jnp.uint64) ) From 21d9b9c5e3bbafa184ad115bd30786c1cf56cd92 Mon Sep 17 00:00:00 2001 From: Daniele Pucci Date: Tue, 31 Oct 2023 18:04:01 +0100 Subject: [PATCH 02/11] Added Filippo Ferretti as maintainer As per the title --- README.md | 2 ++ 1 file changed, 2 insertions(+) diff --git a/README.md b/README.md index 9bcfc3c6e..248307071 100644 --- a/README.md +++ b/README.md @@ -90,8 +90,10 @@ For major changes, please open an issue first to discuss what you would like to | [][df] | [@diegoferigo][df] | |:---------------------------------------------------------------:|:------------------:| +| [][ff] | [@flferretti][ff] | [df]: https://github.com/diegoferigo +[ff]: https://github.com/flferretti ## License From 0f40789dca50ae6e5d3a3d210e2c45ddcdd47ac6 Mon Sep 17 00:00:00 2001 From: Diego Ferigo Date: Thu, 9 Nov 2023 09:18:57 +0100 Subject: [PATCH 03/11] Update README.md --- README.md | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/README.md b/README.md index 248307071..a2a033eab 100644 --- a/README.md +++ b/README.md @@ -86,11 +86,11 @@ For major changes, please open an issue first to discuss what you would like to } ``` -## Maintainers +## People -| [][df] | [@diegoferigo][df] | -|:---------------------------------------------------------------:|:------------------:| -| [][ff] | [@flferretti][ff] | +| Author | Maintainers | +|:------:|:-----------:| +| [][df] | [][ff] [][df] | [df]: https://github.com/diegoferigo [ff]: https://github.com/flferretti From 1e0fafafccf39a8f45106a6e9666f8b21513f72b Mon Sep 17 00:00:00 2001 From: Diego Ferigo Date: Thu, 9 Nov 2023 10:18:18 +0100 Subject: [PATCH 04/11] Update README.md --- README.md | 56 ++++++++++++++++++++++++++++--------------------------- 1 file changed, 29 insertions(+), 27 deletions(-) diff --git a/README.md b/README.md index a2a033eab..1415d041f 100644 --- a/README.md +++ b/README.md @@ -1,33 +1,33 @@ # JAXsim -**A scalable physics engine implemented with JAX. With JIT batteries 🔋** +**A scalable physics engine and multibody dynamics library implemented with JAX. With JIT batteries 🔋** -⚠ This project is still experimental, APIs could change without notice. ️⚠ +> [!WARNING] +> This project is still experimental, APIs could change without notice. -⚠ This simulator currently focuses on locomotion applications. Only contacts with ground are supported. ️⚠ +> [!NOTE] +> This simulator currently focuses on locomotion applications. Only contacts with ground are supported. ## Features - Physics engine in reduced coordinates implemented with [JAX][jax] in Python. -- Supported JIT compilation of Python code for increased performance. -- Transparent support to execute the simulation on CPUs, GPUs, and TPUs. -- Possibility to run parallel multi-body simulations on hardware accelerators for significantly increased throughput. -- Support of SDF models (and, upon conversion, URDF models). +- JIT compilation of Python code for increased performance. +- Transparent support to execute logic on CPUs, GPUs, and TPUs. +- Parallel multi-body simulations on hardware accelerators for significantly increased throughput. +- Support for SDF models (and, upon conversion, URDF models). - Collision detection between bodies and uneven ground surface. -- Continuous soft contacts model with no friction cone approximations. -- Full support of inertial properties of bodies. +- Soft contacts model supporting full friction cone and sticking / slipping transition. +- Complete support for inertial properties of rigid bodies. - Revolute, prismatic, and fixed joints support. - Integrators: forward Euler, semi-implicit Euler, Runge-Kutta 4. +- High-level classes for object-oriented programming. - High-level classes to compute multi-body dynamics quantities from simulation state. -- High-level classes supporting both object-oriented and functional programming. -- Optional validation of JAX pytrees to prevent JIT re-compilation. - -Planned features: - -- Reinforcement Learning module developed in JAX. -- Finalization of differentiable physics through the simulation. +- High-level classes wrapping the low-level functional RBDAs with support of [multiple velocities representations][notation]. +- Default validation of JAX pytrees to prevent JIT re-compilations. +- Preliminary support for automatic differentiation of RBDAs. [jax]: https://github.com/google/jax/ +[notation]: https://research.tue.nl/en/publications/multibody-dynamics-notation-version-2 ## Installation @@ -37,10 +37,10 @@ You can install the project with [`pypa/pip`][pip], preferably in a [virtual env pip install jaxsim ``` -Have a look to [`setup.cfg`](setup.cfg) for a complete list of optional dependencies. -You can install all of them by specifying `jaxsim[all]`. +Check [`setup.cfg`](setup.cfg) for the complete list of optional dependencies. +Install all of them with `jaxsim[all]`. -**Note:** if you need GPU support, please follow the official [installation instruction][jax_gpu] of JAX. +**Note:** For GPU support, follow the official [installation instruction][jax_gpu] of JAX. [pip]: https://github.com/pypa/pip/ [venv]: https://docs.python.org/3.8/tutorial/venv.html @@ -49,25 +49,27 @@ You can install all of them by specifying `jaxsim[all]`. ## Credits The physics module of JAXsim is based on the theory of the [Rigid Body Dynamics Algorithms][RBDA] -book authored by Roy Featherstone. +book by Roy Featherstone. We structured part of our logic following its accompanying [code][spatial_v2]. The physics engine is developed entirely in Python using [JAX][jax]. [RBDA]: https://link.springer.com/book/10.1007/978-1-4899-7560-7 [spatial_v2]: http://royfeatherstone.org/spatial/index.html#spatial-software -The inspiration of developing JAXsim stems from [`google/brax`][brax]. +The inspiration for developing JAXsim originally stemmed from early versions of [`google/brax`][brax]. Here below we summarize the differences between the projects: -- JAXsim simulates multibody dynamics in reduced coordinates, while `brax` uses maximal coordinates. -- The rigid body algorithms used in JAXsim allow to efficiently compute quantities based on the Euler-Poincarè +- JAXsim simulates multibody dynamics in reduced coordinates, while brax v1 uses maximal coordinates. +- The new v2 APIs of brax (and the new [MJX][mjx]) were then implemented in reduced coordinates, following an approach comparable to JAXsim, with major differences in contact handling. +- The rigid-body algorithms used in JAXsim allow to efficiently compute quantities based on the Euler-Poincarè formulation of the equations of motion, necessary for model-based robotics research. -- JAXsim supports SDF (and, indirectly, URDF) models, under the assumption that the model is described with the +- JAXsim supports SDF (and, indirectly, URDF) models, assuming the model is described with the recent [Pose Frame Semantics][PFS]. -- Contrarily to `brax`, JAXsim only supports collision detection between bodies and a compliant ground surface. -- While supported thanks to the usage of JAX, differentiating through the simulator has not yet been studied. +- Contrarily to brax, JAXsim only supports collision detection between bodies and a compliant ground surface. +- The RBDAs of JAXsim support automatic differentiation, but this functionality has not being thoroughly tested. [brax]: https://github.com/google/brax +[mjx]: https://mujoco.readthedocs.io/en/3.0.0/mjx.html [PFS]: http://sdformat.org/tutorials?tut=pose_frame_semantics ## Contributing @@ -80,7 +82,7 @@ For major changes, please open an issue first to discuss what you would like to ```bibtex @software{ferigo_jaxsim_2022, author = {Diego Ferigo and Silvio Traversaro and Daniele Pucci}, - title = {{JAXsim}: A Physics Engine in Reduced Coordinates for Control and Robot Learning}, + title = {{JAXsim}: A Physics Engine in Reduced Coordinates and Multibody Dynamics Library for Control and Robot Learning}, url = {http://github.com/ami-iit/jaxsin}, year = {2022}, } From ae0b9ae0cf1f8c73c13db5fb434e3260feba1c2f Mon Sep 17 00:00:00 2001 From: Diego Ferigo Date: Thu, 9 Nov 2023 10:28:28 +0100 Subject: [PATCH 05/11] Update README.md --- README.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index 1415d041f..a8ff58edd 100644 --- a/README.md +++ b/README.md @@ -43,7 +43,7 @@ Install all of them with `jaxsim[all]`. **Note:** For GPU support, follow the official [installation instruction][jax_gpu] of JAX. [pip]: https://github.com/pypa/pip/ -[venv]: https://docs.python.org/3.8/tutorial/venv.html +[venv]: https://docs.python.org/3/tutorial/venv.html [jax_gpu]: https://github.com/google/jax/#installation ## Credits @@ -83,7 +83,7 @@ For major changes, please open an issue first to discuss what you would like to @software{ferigo_jaxsim_2022, author = {Diego Ferigo and Silvio Traversaro and Daniele Pucci}, title = {{JAXsim}: A Physics Engine in Reduced Coordinates and Multibody Dynamics Library for Control and Robot Learning}, - url = {http://github.com/ami-iit/jaxsin}, + url = {http://github.com/ami-iit/jaxsim}, year = {2022}, } ``` From 1403014ed4d6a0f0129507aa9187d488a09c1f5d Mon Sep 17 00:00:00 2001 From: Diego Ferigo Date: Thu, 9 Nov 2023 10:40:23 +0100 Subject: [PATCH 06/11] Fix CI/CD --- setup.cfg | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/setup.cfg b/setup.cfg index 398189d00..76f9511d0 100644 --- a/setup.cfg +++ b/setup.cfg @@ -53,7 +53,7 @@ package_dir = python_requires = >=3.10 install_requires = coloredlogs - jax >= 0.4.1 + jax >= 0.4.1, <0.4.11 jaxlib jaxlie jax_dataclasses >= 1.4.0 From 3b995fbb6b1829190f4f68b2db758ee0ea2551a0 Mon Sep 17 00:00:00 2001 From: Diego Ferigo Date: Thu, 9 Nov 2023 10:45:33 +0100 Subject: [PATCH 07/11] Pin also jaxlib --- setup.cfg | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/setup.cfg b/setup.cfg index 76f9511d0..9023b2659 100644 --- a/setup.cfg +++ b/setup.cfg @@ -54,7 +54,7 @@ python_requires = >=3.10 install_requires = coloredlogs jax >= 0.4.1, <0.4.11 - jaxlib + jaxlib < 0.4.11 jaxlie jax_dataclasses >= 1.4.0 pptree From 0dba223c16abf00627f03b2561aac47fa7ba49d0 Mon Sep 17 00:00:00 2001 From: Diego Ferigo Date: Thu, 9 Nov 2023 10:51:01 +0100 Subject: [PATCH 08/11] Pin ml-dytypes The pin of this transitional dependency can be removed as soon as we support newer jax versions --- setup.cfg | 1 + 1 file changed, 1 insertion(+) diff --git a/setup.cfg b/setup.cfg index 9023b2659..d3aa7c991 100644 --- a/setup.cfg +++ b/setup.cfg @@ -57,6 +57,7 @@ install_requires = jaxlib < 0.4.11 jaxlie jax_dataclasses >= 1.4.0 + ml-dtypes < 0.3.0 pptree rod scipy From 458eeac04125475aba576d8df88d21fd0e1686fe Mon Sep 17 00:00:00 2001 From: diegoferigo Date: Wed, 6 Dec 2023 05:12:27 +0100 Subject: [PATCH 09/11] Remove pinnings necessary on old jax versions --- setup.cfg | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/setup.cfg b/setup.cfg index ef0fa8207..b6eed46e7 100644 --- a/setup.cfg +++ b/setup.cfg @@ -53,11 +53,10 @@ package_dir = python_requires = >=3.10 install_requires = coloredlogs - jax >= 0.4.1, <0.4.11 - jaxlib < 0.4.11 + jax >= 0.4.1 + jaxlib jaxlie jax_dataclasses >= 1.4.0 - ml-dtypes < 0.3.0 pptree rod typing_extensions; python_version < "3.11" From 50aea065d4ed3825b8022b7d622d7f7762270e2a Mon Sep 17 00:00:00 2001 From: diegoferigo Date: Wed, 6 Dec 2023 05:18:25 +0100 Subject: [PATCH 10/11] Update bool typing --- src/jaxsim/typing.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/jaxsim/typing.py b/src/jaxsim/typing.py index 82355cf33..94b9508be 100644 --- a/src/jaxsim/typing.py +++ b/src/jaxsim/typing.py @@ -1,7 +1,6 @@ from typing import Any, Dict, Hashable, List, NamedTuple, Tuple, Union import jax.numpy as jnp -import numpy as np import numpy.typing as npt # JAX types @@ -35,6 +34,6 @@ Tensor = Union[npt.NDArray, ArrayJax] Vector = Array Matrix = Array -Bool = bool +Bool = Union[bool, ArrayJax] Int = Union[int, IntJax] Float = Union[float, FloatJax] From b18b930f0e2fc13a128a558e91c03c5234d43950 Mon Sep 17 00:00:00 2001 From: diegoferigo Date: Wed, 6 Dec 2023 05:19:22 +0100 Subject: [PATCH 11/11] Update typing of random key --- src/jaxsim/high_level/model.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/jaxsim/high_level/model.py b/src/jaxsim/high_level/model.py index b7131da2d..db64e4294 100644 --- a/src/jaxsim/high_level/model.py +++ b/src/jaxsim/high_level/model.py @@ -491,7 +491,7 @@ def joint_positions(self, joint_names: tuple[str, ...] = None) -> jtp.Vector: @functools.partial(oop.jax_tf.method_ro, static_argnames=["joint_names"]) def joint_random_positions( - self, joint_names: tuple[str, ...] = None, key: jax.random.PRNGKeyArray = None + self, joint_names: tuple[str, ...] = None, key: jax.Array = None ) -> jtp.Vector: """"""