Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Introduce resources to streamline the combination of jax transformation with OOP pattern #44

Merged
merged 13 commits into from
Oct 9, 2023

Conversation

diegoferigo
Copy link
Member

@diegoferigo diegoferigo commented Jul 27, 2023

JAX endorses a functional programming pattern in which algorithms are executed stateless over some data structure that is passed as input and returned (possibly updated) as output. This approach has many benefits, but often a OOP might result more user friendly.

We already adopted such OOP pattern by using jax_dataclasses to create pytrees. These pytrees can have dataclass attributes (fields) containing static/dynamic data, and algorithms can be implemented as dataclass methods. If these methods were staticmethods, the functional pattern would be preserved, however it'd be users' responsibility to correctly propagating the state. If the methods, instead, are not static, they might update the data that, if not done correctly, could trigger jit recompilations and other problems like trace leaks (#43).

We already apply such OOP pattern to our classes and so far it has worked decently well, but recent jax versions became more picky about trace leaks when this pattern is used in conjunction with jit compilation.
The main problem seems that jax thinks we have trace leaks because instead of having dataclass methods returning a tuple (output, state), they just return output and update state directly from the method. This breaks the functional pattern that jax expects.

Furthermore, applying algorithms of parallel objects was not straightforward since it required creating lambdas/closures and pass them to jax.vmap. This is quite confusing, particularly for new users.


This PR tries to solve these limitations. The idea is to keep using OOP, but handle transparently some internal jax details to simplify developers' and users' life. In particular, the following is the desiderata:

  • Streamline the combination of OOP with jax.jit (addressing also leaked traces).
  • Promote the usage of jax_dataclasses to create custom PyTrees.
  • Allow pytrees to have:
    • static attributes (that trigger jit recompilation);
    • normal attributes (considered as leafs);
  • Enforce mutable/frozen property with optional checks for preserving the pytree structure (minimizing the occurrence of jit recompilations in case type/shape/weakness of data changes without notice).
  • Introduce resources to parallelize the dynamic attributes of pytrees by automatically adding a new batch dimension.
  • Class methods are jit-compiled as they would be static methods (self is a pytree so it's ok), therefore different objects can re-use the first jit-compiled method.
  • Class methods are automatically paralellized with jax.vmap if they have been parallelized (they have all fields with the batch dimension as first axis).
  • Algorithms can be developed with OOP as dataclass methods, and they can be marked as either read-only or read-write in case they alter the pytree attributes.

There are few caveat to consider when such approach is implemented. In this PR, I only introduce the tooling and a test to achieve these goals. I'll update all the jaxsim modules in a new PR.

@diegoferigo diegoferigo self-assigned this Jul 27, 2023
@diegoferigo diegoferigo linked an issue Jul 27, 2023 that may be closed by this pull request
@diegoferigo diegoferigo force-pushed the fix/trace_leak branch 4 times, most recently from 12b4bbb to 5ee6189 Compare August 3, 2023 22:13
@diegoferigo diegoferigo changed the title Fix trace leak by introducing a new pattern for jit-compiling class methods Introduce resources to streamline the combination of jax transformation with OOP pattern Aug 4, 2023
@diegoferigo diegoferigo marked this pull request as ready for review August 4, 2023 12:22
@diegoferigo diegoferigo changed the base branch from main to new_api October 9, 2023 14:10
@diegoferigo
Copy link
Member Author

Given that the OOP decorators alter considerably downstream code and we're not yet sure they are 100% compatible with our long-term goals, I'll proceed with caution by merging this and next PRs into the new_api feature branch.

If everything keeps looking good, the new APIs will become the default ones.

@diegoferigo diegoferigo merged commit cf3575c into new_api Oct 9, 2023
16 checks passed
@diegoferigo diegoferigo deleted the fix/trace_leak branch October 9, 2023 14:13
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

3 participants