Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix passing different PyTrees to JIT-compiled functions #165

Merged
merged 10 commits into from
Jun 4, 2024

Conversation

diegoferigo
Copy link
Member

@diegoferigo diegoferigo commented Jun 3, 2024

This PR should finally solve #103 also for models with joints, in particular:

  • Properly implements the __eq__ magic method of all classes that are used in JIT compiled functions.
  • Introduces a new class jaxsim.utils.wrappers.HashedNumpyArray that can wrap NumPy and JAX NumPy array to provide them a hash and equality operators. This is necessary for all jax_dataclasses.Static fields storing arrays.
  • Updates the pattern to hide the objects wrapped with HashedNumpyArray and HashlessObject exposing only the underlying object through a class property.
  • Fixed JointModel.joint_axis attribute to properly contain JointGenericAxis objects. Before this PR (from #153 I guess), this attribute was typed to hold a JointGenericAxis but instead it stored plain JAX numpy arrays with the 3D vector of the axis.

馃摎 Documentation preview 馃摎: https://jaxsim--165.org.readthedocs.build//165/

@diegoferigo diegoferigo force-pushed the fix_static_pytree_attributes branch from fd6cde5 to 6420495 Compare June 3, 2024 16:30
@diegoferigo diegoferigo changed the title Fix passing different PyTree to JIT-compiled functions Fix passing different PyTrees to JIT-compiled functions Jun 3, 2024
@diegoferigo
Copy link
Member Author

diegoferigo commented Jun 3, 2024

This PR also fixes the M(N)WE provided by @flferretti in #103 (comment) (hopefully, without side effects).

@diegoferigo diegoferigo self-assigned this Jun 3, 2024
@diegoferigo diegoferigo marked this pull request as ready for review June 4, 2024 07:09
Copy link
Collaborator

@flferretti flferretti left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks Diego, LGTM!

src/jaxsim/api/data.py Show resolved Hide resolved
src/jaxsim/api/data.py Show resolved Hide resolved
src/jaxsim/api/model.py Outdated Show resolved Hide resolved
src/jaxsim/api/ode_data.py Show resolved Hide resolved
src/jaxsim/rbda/soft_contacts.py Show resolved Hide resolved
tests/conftest.py Show resolved Hide resolved
Co-authored-by: Filippo Luca Ferretti <102977828+flferretti@users.noreply.github.com>
@diegoferigo diegoferigo merged commit 1594f4a into main Jun 4, 2024
29 checks passed
@diegoferigo diegoferigo deleted the fix_static_pytree_attributes branch June 4, 2024 09:59
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Error when passing two different pytrees with the same structure to a JIT-compiled function
2 participants