Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Transforms shallow copy qnode #4736

Merged
merged 9 commits into from
Oct 26, 2023
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions doc/releases/changelog-0.33.0.md
Original file line number Diff line number Diff line change
Expand Up @@ -382,6 +382,10 @@
decomposition.
[(#4675)](https://github.com/PennyLaneAI/pennylane/pull/4675)

* Shallow copies of the `QNode` now also copy the `execute_kwargs` and transform program. When applying
a transform to a `QNode`, the new qnode is only a shallow copy of the original and thus keeps the same
albi3ro marked this conversation as resolved.
Show resolved Hide resolved
device.

<h3>Breaking changes 💔</h3>

* ``qml.defer_measurements`` now raises an error if a transformed circuit measures ``qml.probs``,
Expand Down
12 changes: 12 additions & 0 deletions pennylane/qnode.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import inspect
import warnings
from collections.abc import Sequence
from copy import copy
from typing import Union
import logging

Expand Down Expand Up @@ -482,6 +483,17 @@ def __init__(
functools.update_wrapper(self, func)
self._transform_program = qml.transforms.core.TransformProgram()

def __copy__(self):
copied_qnode = QNode.__new__(QNode)
for attr, value in vars(self).items():
if attr not in {"execute_kwargs", "_transform_program"}:
albi3ro marked this conversation as resolved.
Show resolved Hide resolved
setattr(copied_qnode, attr, value)

copied_qnode.execute_kwargs = dict(self.execute_kwargs)
copied_qnode._transform_program = qml.transforms.core.TransformProgram(self.transform_program) # pylint: disable=protected-access
copied_qnode.gradient_kwrags = dict(self.gradient_kwargs)
return copied_qnode

def __repr__(self):
"""String representation."""
if isinstance(self.device, qml.devices.Device):
Expand Down
2 changes: 1 addition & 1 deletion pennylane/transforms/core/transform_dispatcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,7 +205,7 @@ def default_qnode_transform(self, qnode, targs, tkwargs):
qnode._original_device.reset()
qnode.device.reset()
albi3ro marked this conversation as resolved.
Show resolved Hide resolved

qnode = copy.deepcopy(qnode)
qnode = copy.copy(qnode)

if self.expand_transform:
qnode.add_transform(TransformContainer(self._expand_transform, targs, tkwargs))
Expand Down
19 changes: 19 additions & 0 deletions tests/test_qnode.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,25 @@ def compute_derivatives(self, circuits, execution_config=None):
"""Device defines its own method to compute derivatives"""
return 0

def test_copy():
"""Test that a shallow copy also copies the execute kwargs, gradient kwargs, and transform program."""
dev = CustomDevice()

qn = qml.QNode(dummyfunc, dev)
copied_qn = copy.copy(qn)
assert copied_qn is not qn
assert copied_qn.execute_kwargs == qn.execute_kwargs
assert copied_qn.execute_kwargs is not qn.execute_kwargs
assert copied_qn.transform_program == qn.transform_program
assert copied_qn.transform_program is not qn.transform_program
assert copied_qn.gradient_kwargs == qn.gradient_kwargs
assert copied_qn.gradient_kwargs is not qn.gradient_kwargs

assert copied_qn.func is qn.func
assert copied_qn.device is qn.device
assert copied_qn.interface is qn.interface
assert copied_qn.diff_method == qn.diff_method
assert copied_qn.expansion_strategy == qn.expansion_strategy

# pylint: disable=too-many-public-methods
class TestValidation:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -182,6 +182,12 @@ def qnode_circuit(a):
qnode_transformed = dispatched_transform(qnode_circuit, 0)
assert not qnode_circuit.transform_program

assert qnode_transformed.device is qnode_circuit.device

with dev.tracker:
qnode_circuit(0.1)
assert dev.tracker.totals['executions'] == 1

assert isinstance(qnode_transformed, qml.QNode)
assert isinstance(qnode_transformed.transform_program, qml.transforms.core.TransformProgram)
assert isinstance(
Expand Down
Loading