Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Apply new OOP pattern compatible with jax transformations #48

Merged
merged 27 commits into from
Oct 9, 2023

Conversation

diegoferigo
Copy link
Member

@diegoferigo diegoferigo commented Aug 4, 2023

This PR applies the resources of #44 on the jaxsim classes, solving #43 that is the original problem that triggered the development of this new approach.

Beyond fixing trace leaks, this PR:

  • Enforces pytrees to be frozen, but enables developing r/w methods that succeed only if the pytree structure is not altered (unless manually allowed).
  • Allows developers to catch early mistakes that would lead to cryptic jit recompilations.
  • Allows users to develop applications with traditional OOP, with a much simplified approach to apply jax.jit and jax.vmap.
  • Updates all returned types to be jax.numpy arrays, even when they are scalar quantities (in this case they are 0-dimensional arrays).
  • Removes the confusing usage of Mutability from the methods, now the mutable/frozen context is enforced in the new decorators.

The key resources of this pattern are new method decorators and a Vmappable inheritance (introduced in #44). There are, however, some caveats to consider:

  • Methods with arguments or returned object of types not supported by jax transformations (e.g. str) cannot be jit compiled nor vectorized.
  • Default arguments of vmappable methods cannot use the scalar value. They have to be initialized as None and the default value has to be configured inside the method.

@diegoferigo
Copy link
Member Author

I got back working on this PR. The new OOP decorator works already pretty good for r/o methods, however the r/w logic miserably fails.

Our OOP strategy for the high-level resources (Model, Link, Joints) consists of the following design:

  • The entire state is stored inside Model.data.
  • Links and joints objects are created and returned on the fly by jit-compiled methods.
  • Links and joint objects contain a reference to their parent model so that they can read -and alter- the state (i.e. data).

The main problem of this first version of the design is that a jit-compiled -decorated- r/w method of links and joints (like Link.add_external_force) applies a simple logic to copy the pytree's dynamic leafs from the object returned by the associated functional execution (obj) to the original OO instance.

When this logic runs, the obj._parent_model attribute has a different id than instance._parent_model, and the reference get lost. It seems to me that discusses similar problems and use cases.

I've been experimenting with few solutions similar to those proposed in google/jax#7919 and google/jax#17341, but the logic seems complicated to maintain and quite error prone.

In a context of data sharing (that could be state like in this case, but also more generic r/w parameters), probably the most simple solution is not to allow any r/w methods on child classes like Link and Joint. We have just few of them, that can be moved to Model for now. I'll follow with interest possible upstream development of ideas from other jaxsim users.

Note that we have a similar relationship also between JaxSim and Model, but in this case there's no parameter sharing. Therefore, having r/w methods in Model -that operate directly on its data attribute- should be ok.

@diegoferigo diegoferigo force-pushed the feature/oop_with_jax branch 2 times, most recently from 872d4e7 to 74bd6fa Compare October 9, 2023 07:53
Otherwise the pytree structure changes after the first applied jax transformation
Static arguments must be hashable, therefore lists cannot be passed
The link force can either override or be summed with previously set forces. The default behavior is to sum it.
@samskiter
Copy link

Happy to help adapting my class as shown in google/jax#7919 (comment)
I was able to convert an OOP code base with this and have it work in a metagradient descent scenario. Key is providing a flatten/unflatten function for every class but the class ReffableTreeNode does a lot of the complex work to make this super easy to implement...

@diegoferigo
Copy link
Member Author

Happy to help adapting my class as shown in google/jax#7919 (comment) I was able to convert an OOP code base with this and have it work in a metagradient descent scenario. Key is providing a flatten/unflatten function for every class but the class ReffableTreeNode does a lot of the complex work to make this super easy to implement...

Thank you @samskiter for chiming in! I'm definitely interested in your approach using ReffableTreeNode. For the moment, I worked around the limitation described in #48 (comment) by removing r/w methods of classes operating on data not owned by them. I found that r/o methods are, of course, working fine, but problems occur in case of jitted r/w methods.

Luckily, we only have two of such methods, and I moved them for now to the parent class (5a975d4).

I keep in mind to try playing around with your solution, it could be an excellent companion to our JaxsimDataclass and Vmappable. Is the project where you use it open source? If yes, I'd like to have a look to your use cases. If not (yet), I'll be following your projects.

@diegoferigo diegoferigo changed the base branch from main to new_api October 9, 2023 14:39
@diegoferigo diegoferigo marked this pull request as ready for review October 9, 2023 14:40
Copy link
Collaborator

@flferretti flferretti left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@samskiter
Copy link

Happy to help adapting my class as shown in google/jax#7919 (comment) I was able to convert an OOP code base with this and have it work in a metagradient descent scenario. Key is providing a flatten/unflatten function for every class but the class ReffableTreeNode does a lot of the complex work to make this super easy to implement...

Thank you @samskiter for chiming in! I'm definitely interested in your approach using ReffableTreeNode. For the moment, I worked around the limitation described in #48 (comment) by removing r/w methods of classes operating on data not owned by them. I found that r/o methods are, of course, working fine, but problems occur in case of jitted r/w methods.

Luckily, we only have two of such methods, and I moved them for now to the parent class (5a975d4).

I keep in mind to try playing around with your solution, it could be an excellent companion to our JaxsimDataclass and Vmappable. Is the project where you use it open source? If yes, I'd like to have a look to your use cases. If not (yet), I'll be following your projects.

Afraid not open source, but I'm happy for you to use ReffableTreeNode - the toy usage in my other comment still stands. Essentially this stores a UUID of every item in the tree structure. There are 4 methods each class must implement:

  • tree_flatten_adia
  • tree_unflatten_adia
  • get_refs
  • set_refs

You don't have to implement any recursion yourself - the super class will call get_refs to find other nodes.

The only change from the toy usage code is to store the UUIDs in the aux_data:

aux_date["gradient"] = self.gradient.get_uuid()
self.gradient = aux_data["gradient"]

@diegoferigo
Copy link
Member Author

Afraid not open source, but I'm happy for you to use ReffableTreeNode - the toy usage in my other comment still stands.

Yeah I didn't find anything in your repos and I assumed still being closed source. I opened #52 for collecting resources and attempts we might perform in the future. For now, thanks a lot for the details!

@diegoferigo diegoferigo merged commit 309582f into new_api Oct 9, 2023
16 checks passed
@diegoferigo diegoferigo deleted the feature/oop_with_jax branch October 9, 2023 16:15
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

3 participants