Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Performing a simulation enforcing JAX to run in 32-bit precision overflows after 4.29 seconds #136

Open
diegoferigo opened this issue Apr 8, 2024 · 6 comments
Assignees

Comments

@diegoferigo
Copy link
Member

diegoferigo commented Apr 8, 2024

JaxSim, by default, enforces JAX to run in 64-bit precision. This is necessary to execute our RBDAs and compare the results to third-party C++ libraries.

As explained in #109, executing JaxSim with 32-bit precision might result to a significant speedup, especially on hardware accelerators like GPUs. Furthermore, TPUs only support 32-bits (google/jax#9862), therefore we need to ensure that JaxSim runs with this precision if we want to exploit this hardware.

JaxSim is almost 32-bit compatible, the only caveats originate on how we handle the simulation time:

time_ns: jtp.Int = dataclasses.field(
default_factory=lambda: jnp.array(0, dtype=jnp.uint64)
)

When running in 32-bit precision, time_ns is downcasted as jnp.uint32. Unfortunately, the range of this data type is problematic when it describes nanoseconds. In fact:

In [1]: jnp.iinfo(jnp.uint32)
Out[1]: iinfo(min=0, max=4294967295, dtype=uint32)

The maximum time we can simulate without overflowing is approximately 4.29 seconds. When running a simulation longer than that, the simulator hangs, requiring to press Ctrl+C.

Something else to notice is that a time is seconds represented as a 32-bit float will have the following limits:

In [2]: jnp.finfo(jnp.float32)
Out[2]: finfo(resolution=1e-06, min=-3.4028235e+38, max=3.4028235e+38, dtype=float32)

Therefore, in this case we cannot achieve a perfect granularity up to 1 nanosecond (it's pointless going under that for rigid body simulations, already 10-100 ns is small enough).


We should investigate alternative representations of the simulation time.

As a consequence, all functions accepting the current time should be ported to this new representation, that will be something more complex than a plain jax.Array:

@diegoferigo
Copy link
Member Author

xref ami-iit/bipedal-locomotion-framework#630. Note that JaxSim uses nanoseconds since the beginning (#1). The problem here is not handling time as nanoseconds, but handling time as nanoseconds stored in 32-bit types.

@diegoferigo diegoferigo self-assigned this Apr 8, 2024
@diegoferigo
Copy link
Member Author

The Python Standard Library has a decimal module that is pretty interesting for this kind of problems. Unfortunately, there is nothing comparable for either numpy of JAX. In order to be included in JaxSim and support jax.jit, we need to manipulate PyTree objects.

@diegoferigo
Copy link
Member Author

diegoferigo commented Apr 8, 2024

A MWE of a Time class should at least have the following features:

  • Implemented as a PyTree (therefore, supporting jax.jit).
  • Have a resolution of 1-10 nanoseconds, and at least 1 hour of simulation time.
  • Support the > < == operators.
  • Support the % operator.
  • Support to be advanced|stepped by a $\Delta t$ in a functional way.

@flferretti
Copy link
Collaborator

I was thinking of using something that leverages functools.total_ordering for the comparison operators. We could also use jax_dataclasses.pytree_dataclass to maintain immutability as follows:

@jax_dataclasses.pytree_dataclass
@total_ordering
class Time:
    nanoseconds: jnp.int64

    def __init__(self, nanoseconds):
        self.nanoseconds = jnp.int64(nanoseconds)

    def __add__(self, other):
        return Time(self.nanoseconds + other.nanoseconds)

    def __sub__(self, other):
        return Time(self.nanoseconds - other.nanoseconds)

    def __mod__(self, other):
        return Time(self.nanoseconds % other.nanoseconds)

    def __eq__(self, other):
        return self.nanoseconds == other.nanoseconds

    def __lt__(self, other):
        return self.nanoseconds < other.nanoseconds

    def advance(self, nanoseconds):
        return Time(self.nanoseconds + jnp.int64(nanoseconds))

    def __repr__(self):
        return f"Time({self.nanoseconds} ns)"

@diegoferigo
Copy link
Member Author

This is a good idea, I'll have a look at functools.total_ordering for the comparison operators.

However, currently the main difficulty is in handling properly the overflow when JAX runs in 32-bits precision. I have a kind-of-working local implementation that I started probably 2 months ago in which seconds and nanoseconds are treated separately. Then the priorities shifted towards contact-related improvements, I'll have a look at this as soon as I find some spare time. Not sure if we have use cases on TPUs in the close future.

@diegoferigo
Copy link
Member Author

In these days I realized that any option that replaces a single float with a class to handle the simulation time would introduce major challenges in applications that need to compute gradients with AD against $t$. I don't have yet a solution that makes me happy enough and does not introduce side-effects.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants