Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for @tf.function on non-TF devices #1886

Merged
merged 27 commits into from
Nov 18, 2021
Merged

Add support for @tf.function on non-TF devices #1886

merged 27 commits into from
Nov 18, 2021

Conversation

josh146
Copy link
Member

@josh146 josh146 commented Nov 10, 2021

Context: Currently, TensorFlow's AutoGraph mode (enabled by decorating QNodes with @tf.function) only works with backpropagation mode. This is because, for devices that do not support TensorFlow natively, we define a custom gradient interface that calls arbitrary Python code for execution.

Description of the Change:

  • Adds a new interface, tf-autograph, that is re-written using tf.numpy_function to execute quantum devices. This converts the Python execution to a graph node, allowing the computation to be converted to a TF graph.

  • A few changes to vjp.py need to be made to ensure that it is graph-compatible. In particular, when TF traces the QNode execution, tensors have no value and no shape, so we need to explicitly pass the size of the dy vector to compute_vjp.

Benefits:

  • When decorating a QNode (or cost function) with @tf.function, PennyLane automatically detects this, and applies the new tf-autograph interface, resulting in no special requirements from a user-perspective.

  • Jacobians and nth-order derivatives continue to work as expected.

For example, consider the following QNode:

dev = qml.device("default.qubit", wires=2)
x = tf.Variable(0.543, dtype=tf.float64)
y = tf.Variable(-0.654, dtype=tf.float64)

@qml.beta.qnode(dev, diff_method="parameter-shift", max_diff=2, interface="tf")
def circuit(x, y):
    qml.RX(x, wires=[0])
    qml.RY(y, wires=[1])
    qml.CNOT(wires=[0, 1])
    return qml.probs(wires=[0]), qml.probs(wires=[1])

We can autograph it:

circuit_autograph = tf.function(circuit)

Comparing execution times:

>>> %timeit circuit(x, y)
1.72 ms ± 88.5 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
>>> %timeit circuit_autograph(x, y)
879 µs ± 68.7 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)

Comparing gradient execution times for a given cost function:

>>> def cost(circuit, x, y):
...     res = circuit(x, y)
...     return res[0, 0] - res[1, 1]
>>> %%timeit
... with tf.GradientTape() as tape:
...     loss = cost(circuit, x, y)
... grad = tape.gradient(loss, [x, y])
12.3 ms ± 527 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
>>> %%timeit
... with tf.GradientTape() as tape:
...     loss = cost(circuit_autograph, x, y)
... grad = tape.gradient(loss, [x, y])
4 ms ± 218 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)

Note that, while a large improvement, native backprop with autograph still beats both:

>>> @tf.function
... @qml.beta.qnode(dev, diff_method="backprop", interface="tf")
... def circuit_backprop_autograph(x, y):
...     qml.RX(x, wires=[0])
...     qml.RY(y, wires=[1])
...     qml.CNOT(wires=[0, 1])
...     return qml.probs(wires=[0]), qml.probs(wires=[1])
>>> %%timeit
... with tf.GradientTape() as tape:
...     loss = cost(circuit_backprop_autograph, x, y)
... grad = tape.gradient(loss, [x, y])
2.47 ms ± 459 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)

Possible Drawbacks:

  • The initial tracing stage is significantly slower.

  • QNodes that return combinations of samples and other measurement types are not supported.

  • numpy_function must know the returned output types ahead of time, which means that only float64 returns for measurement statistics are currently supported.

  • @tf.function(jit_compile=True) is not supported.

Related GitHub Issues: n.a

@github-actions
Copy link
Contributor

Hello. You may have forgotten to update the changelog!
Please edit doc/releases/changelog-dev.md with:

  • A one-to-two sentence description of the change. You may include a small working example for new features.
  • A link back to this PR.
  • Your name (or GitHub username) in the contributors section.

@codecov
Copy link

codecov bot commented Nov 11, 2021

Codecov Report

