Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove norm check for jax.jit functions in QubitStateVector #1683

Merged
merged 18 commits into from Sep 28, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
8 changes: 6 additions & 2 deletions doc/releases/changelog-dev.md
Expand Up @@ -302,6 +302,9 @@

<h3>Bug fixes</h3>

* Fix a bug where it was not possible to use `jax.jit` on a `QNode` when using `QubitStateVector`.
[(#1683)](https://github.com/PennyLaneAI/pennylane/pull/1683)

* The device suite tests can now execute successfully if no shots configuration variable is given.
[(#1641)](https://github.com/PennyLaneAI/pennylane/pull/1641)

Expand All @@ -321,5 +324,6 @@

This release contains contributions from (in alphabetical order):

Utkarsh Azad, Olivia Di Matteo, Andrew Gardhouse, Josh Izaac, Christina Lee, Maria Schuld,
Ingrid Strandberg, Antal Száva, David Wierichs.

Utkarsh Azad, Olivia Di Matteo, Andrew Gardhouse, Josh Izaac, Christina Lee, Romain Moyard,
Maria Schuld, Ingrid Strandberg, Antal Száva, David Wierichs.
8 changes: 5 additions & 3 deletions pennylane/devices/default_qubit.py
Expand Up @@ -626,12 +626,14 @@ def _apply_state_vector(self, state, device_wires):
raise ValueError("State vector must be of length 2**wires.")

norm_error_message = "Sum of amplitudes-squared does not equal one."
if qml.math.get_interface(state) == "torch":
if qml.math.get_interface(state) != "jax":
if not qml.math.allclose(qml.math.linalg.norm(state, ord=2), 1.0, atol=tolerance):
raise ValueError(norm_error_message)
else:
if not np.allclose(np.linalg.norm(state, ord=2), 1.0, atol=tolerance):
raise ValueError(norm_error_message)
# Case for jax without jit, full_lower is an attribute for abstract tracers
if not hasattr(qml.math.linalg.norm(state, ord=2), "full_lower"):
if not qml.math.allclose(qml.math.linalg.norm(state, ord=2), 1.0, atol=tolerance):
raise ValueError(norm_error_message)
Comment on lines +633 to +636
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is much clearer now 👍


if len(device_wires) == self.num_wires and sorted(device_wires) == device_wires:
# Initialize the entire wires with the state
Expand Down