Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Bringing templates to feature parity in tape mode #873

Merged
merged 100 commits into from
Jan 5, 2021

Conversation

mariaschuld
Copy link
Contributor

Work in progress: Make the template integration tests pass in tape mode.

@josh146
Copy link
Member

josh146 commented Dec 14, 2020

With the latest commit, all tests should now pass! As a result, all templates, including the notorious Mottonen state prep, are fully differentiable:

import pennylane as qml
from pennylane import numpy as np

n_wires = 3
dev = qml.device("default.qubit", wires=n_wires)

qml.enable_tape()

@qml.qnode(dev)
def circuit(weights, init_state):
    qml.templates.MottonenStatePreparation(init_state, wires=range(n_wires))
    qml.templates.StronglyEntanglingLayers(weights, wires=range(n_wires))
    return qml.expval(qml.PauliX(0) @ qml.PauliZ(1)), qml.expval(qml.PauliZ(1))

init_state = np.array([1, 1, 1, 1, 1, 1, 1, 1], requires_grad=True)
init_state = init_state / np.linalg.norm(init_state)
weights = qml.init.strong_ent_layers_normal(n_wires=n_wires, n_layers=3)

def cost(weights, init_state):
    return np.sum(circuit(weights, init_state))

res = cost(weights, init_state)

print("Cost:", res)

grad_fn = qml.grad(cost)
print("Gradient:", grad_fn(weights, init_state))

@josh146
Copy link
Member

josh146 commented Dec 14, 2020

Edit: there are two JAX tests failing, because the TensorBox class does not yet have JAX support 😢 This is probably a separate PR

Comment on lines +64 to +66
if not qml.math.allclose(norm, 1.0, atol=TOLERANCE):
if normalize or pad_with:
features = features / np.sqrt(norm)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As discussed, JAX abstract types will break on conditional branching based on tensor values.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We need a general strategy how to check numerical properties of inputs then, for example that inputs of BasisEmbedding are 0/1, that inputs to AmplitudeEmbedding are normalised etc...

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  • For basis embedding, since it is not differentiable, what if we restricted it to only NumPy arrays? Would the computational graph just treat is as a constant?

  • For amplitude embedding, one solution is to simply always normalize (e.g., just get rid of the if statement).


if qml.tape_mode_active():
arg = qml.math.angle(arg)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Else?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looking at the 'before', it doesn't seem to be needed 🤔

@mariaschuld, can you remember why we need this in tape mode, but not in non-tape mode?

@@ -64,6 +66,79 @@ def wires_all_to_all(wires):
###################


def _preprocess(parameters, pattern, wires):
"""Validate and pre-process inputs."""
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add function argument documentation?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

True, that would be good. I was sloppy here...

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@josh146 you want me to add them?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We could also be more precise about what preprocessing does in each case...

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, actually, it might be a good idea. Sorry I'm so used to this PR, I didn't even notice the docstrings were missing!

Copy link
Contributor

@chaserileyroberts chaserileyroberts left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM. Only comments are the breakages in tf graph mode or in a jax.jit.

return qml.expval(qml.PauliZ(0))

qnode = qml.QNode(circuit, dev, interface="jax", diff_method="parameter-shift")
weights = jnp.array(qml.init.strong_ent_layers_normal(n_wires=2, n_layers=2))
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Interesting. Why this change? Just to remove a dependency?

Copy link
Member

@josh146 josh146 Jan 4, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yep, this test is testing JAX differentiation, so shouldn't need to depend on extra template code. This narrows the test to ensure that it passes if JAX differentiation works as intended. This way, if the test fails, we know for sure why (as opposed to it potentially being a template issue).

Testing if templates can be differentiated is instead tested by the template integration tests

@josh146 josh146 marked this pull request as ready for review January 4, 2021 07:08
@github-actions
Copy link
Contributor

github-actions bot commented Jan 4, 2021

Hello. You may have forgotten to update the changelog!
Please edit .github/CHANGELOG.md with:

  • A one-to-two sentence description of the change. You may include a small working example for new features.
  • A link back to this PR.
  • Your name (or GitHub username) in the contributors section.

@josh146 josh146 changed the title [WIP] Bringing templates to feature parity in tape mode Bringing templates to feature parity in tape mode Jan 4, 2021
@mariaschuld mariaschuld merged commit 09a6239 into master Jan 5, 2021
@mariaschuld mariaschuld deleted the templates_tape_feature_parity branch January 5, 2021 06:26
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

4 participants