Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Consider measurements when resolving to jax-jit, fix indexing issues with DeviceArray in QubitDevice #2427

Merged
merged 15 commits into from
Apr 8, 2022

Conversation

antalszava
Copy link
Contributor

@antalszava antalszava commented Apr 7, 2022

Context:

import pennylane as qml
import jax

from pennylane import numpy as np

dev = qml.device('default.qubit.jax', wires=2, shots=10)
projector = np.zeros((2 ** 2, 2 ** 2))
projector[0][0] = 1

@jax.jit
@qml.qnode(dev, interface='jax')
def circ(projector):
    qml.PauliZ(wires=0)
    return qml.expval(qml.Hermitian(projector, wires=range(2)))

print(circ(projector))

Raises errors due to:

  1. Indexing into the samples generated by QubitDevice and stores under the hood with a DeviceArray instead of a NumPy array. We have a DeviceArray because of using the default.qubit.jax device.
  2. Once the 1. issue is solved, issues with concretization arise from JAX due to not using the jax-jit interface. This is an issue in the automatic change to jax-jit when interface="jax" is specified.

Description of the Change:

  1. Casts the indices used to a NumPy array;
  2. Changes the logic that determines the jax-jit interface to be used such that observables are considered too.

Benefits:

The original snippet executes without issues.

Possible Drawbacks:
N/A

Related GitHub Issues:
Closes #2406

@github-actions
Copy link
Contributor

github-actions bot commented Apr 7, 2022

Hello. You may have forgotten to update the changelog!
Please edit doc/releases/changelog-dev.md with:

  • A one-to-two sentence description of the change. You may include a small working example for new features.
  • A link back to this PR.
  • Your name (or GitHub username) in the contributors section.

@codecov
Copy link

codecov bot commented Apr 7, 2022

Codecov Report

Merging #2427 (26e9388) into master (5c7c77b) will increase coverage by 0.00%.
The diff coverage is 100.00%.

@@           Coverage Diff           @@
##           master    #2427   +/-   ##
=======================================
  Coverage   99.45%   99.45%           
=======================================
  Files         244      244           
  Lines       18976    18978    +2     
=======================================
+ Hits        18872    18874    +2     
  Misses        104      104           
Impacted Files Coverage Δ
pennylane/_qubit_device.py 98.80% <100.00%> (+<0.01%) ⬆️
pennylane/interfaces/jax.py 100.00% <100.00%> (ø)

Continue to review full report at Codecov.

Legend - Click here to learn more
Δ = absolute <relative> (impact), ø = not affected, ? = missing data
Powered by Codecov. Last update 5c7c77b...26e9388. Read the comment docs.

@antalszava antalszava changed the title Fix 2406 Consider measurements too when determining jax-jit, fix indexing issues with DeviceArray in QubitDevice Apr 8, 2022
@antalszava antalszava changed the title Consider measurements too when determining jax-jit, fix indexing issues with DeviceArray in QubitDevice Consider measurements when resolving to jax-jit, fix indexing issues with DeviceArray in QubitDevice Apr 8, 2022
@antalszava antalszava requested a review from rmoyard April 8, 2022 15:07
Copy link
Contributor

@rmoyard rmoyard left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you @antalszava nice solutions! It looks good to me 💯

@@ -928,6 +928,7 @@ def sample(self, observable, shot_range=None, bin_size=None):
] # Add np.array here for Jax support.
powers_of_two = 2 ** np.arange(samples.shape[-1])[::-1]
indices = samples @ powers_of_two
indices = np.array(indices) # Add np.array here for Jax support.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Very good catch @antalszava 🕵️‍♂️

@antalszava
Copy link
Contributor Author

[sc-17102]

@antalszava antalszava merged commit e8e6faa into master Apr 8, 2022
@antalszava antalszava deleted the fix_2406 branch April 8, 2022 16:21
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

[BUG] Error with JAX-JIT, finite shot, qml.expval(qml.Hermitian(matrix))
2 participants