-
Notifications
You must be signed in to change notification settings - Fork 575
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[WIP] Add a framework for batch reduction of multiple tapes #1362
Conversation
Co-authored-by: Olivia Di Matteo <2068515+glassnotes@users.noreply.github.com>
Co-authored-by: Olivia Di Matteo <2068515+glassnotes@users.noreply.github.com>
…into better-transforms
Co-authored-by: Olivia Di Matteo <2068515+glassnotes@users.noreply.github.com>
Co-authored-by: Olivia Di Matteo <2068515+glassnotes@users.noreply.github.com>
Co-authored-by: Olivia Di Matteo <2068515+glassnotes@users.noreply.github.com>
Co-authored-by: David Wierichs <davidwierichs@gmail.com>
Hello. You may have forgotten to update the changelog!
|
Codecov Report
@@ Coverage Diff @@
## master #1362 +/- ##
==========================================
- Coverage 98.23% 98.10% -0.14%
==========================================
Files 160 161 +1
Lines 11966 12031 +65
==========================================
+ Hits 11755 11803 +48
- Misses 211 228 +17
Continue to review full report at Codecov.
|
def batch_reduce(fn): | ||
"""Register a new batch reduce transform. | ||
|
||
A batch reduction transform takes a QNode as input, and: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
A batch reduction transform takes a QNode as input, and: | |
A batch reduction transform takes a QNode or quantum tape as input, and: |
qml.CRX(x, wires=[0, 1]) | ||
return qml.expval(qml.PauliZ(1)) | ||
|
||
>>> circuit(0.6) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This still feels a bit weird to me. If I run circuit
, even with the batch_reduce
decorator, I would still instinctively expect to get the output of circuit
, and not the processing function. Not sure if this makes sense, but is there a way to enable a distinction between the original QNode and the processing function? I guess just not using the decorator for batch_reduce
would give the list of tapes / function explicitly?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think it essentially comes down to Python syntax:
-
Use the decorator if you don't want to keep the original QNode, you want to replace it.
-
If you want both the original QNode and the transform, don't use it as a decorator:
>>> fn = transform(qnode) >>> print(fn(params), qnode(params)
I think this is more of a documentation thing? E.g., qml.grad
would work equally well as a decorator, where it would simply 'replace' the QNode:
@qml.grad
def circuit(params):
Calling circuit
would now return the gradient of the QNode.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So maybe in the documentation we just avoid the decorator syntax?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we should show both (since the decorators are super convenient). We'll just have to be clear about the replacement aspect.
Context:
There is a transformation pattern that repeats often throughout the PennyLane codebase:
That is, a transformation that maps a sequence of tapes to another sequence of independent tapes (that is, tapes that have input that does not depend on the results of other tapes) and a classical processing function to reduce the tape execution results.
Current logic that adhere to this pattern include:
Because the tapes are independent, they can be executed in a single 'batch', potentially via
device.batch_execute
or Dask parallelism, potentially leading to speedup.Description of changes:
This PR adds a batch reduce decorator
@qml.batch_reduce
, that serves the following purposes:It serves to 'register' functions that contain transforms of the following form, creating transformations that act directly on QNodes.
It abstracts away the execution logic, allowing batch reduction transforms to be easily and quickly created which make use of
device.batch_execute
and Dask parallelism.Rendered docs and code examples available here.
For example, the existing
qml.transforms.hamiltonian_expand
function mapstape -> tapes, reduction_function
. We can register this tape transform as a batch-reduce operation:Benefits
Existing tape transforms that fit this 'batch-reduce' pattern can be easily converted into high-level QNode transforms that take advantage of
device.batch_execute
and Dask parallelism.Execution logic is contained in a single place, leading to better long term maintainability. Future batch-reduce transforms only need to implement the transformation logic, and leave the execution logic to the decorator.
Enforces a standard API for this batch-reduce pattern going forward.
Potential drawbacks:
Currently,
device.batch_execute
does not support differentiation when using the parameter-shift or finite-diff differentiation methods. Fixing this will require a much larger device/interface refactor.Dask parallelism can sometimes result in race conditions when used with a local Python simulator (such as
default.qubit
).The decorator enforces the syntax
my_transform(qnode)(transform_args)(qnode_args)
, which differs from some existing transforms, such asqml.metric_tensor(qnode, transform_args)(qnode_args)
.Having the decorator take arguments results in an extra set of parenthesis:
my_transform(qnode)(batch_execute=True)(transform_args)(qnode_args)
. Instead, we could simply split the decorator into two,@batch_reduce
and@parallel_reduce
to avoid this.Finally, does supporting Dask parallelism here make sense? Is this something that should instead be added directly to
QubitDevice.batch_execute
?