Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix jax deprecations #1346

Merged
merged 18 commits into from
Apr 6, 2024
Merged

fix jax deprecations #1346

merged 18 commits into from
Apr 6, 2024

Conversation

FFroehlich
Copy link
Contributor

@FFroehlich FFroehlich commented Mar 28, 2024

  • Replaces experimental host_callback calls in jax, which are now deprecated, with pure_callback calls.
  • Replaces caching with appropriate calls to the inner objective as caching results in technically non-pure. (Caching the way it was implemented might work anyways, but that's a bit iffy). This required removal of second order code, but I am doubtful that part was functional anyways.
  • Removes superfluous calls to jax.jit in Jax.Objective.
  • Added compatibility with jax.vmap. This required quite a bit of refactoring as the base objective assumes inputs/outputs to be numpy arrays, which is incompatible with jax batch tracers.
  • Refactored tests to actually test interoperability with jax.grad and jax.vmap.

@codecov-commenter
Copy link

codecov-commenter commented Mar 28, 2024

Codecov Report

Attention: Patch coverage is 75.86207% with 7 lines in your changes are missing coverage. Please review.

Project coverage is 84.43%. Comparing base (38d91c4) to head (67ec64f).

Files Patch % Lines
pypesto/objective/jax/base.py 75.86% 7 Missing ⚠️

❗ Your organization needs to install the Codecov GitHub app to enable full functionality.

Additional details and impacted files
@@             Coverage Diff             @@
##           develop    #1346      +/-   ##
===========================================
- Coverage    84.50%   84.43%   -0.07%     
===========================================
  Files          157      157              
  Lines        12906    12881      -25     
===========================================
- Hits         10906    10876      -30     
- Misses        2000     2005       +5     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

Copy link
Collaborator

@PaulJonasJost PaulJonasJost left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, not clear on the very details of the implementation though, but the tests checks out 👌🏼 thanks for the update

pypesto/objective/jax/base.py Outdated Show resolved Hide resolved
pypesto/objective/jax/base.py Outdated Show resolved Hide resolved
test/base/test_objective.py Outdated Show resolved Hide resolved
FFroehlich and others added 3 commits March 29, 2024 11:43
Co-authored-by: Paul Jonas Jost <70631928+PaulJonasJost@users.noreply.github.com>
@FFroehlich
Copy link
Contributor Author

Now no longer requires passing of input function. Extended tests to demonstrate jax transformations of inputs and outputs.

@FFroehlich FFroehlich merged commit 3163687 into develop Apr 6, 2024
18 checks passed
@FFroehlich FFroehlich deleted the fix_jax_callback branch April 6, 2024 15:45
This was referenced Apr 9, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

4 participants