Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Deprecate fancy decorator syntax in batch transforms #4457

Merged
merged 17 commits into from
Sep 6, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 23 additions & 0 deletions doc/development/deprecations.rst
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,29 @@ Completed deprecation cycles
- Deprecated in v0.32
- Removed in v0.33

* The following decorator syntax for transforms has been deprecated:

.. code-block:: python

@transform_fn(**transform_kwargs)
@qml.qnode(dev)
def circuit():
...

If you are using a transform that has supporting ``transform_kwargs``, please call the
transform directly using ``circuit = transform_fn(circuit, **transform_kwargs)``,
or use ``functools.partial``:

.. code-block:: python

@functools.partial(transform_fn, **transform_kwargs)
@qml.qnode(dev)
def circuit():
...

- Deprecated in v0.33
- Will be removed in v0.34

* The ``mode`` keyword argument in ``QNode`` has been removed, as it was only used in the old return
system (which has also been removed). Please use ``grad_on_execution`` instead.

Expand Down
18 changes: 18 additions & 0 deletions doc/releases/changelog-dev.md
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,24 @@
``StatePrepBase`` operations should be placed at the beginning of the `ops` list instead.
[(#4554)](https://github.com/PennyLaneAI/pennylane/pull/4554)

* The following decorator syntax for transforms has been deprecated and will raise a warning:
```python
@transform_fn(**transform_kwargs)
@qml.qnode(dev)
def circuit():
...
```
If you are using a transform that has supporting `transform_kwargs`, please call the
transform directly using `circuit = transform_fn(circuit, **transform_kwargs)`,
or use `functools.partial`:
```python
@functools.partial(transform_fn, **transform_kwargs)
@qml.qnode(dev)
def circuit():
...
```
[(#4457)](https://github.com/PennyLaneAI/pennylane/pull/4457/)

<h3>Documentation 📝</h3>

* Minor documentation improvements to the new device API. The documentation now correctly states that interface-specific
Expand Down
26 changes: 24 additions & 2 deletions pennylane/transforms/core/transform_dispatcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
This module contains the transform function, the transform dispatcher and the transform container.
"""
import copy
import warnings
import types

import pennylane as qml
Expand Down Expand Up @@ -72,10 +73,31 @@ def __call__(self, *targs, **tkwargs):
if callable(obj):
return self._qfunc_transform(obj, targs, tkwargs)

raise TransformError(
"The object on which the transform is applied is not valid. It can only be a tape, a QNode or a qfunc."
# Input is not a QNode nor a quantum tape nor a device.
# Assume Python decorator syntax:
#
# result = some_transform(*transform_args)(qnode)(*qnode_args)

warnings.warn(
"Decorating a QNode with @transform_fn(**transform_kwargs) has been "
"deprecated and will be removed in a future version. Please decorate "
"with @functools.partial(transform_fn, **transform_kwargs) instead, "
"or call the transform directly using qnode = transform_fn(qnode, **transform_kwargs)",
UserWarning,
)

if obj is not None:
targs = (obj, *targs)

def wrapper(obj):
return self(obj, *targs, **tkwargs)

wrapper.__doc__ = (
f"Partial of transform {self._transform} with bound arguments and keyword arguments."
)

return wrapper

@property
def transform(self):
"""Return the quantum transform."""
Expand Down
40 changes: 28 additions & 12 deletions tests/transforms/test_experimental/test_transform_dispatcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,7 +174,7 @@ def qnode_circuit(a):
assert not dispatched_transform.is_informative

@pytest.mark.parametrize("valid_transform", valid_transforms)
def test_integration_dispatcher_with_valid_transform_decorator(self, valid_transform):
def test_integration_dispatcher_with_valid_transform_decorator_partial(self, valid_transform):
timmysilv marked this conversation as resolved.
Show resolved Hide resolved
"""Test that no error is raised with the transform function and that the transform dispatcher returns
the right object."""

Expand All @@ -197,6 +197,33 @@ def qnode_circuit(a):
qnode_circuit.transform_program.pop_front(), qml.transforms.core.TransformContainer
)

@pytest.mark.parametrize("valid_transform", valid_transforms)
def test_integration_dispatcher_with_valid_transform_decorator(self, valid_transform):
"""Test that a warning is raised with the transform function and that the transform dispatcher returns
the right object."""

dispatched_transform = transform(valid_transform)
targs = [0]

msg = r"Decorating a QNode with @transform_fn\(\*\*transform_kwargs\) has been deprecated"
with pytest.warns(UserWarning, match=msg):

@dispatched_transform(targs)
@qml.qnode(device=dev)
def qnode_circuit(a):
"""QNode circuit."""
qml.Hadamard(wires=0)
qml.CNOT(wires=[0, 1])
qml.PauliX(wires=0)
qml.RZ(a, wires=1)
return qml.expval(qml.PauliZ(wires=0))

assert isinstance(qnode_circuit, qml.QNode)
assert isinstance(qnode_circuit.transform_program, qml.transforms.core.TransformProgram)
assert isinstance(
qnode_circuit.transform_program.pop_front(), qml.transforms.core.TransformContainer
)

def test_queuing_qfunc_transform(self):
"""Test that queuing works with the transformed quantum function."""

Expand Down Expand Up @@ -339,17 +366,6 @@ def test_cotransform_not_implemented(self):
):
transform(first_valid_transform, classical_cotransform=non_callable)

def test_apply_dispatched_transform_non_valid_obj(self):
"""Test that applying a dispatched function on a non-valid object raises an error."""
dispatched_transform = transform(first_valid_transform)
obj = qml.RX(0.1, wires=0)
with pytest.raises(
TransformError,
match="The object on which the transform is applied is not valid. It can only be a tape, a QNode or a "
"qfunc.",
):
dispatched_transform(obj)

def test_qfunc_transform_multiple_tapes(self):
"""Test that quantum function is not compatible with multiple tapes."""
dispatched_transform = transform(second_valid_transform)
Expand Down
Loading