-
Notifications
You must be signed in to change notification settings - Fork 575
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[OpRefactor] Add custom decomposition context manager to device #1900
Changes from 27 commits
dfdd2f1
79f4acb
38abe2d
c5f6701
1293802
5644463
577878d
99092a1
507e610
2ddd724
5ae8c62
2e2239e
e3b10e6
0b8985a
bd52d19
19d1366
ea432ac
4cd4915
c21ed79
5824a91
043b125
3d7e455
f55dfc5
54d08ad
7bd73fc
e551e34
68f265a
003a266
4154b04
8005f03
a4e7aec
1a1d7cc
fa13a2f
bff8e48
8a4cd64
831e7ec
201218f
dc01d2a
71016bc
ed24c2c
f61b82a
e9cdab5
a399eb9
d072957
e5769b8
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -168,6 +168,17 @@ def device(name, *args, **kwargs): | |
the `available plugins <https://pennylane.ai/plugins.html>`_ for more | ||
details. | ||
|
||
Args: | ||
name (str): the name of the device to load | ||
wires (int): the number of wires (subsystems) to initialise | ||
the device with | ||
|
||
Keyword Args: | ||
config (pennylane.Configuration): a PennyLane configuration object | ||
that contains global and/or device specific configurations. | ||
custom_decomps (Dict[Union(str, qml.Operator), Callable]): Custom | ||
decompositions to be applied by the device at runtime. | ||
Comment on lines
+172
to
+181
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. thanks for moving this! |
||
|
||
All devices must be loaded by specifying their **short-name** as listed above, | ||
followed by the **wires** (subsystems) you wish to initialize. The *wires* | ||
argument can be an integer, in which case the wires of the device are addressed | ||
|
@@ -217,22 +228,56 @@ def circuit(a): | |
>>> circuit(0.8) # back to default of 10 samples | ||
[ 1 1 1 -1 -1 1 1 1 1 1] | ||
|
||
When constructing a device, we may optionally pass a dictionary of custom | ||
decompositions to be applied to certain operations upon device execution. | ||
This is useful for enabling support of gates on devices where they would normally | ||
be unsupported. | ||
|
||
For example, suppose we are running on an ion trap device which does not | ||
natively implement the CNOT gate, but we would still like to write our | ||
circuits in terms of CNOTs. On a ion trap device, CNOT can be implemented | ||
using the ``IsingXX`` gate. We first define a decomposition function | ||
(such functions have the signature ``decomposition(*params, wires)``): | ||
|
||
.. code-block:: python | ||
|
||
def ion_trap_cnot(wires): | ||
return [ | ||
qml.RY(np.pi/2, wires=wires[0]), | ||
qml.IsingXX(np.pi/2, wires=wires), | ||
qml.RX(-np.pi/2, wires=wires[0]), | ||
qml.RY(-np.pi/2, wires=wires[0]), | ||
qml.RY(-np.pi/2, wires=wires[1]) | ||
] | ||
|
||
Next, we create a device, and a QNode for testing. When constructing the | ||
QNode, we can set the expansion strategy to ``"device"`` to ensure the | ||
decomposition is applied and will be viewable when we draw the circuit. | ||
|
||
.. code-block:: python | ||
|
||
# As the CNOT gate normally has no decomposition, we can use default.qubit | ||
# here for expository purposes. | ||
dev = qml.device( | ||
'default.qubit', wires=2, custom_decomps={"CNOT" : ion_trap_cnot} | ||
) | ||
Comment on lines
+264
to
+266
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. this feels very natural and intuitive 🙌 |
||
|
||
@qml.qnode(dev, expansion_strategy="device") | ||
def run_cnot(): | ||
qml.CNOT(wires=[0, 1]) | ||
return qml.expval(qml.PauliX(wires=1)) | ||
|
||
>>> print(qml.draw(run_cnot)()) | ||
0: ──RY(1.57)──╭IsingXX(1.57)──RX(-1.57)──RY(-1.57)──┤ | ||
1: ────────────╰IsingXX(1.57)──RY(-1.57)─────────────┤ ⟨X⟩ | ||
|
||
Some devices may accept additional arguments. For instance, | ||
``default.gaussian`` accepts the keyword argument ``hbar``, to set | ||
the convention used in the commutation relation :math:`[\x,\p]=i\hbar` | ||
(by default set to 2). | ||
|
||
Please refer to the documentation for the individual devices to see any | ||
additional arguments that might be required or supported. | ||
|
||
Args: | ||
name (str): the name of the device to load | ||
wires (int): the number of wires (subsystems) to initialise | ||
the device with | ||
|
||
Keyword Args: | ||
config (pennylane.Configuration): a PennyLane configuration object | ||
that contains global and/or device specific configurations. | ||
""" | ||
if name not in plugin_devices: | ||
# Device does not exist in the loaded device list. | ||
|
@@ -254,6 +299,10 @@ def circuit(a): | |
options.update(config[name.split(".")[0] + ".global"]) | ||
options.update(config[name]) | ||
|
||
# Pop the custom decomposition keyword argument; we will use it here | ||
# only and not pass it to the device. | ||
custom_decomps = kwargs.pop("custom_decomps", None) | ||
|
||
kwargs.pop("config", None) | ||
options.update(kwargs) | ||
|
||
|
@@ -268,8 +317,18 @@ def circuit(a): | |
) | ||
) | ||
|
||
# load device | ||
return plugin_device_class(*args, **options) | ||
# Construct the device | ||
dev = plugin_device_class(*args, **options) | ||
|
||
# Once the device is constructed, we set its custom expansion function if | ||
# any custom decompositions were specified. | ||
if custom_decomps: | ||
glassnotes marked this conversation as resolved.
Show resolved
Hide resolved
|
||
custom_decomp_expand_fn = pennylane.transforms.create_decomp_expand_fn( | ||
custom_decomps, dev | ||
) | ||
dev.custom_expand(custom_decomp_expand_fn) | ||
|
||
return dev | ||
|
||
raise DeviceError("Device does not exist. Make sure the required plugin is installed.") | ||
|
||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -13,6 +13,8 @@ | |
# limitations under the License. | ||
"""This module contains tape expansion functions and stopping criteria to | ||
generate such functions from.""" | ||
# pylint: disable=unused-argument | ||
import contextlib | ||
|
||
import pennylane as qml | ||
from pennylane.operation import ( | ||
|
@@ -25,6 +27,11 @@ | |
not_tape, | ||
) | ||
|
||
# Needed for custom decomposition context manager | ||
from pennylane.transforms.qfunc_transforms import NonQueuingTape | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I only learnt about this recently :) There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The |
||
|
||
NonQueuingTape = type("NonQueuingTape", (NonQueuingTape, qml.tape.QuantumTape), {}) | ||
mariaschuld marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
|
||
def _update_trainable_params(tape): | ||
params = tape.get_parameters(trainable_only=False) | ||
|
@@ -180,3 +187,79 @@ def expand_fn(tape, _depth=depth, **kwargs): | |
stop_at=not_tape | is_measurement | (~is_trainable) | has_grad_method, | ||
docstring=_expand_invalid_trainable_doc, | ||
) | ||
|
||
|
||
@contextlib.contextmanager | ||
def _custom_decomp_context(custom_decomps): | ||
"""A context manager for applying custom decompositions of operations.""" | ||
|
||
# Creates an individual context | ||
@contextlib.contextmanager | ||
def _custom_decomposition(obj, fn): | ||
# Covers the case where the user passes a string to indicate the Operator | ||
if isinstance(obj, str): | ||
obj = getattr(qml, obj) | ||
|
||
original_decomp_method = obj.decompose | ||
|
||
# This is the method that will override the operations .decompose method | ||
def new_decomp_method(self): | ||
with NonQueuingTape() as tape: | ||
if self.num_params == 0: | ||
return fn(self.wires) | ||
return fn(*self.parameters, self.wires) | ||
return tape | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I am momentarily confused, will this line ever be used? Since there is a return function above without an if fork? Maybe this is why the coverage complains? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Whoa!! No, it won't be, removing it has absolutely no effect. I swear, when I checked the coverage yesterday everything was passing 😕 |
||
|
||
try: | ||
# Explicitly set the new .decompose method | ||
obj.decompose = new_decomp_method | ||
yield | ||
|
||
finally: | ||
obj.decompose = original_decomp_method | ||
|
||
# Loop through the decomposition dictionary and create all the contexts | ||
try: | ||
with contextlib.ExitStack() as stack: | ||
for obj, fn in custom_decomps.items(): | ||
# We enter a new context for each decomposition the user passes | ||
stack.enter_context(_custom_decomposition(obj, fn)) | ||
|
||
stack = stack.pop_all() | ||
|
||
yield | ||
|
||
finally: | ||
stack.close() | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Ingenious, these context managers! |
||
|
||
|
||
def create_decomp_expand_fn(custom_decomps, dev): | ||
"""Creates a custom expansion function for a device that applies | ||
a set of specified custom decompositions. | ||
|
||
Args: | ||
custom_decomps (Dict[Union(str, qml.operation.Operation), Callable]): Custom | ||
decompositions to be applied by the device at runtime. | ||
dev (qml.Device): A quantum device. | ||
|
||
Returns: | ||
Callable: A custom expansion function that a device can call to expand | ||
its tapes within a context manager that applies custom decompositions. | ||
mariaschuld marked this conversation as resolved.
Show resolved
Hide resolved
|
||
""" | ||
custom_op_names = [op if isinstance(op, str) else op.__name__ for op in custom_decomps.keys()] | ||
|
||
# Create a new expansion function; stop at things that do not have | ||
# custom decompositions, or that satisfy the regular device stopping criteria | ||
custom_fn = qml.transforms.create_expand_fn( | ||
depth=10, | ||
stop_at=qml.BooleanFn(lambda obj: obj.name not in custom_op_names), | ||
device=dev, | ||
) | ||
|
||
# Finally, we set the device's custom_expand_fn to a new one that | ||
# runs in a context where the decompositions have been replaced. | ||
def custom_decomp_expand(self, circuit, max_expansion=10): | ||
with _custom_decomp_context(custom_decomps): | ||
return custom_fn(circuit, max_expansion) | ||
mariaschuld marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
return custom_decomp_expand |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I spent a long time looking at this, but it works out the recursive relationship between the custom CNOT and the custom Hadamard really well! 🎉
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That's so wonderful @glassnotes, it feels right to throw this into the device!