Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for @tf.function on non-TF devices #1886

Merged
merged 27 commits into from
Nov 18, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 43 additions & 2 deletions doc/releases/changelog-dev.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,46 @@

<h3>New features since last release</h3>

* It is now possible to use TensorFlow's [AutoGraph
mode](https://www.tensorflow.org/guide/function) with QNodes on all devices and with arbitrary
differentiation methods. Previously, AutoGraph mode only support `diff_method="backprop"`. This
will result in significantly more performant model execution, at the cost of a more expensive
initial compilation. [(#1866)](https://github.com/PennyLaneAI/pennylane/pull/1886)

Use AutoGraph to convert your QNodes or cost functions into TensorFlow
graphs by decorating them with `@tf.function`:

```python
dev = qml.device("lightning.qubit", wires=2)

@qml.beta.qnode(dev, diff_method="adjoint", interface="tf", max_diff=1)
def circuit(x):
qml.RX(x[0], wires=0)
qml.RY(x[1], wires=1)
return qml.expval(qml.PauliZ(0) @ qml.PauliZ(1)), qml.expval(qml.PauliZ(0))

@tf.function
def cost(x):
return tf.reduce_sum(circuit(x))

x = tf.Variable([0.5, 0.7], dtype=tf.float64)

with tf.GradientTape() as tape:
loss = cost(x)

grad = tape.gradient(loss, x)
```

The initial execution may take slightly longer than when executing the circuit in
eager mode; this is because TensorFlow is tracing the function to create the graph.
Subsequent executions will be much more performant.

Note that using AutoGraph with backprop-enabled devices, such as `default.qubit`,
will yield the best performance.

For more details, please see the [TensorFlow AutoGraph
documentation](https://www.tensorflow.org/guide/function).

* `qml.math.scatter_element_add` now supports adding multiple values at
multiple indices with a single function call, in all interfaces
[(#1864)](https://github.com/PennyLaneAI/pennylane/pull/1864)
Expand Down Expand Up @@ -231,5 +271,6 @@

This release contains contributions from (in alphabetical order):

Guillermo Alonso-Linaje, Benjamin Cordier, Olivia Di Matteo, Jalani Kanem, Ankit Khandelwal, Shumpei Kobayashi,
Christina Lee, Alejandro Montanez, Romain Moyard, Maria Schuld, Jay Soni, David Wierichs
Guillermo Alonso-Linaje, Benjamin Cordier, Olivia Di Matteo, Josh Izaac,
Jalani Kanem, Ankit Khandelwal, Shumpei Kobayashi, Christina Lee, Alejandro Montanez,
Romain Moyard, Maria Schuld, Jay Soni, David Wierichs
25 changes: 17 additions & 8 deletions pennylane/gradients/vjp.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from pennylane import math


def compute_vjp(dy, jac):
def compute_vjp(dy, jac, num=None):
josh146 marked this conversation as resolved.
Show resolved Hide resolved
"""Convenience function to compute the vector-Jacobian product for a given
vector of gradient outputs and a Jacobian.

Expand All @@ -29,6 +29,9 @@ def compute_vjp(dy, jac):
jac (tensor_like): Jacobian matrix. For an n-dimensional ``dy``
vector, the first n-dimensions of ``jac`` should match
the shape of ``dy``.
num (int): The length of the flattened ``dy`` argument. This is an
optional argument, but can be useful to provide if ``dy`` potentially
has no shape (for example, due to tracing or just-in-time compilation).

Returns:
tensor_like: the vector-Jacobian product
Expand All @@ -38,10 +41,13 @@ def compute_vjp(dy, jac):

dy_row = math.reshape(dy, [-1])

if num is None:
num = math.shape(dy_row)[0]

if not isinstance(dy_row, np.ndarray):
jac = math.convert_like(jac, dy_row)

jac = math.reshape(jac, [dy_row.shape[0], -1])
jac = math.reshape(jac, [num, -1])

try:
if math.allclose(dy, 0):
Expand Down Expand Up @@ -156,23 +162,23 @@ def vjp(tape, dy, gradient_fn, gradient_kwargs=None):
if num_params == 0:
# The tape has no trainable parameters; the VJP
# is simply none.
return [], lambda _: None
return [], lambda _, num=None: None

try:
if math.allclose(dy, 0):
# If the dy vector is zero, then the
# corresponding element of the VJP will be zero,
# and we can avoid a quantum computation.
return [], lambda _: math.convert_like(np.zeros([num_params]), dy)
return [], lambda _, num=None: math.convert_like(np.zeros([num_params]), dy)
except (AttributeError, TypeError):
pass

gradient_tapes, fn = gradient_fn(tape, **gradient_kwargs)

def processing_fn(results):
def processing_fn(results, num=None):
# postprocess results to compute the Jacobian
jac = fn(results)
return compute_vjp(dy, jac)
return compute_vjp(dy, jac, num=num)

return gradient_tapes, processing_fn

Expand Down Expand Up @@ -304,18 +310,21 @@ def ansatz(x):
processing_fns.append(fn)
gradient_tapes.extend(g_tapes)

def processing_fn(results):
def processing_fn(results, nums=None):
vjps = []
start = 0

if nums is None:
nums = [None] * len(tapes)

for t_idx in range(len(tapes)):
# extract the correct results from the flat list
res_len = reshape_info[t_idx]
res_t = results[start : start + res_len]
start += res_len

# postprocess results to compute the VJP
vjp_ = processing_fns[t_idx](res_t)
vjp_ = processing_fns[t_idx](res_t, num=nums[t_idx])

if vjp_ is None:
if reduction == "append":
Expand Down
16 changes: 13 additions & 3 deletions pennylane/interfaces/batch/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
"Autograd": ("autograd", "numpy"), # for backwards compatibility
"JAX": ("jax", "JAX"),
"PyTorch": ("torch", "pytorch"),
"TensorFlow": ("tf", "tensorflow"),
"TensorFlow": ("tf", "tensorflow", "tensorflow-autograph", "tf-autograph"),
}
"""dict[str, str]: maps allowed interface strings to the name of the interface"""

Expand Down Expand Up @@ -322,6 +322,7 @@ def cost_fn(params, x):

# the default execution function is batch_execute
execute_fn = cache_execute(batch_execute, cache, expand_fn=expand_fn)
_mode = "backward"

if gradient_fn == "device":
# gradient function is a device method
Expand All @@ -338,6 +339,7 @@ def cost_fn(params, x):
# both results and gradients
execute_fn = set_shots(device, override_shots)(device.execute_and_gradients)
gradient_fn = None
_mode = "forward"

elif mode == "backward":
# disable caching on the forward pass
Expand All @@ -361,7 +363,13 @@ def cost_fn(params, x):
if interface in INTERFACE_NAMES["Autograd"]:
from .autograd import execute as _execute
elif interface in INTERFACE_NAMES["TensorFlow"]:
from .tensorflow import execute as _execute
import tensorflow as tf

if not tf.executing_eagerly() or "autograph" in interface:
from .tensorflow_autograph import execute as _execute
else:
from .tensorflow import execute as _execute

elif interface in INTERFACE_NAMES["PyTorch"]:
from .torch import execute as _execute
elif interface in INTERFACE_NAMES["JAX"]:
Expand All @@ -379,6 +387,8 @@ def cost_fn(params, x):
f"version of {interface_name} to enable the '{interface}' interface."
) from e

res = _execute(tapes, device, execute_fn, gradient_fn, gradient_kwargs, _n=1, max_diff=max_diff)
res = _execute(
tapes, device, execute_fn, gradient_fn, gradient_kwargs, _n=1, max_diff=max_diff, mode=_mode
)

return res
5 changes: 4 additions & 1 deletion pennylane/interfaces/batch/autograd.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
from pennylane import numpy as np


def execute(tapes, device, execute_fn, gradient_fn, gradient_kwargs, _n=1, max_diff=2):
def execute(tapes, device, execute_fn, gradient_fn, gradient_kwargs, _n=1, max_diff=2, mode=None):
"""Execute a batch of tapes with Autograd parameters on a device.

Args:
Expand All @@ -44,11 +44,14 @@ def execute(tapes, device, execute_fn, gradient_fn, gradient_kwargs, _n=1, max_d
the maximum order of derivatives to support. Increasing this value allows
for higher order derivatives to be extracted, at the cost of additional
(classical) computational overhead during the backwards pass.
mode (str): Whether the gradients should be computed on the forward
pass (``forward``) or the backward pass (``backward``).

Returns:
list[list[float]]: A nested list of tape results. Each element in
the returned list corresponds in order to the provided tapes.
"""
# pylint: disable=unused-argument
for tape in tapes:
# set the trainable parameters
params = tape.get_parameters(trainable_only=False)
Expand Down
5 changes: 4 additions & 1 deletion pennylane/interfaces/batch/jax.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
dtype = jnp.float64


def execute(tapes, device, execute_fn, gradient_fn, gradient_kwargs, _n=1, max_diff=1):
def execute(tapes, device, execute_fn, gradient_fn, gradient_kwargs, _n=1, max_diff=1, mode=None):
"""Execute a batch of tapes with JAX parameters on a device.

Args:
Expand All @@ -48,11 +48,14 @@ def execute(tapes, device, execute_fn, gradient_fn, gradient_kwargs, _n=1, max_d
the maximum order of derivatives to support. Increasing this value allows
for higher order derivatives to be extracted, at the cost of additional
(classical) computational overhead during the backwards pass.
mode (str): Whether the gradients should be computed on the forward
pass (``forward``) or the backward pass (``backward``).

Returns:
list[list[float]]: A nested list of tape results. Each element in
the returned list corresponds in order to the provided tapes.
"""
# pylint: disable=unused-argument
if max_diff > 1:
raise ValueError("The JAX interface only supports first order derivatives.")

Expand Down
5 changes: 4 additions & 1 deletion pennylane/interfaces/batch/tensorflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ def _compute_vjp(dy, jacs):
return vjps


def execute(tapes, device, execute_fn, gradient_fn, gradient_kwargs, _n=1, max_diff=2):
def execute(tapes, device, execute_fn, gradient_fn, gradient_kwargs, _n=1, max_diff=2, mode=None):
"""Execute a batch of tapes with TensorFlow parameters on a device.

Args:
Expand All @@ -60,11 +60,14 @@ def execute(tapes, device, execute_fn, gradient_fn, gradient_kwargs, _n=1, max_d
the maximum number of derivatives to support. Increasing this value allows
for higher order derivatives to be extracted, at the cost of additional
(classical) computational overhead during the backwards pass.
mode (str): Whether the gradients should be computed on the forward
pass (``forward``) or the backward pass (``backward``).

Returns:
list[list[tf.Tensor]]: A nested list of tape results. Each element in
the returned list corresponds in order to the provided tapes.
"""
# pylint: disable=unused-argument

parameters = []
params_unwrapped = []
Expand Down
Loading