Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove norm check for jax.jit functions in QubitStateVector #1683

Merged
merged 18 commits into from Sep 28, 2021
Merged

Conversation

rmoyard
Copy link
Contributor

@rmoyard rmoyard commented Sep 23, 2021

Context:
The actual behavior of the QubitStateVector checks that the norm of the given vector is one. Unfortunately it is not possible to condition on the value when we use jax.jit. Therefore the check is removed when the decorator is used.

Description of the Change:
We verify the type of the norm, if this is a jax tracer then we do not apply the norm check.

Benefits:
We can now use jax.jit when using QubitStateVector

Possible Drawbacks:
The norm is not checked, it can lead to some errors if the user did a mistake in the state vector.

Related GitHub Issues:
#1670

@github-actions
Copy link
Contributor

Hello. You may have forgotten to update the changelog!
Please edit doc/releases/changelog-dev.md with:

  • A one-to-two sentence description of the change. You may include a small working example for new features.
  • A link back to this PR.
  • Your name (or GitHub username) in the contributors section.

@rmoyard rmoyard added the WIP 🚧 Work-in-progress label Sep 23, 2021
Copy link
Member

@josh146 josh146 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice solution @rmoyard! It might be best to find a way of doing it without import JAX however, since JAX is a soft dependency.

Perhaps we could use an approach that checks for a JAX-specific attribute or string?

pennylane/devices/default_qubit.py Outdated Show resolved Hide resolved
pennylane/devices/default_qubit.py Outdated Show resolved Hide resolved
tests/devices/test_default_qubit_jax.py Outdated Show resolved Hide resolved
@rmoyard rmoyard removed the WIP 🚧 Work-in-progress label Sep 24, 2021
@codecov
Copy link

codecov bot commented Sep 24, 2021

Codecov Report

Merging #1683 (55ad9ef) into master (4df8dc5) will increase coverage by 0.00%.
The diff coverage is 100.00%.

Impacted file tree graph

@@           Coverage Diff           @@
##           master    #1683   +/-   ##
=======================================
  Coverage   99.20%   99.20%           
=======================================
  Files         201      201           
  Lines       15121    15122    +1     
=======================================
+ Hits        15001    15002    +1     
  Misses        120      120           
Impacted Files Coverage Δ
pennylane/devices/default_qubit.py 100.00% <100.00%> (ø)

Continue to review full report at Codecov.

Legend - Click here to learn more
Δ = absolute <relative> (impact), ø = not affected, ? = missing data
Powered by Codecov. Last update 4df8dc5...55ad9ef. Read the comment docs.

@@ -626,12 +626,9 @@ def _apply_state_vector(self, state, device_wires):
raise ValueError("State vector must be of length 2**wires.")

norm_error_message = "Sum of amplitudes-squared does not equal one."
if qml.math.get_interface(state) == "torch":
if not hasattr(qml.math.linalg.norm(state, ord=2), "full_lower"):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What does full_lower mean :)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@rmoyard I have a worry that this might not be explicit enough. Maybe something like

if qml.math.get_interface(state) == "jax" and not hasattr(...):

might be safer, in that the if statement will only explicitly apply in the JAX case?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

From Jax documentation, full_lower is an optional optimization so that we unbox values out of Tracers as much as possible.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree that it is not clear at all, I will separate the cases!

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

From Jax documentation, full_lower is an optional optimization so that we unbox values out of Tracers as much as possible.

Nice, I wasn't aware of this function!

Copy link
Member

@josh146 josh146 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @rmoyard! Even though it was just a minor change to the logic, it reads much clearer now 💯

return qml.expval(qml.PauliZ(wires=0))

with pytest.raises(ValueError, match="Sum of amplitudes-squared does not equal one."):
circuit(0.1)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

really nice tests @rmoyard!

Comment on lines +633 to +636
# Case for jax without jit, full_lower is an attribute for abstract tracers
if not hasattr(qml.math.linalg.norm(state, ord=2), "full_lower"):
if not qml.math.allclose(qml.math.linalg.norm(state, ord=2), 1.0, atol=tolerance):
raise ValueError(norm_error_message)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is much clearer now 👍

@antalszava
Copy link
Contributor

[ch9188]

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

4 participants