Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

do the same in jax as in jax-jit for getting results in execute_fwd #4190

Merged
merged 5 commits into from May 31, 2023

Conversation

timmysilv
Copy link
Contributor

Context:
The jax code was wrongly iterating over [[res1, res2, ...], jacs] instead of [res1, res2, ...] but it slipped by undetected because you'd need 3 or more tapes to be executed to hit an IndexError.

Description of the Change:
Remove the unnecessary tracing stuff and match what jax-jit does.

Benefits:
Code now works as expected

Possible Drawbacks:
It's a bit hacky, but it gets the job done.

@timmysilv timmysilv changed the title do the same as jax-jit for getting results in execute_fwd do the same in jax as in jax-jit for getting results in execute_fwd May 30, 2023
@timmysilv timmysilv requested review from albi3ro and rmoyard May 30, 2023 20:16
Copy link
Contributor

@rmoyard rmoyard left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, it looks good to me 👍

Copy link
Contributor

@albi3ro albi3ro left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also, still need to fix the

TypeError: Custom JVP rule must produce primal and tangent outputs with equal shapes and dtypes, but got:
  primal float32[] for tangent float32[1,0]

that occurs whenever we take a derivative.

@codecov
Copy link

codecov bot commented May 30, 2023

Codecov Report

Merging #4190 (4bf86d5) into master (7aa9ce3) will decrease coverage by 0.01%.
The diff coverage is 100.00%.

@@            Coverage Diff             @@
##           master    #4190      +/-   ##
==========================================
- Coverage   99.77%   99.77%   -0.01%     
==========================================
  Files         342      342              
  Lines       30704    30696       -8     
==========================================
- Hits        30634    30626       -8     
  Misses         70       70              
Impacted Files Coverage Δ
pennylane/interfaces/jax.py 99.51% <100.00%> (-0.02%) ⬇️
pennylane/interfaces/jax_jit_tuple.py 100.00% <100.00%> (ø)

Copy link
Contributor

@albi3ro albi3ro left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for this quick fix.

We can keep trying to track down the a fix for the differentiation, but that can be a later fix.

@timmysilv timmysilv merged commit 48303c5 into master May 31, 2023
43 checks passed
@timmysilv timmysilv deleted the fix-jax-grad-on-exec branch May 31, 2023 14:14
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

3 participants