Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.

Already on GitHub? Sign in to your account

Jax jit qnode integration with DefaultQubit2 #4352

Merged
merged 14 commits into from Jul 17, 2023
Merged

Conversation

timmysilv
Copy link
Contributor

@timmysilv timmysilv commented Jul 11, 2023

Context:
All interfaces should work with the new device, and that includes jax-jit!

Description of the Change:
Add jax-jit QNode support. This involved a few little things:

  1. a certain simulation was missing the is_state_batched arg when calling measure, so I fixed that
  2. jax-jit expected all result types to match, so I cast a bunch of things to tuples (because those are better than lists 馃挭)
  3. The shape_dtype thing is a weird thing that I can't explain easily in words, but I will give the low-down. The helper function that generates shape objects says "if only 1 measurement and only 1 tape, return shapes[0]", but the function that interpreted it assumed that it was always a list of shape objects. This was only found because a) the new devices does some pre-processing for adjoint differentiation (including setting tape.trainable_parameters) but wrongly - see [BUG] validate_and_expand_adjoint does not handle new parameters correctly聽#4351.
  4. Change QuantumScript.shape to support the new device interface. Basically, device won't provide anything we need, but self (a tape) will! It shouldn't be a dangerous change because jax-jit is the only place that uses QuantumScript.shape afaik (pls correct me if I'm wrong)

Benefits:
jax-jit works with QNodes and DefaultQubit2

Possible Drawbacks:
Changes 2-4 from above all seem to suggest some cleanup could be necessary for various parts of our code (tape.shape, jax_jit_tuple.py in general)

[sc-41087]

@codecov
Copy link

codecov bot commented Jul 11, 2023

Codecov Report

Merging #4352 (bd81b3a) into master (e7f4b82) will increase coverage by 0.00%.
The diff coverage is 100.00%.

@@           Coverage Diff           @@
##           master    #4352   +/-   ##
=======================================
  Coverage   99.77%   99.77%           
=======================================
  Files         351      351           
  Lines       32143    32150    +7     
=======================================
+ Hits        32070    32077    +7     
  Misses         73       73           
Impacted Files Coverage 螖
pennylane/devices/qubit/simulate.py 100.00% <100.00%> (酶)
pennylane/gradients/jvp.py 100.00% <100.00%> (酶)
pennylane/interfaces/jax.py 99.51% <100.00%> (酶)
pennylane/interfaces/jax_jit_tuple.py 100.00% <100.00%> (酶)
pennylane/tape/qscript.py 99.04% <100.00%> (+<0.01%) 猬嗭笍

pennylane/devices/qubit/simulate.py Show resolved Hide resolved
pennylane/tape/qscript.py Outdated Show resolved Hide resolved
Copy link
Collaborator

@eddddddy eddddddy left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good 馃挴

Copy link
Contributor

@albi3ro albi3ro left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks great ! 馃憤

@timmysilv timmysilv enabled auto-merge (squash) July 17, 2023 17:56
@timmysilv timmysilv merged commit 1a4d720 into master Jul 17, 2023
43 checks passed
@timmysilv timmysilv deleted the jax-jit-qnode-dq2 branch July 17, 2023 18:22
mudit2812 pushed a commit that referenced this pull request Jul 18, 2023
* copy-paste jax-jit file

* replace old dev with new one

* make tests pass for jax-jit

* changelog; fix copyright year

* black

* Update doc/releases/changelog-dev.md

* convert to tuple everywhere in jax

* fix pesky test

* also pass is_state_batched to finite shot case

* add more detailed comment to tape.shape change
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

3 participants