Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Deprecate fancy decorator syntax in batch transforms #4457

Merged
merged 17 commits into from
Sep 6, 2023
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 21 additions & 0 deletions doc/development/deprecations.rst
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,27 @@ Pending deprecations
- Deprecated in v0.32
- Will be removed in v0.33

* The following decorator syntax for transforms has been deprecated:

.. code-block:: python

@transform_fn(*transform_args)
@qml.qnode(dev)
def circuit():
...

Please call the transform directly using ``circuit = transform_fn(circuit, *transform_args)``,
eddddddy marked this conversation as resolved.
Show resolved Hide resolved
or use ``functools.partial``:

.. code-block:: python

@functools.partial(transform_fn, *transform_args)
eddddddy marked this conversation as resolved.
Show resolved Hide resolved
@qml.qnode(dev)
def circuit():
...

- Deprecated in v0.32
eddddddy marked this conversation as resolved.
Show resolved Hide resolved
- Will be removed in v0.34

Completed deprecation cycles
----------------------------
Expand Down
17 changes: 17 additions & 0 deletions doc/releases/changelog-dev.md
Original file line number Diff line number Diff line change
Expand Up @@ -245,6 +245,23 @@ array([False, False])
changes to operator equality and hashing.
[(#4144)](https://github.com/PennyLaneAI/pennylane/pull/4144)

* The following decorator syntax for transforms has been deprecated and will raise a warning:
```python
@transform_fn(*transform_args)
@qml.qnode(dev)
def circuit():
...
```
Please call the transform directly using `circuit = transform_fn(circuit, *transform_args)`,
or use `functools.partial`:
```python
@functools.partial(transform_fn, *transform_args)
@qml.qnode(dev)
def circuit():
...
```
[]()
eddddddy marked this conversation as resolved.
Show resolved Hide resolved

<h3>Documentation 📝</h3>

* The `qml.pulse.transmon_interaction` and `qml.pulse.transmon_drive` documentation has been updated.
Expand Down
8 changes: 8 additions & 0 deletions pennylane/transforms/batch_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -384,6 +384,14 @@ def _construct(args, kwargs):
# ...
# result = circuit(*qnode_args)

warnings.warn(
"The decorator syntax transform_fn(*transform_args)(qnode) has been "
"deprecated and will be removed in a future version. Please use either "
"transform_fn(qnode, *transform_args) or "
"functools.partial(transform_fn, *transform_args)(qnode) instead.",
UserWarning,
)

# Prepend the input to the transform args,
# and create a wrapper function.
if qnode is not None:
Expand Down
23 changes: 21 additions & 2 deletions pennylane/transforms/core/transform_dispatcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@
This module contains the transform function, the transform dispatcher and the transform container.
"""
import copy
import warnings

import pennylane as qml


Expand Down Expand Up @@ -68,10 +70,27 @@ def __call__(self, *targs, **tkwargs):
if callable(obj):
return self._qfunc_transform(obj, targs, tkwargs)

raise TransformError(
"The object on which the transform is applied is not valid. It can only be a tape, a QNode or a qfunc."
# Input is not a QNode nor a quantum tape nor a device.
# Assume Python decorator syntax:
#
# result = some_transform(*transform_args)(qnode)(*qnode_args)

warnings.warn(
"The decorator syntax transform_fn(*transform_args)(qnode) has been "
eddddddy marked this conversation as resolved.
Show resolved Hide resolved
"deprecated and will be removed in a future version. Please use either "
"transform_fn(qnode, *transform_args) or "
eddddddy marked this conversation as resolved.
Show resolved Hide resolved
"functools.partial(transform_fn, *transform_args)(qnode) instead.",
UserWarning,
)

if obj is not None:
targs = (obj, *targs)

def wrapper(obj):
return self(obj, *targs, **tkwargs)

return wrapper
eddddddy marked this conversation as resolved.
Show resolved Hide resolved

@property
def transform(self):
"""Return the quantum transform."""
Expand Down
20 changes: 12 additions & 8 deletions tests/transforms/test_batch_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -320,7 +320,8 @@ def test_parametrized_transform_tape_decorator(self):
qml.expval(qml.PauliX(0))

tape = qml.tape.QuantumScript.from_queue(q)
tapes, _ = self.my_transform(a, b)(tape) # pylint: disable=no-value-for-parameter
with pytest.warns(UserWarning, match="The decorator syntax"):
tapes, _ = self.my_transform(a, b)(tape) # pylint: disable=no-value-for-parameter

assert len(tapes[0].operations) == 2
assert tapes[0].operations[0].name == "Hadamard"
Expand Down Expand Up @@ -372,7 +373,8 @@ def test_parametrized_transform_device_decorator(self, mocker):
x = 0.543

dev = qml.device("default.qubit", wires=1)
dev = self.my_transform(a, b)(dev) # pylint: disable=no-value-for-parameter
with pytest.warns(UserWarning, match="The decorator syntax"):
dev = self.my_transform(a, b)(dev) # pylint: disable=no-value-for-parameter

@qml.qnode(dev, interface="autograd")
def circuit(x):
Expand Down Expand Up @@ -441,12 +443,14 @@ def test_parametrized_transform_qnode_decorator(self, mocker):

dev = qml.device("default.qubit", wires=2)

@self.my_transform(a, b) # pylint: disable=no-value-for-parameter
@qml.qnode(dev)
def circuit(x):
qml.Hadamard(wires=0)
qml.RX(x, wires=0)
return qml.expval(qml.PauliX(0))
with pytest.warns(UserWarning, match="The decorator syntax"):

@self.my_transform(a, b) # pylint: disable=no-value-for-parameter
@qml.qnode(dev)
def circuit(x):
qml.Hadamard(wires=0)
qml.RX(x, wires=0)
return qml.expval(qml.PauliX(0))

spy = mocker.spy(self.my_transform, "construct")
res = circuit(x)
Expand Down
39 changes: 27 additions & 12 deletions tests/transforms/test_experimental/test_transform_dispatcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,7 +174,7 @@ def qnode_circuit(a):
assert not dispatched_transform.is_informative

@pytest.mark.parametrize("valid_transform", valid_transforms)
def test_integration_dispatcher_with_valid_transform_decorator(self, valid_transform):
def test_integration_dispatcher_with_valid_transform_decorator_partial(self, valid_transform):
timmysilv marked this conversation as resolved.
Show resolved Hide resolved
"""Test that no error is raised with the transform function and that the transform dispatcher returns
the right object."""

Expand All @@ -197,6 +197,32 @@ def qnode_circuit(a):
qnode_circuit.transform_program.pop_front(), qml.transforms.core.TransformContainer
)

@pytest.mark.parametrize("valid_transform", valid_transforms)
def test_integration_dispatcher_with_valid_transform_decorator(self, valid_transform):
"""Test that a warning is raised with the transform function and that the transform dispatcher returns
the right object."""

dispatched_transform = transform(valid_transform)
targs = [0]

with pytest.warns(UserWarning, match="The decorator syntax"):

@dispatched_transform(targs)
@qml.qnode(device=dev)
def qnode_circuit(a):
"""QNode circuit."""
qml.Hadamard(wires=0)
qml.CNOT(wires=[0, 1])
qml.PauliX(wires=0)
qml.RZ(a, wires=1)
return qml.expval(qml.PauliZ(wires=0))

assert isinstance(qnode_circuit, qml.QNode)
assert isinstance(qnode_circuit.transform_program, qml.transforms.core.TransformProgram)
assert isinstance(
qnode_circuit.transform_program.pop_front(), qml.transforms.core.TransformContainer
)

def test_queuing_qfunc_transform(self):
"""Test that queuing works with the transformed quantum function."""

Expand Down Expand Up @@ -339,17 +365,6 @@ def test_cotransform_not_implemented(self):
):
transform(first_valid_transform, classical_cotransform=non_callable)

def test_apply_dispatched_transform_non_valid_obj(self):
"""Test that applying a dispatched function on a non-valid object raises an error."""
dispatched_transform = transform(first_valid_transform)
obj = qml.RX(0.1, wires=0)
with pytest.raises(
TransformError,
match="The object on which the transform is applied is not valid. It can only be a tape, a QNode or a "
"qfunc.",
):
dispatched_transform(obj)

def test_qfunc_transform_multiple_tapes(self):
"""Test that quantum function is not compatible with multiple tapes."""
dispatched_transform = transform(second_valid_transform)
Expand Down
Loading