Merging #1886 (62832bf) into master (4078456) will increase coverage by 0.00%.
The diff coverage is 100.00%.

Impacted file tree graph

@@           Coverage Diff           @@
##           master    #1886   +/-   ##
=======================================
  Coverage   98.84%   98.84%           
=======================================
  Files         220      221    +1     
  Lines       16843    16937   +94     
=======================================
+ Hits        16648    16742   +94     
  Misses        195      195           
Impacted Files Coverage Δ
pennylane/gradients/vjp.py 100.00% <100.00%> (ø)
pennylane/interfaces/batch/__init__.py 100.00% <100.00%> (ø)
pennylane/interfaces/batch/autograd.py 100.00% <100.00%> (ø)
pennylane/interfaces/batch/jax.py 100.00% <100.00%> (ø)
pennylane/interfaces/batch/tensorflow.py 100.00% <100.00%> (ø)
pennylane/interfaces/batch/tensorflow_autograph.py 100.00% <100.00%> (ø)
pennylane/interfaces/batch/torch.py 100.00% <100.00%> (ø)

Continue to review full report at Codecov.

Legend - Click here to learn more
Δ = absolute <relative> (impact), ø = not affected, ? = missing data
Powered by Codecov. Last update 4078456...62832bf. Read the comment docs.

@josh146 josh146 changed the title [WIP] Add support for @tf.function on non-TF devices [WIP] Add support for @tf.function on non-TF devices Nov 11, 2021
@josh146 josh146 changed the title [WIP] Add support for @tf.function on non-TF devices Add support for @tf.function on non-TF devices Nov 11, 2021
@josh146 josh146 added the review-ready 👌 PRs which are ready for review by someone from the core team. label Nov 12, 2021
@AmintorDusko AmintorDusko self-requested a review November 12, 2021 13:38
@AmintorDusko
Copy link
Contributor

AmintorDusko commented Nov 15, 2021

Hi @josh146. I checked the changes in code and everything looks OK.
Now, I'm working on some benchmark tests.

I'm defining a QNODE in the same way you did in your example, and I'm getting:

%timeit circuit(x, y)
1.33 ms ± 15 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
%timeit circuit_autograph(x, y)
1.48 ms ± 35.3 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)

In the last one I'm getting a warning:

WARNING:tensorflow:@custom_gradient grad_fn has 'variables' in signature, but no ResourceVariables were used on the forward pass.

Comparing gradient execution times for a given cost function:

>>> def cost(circuit, x, y):
...     res = circuit(x, y)
...     return res[0, 0] - res[1, 1]
>>> %%timeit
... with tf.GradientTape() as tape:
...     loss = cost(circuit, x, y)
... grad = tape.gradient(loss, [x, y])
12.3 ms ± 527 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)

Here I'm getting:

8.81 ms ± 298 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
%%timeit
... with tf.GradientTape() as tape:
... loss = cost(circuit_autograph, x, y)
... grad = tape.gradient(loss, [x, y])
4 ms ± 218 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)

But here I'm getting an error:

TypeError: 'Operation' object is not subscriptable

I can provide you a traceback of this error if you like.

Note that, while a large improvement, native backprop with autograph still beats both:

>>> @tf.function
... @qml.beta.qnode(dev, diff_method="backprop", interface="tf")
... def circuit_backprop_autograph(x, y):
...     qml.RX(x, wires=[0])
...     qml.RY(y, wires=[1])
...     qml.CNOT(wires=[0, 1])
...     return qml.probs(wires=[0]), qml.probs(wires=[1])
>>> %%timeit
... with tf.GradientTape() as tape:
...     loss = cost(circuit_backprop_autograph, x, y)
... grad = tape.gradient(loss, [x, y])
2.47 ms ± 459 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)

If, on the other hand, I repeat this benchmark tests with diff_method="backprop", no error or warning appear, and I can observe a nice increase in performance decorating the QNODE with @tf.function.
%timeit -n 1000 circuit(x, y)
3.82 ms ± 52.9 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
%timeit -n 1000 circuit_autograph(x, y)
196 µs ± 13.3 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
20 times faster!

%%timeit
    with tf.GradientTape() as tape:
        loss = cost(circuit, x, y)
    grad = tape.gradient(loss, [x, y])

8.87 ms ± 28 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)

%%timeit
    with tf.GradientTape() as tape:
        loss = cost(circuit_autograph, x, y)
    grad = tape.gradient(loss, [x, y])

2.07 ms ± 383 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)

@tf.function
@qml.beta.qnode(dev, diff_method="backprop", interface="tf")
def circuit_backprop_autograph(x, y):
    qml.RX(x, wires=[0])
    qml.RY(y, wires=[1])
    qml.CNOT(wires=[0, 1])
    return qml.probs(wires=[0]), qml.probs(wires=[1])
%%timeit
    with tf.GradientTape() as tape:
        loss = cost(circuit_backprop_autograph, x, y)
    grad = tape.gradient(loss, [x, y])

2.08 ms ± 502 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
These two last benchmarks are equivalent, as one would expect.

Let me know your thoughts about it.
Maybe there is something I need to set to make full use of the new tf-autograph interface.

@mlxd mlxd requested a review from maliasadi November 15, 2021 17:54
@josh146
Copy link
Member Author

josh146 commented Nov 16, 2021

Thanks @AmintorDusko for taking a look!

I'm defining a QNODE in the same way you did in your example, and I'm getting:

That is odd, I cannot recreate that 🤔 For me, autograph is always faster. When you are performing the timing, are you including the initial compilation time or not?

In the last one I'm getting a warning:

If you know how to get rid of this warning, that would be very much appreciated! I could not find it.

But here I'm getting an error:

TypeError: 'Operation' object is not subscriptable

Ah, I also received this at one point. This could be because you are first calling/compiling circuit_autograph outside a GradientTape, and later calling it again inside a GradientTape. I discovered that you need to compile the function within a GradientTape if you want it to be differentiable later.

Note that, while a large improvement, native backprop with autograph still beats both:

Yep! this is expected I think

Copy link
Member

@maliasadi maliasadi left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great PR! 🚀 As of my experience with tf.Graph, I think it should be documented somewhere that using @tf.function may not necessarily bring any speed-up as it can be faster than eager execution particularly for graphs with many small operations but not for those with not many and expensive operations. For this PR, I only have a few minor comments.

pennylane/gradients/vjp.py Show resolved Hide resolved
@@ -23,7 +23,9 @@
from pennylane import numpy as np


def execute(tapes, device, execute_fn, gradient_fn, gradient_kwargs, _n=1, max_diff=2):
def execute(
tapes, device, execute_fn, gradient_fn, gradient_kwargs, _n=1, max_diff=2, mode="backward"
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems that mode is never used in this function, so, I guess you can remove it from the list of Args. Oh! wait! This is a required argument for execute at ./tensorflow_autograph.py and so you added this unused arg to the execute function of all other interfaces. There are two concerns here,

  • Is this really the best approach to tackle this incompatibility issue?
  • Why is the default mode always "backward" for all autograd, jax, tensorflow, and torch?

I suppose that the mode description is also missed in the list of Args.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good catch @maliasadi, I've made the following changes:

  • mode is documented in all interfaces
  • mode=None is the new default.

Is this really the best approach to tackle this incompatibility issue?

I thought about this a lot. I think it was something that was missed from the original interfaces, and should be included - the existing interfaces are currently querying the shape of the forward pass result to implicitly determine the mode, which is not ideal.

@josh146 josh146 merged commit 8b313e4 into master Nov 18, 2021
@josh146 josh146 deleted the tf-pyfunc branch November 18, 2021 13:59
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
review-ready 👌 PRs which are ready for review by someone from the core team.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

3 participants