Skip to content

Commit

Permalink
Remove norm check for jax.jit functions in QubitStateVector (#1683)
Browse files Browse the repository at this point in the history
* remove newlines in docstring (#1647)

* Differentiate jit jax and jax

* Update

* Add tests.

* Black

* avoid import jax

* Update comments and tests.

* Add changelog.

* jnp to np

* jnp.array for jnp.allclose

* jnp.array

* Differentiate the case more explicitly.

Co-authored-by: Theodor <theodor@xanadu.ai>
  • Loading branch information
rmoyard and thisac committed Sep 28, 2021
1 parent 4df8dc5 commit 283030d
Show file tree
Hide file tree
Showing 3 changed files with 678 additions and 581 deletions.
8 changes: 6 additions & 2 deletions doc/releases/changelog-dev.md
Expand Up @@ -302,6 +302,9 @@

<h3>Bug fixes</h3>

* Fix a bug where it was not possible to use `jax.jit` on a `QNode` when using `QubitStateVector`.
[(#1683)](https://github.com/PennyLaneAI/pennylane/pull/1683)

* The device suite tests can now execute successfully if no shots configuration variable is given.
[(#1641)](https://github.com/PennyLaneAI/pennylane/pull/1641)

Expand All @@ -321,5 +324,6 @@

This release contains contributions from (in alphabetical order):

Utkarsh Azad, Olivia Di Matteo, Andrew Gardhouse, Josh Izaac, Christina Lee, Maria Schuld,
Ingrid Strandberg, Antal Száva, David Wierichs.

Utkarsh Azad, Olivia Di Matteo, Andrew Gardhouse, Josh Izaac, Christina Lee, Romain Moyard,
Maria Schuld, Ingrid Strandberg, Antal Száva, David Wierichs.
8 changes: 5 additions & 3 deletions pennylane/devices/default_qubit.py
Expand Up @@ -626,12 +626,14 @@ def _apply_state_vector(self, state, device_wires):
raise ValueError("State vector must be of length 2**wires.")

norm_error_message = "Sum of amplitudes-squared does not equal one."
if qml.math.get_interface(state) == "torch":
if qml.math.get_interface(state) != "jax":
if not qml.math.allclose(qml.math.linalg.norm(state, ord=2), 1.0, atol=tolerance):
raise ValueError(norm_error_message)
else:
if not np.allclose(np.linalg.norm(state, ord=2), 1.0, atol=tolerance):
raise ValueError(norm_error_message)
# Case for jax without jit, full_lower is an attribute for abstract tracers
if not hasattr(qml.math.linalg.norm(state, ord=2), "full_lower"):
if not qml.math.allclose(qml.math.linalg.norm(state, ord=2), 1.0, atol=tolerance):
raise ValueError(norm_error_message)

if len(device_wires) == self.num_wires and sorted(device_wires) == device_wires:
# Initialize the entire wires with the state
Expand Down

0 comments on commit 283030d

Please sign in to comment.