Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow downstream projects to disable the enforcement of 64 bit precision #109

Closed
diegoferigo opened this issue Mar 13, 2024 · 1 comment
Closed

Comments

@diegoferigo
Copy link
Member

So far, in JaxSim we always aimed to obtain results that could be comparable with those computed by third-party multibody dynamic libraries like robotology/idyntree. For this reason, we enforce jax to run using 64 bit precision.

def _jnp_options() -> None:
import os
import jax
# Enable by default
if not ("JAX_ENABLE_X64" in os.environ and os.environ["JAX_ENABLE_X64"] == "0"):
logging.info("Enabling JAX to use 64bit precision")
jax.config.update("jax_enable_x64", True)
import jax.numpy as jnp
import numpy as np
if jnp.empty(0, dtype=float).dtype != jnp.empty(0, dtype=np.float64).dtype:
logging.warning("Failed to enable 64bit precision in JAX")

This is necessary in order to compute quantities with enough accuracy that enables a comparison with small tolerances against third-party libraries (check #106 for more details).

However, especially on GPUs, using 64 bit precision may lead to excessive memory usage and/or could significantly decrease runtime performance. Therefore, we should:

  • Assess if simulating with 32 bit precision maintains the same stability of using 64 bit.
  • While maintaining 64 bit precision as default choice, allow downstream projects to disable the 64 bit enforcement.
@diegoferigo
Copy link
Member Author

While maintaining 64 bit precision as default choice, allow downstream projects to disable the 64 bit enforcement.

I just realized that with the current logic, 32 bit precision can be already enforced downstream by exporting JAX_ENABLE_X64=0. If this environment variable is set to 0, JaxSim does not configure JAX to run with 64 bit precision when it is first imported in an application.

Assess if simulating with 32 bit precision maintains the same stability of using 64 bit.

As expected, the default pytest tolerances are too strict when comparing quantities computed with iDynTree against the equivalent computed with JaxSim in 32 bit. Our test suite fails in this case. This doesn't mean that running code with 32 bit precision is bugged, it only means that -as expected- we lose accuracy. There are many applications, especially those highly parallel running on GPUs and TPUs, that would value more a much lower memory usage over the fidelity of the simulation (that already accounts of many approximations of real-world physics).


This being said, I think we can close this issue.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant