Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add caching to the autograd batch interface #1508

Merged
merged 66 commits into from
Aug 20, 2021
Merged
Show file tree
Hide file tree
Changes from 64 commits
Commits
Show all changes
66 commits
Select commit Hold shift + click to select a range
0c57919
Added differentiable VJP transform
josh146 Aug 4, 2021
674604b
linting
josh146 Aug 4, 2021
688f4a2
more tests
josh146 Aug 4, 2021
9a8476b
linting
josh146 Aug 5, 2021
0413307
add tests
josh146 Aug 5, 2021
6b44284
add comment
josh146 Aug 5, 2021
35e1848
fix
josh146 Aug 5, 2021
67e216a
more
josh146 Aug 5, 2021
d0e40f8
typos
josh146 Aug 5, 2021
89bdd8d
Apply suggestions from code review
josh146 Aug 6, 2021
f415d9f
fixes
josh146 Aug 9, 2021
a4592da
Merge branch 'master' into vjp-transform
josh146 Aug 9, 2021
9153c45
merge
josh146 Aug 9, 2021
6c5dc72
add tests
josh146 Aug 9, 2021
11f20b3
more tests
josh146 Aug 9, 2021
122194c
renamed
josh146 Aug 9, 2021
e98c835
typo
josh146 Aug 9, 2021
5956967
Add caching to the autograd backend
josh146 Aug 9, 2021
8e3159f
more
josh146 Aug 9, 2021
3cbfc22
Merge branch 'master' into vjp-transform
mariaschuld Aug 9, 2021
b36ec30
more
josh146 Aug 9, 2021
3bd36bf
more
josh146 Aug 10, 2021
d644228
more
josh146 Aug 10, 2021
81bd371
caching
josh146 Aug 10, 2021
9a19ce2
fix
josh146 Aug 10, 2021
44ca01d
fix
josh146 Aug 10, 2021
b4bb9d2
fix tests
josh146 Aug 10, 2021
102d551
final
josh146 Aug 10, 2021
55be8f2
update changelog
josh146 Aug 11, 2021
49412da
update
josh146 Aug 11, 2021
b5e4665
merge master
josh146 Aug 11, 2021
11ebfe1
Merge branch 'batch-autograd' into autograd-caching
josh146 Aug 11, 2021
ff2ecb0
more
josh146 Aug 11, 2021
efa7c49
revert formatting
josh146 Aug 11, 2021
4f8342a
more
josh146 Aug 11, 2021
0818184
add tests
josh146 Aug 11, 2021
815e1f3
linting
josh146 Aug 11, 2021
59572bf
Merge branch 'master' into vjp-transform
josh146 Aug 11, 2021
378bcd4
merge master
josh146 Aug 11, 2021
1ca227a
merge master
josh146 Aug 11, 2021
1942afb
Merge branch 'batch-autograd' into autograd-caching
josh146 Aug 11, 2021
96b567e
Apply suggestions from code review
josh146 Aug 12, 2021
2e5e9a9
fix
josh146 Aug 12, 2021
0ad93f6
Merge branch 'autograd-caching' of github.com:PennyLaneAI/pennylane i…
josh146 Aug 12, 2021
2057b86
Apply suggestions from code review
josh146 Aug 15, 2021
6d77f3e
more
josh146 Aug 16, 2021
c1ccb0d
linting
josh146 Aug 16, 2021
6aebd37
linting
josh146 Aug 16, 2021
3e0c909
Merge branch 'batch-autograd' into autograd-caching
josh146 Aug 16, 2021
2f7aeac
merge master
josh146 Aug 16, 2021
c540c53
linting
josh146 Aug 16, 2021
5fb9a4b
Merge branch 'master' into autograd-caching
josh146 Aug 17, 2021
cbbb5f0
remove pass
josh146 Aug 17, 2021
9e749ef
Merge branch 'autograd-caching' of github.com:PennyLaneAI/pennylane i…
josh146 Aug 17, 2021
57b747a
Merge branch 'master' into autograd-caching
josh146 Aug 17, 2021
029e5bf
Merge branch 'master' into autograd-caching
josh146 Aug 17, 2021
64e0dd1
changelog
josh146 Aug 17, 2021
b7a58cf
Merge branch 'autograd-caching' of github.com:PennyLaneAI/pennylane i…
josh146 Aug 17, 2021
77e5df1
Apply suggestions from code review
josh146 Aug 18, 2021
3d2b9b6
Update pennylane/interfaces/batch/__init__.py
josh146 Aug 18, 2021
34b379b
Merge branch 'master' into autograd-caching
josh146 Aug 18, 2021
aec3cc0
Merge branch 'master' into autograd-caching
josh146 Aug 18, 2021
9e0eb7a
Add hashing tests
josh146 Aug 18, 2021
142a662
Merge branch 'master' into autograd-caching
josh146 Aug 19, 2021
ec0bf60
Merge branch 'master' into autograd-caching
josh146 Aug 20, 2021
354aec9
Apply suggestions from code review
josh146 Aug 20, 2021
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 9 additions & 2 deletions .github/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -57,12 +57,15 @@
```

* Support for differentiable execution of batches of circuits has been
added, via the beta `pennylane.batch` module.
added, via the beta `pennylane.interfaces.batch` module.
[(#1501)](https://github.com/PennyLaneAI/pennylane/pull/1501)
[(#1508)](https://github.com/PennyLaneAI/pennylane/pull/1508)

For example:

```python
from pennylane.interfaces.batch import execute

def cost_fn(x):
with qml.tape.JacobianTape() as tape1:
qml.RX(x[0], wires=[0])
Expand All @@ -76,7 +79,11 @@
qml.CNOT(wires=[0, 1])
qml.probs(wires=1)

result = execute([tape1, tape2], dev, gradient_fn=param_shift)
result = execute(
[tape1, tape2], dev,
gradient_fn=qml.gradients.param_shift,
interface="autograd"
)
return result[0] + result[1][0, 0]

res = qml.grad(cost_fn)(params)
Expand Down
160 changes: 152 additions & 8 deletions pennylane/interfaces/batch/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,137 @@
This subpackage defines functions for interfacing devices' batch execution
trbromley marked this conversation as resolved.
Show resolved Hide resolved
capabilities with different machine learning libraries.
"""
# pylint: disable=import-outside-toplevel,too-many-arguments
# pylint: disable=import-outside-toplevel,too-many-arguments,too-many-branches
from functools import wraps

from cachetools import LRUCache
import numpy as np

import pennylane as qml

from .autograd import execute as execute_autograd


def execute(tapes, device, gradient_fn, interface="autograd", mode="best", gradient_kwargs=None):
def cache_execute(fn, cache, pass_kwargs=False, return_tuple=True):
"""Decorator that adds caching to a function that executes
multiple tapes on a device.

This decorator makes use of :attr:`.QuantumTape.hash` to identify
unique tapes.

- If a tape does not match a hash in the cache, then the tape
has not been previously executed. It is executed, and the result
added to the cache.

- If a tape matches a hash in the cache, then the tape has been previously
executed. The corresponding cached result is
extracted, and the tape is not passed to the execution function.

- Finally, there might be the case where one or more tapes in the current
set of tapes to be executed are identical and thus share a hash. If this is the case,
duplicates are removed, to avoid redundant evaluations.

Args:
fn (callable): The execution function to add caching to.
This function should have the signature ``fn(tapes, **kwargs)``,
and it should return ``list[tensor_like]``, with the
same length as the input ``tapes``.
cache (None or dict or Cache): The cache to use. If ``None``,
josh146 marked this conversation as resolved.
Show resolved Hide resolved
caching will not occur.
pass_kwargs (bool): If ``True``, keyword arguments passed to the
wrapped function will be passed directly to ``fn``. If ``False``,
they will be ignored.
return_tuple (bool): If ``True``, the output of ``fn`` is returned
as a tuple ``(fn_ouput, [])``, to match the output of execution functions
that also return gradients.

Returns:
function: a wrapped version of the execution function ``fn`` with caching
support
"""

@wraps(fn)
def wrapper(tapes, **kwargs):

if not pass_kwargs:
kwargs = {}

if cache is None or (isinstance(cache, bool) and not cache):
# No caching. Simply execute the execution function
# and return the results.
res = fn(tapes, **kwargs)
return res, [] if return_tuple else res

execution_tapes = {}
cached_results = {}
hashes = {}
repeated = {}

for i, tape in enumerate(tapes):
h = tape.hash

if h in hashes.values():
# Tape already exists within ``tapes``. Determine the
# index of the first occurance of the tape, store this,
josh146 marked this conversation as resolved.
Show resolved Hide resolved
# and continue to the next iteration.
idx = list(hashes.keys())[list(hashes.values()).index(h)]
repeated[i] = idx
continue

hashes[i] = h

if hashes[i] in cache:
# Tape exists within the cache, store the cached result
cached_results[i] = cache[hashes[i]]
else:
# Tape does not exist within the cache, store the tape
# for execution via the execution function.
execution_tapes[i] = tape

# if there are no execution tapes, simply return!
if not execution_tapes:
if not repeated:
res = list(cached_results.values())
return res, [] if return_tuple else res

else:
# execute all unique tapes that do not exist in the cache
res = fn(execution_tapes.values(), **kwargs)

final_res = []

for i, tape in enumerate(tapes):
if i in cached_results:
# insert cached results into the results vector
final_res.append(cached_results[i])

elif i in repeated:
# insert repeated results into the results vector
final_res.append(final_res[repeated[i]])

else:
# insert evaluated results into the results vector
r = res.pop(0)
final_res.append(r)
cache[hashes[i]] = r
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Clever recombination :)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Clever, but this ends up being slow if there is no tape duplication/caching 🤔

It seems that there is always a tug of war between reducing quantum evals, and reducing classical compute time


return final_res, [] if return_tuple else final_res

wrapper.fn = fn
josh146 marked this conversation as resolved.
Show resolved Hide resolved
return wrapper


def execute(
tapes,
device,
gradient_fn,
interface="autograd",
mode="best",
gradient_kwargs=None,
cache=True,
cachesize=10000,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we have an idea of the memory implications of this? 🤔

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Assuming you do not pass a cache object manually to the execute function, the cache will be created inside execute. What this means is that - as soon as execute has exited, the cache is out of scope and will be garbage collected by Python.

I am 99.99% sure of this, but don't know how to sanity check 😖

This is from the last time I tried to explore this: #1131 (comment)

Do you have any ideas on how to double check that the cache is deleted after execution?

max_diff=2,
):
"""Execute a batch of tapes on a device in an autodifferentiable-compatible manner.

Args:
Expand All @@ -42,6 +166,13 @@ def execute(tapes, device, gradient_fn, interface="autograd", mode="best", gradi
pass.
gradient_kwargs (dict): dictionary of keyword arguments to pass when
determining the gradients of tapes
cache (bool): Whether to cache evaluations. This can result in
a significant reduction in quantum evaluations during gradient computations.
cachesize (int): the size of the cache
max_diff (int): If ``gradient_fn`` is a gradient transform, this option specifies
the maximum number of derivatives to support. Increasing this value allows
for higher order derivatives to be extracted, at the cost of additional
(classical) computational overhead during the backwards pass.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

well-explained!


Returns:
list[list[float]]: A nested list of tape results. Each element in
Expand Down Expand Up @@ -101,11 +232,15 @@ def cost_fn(params, x):
[ 0.01983384, -0.97517033, 0. ],
[ 0. , 0. , -0.95533649]])
"""
# Default execution function; simply call device.batch_execute
# and return no Jacobians.
execute_fn = lambda tapes, **kwargs: (device.batch_execute(tapes), [])
gradient_kwargs = gradient_kwargs or {}

if isinstance(cache, bool) and cache:
# cache=True: create a LRUCache object
cache = LRUCache(maxsize=cachesize, getsizeof=len)
josh146 marked this conversation as resolved.
Show resolved Hide resolved

# the default execution function is device.batch_execute
execute_fn = cache_execute(device.batch_execute, cache)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice


if gradient_fn == "device":
# gradient function is a device method

Expand All @@ -116,8 +251,13 @@ def cost_fn(params, x):
gradient_fn = None

elif mode == "backward":
# disable caching on the forward pass
execute_fn = cache_execute(device.batch_execute, cache=None)

# replace the backward gradient computation
gradient_fn = device.gradients
gradient_fn = cache_execute(
device.gradients, cache, pass_kwargs=True, return_tuple=False
)
Comment on lines +254 to +260
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Probably my unfamiliarity with the recent changes, but do we expect to need caching for device-based gradients? I thought this was mainly for parameter shift.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Caching is only needed for device-based gradients if mode="backwards". Backwards mode essentially means:

  • On the forward pass, only the cost function is computed
  • The gradients are only requested during backpropagation

This means that there will always be 1 additional eval required -- caching therefore reduces the number of evals by 1 😆

Worth it?

I mean, I'd expect 99% of users to use device gradients with mode="forward".

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sounds good!

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does this supersede #1341?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No this complements it for now 🙂

#1341 added the use_device_state keyword argument which instructs QubitDevice.adjoint_jacobian() to use the existing device state and avoid a redundant forward pass.

When mode="forward", we can pass this option:

execute(
    tapes,
    dev,
    gradient_fn="device",
    interface="torch",
    gradient_kwargs={"method": "adjoint_jacobian", "use_device_state": True},
    mode="forward"
)


elif mode == "forward":
# In "forward" mode, gradients are automatically handled
Expand All @@ -126,6 +266,10 @@ def cost_fn(params, x):
raise ValueError("Gradient transforms cannot be used with mode='forward'")

if interface == "autograd":
return execute_autograd(tapes, device, execute_fn, gradient_fn, gradient_kwargs)
res = execute_autograd(
tapes, device, execute_fn, gradient_fn, gradient_kwargs, _n=1, max_diff=max_diff
)
else:
raise ValueError(f"Unknown interface {interface}")

raise ValueError(f"Unknown interface {interface}")
return res
62 changes: 49 additions & 13 deletions pennylane/interfaces/batch/autograd.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
from pennylane import numpy as np


def execute(tapes, device, execute_fn, gradient_fn, gradient_kwargs, _n=1):
def execute(tapes, device, execute_fn, gradient_fn, gradient_kwargs, _n=1, max_diff=2):
"""Execute a batch of tapes with Autograd parameters on a device.

Args:
Expand All @@ -42,6 +42,10 @@ def execute(tapes, device, execute_fn, gradient_fn, gradient_kwargs, _n=1):
gradient_fn (callable): the gradient function to use to compute quantum gradients
_n (int): a positive integer used to track nesting of derivatives, for example
if the nth-order derivative is requested.
max_diff (int): If ``gradient_fn`` is a gradient transform, this option specifies
the maximum order of derivatives to support. Increasing this value allows
for higher order derivatives to be extracted, at the cost of additional
(classical) computational overhead during the backwards pass.

Returns:
list[list[float]]: A nested list of tape results. Each element in
Expand All @@ -64,6 +68,7 @@ def execute(tapes, device, execute_fn, gradient_fn, gradient_kwargs, _n=1):
gradient_fn=gradient_fn,
gradient_kwargs=gradient_kwargs,
_n=_n,
max_diff=max_diff,
)[0]


Expand All @@ -76,6 +81,7 @@ def _execute(
gradient_fn=None,
gradient_kwargs=None,
_n=1,
max_diff=2,
): # pylint: disable=dangerous-default-value,unused-argument
"""Autodifferentiable wrapper around ``Device.batch_execute``.

Expand Down Expand Up @@ -119,6 +125,7 @@ def vjp(
gradient_fn=None,
gradient_kwargs=None,
_n=1,
max_diff=2,
): # pylint: disable=dangerous-default-value,unused-argument
"""Returns the vector-Jacobian product operator for a batch of quantum tapes.

Expand All @@ -139,6 +146,10 @@ def vjp(
determining the gradients of tapes
_n (int): a positive integer used to track nesting of derivatives, for example
if the nth-order derivative is requested.
max_diff (int): If ``gradient_fn`` is a gradient transform, this option specifies
the maximum number of derivatives to support. Increasing this value allows
for higher order derivatives to be extracted, at the cost of additional
(classical) computational overhead during the backwards pass.

Returns:
function: this function accepts the backpropagation
Expand Down Expand Up @@ -169,18 +180,43 @@ def grad_fn(dy):
if "pennylane.gradients" in module_name:

# Generate and execute the required gradient tapes
vjp_tapes, processing_fn = qml.gradients.batch_vjp(
tapes, dy, gradient_fn, reduction="append", gradient_kwargs=gradient_kwargs
)

# This is where the magic happens. Note that we call ``execute``.
# This recursion, coupled with the fact that the gradient transforms
# are differentiable, allows for arbitrary order differentiation.
vjps = processing_fn(
execute(vjp_tapes, device, execute_fn, gradient_fn, gradient_kwargs, _n=_n + 1)
)

elif inspect.ismethod(gradient_fn) and gradient_fn.__self__ is device:
if _n == max_diff:
with qml.tape.Unwrap(*tapes):
vjp_tapes, processing_fn = qml.gradients.batch_vjp(
tapes,
dy,
gradient_fn,
reduction="append",
gradient_kwargs=gradient_kwargs,
)

vjps = processing_fn(execute_fn(vjp_tapes)[0])

else:
vjp_tapes, processing_fn = qml.gradients.batch_vjp(
tapes, dy, gradient_fn, reduction="append", gradient_kwargs=gradient_kwargs
)

# This is where the magic happens. Note that we call ``execute``.
# This recursion, coupled with the fact that the gradient transforms
# are differentiable, allows for arbitrary order differentiation.
vjps = processing_fn(
execute(
vjp_tapes,
device,
execute_fn,
gradient_fn,
gradient_kwargs,
_n=_n + 1,
max_diff=max_diff,
)
)

elif (
hasattr(gradient_fn, "fn")
and inspect.ismethod(gradient_fn.fn)
and gradient_fn.fn.__self__ is device
):
# Gradient function is a device method.
# Note that unlike the previous branch:
#
Expand Down
20 changes: 20 additions & 0 deletions pennylane/measure.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,6 +202,26 @@ def queue(self, context=qml.QueuingContext):

return self

@property
def hash(self):
"""int: returns an integer hash uniquely representing the measurement process"""
if self.obs is None:
fingerprint = (
str(self.name),
tuple(self.wires.tolist()),
str(self.data),
self.return_type,
)
else:
fingerprint = (
str(self.obs.name),
tuple(self.wires.tolist()),
str(self.obs.data),
self.return_type,
)

return hash(fingerprint)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh, surprised this wasn't there yet!

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It was on the CircuitGraph, but not the tape!



def expval(op):
r"""Expectation value of the supplied observable.
Expand Down
15 changes: 15 additions & 0 deletions pennylane/operation.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,6 +233,16 @@ def classproperty(func):
# =============================================================================


def _process_data(op):
if op.name in ("RX", "RY", "RZ", "PhaseShift", "Rot"):
return str([d % (2 * np.pi) for d in op.data])

if op.name in ("CRX", "CRY", "CRZ", "CRot"):
return str([d % (4 * np.pi) for d in op.data])

return str(op.data)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also needed to compute hashes?



class Operator(abc.ABC):
r"""Base class for quantum operators supported by a device.

Expand Down Expand Up @@ -282,6 +292,11 @@ def __deepcopy__(self, memo):
setattr(copied_op, attribute, copy.deepcopy(value, memo))
return copied_op

@property
def hash(self):
"""int: returns an integer hash uniquely representing the operator"""
return hash((str(self.name), tuple(self.wires.tolist()), _process_data(self)))

@classmethod
def _matrix(cls, *params):
"""Matrix representation of the operator
Expand Down
Loading