You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
So far, in JaxSim we always aimed to obtain results that could be comparable with those computed by third-party multibody dynamic libraries like robotology/idyntree. For this reason, we enforce jax to run using 64 bit precision.
logging.warning("Failed to enable 64bit precision in JAX")
This is necessary in order to compute quantities with enough accuracy that enables a comparison with small tolerances against third-party libraries (check #106 for more details).
However, especially on GPUs, using 64 bit precision may lead to excessive memory usage and/or could significantly decrease runtime performance. Therefore, we should:
Assess if simulating with 32 bit precision maintains the same stability of using 64 bit.
While maintaining 64 bit precision as default choice, allow downstream projects to disable the 64 bit enforcement.
The text was updated successfully, but these errors were encountered:
While maintaining 64 bit precision as default choice, allow downstream projects to disable the 64 bit enforcement.
I just realized that with the current logic, 32 bit precision can be already enforced downstream by exporting JAX_ENABLE_X64=0. If this environment variable is set to 0, JaxSim does not configure JAX to run with 64 bit precision when it is first imported in an application.
Assess if simulating with 32 bit precision maintains the same stability of using 64 bit.
As expected, the default pytest tolerances are too strict when comparing quantities computed with iDynTree against the equivalent computed with JaxSim in 32 bit. Our test suite fails in this case. This doesn't mean that running code with 32 bit precision is bugged, it only means that -as expected- we lose accuracy. There are many applications, especially those highly parallel running on GPUs and TPUs, that would value more a much lower memory usage over the fidelity of the simulation (that already accounts of many approximations of real-world physics).
So far, in JaxSim we always aimed to obtain results that could be comparable with those computed by third-party multibody dynamic libraries like robotology/idyntree. For this reason, we enforce jax to run using 64 bit precision.
jaxsim/src/jaxsim/__init__.py
Lines 6 to 20 in 4fd2032
This is necessary in order to compute quantities with enough accuracy that enables a comparison with small tolerances against third-party libraries (check #106 for more details).
However, especially on GPUs, using 64 bit precision may lead to excessive memory usage and/or could significantly decrease runtime performance. Therefore, we should:
The text was updated successfully, but these errors were encountered: