New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Remove norm check for jax.jit functions in QubitStateVector
#1683
Conversation
Hello. You may have forgotten to update the changelog!
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice solution @rmoyard! It might be best to find a way of doing it without import JAX however, since JAX is a soft dependency.
Perhaps we could use an approach that checks for a JAX-specific attribute or string?
Codecov Report
@@ Coverage Diff @@
## master #1683 +/- ##
=======================================
Coverage 99.20% 99.20%
=======================================
Files 201 201
Lines 15121 15122 +1
=======================================
+ Hits 15001 15002 +1
Misses 120 120
Continue to review full report at Codecov.
|
pennylane/devices/default_qubit.py
Outdated
@@ -626,12 +626,9 @@ def _apply_state_vector(self, state, device_wires): | |||
raise ValueError("State vector must be of length 2**wires.") | |||
|
|||
norm_error_message = "Sum of amplitudes-squared does not equal one." | |||
if qml.math.get_interface(state) == "torch": | |||
if not hasattr(qml.math.linalg.norm(state, ord=2), "full_lower"): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What does full_lower
mean :)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@rmoyard I have a worry that this might not be explicit enough. Maybe something like
if qml.math.get_interface(state) == "jax" and not hasattr(...):
might be safer, in that the if statement will only explicitly apply in the JAX case?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
From Jax documentation, full_lower is an optional optimization so that we unbox values out of Tracers as much as possible.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I agree that it is not clear at all, I will separate the cases!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
From Jax documentation, full_lower is an optional optimization so that we unbox values out of Tracers as much as possible.
Nice, I wasn't aware of this function!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks @rmoyard! Even though it was just a minor change to the logic, it reads much clearer now 💯
return qml.expval(qml.PauliZ(wires=0)) | ||
|
||
with pytest.raises(ValueError, match="Sum of amplitudes-squared does not equal one."): | ||
circuit(0.1) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
really nice tests @rmoyard!
# Case for jax without jit, full_lower is an attribute for abstract tracers | ||
if not hasattr(qml.math.linalg.norm(state, ord=2), "full_lower"): | ||
if not qml.math.allclose(qml.math.linalg.norm(state, ord=2), 1.0, atol=tolerance): | ||
raise ValueError(norm_error_message) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is much clearer now 👍
[ch9188] |
Context:
The actual behavior of the
QubitStateVector
checks that the norm of the given vector is one. Unfortunately it is not possible to condition on the value when we use jax.jit. Therefore the check is removed when the decorator is used.Description of the Change:
We verify the type of the norm, if this is a jax tracer then we do not apply the norm check.
Benefits:
We can now use jax.jit when using
QubitStateVector
Possible Drawbacks:
The norm is not checked, it can lead to some errors if the user did a mistake in the state vector.
Related GitHub Issues:
#1670