New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
do the same in jax as in jax-jit for getting results in execute_fwd #4190
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks, it looks good to me 👍
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Also, still need to fix the
TypeError: Custom JVP rule must produce primal and tangent outputs with equal shapes and dtypes, but got:
primal float32[] for tangent float32[1,0]
that occurs whenever we take a derivative.
Codecov Report
@@ Coverage Diff @@
## master #4190 +/- ##
==========================================
- Coverage 99.77% 99.77% -0.01%
==========================================
Files 342 342
Lines 30704 30696 -8
==========================================
- Hits 30634 30626 -8
Misses 70 70
|
Co-authored-by: Christina Lee <christina@xanadu.ai>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for this quick fix.
We can keep trying to track down the a fix for the differentiation, but that can be a later fix.
Context:
The jax code was wrongly iterating over
[[res1, res2, ...], jacs]
instead of[res1, res2, ...]
but it slipped by undetected because you'd need 3 or more tapes to be executed to hit an IndexError.Description of the Change:
Remove the unnecessary tracing stuff and match what jax-jit does.
Benefits:
Code now works as expected
Possible Drawbacks:
It's a bit hacky, but it gets the job done.