-
Notifications
You must be signed in to change notification settings - Fork 575
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Bringing templates to feature parity in tape mode #873
Conversation
…ennylane into dense_matrices_in_mottonen
…matrices_in_mottonen
…ennylane into dense_matrices_in_mottonen
With the latest commit, all tests should now pass! As a result, all templates, including the notorious Mottonen state prep, are fully differentiable: import pennylane as qml
from pennylane import numpy as np
n_wires = 3
dev = qml.device("default.qubit", wires=n_wires)
qml.enable_tape()
@qml.qnode(dev)
def circuit(weights, init_state):
qml.templates.MottonenStatePreparation(init_state, wires=range(n_wires))
qml.templates.StronglyEntanglingLayers(weights, wires=range(n_wires))
return qml.expval(qml.PauliX(0) @ qml.PauliZ(1)), qml.expval(qml.PauliZ(1))
init_state = np.array([1, 1, 1, 1, 1, 1, 1, 1], requires_grad=True)
init_state = init_state / np.linalg.norm(init_state)
weights = qml.init.strong_ent_layers_normal(n_wires=n_wires, n_layers=3)
def cost(weights, init_state):
return np.sum(circuit(weights, init_state))
res = cost(weights, init_state)
print("Cost:", res)
grad_fn = qml.grad(cost)
print("Gradient:", grad_fn(weights, init_state)) |
Edit: there are two JAX tests failing, because the |
if not qml.math.allclose(norm, 1.0, atol=TOLERANCE): | ||
if normalize or pad_with: | ||
features = features / np.sqrt(norm) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
As discussed, JAX abstract types will break on conditional branching based on tensor values.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We need a general strategy how to check numerical properties of inputs then, for example that inputs of BasisEmbedding are 0/1, that inputs to AmplitudeEmbedding are normalised etc...
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
-
For basis embedding, since it is not differentiable, what if we restricted it to only NumPy arrays? Would the computational graph just treat is as a constant?
-
For amplitude embedding, one solution is to simply always normalize (e.g., just get rid of the if statement).
|
||
if qml.tape_mode_active(): | ||
arg = qml.math.angle(arg) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Else?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looking at the 'before', it doesn't seem to be needed 🤔
@mariaschuld, can you remember why we need this in tape mode, but not in non-tape mode?
pennylane/templates/broadcast.py
Outdated
@@ -64,6 +66,79 @@ def wires_all_to_all(wires): | |||
################### | |||
|
|||
|
|||
def _preprocess(parameters, pattern, wires): | |||
"""Validate and pre-process inputs.""" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Add function argument documentation?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
True, that would be good. I was sloppy here...
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@josh146 you want me to add them?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We could also be more precise about what preprocessing does in each case...
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah, actually, it might be a good idea. Sorry I'm so used to this PR, I didn't even notice the docstrings were missing!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM. Only comments are the breakages in tf graph mode or in a jax.jit.
return qml.expval(qml.PauliZ(0)) | ||
|
||
qnode = qml.QNode(circuit, dev, interface="jax", diff_method="parameter-shift") | ||
weights = jnp.array(qml.init.strong_ent_layers_normal(n_wires=2, n_layers=2)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Interesting. Why this change? Just to remove a dependency?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yep, this test is testing JAX differentiation, so shouldn't need to depend on extra template code. This narrows the test to ensure that it passes if JAX differentiation works as intended. This way, if the test fails, we know for sure why (as opposed to it potentially being a template issue).
Testing if templates can be differentiated is instead tested by the template integration tests
Hello. You may have forgotten to update the changelog!
|
…I/pennylane into templates_tape_feature_parity
…I/pennylane into templates_tape_feature_parity
Work in progress: Make the template integration tests pass in tape mode.