Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow vector-valued QNodes with JAX using host_callback.call #2034

Merged
merged 174 commits into from
Jun 10, 2022

Conversation

antalszava
Copy link
Contributor

@antalszava antalszava commented Dec 15, 2021

Context:

Vector-valued QNodes may include:

  1. Multiple scalar return types: e.g., return qml.expval(qml.PauliZ(0), qml.expval(qml.PauliZ(1);
  2. QNodes with return types qml.probs, qml.state or qml.density_matrix.

The JAX interface doesn't support these return types. The main reason is that host_callback.call, the underlying function that is being used requires the output shape to be passed. The current logic always considers tapes with scalar outputs.

Uses the machinery introduced in #2044.

Description of the Change:

  • Updates the MeasurementProcess class such that shape and numeric_type are methods because there were uncovered edge cases of not having pre-set shapes/numeric types (e.g., self._shape was None);
  • Changes the shape that is being passed to host_callback.call.

Benefits:

Vector-valued QNodes can be evaluated using the JAX JIT interface that uses host_callback.call; host_callback.call is jittable.

Possible Drawbacks:

The JAX JIT interface doesn't support jax.jacobian (see discussion in #2163).

Related GitHub Issues:
Closes #1208, #2404


Testing:

Testing categories:

  • Shape
  • Correctness (comparison with the expected result)

Multiple scalar outputs:

  • qml.expval, qml.var

Single vector valued outputs:

  • qml.probs
  • qml.state
  • qml.density_matrix

Mixed vector and scalar-valued outputs:

  • Scalar valued gradient functions using vector-valued QNodes
  • Explicitly consider multi-tape cases
  • Check what's wrong with the xfail test case (independent param)
  • Double-check if output dimensions are the same as with other interfaces

@codecov
Copy link

codecov bot commented Dec 15, 2021

Codecov Report

Merging #2034 (cb485c3) into master (1b788a6) will increase coverage by 0.00%.
The diff coverage is 100.00%.

@@           Coverage Diff           @@
##           master    #2034   +/-   ##
=======================================
  Coverage   99.61%   99.61%           
=======================================
  Files         251      251           
  Lines       20553    20569   +16     
=======================================
+ Hits        20473    20489   +16     
  Misses         80       80           
Impacted Files Coverage Δ
pennylane/interfaces/autograd.py 100.00% <ø> (ø)
pennylane/interfaces/jax.py 100.00% <100.00%> (ø)
pennylane/interfaces/jax_jit.py 100.00% <100.00%> (ø)
pennylane/measurements.py 99.62% <100.00%> (+0.01%) ⬆️
pennylane/tape/tape.py 98.89% <100.00%> (ø)

Continue to review full report at Codecov.

Legend - Click here to learn more
Δ = absolute <relative> (impact), ø = not affected, ? = missing data
Powered by Codecov. Last update 1b788a6...cb485c3. Read the comment docs.

@antalszava antalszava marked this pull request as ready for review June 2, 2022 22:34
Copy link
Member

@maliasadi maliasadi left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great job @antalszava! 😍 It's awesome to see the JAX support for vector-value QNodes. I only have a few minor suggestions; but I am happy to approve.

doc/releases/changelog-dev.md Outdated Show resolved Hide resolved
pennylane/interfaces/jax_jit.py Outdated Show resolved Hide resolved
pennylane/interfaces/jax_jit.py Outdated Show resolved Hide resolved
pennylane/interfaces/jax_jit.py Show resolved Hide resolved
pennylane/measurements.py Outdated Show resolved Hide resolved
tests/interfaces/test_jax.py Outdated Show resolved Hide resolved
tests/interfaces/test_jax.py Show resolved Hide resolved
tests/interfaces/test_jax_qnode.py Show resolved Hide resolved
Copy link
Contributor

@rmoyard rmoyard left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @antalszava ! Thanks for this awesome work 💯 No major blockers, but I have some questions. I think you can also update the documentation (chart with configurations) :)

doc/introduction/interfaces/jax.rst Show resolved Hide resolved
pennylane/measurements.py Outdated Show resolved Hide resolved
pennylane/measurements.py Show resolved Hide resolved
tests/interfaces/test_jax.py Show resolved Hide resolved
pennylane/interfaces/jax_jit.py Outdated Show resolved Hide resolved
@antalszava antalszava requested a review from rmoyard June 9, 2022 14:29
@antalszava
Copy link
Contributor Author

Thank you @maliasadi, @rmoyard for the reviews! The comments should be addressed now.

(One thing todo for me will be to double-check an edge case of multiple tapes with multiple parameters for forward mode.)

Copy link
Contributor

@rmoyard rmoyard left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Awesome work @antalszava 💯 I've left a comment about the problem you mentioned, we can always come back to it 👍 Thank you for changing VnEntropy and MutualInformation measurements process!

pennylane/interfaces/jax_jit.py Show resolved Hide resolved
@antalszava antalszava merged commit 29557dd into master Jun 10, 2022
@antalszava antalszava deleted the qnode_vector_val_jax branch June 10, 2022 19:39
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Allow the JAX interface to work with additional return types
4 participants