-
Notifications
You must be signed in to change notification settings - Fork 9
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Performing a simulation enforcing JAX to run in 32-bit precision overflows after 4.29 seconds #136
Comments
xref ami-iit/bipedal-locomotion-framework#630. Note that JaxSim uses nanoseconds since the beginning (#1). The problem here is not handling time as nanoseconds, but handling time as nanoseconds stored in 32-bit types. |
The Python Standard Library has a |
A MWE of a
|
I was thinking of using something that leverages @jax_dataclasses.pytree_dataclass
@total_ordering
class Time:
nanoseconds: jnp.int64
def __init__(self, nanoseconds):
self.nanoseconds = jnp.int64(nanoseconds)
def __add__(self, other):
return Time(self.nanoseconds + other.nanoseconds)
def __sub__(self, other):
return Time(self.nanoseconds - other.nanoseconds)
def __mod__(self, other):
return Time(self.nanoseconds % other.nanoseconds)
def __eq__(self, other):
return self.nanoseconds == other.nanoseconds
def __lt__(self, other):
return self.nanoseconds < other.nanoseconds
def advance(self, nanoseconds):
return Time(self.nanoseconds + jnp.int64(nanoseconds))
def __repr__(self):
return f"Time({self.nanoseconds} ns)" |
This is a good idea, I'll have a look at However, currently the main difficulty is in handling properly the overflow when JAX runs in 32-bits precision. I have a kind-of-working local implementation that I started probably 2 months ago in which seconds and nanoseconds are treated separately. Then the priorities shifted towards contact-related improvements, I'll have a look at this as soon as I find some spare time. Not sure if we have use cases on TPUs in the close future. |
In these days I realized that any option that replaces a single float with a class to handle the simulation time would introduce major challenges in applications that need to compute gradients with AD against |
JaxSim, by default, enforces JAX to run in 64-bit precision. This is necessary to execute our RBDAs and compare the results to third-party C++ libraries.
As explained in #109, executing JaxSim with 32-bit precision might result to a significant speedup, especially on hardware accelerators like GPUs. Furthermore, TPUs only support 32-bits (google/jax#9862), therefore we need to ensure that JaxSim runs with this precision if we want to exploit this hardware.
JaxSim is almost 32-bit compatible, the only caveats originate on how we handle the simulation time:
jaxsim/src/jaxsim/api/data.py
Lines 42 to 44 in 10c7199
When running in 32-bit precision,
time_ns
is downcasted asjnp.uint32
. Unfortunately, the range of this data type is problematic when it describes nanoseconds. In fact:The maximum time we can simulate without overflowing is approximately 4.29 seconds. When running a simulation longer than that, the simulator hangs, requiring to press Ctrl+C.
Something else to notice is that a time is seconds represented as a 32-bit float will have the following limits:
Therefore, in this case we cannot achieve a perfect granularity up to 1 nanosecond (it's pointless going under that for rigid body simulations, already 10-100 ns is small enough).
We should investigate alternative representations of the simulation time.
As a consequence, all functions accepting the current time should be ported to this new representation, that will be something more complex than a plain
jax.Array
:data.JaxSimModelData.time
.ode.wrap_system_dynamics_for_integration
.integrators.common.Integrator.step|init|_call__
.integrators.variable_step.EmbeddedRungeKutta.__call__
. The adaptivemodel.step
The text was updated successfully, but these errors were encountered: