diff --git a/.github/CHANGELOG.md b/.github/CHANGELOG.md
index 41ab0a209de..98050c024a0 100644
--- a/.github/CHANGELOG.md
+++ b/.github/CHANGELOG.md
@@ -2,6 +2,69 @@
New features since last release
+* Adds a decorator `@qml.qfunc_transform` to easily create a transformation
+ that modifies the behaviour of a quantum function.
+ [(#1315)](https://github.com/PennyLaneAI/pennylane/pull/1315)
+
+ For example, consider the following transform, which scales the parameter of
+ all `RX` gates by :math:`x \rightarrow \sin(a) \sqrt{x}`, and the parameters
+ of all `RY` gates by :math:`y \rightarrow \cos(a * b) y`:
+
+ ```python
+ @qml.qfunc_transform
+ def my_transform(tape, a, b):
+ for op in tape.operations + tape.measurements:
+ if op.name == "RX":
+ x = op.parameters[0]
+ qml.RX(qml.math.sin(a) * qml.math.sqrt(x), wires=op.wires)
+ elif op.name == "RY":
+ y = op.parameters[0]
+ qml.RX(qml.math.cos(a * b) * y, wires=op.wires)
+ else:
+ op.queue()
+ ```
+
+ We can now apply this transform to any quantum function:
+
+ ```python
+ dev = qml.device("default.qubit", wires=2)
+
+ def ansatz(x):
+ qml.Hadamard(wires=0)
+ qml.RX(x[0], wires=0)
+ qml.RY(x[1], wires=1)
+ qml.CNOT(wires=[0, 1])
+
+ @qml.qnode(dev)
+ def circuit(params, transform_weights):
+ qml.RX(0.1, wires=0)
+
+ # apply the transform to the ansatz
+ my_transform(*transform_weights)(ansatz)(params)
+
+ return qml.expval(qml.PauliZ(1))
+ ```
+
+ We can print this QNode to show that the qfunc transform is taking place:
+
+ ```pycon
+ >>> x = np.array([0.5, 0.3], requires_grad=True)
+ >>> transform_weights = np.array([0.1, 0.6], requires_grad=True)
+ >>> print(qml.draw(circuit)(x, transform_weights))
+ 0: ──RX(0.1)────H──RX(0.0706)──╭C──┤
+ 1: ──RX(0.299)─────────────────╰X──┤ ⟨Z⟩
+ ```
+
+ Evaluating the QNode, as well as the derivative, with respect to the gate
+ parameter *and* the transform weights:
+
+ ```pycon
+ >>> circuit(x, transform_weights)
+ 0.006728293438238053
+ >>> qml.grad(circuit)(x, transform_weights)
+ (array([ 0.00671711, -0.00207359]), array([6.69695008e-02, 3.73694364e-06]))
+ ```
+
* Added validation for noise channel parameters. Invalid noise parameters now
raise a `ValueError`. [(#1357)](https://github.com/PennyLaneAI/pennylane/pull/1357)
diff --git a/pennylane/__init__.py b/pennylane/__init__.py
index 8d05c46956e..061c59836bd 100644
--- a/pennylane/__init__.py
+++ b/pennylane/__init__.py
@@ -49,6 +49,8 @@
ctrl,
measurement_grouping,
metric_tensor,
+ qfunc_transform,
+ single_tape_transform,
)
from pennylane.utils import inv
from pennylane.vqe import ExpvalCost, Hamiltonian, VQECost
diff --git a/pennylane/transforms/__init__.py b/pennylane/transforms/__init__.py
index a1d9a3a2c37..07dc9015953 100644
--- a/pennylane/transforms/__init__.py
+++ b/pennylane/transforms/__init__.py
@@ -17,10 +17,13 @@
.. currentmodule:: pennylane
-QNode transforms
-----------------
+Transforms
+----------
-The following transforms act on QNodes. They return new transformed functions
+Transforms that act on QNodes
+~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+
+Thes transforms accept QNodes, and return new transformed functions
that compute the desired quantity.
.. autosummary::
@@ -30,11 +33,11 @@
~draw
~metric_tensor
-Quantum function transforms
----------------------------
+Transforms that act on quantum functions
+~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
-The following transforms act on quantum functions (Python functions
-containing quantum operations) that are used *inside* QNodes.
+These transforms accept quantum functions (Python functions
+containing quantum operations) that are used to construct QNodes.
.. autosummary::
:toctree: api
@@ -43,10 +46,10 @@
~ctrl
~transforms.invisible
-Tape transforms
----------------
+Transforms that act on tapes
+~~~~~~~~~~~~~~~~~~~~~~~~~~~~
-The following transforms act on quantum tapes, and return one or
+These transforms accept quantum tapes, and return one or
more tapes as well as a classical processing function.
.. autosummary::
@@ -55,12 +58,26 @@
~transforms.measurement_grouping
~transforms.metric_tensor_tape
~transforms.hamiltonian_expand
+
+Decorators and utility functions
+--------------------------------
+
+The following decorators and convenience functions are provided
+to help build custom QNode, quantum function, and tape transforms:
+
+.. autosummary::
+ :toctree: api
+
+ ~single_tape_transform
+ ~qfunc_transform
+ ~transforms.make_tape
"""
from .adjoint import adjoint
from .classical_jacobian import classical_jacobian
from .control import ControlledOperation, ctrl
from .draw import draw
+from .hamiltonian_expand import hamiltonian_expand
from .invisible import invisible
from .measurement_grouping import measurement_grouping
from .metric_tensor import metric_tensor, metric_tensor_tape
-from .hamiltonian_expand import hamiltonian_expand
+from .qfunc_transforms import make_tape, single_tape_transform, qfunc_transform
diff --git a/pennylane/transforms/qfunc_transforms.py b/pennylane/transforms/qfunc_transforms.py
new file mode 100644
index 00000000000..9581bae67aa
--- /dev/null
+++ b/pennylane/transforms/qfunc_transforms.py
@@ -0,0 +1,386 @@
+# Copyright 2018-2021 Xanadu Quantum Technologies Inc.
+
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+
+# http://www.apache.org/licenses/LICENSE-2.0
+
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Contains tools and decorators for registering qfunc transforms."""
+# pylint: disable=too-few-public-methods
+import functools
+import inspect
+
+import pennylane as qml
+
+
+def make_tape(fn):
+ """Returns a function that generates the tape from a quantum function without any
+ operation queuing taking place.
+
+ This is useful when you would like to manipulate or transform
+ the tape created by a quantum function without evaluating it.
+
+ Args:
+ fn (function): the quantum function to generate the tape from
+
+ Returns:
+ function: The returned function takes the same arguments as the quantum
+ function. When called, it returns the generated quantum tape
+ without any queueing occuring.
+
+ **Example**
+
+ Consider the following quantum function:
+
+ .. code-block:: python
+
+ def qfunc(x):
+ qml.Hadamard(wires=0)
+ qml.CNOT(wires=[0, 1])
+ qml.RX(x, wires=0)
+
+ We can use ``make_tape`` to extract the tape generated by this
+ quantum function, without any of the operations being queued by
+ any existing queuing contexts:
+
+ >>> with qml.tape.QuantumTape() as active_tape:
+ ... qml.RY(1.0, wires=0)
+ ... tape = make_tape(qfunc)(0.5)
+ >>> tape.operations
+ [Hadamard(wires=[0]), CNOT(wires=[0, 1]), RX(0.5, wires=[0])]
+
+ Note that the currently recording tape did not queue any of these quantum operations:
+
+ >>> active_tape.operations
+ [RY(1.0, wires=[0])]
+ """
+
+ def wrapper(*args, **kwargs):
+ active_tape = qml.tape.get_active_tape()
+
+ if active_tape is not None:
+ with active_tape.stop_recording(), active_tape.__class__() as tape:
+ fn(*args, **kwargs)
+ else:
+ with qml.tape.QuantumTape() as tape:
+ fn(*args, **kwargs)
+ return tape
+
+ return wrapper
+
+
+class NonQueuingTape(qml.queuing.AnnotatedQueue):
+ """Mixin class that creates a tape that does not queue
+ itself to the current queuing context."""
+
+ def _process_queue(self):
+ super()._process_queue()
+
+ for obj, info in self._queue.items():
+ qml.queuing.QueuingContext.append(obj, **info)
+
+ qml.queuing.QueuingContext.remove(self)
+
+
+class single_tape_transform:
+ """For registering a tape transform that takes a tape and outputs a single new tape.
+
+ Examples of such transforms include circuit compilation.
+
+ Args:
+ transform_fn (function): The function to register as the single tape transform.
+ It can have an arbitrary number of arguments, but the first argument
+ **must** be the input tape.
+
+ **Example**
+
+ A valid single tape transform is a quantum function that satisfies the following:
+
+ - The first argument must be an input tape
+
+ - Depending on the structure of this input tape, various quantum operations, functions,
+ and templates may be called.
+
+ - Any internal classical processing should use the ``qml.math`` module to ensure
+ the transform is differentiable.
+
+ - There is no return statement.
+
+ For example:
+
+ .. code-block:: python
+
+ @qml.single_tape_transform
+ def my_transform(tape, x, y):
+ # loop through all operations on the input tape
+ for op in tape.operations + tape.measurements:
+ if op.name == "CRX":
+ wires = op.wires
+ param = op.parameters[0]
+
+ qml.RX(x * qml.math.abs(param), wires=wires[1])
+ qml.RY(y * qml.math.abs(param), wires=wires[1])
+ qml.CZ(wires=[wires[1], wires[0]])
+ else:
+ op.queue()
+
+ This transform iterates through the input tape, and replaces any :class:`~.CRX` operation with
+ two single qubit rotations and a :class:`~.CZ` operation. These newly queued operations will
+ form the output transformed tape.
+
+ We can apply this transform to a quantum tape:
+
+ >>> with qml.tape.JacobianTape() as tape:
+ ... qml.Hadamard(wires=0)
+ ... qml.CRX(-0.5, wires=[0, 1])
+ >>> new_tape = my_transform(tape, 1., 2.)
+ >>> print(new_tape.draw())
+ 0: ──H───────────────╭Z──┤
+ 1: ──RX(0.5)──RY(1)──╰C──┤
+ """
+
+ def __init__(self, transform_fn):
+
+ if not callable(transform_fn):
+ raise ValueError(
+ f"The tape transform function to register, {transform_fn}, "
+ "does not appear to be a valid Python function or callable."
+ )
+
+ self.transform_fn = transform_fn
+ functools.update_wrapper(self, transform_fn)
+
+ def __call__(self, tape, *args, **kwargs):
+ tape_class = type(tape.__class__.__name__, (NonQueuingTape, tape.__class__), {})
+
+ # new_tape, when first created, is of the class (NonQueuingTape, tape.__class__), so that it
+ # doesn't result in a nested tape
+ with tape_class() as new_tape:
+ self.transform_fn(tape, *args, **kwargs)
+
+ # Once we're done, revert it back to be simply an instance of tape.__class__.
+ new_tape.__class__ = tape.__class__
+ return new_tape
+
+
+def _create_qfunc_internal_wrapper(fn, tape_transform, transform_args, transform_kwargs):
+ """Convenience function to create the internal wrapper function
+ generated by the qfunc_transform decorator"""
+ if not callable(fn):
+ raise ValueError(
+ f"The qfunc to transform, {fn}, does not appear "
+ "to be a valid Python function or callable."
+ )
+
+ @functools.wraps(fn)
+ def internal_wrapper(*args, **kwargs):
+ tape = make_tape(fn)(*args, **kwargs)
+ tape = tape_transform(tape, *transform_args, **transform_kwargs)
+ return tape.measurements
+
+ return internal_wrapper
+
+
+def qfunc_transform(tape_transform):
+ """Given a function which defines a tape transform, convert the function into
+ one that applies the tape transform to quantum functions (qfuncs).
+
+ Args:
+ tape_transform (function or single_tape_transform): the single tape transform
+ to convert into the qfunc transform.
+
+ Returns:
+ function: A qfunc transform, that acts on any qfunc, and returns a *new*
+ qfunc as per the tape transform. Note that if ``tape_transform`` takes
+ additional parameters beyond a single tape, then the created qfunc transform
+ will take the *same* parameters, prior to being applied to the qfunc.
+
+ **Example**
+
+ Given a single tape transform ``my_transform(tape, x, y)``, you can use
+ this function to convert it into a qfunc transform:
+
+ >>> my_qfunc_transform = qfunc_transform(my_transform)
+
+ It can then be used to transform an existing qfunc:
+
+ >>> new_qfunc = my_qfunc_transform(0.6, 0.7)(old_qfunc)
+ >>> new_qfunc(params)
+
+ It can also be used as a decorator:
+
+ .. code-block:: python
+
+ @qml.qfunc_transform
+ def my_transform(tape, x, y):
+ for op in tape.operations + tape.measurements:
+ if op.name == "CRX":
+ wires = op.wires
+ param = op.parameters[0]
+ qml.RX(x * param, wires=wires[1])
+ qml.RY(y * qml.math.sqrt(param), wires=wires[1])
+ qml.CZ(wires=[wires[1], wires[0]])
+ else:
+ op.queue()
+
+ @my_transform(0.6, 0.1)
+ def qfunc(x):
+ qml.Hadamard(wires=0)
+ qml.CRX(x, wires=[0, 1])
+
+ >>> dev = qml.device("default.qubit", wires=2)
+ >>> qnode = qml.QNode(qfunc, dev)
+ >>> print(qml.draw(qnode)(2.5))
+ 0: ──H───────────────────╭Z──┤
+ 1: ──RX(1.5)──RY(0.158)──╰C──┤
+
+ The transform weights provided to a qfunc transform are fully differentiable,
+ allowing the transform itself to be differentiated and trained. For more details,
+ see the Differentiability section under Usage Details.
+
+ .. UsageDetails::
+
+ **Inline usage**
+
+ qfunc transforms, when used inline (that is, not as a decorator), take the following form:
+
+ >>> my_transform(transform_weights)(ansatz)(param)
+
+ or
+
+ >>> my_transform(ansatz)(param)
+
+ if they do not permit any parameters. We can break this down into distinct steps,
+ to show what is happening with each new function call:
+
+ 0. Create a transform defined by the transform weights:
+
+ >>> specific_transform = my_transform(transform_weights)
+
+ Note that this step is skipped if the transform does not provide any
+ weights/parameters that can be modified!
+
+ 1. Apply the transform to the qfunc. A qfunc transform always acts on
+ a qfunc, returning a new qfunc:
+
+ >>> new_qfunc = specific_transform(ansatz)
+
+ 2. Finally, we evaluate the new, transformed, qfunc:
+
+ >>> new_qfunc(params)
+
+ So the syntax
+
+ >>> my_transform(transform_weights)(ansatz)(param)
+
+ simply 'chains' these three steps together, into a single call.
+
+ **Differentiability**
+
+ When applying a qfunc transform, not only is the newly transformed qfunc fully
+ differentiable, but the qfunc transform parameters *themselves* are differentiable.
+ This allows us to train both the quantum function, as well as the transform
+ that created it.
+
+ Consider the following example, where a pre-defined ansatz is transformed
+ within a QNode:
+
+ .. code-block:: python
+
+ dev = qml.device("default.qubit", wires=2)
+
+ def ansatz(x):
+ qml.Hadamard(wires=0)
+ qml.CRX(x, wires=[0, 1])
+
+ @qml.qnode(dev)
+ def circuit(param, transform_weights):
+ qml.RX(0.1, wires=0)
+
+ # apply the transform to the ansatz
+ my_transform(*transform_weights)(ansatz)(param)
+
+ return qml.expval(qml.PauliZ(1))
+
+ We can print this QNode to show that the qfunc transform is taking place:
+
+ >>> x = np.array(0.5, requires_grad=True)
+ >>> y = np.array([0.1, 0.2], requires_grad=True)
+ >>> print(qml.draw(circuit)(x, y))
+ 0: ──RX(0.1)───H──────────╭Z──┤
+ 1: ──RX(0.05)──RY(0.141)──╰C──┤ ⟨Z⟩
+
+ Evaluating the QNode, as well as the derivative, with respect to the gate
+ parameter *and* the transform weights:
+
+ >>> circuit(x, y)
+ 0.9887793925354269
+ >>> qml.grad(circuit)(x, y)
+ (array(-0.02485651), array([-0.02474011, -0.09954244]))
+
+ **Implementation details**
+
+ Internally, the qfunc transform works as follows:
+
+ .. code-block:: python
+
+ def transform(old_qfunc, params):
+ def new_qfunc(*args, **kwargs):
+ # 1. extract the tape from the old qfunc, being
+ # careful *not* to have it queued.
+ tape = make_tape(old_qfunc)(*args, **kwargs)
+
+ # 2. transform the tape
+ new_tape = tape_transform(tape, params)
+
+ # 3. queue the *new* tape to the active queuing context
+ new_tape.queue()
+ return new_qfunc
+
+ *Note: this is pseudocode; the actual implementation is significantly more complicated!*
+
+ Steps (1) and (3) are identical for all qfunc transforms; it is only step (2),
+ ``tape_transform`` and the corresponding tape transform parameters, that define the qfunc
+ transformation.
+
+ That is, given a tape transform that **defines the qfunc transformation**, the
+ decorator **elevates** the tape transform to one that works on quantum functions
+ rather than tapes. This decorator therefore automates the process of adding in
+ the queueing logic required under steps (1) and (3), so that it does not need to be
+ repeated and tested for every new qfunc transform.
+ """
+ if not callable(tape_transform):
+ raise ValueError(
+ "The qfunc_transform decorator can only be applied "
+ "to single tape transform functions."
+ )
+
+ if not isinstance(tape_transform, single_tape_transform):
+ tape_transform = single_tape_transform(tape_transform)
+
+ sig = inspect.signature(tape_transform)
+ params = sig.parameters
+
+ if len(params) > 1:
+
+ @functools.wraps(tape_transform)
+ def make_qfunc_transform(*targs, **tkwargs):
+ def wrapper(fn):
+ return _create_qfunc_internal_wrapper(fn, tape_transform, targs, tkwargs)
+
+ return wrapper
+
+ elif len(params) == 1:
+
+ @functools.wraps(tape_transform)
+ def make_qfunc_transform(fn):
+ return _create_qfunc_internal_wrapper(fn, tape_transform, [], {})
+
+ make_qfunc_transform.tape_fn = tape_transform
+ return make_qfunc_transform
diff --git a/tests/transforms/test_qfunc_transform.py b/tests/transforms/test_qfunc_transform.py
new file mode 100644
index 00000000000..f7d157de5ae
--- /dev/null
+++ b/tests/transforms/test_qfunc_transform.py
@@ -0,0 +1,395 @@
+# Copyright 2018-2021 Xanadu Quantum Technologies Inc.
+
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+
+# http://www.apache.org/licenses/LICENSE-2.0
+
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""
+Unit tests for the qfunc transform decorators.
+"""
+import pytest
+
+import pennylane as qml
+from pennylane import numpy as np
+
+
+class TestSingleTapeTransform:
+ """Tests for the single_tape_transform decorator"""
+
+ def test_error_invalid_callable(self):
+ """Test that an error is raised if the transform
+ is applied to an invalid function"""
+
+ with pytest.raises(ValueError, match="does not appear to be a valid Python function"):
+ qml.single_tape_transform(5)
+
+ def test_parametrized_transform(self):
+ """Test that a parametrized transform can be applied
+ to a tape"""
+
+ @qml.single_tape_transform
+ def my_transform(tape, a, b):
+ for op in tape.operations + tape.measurements:
+ if op.name == "CRX":
+ wires = op.wires
+ param = op.parameters[0]
+ qml.RX(a, wires=wires[1])
+ qml.RY(qml.math.sum(b) * param / 2, wires=wires[1])
+ qml.CZ(wires=[wires[1], wires[0]])
+ else:
+ op.queue()
+
+ a = 0.1
+ b = np.array([0.2, 0.3])
+ x = 0.543
+
+ with qml.tape.QuantumTape() as tape:
+ qml.Hadamard(wires=0)
+ qml.CRX(x, wires=[0, 1])
+
+ ops = my_transform(tape, a, b).operations
+ assert len(ops) == 4
+ assert ops[0].name == "Hadamard"
+
+ assert ops[1].name == "RX"
+ assert ops[1].parameters == [a]
+
+ assert ops[2].name == "RY"
+ assert ops[2].parameters == [np.sum(b) * x / 2]
+
+ assert ops[3].name == "CZ"
+
+
+class TestQFuncTransforms:
+ """Tests for the qfunc_transform decorator"""
+
+ def test_error_invalid_transform_callable(self):
+ """Test that an error is raised if the transform
+ is applied to an invalid function"""
+
+ with pytest.raises(
+ ValueError, match="can only be applied to single tape transform functions"
+ ):
+ qml.qfunc_transform(5)
+
+ def test_error_invalid_qfunc(self):
+ """Test that an error is raised if the transform
+ is applied to an invalid function"""
+
+ def identity_transform(tape):
+ for op in tape.operations + tape.measurements:
+ op.queue()
+
+ my_transform = qml.qfunc_transform(identity_transform)
+
+ with pytest.raises(ValueError, match="does not appear to be a valid Python function"):
+ my_transform(5)
+
+ def test_unparametrized_transform(self):
+ """Test that an unparametrized transform can be applied
+ to a quantum function"""
+
+ def my_transform(tape):
+ for op in tape.operations + tape.measurements:
+ if op.name == "CRX":
+ wires = op.wires
+ param = op.parameters[0]
+ qml.RX(param, wires=wires[1])
+ qml.RY(param / 2, wires=wires[1])
+ qml.CZ(wires=[wires[1], wires[0]])
+ else:
+ op.queue()
+
+ my_transform = qml.qfunc_transform(my_transform)
+
+ def qfunc(x):
+ qml.Hadamard(wires=0)
+ qml.CRX(x, wires=[0, 1])
+
+ new_qfunc = my_transform(qfunc)
+ x = 0.543
+
+ ops = qml.transforms.make_tape(new_qfunc)(x).operations
+ assert len(ops) == 4
+ assert ops[0].name == "Hadamard"
+
+ assert ops[1].name == "RX"
+ assert ops[1].parameters == [x]
+
+ assert ops[2].name == "RY"
+ assert ops[2].parameters == [x / 2]
+
+ assert ops[3].name == "CZ"
+
+ def test_unparametrized_transform_decorator(self):
+ """Test that an unparametrized transform can be applied
+ to a quantum function via a decorator"""
+
+ @qml.qfunc_transform
+ def my_transform(tape):
+ for op in tape.operations + tape.measurements:
+ if op.name == "CRX":
+ wires = op.wires
+ param = op.parameters[0]
+ qml.RX(param, wires=wires[1])
+ qml.RY(param / 2, wires=wires[1])
+ qml.CZ(wires=[wires[1], wires[0]])
+ else:
+ op.queue()
+
+ @my_transform
+ def qfunc(x):
+ qml.Hadamard(wires=0)
+ qml.CRX(x, wires=[0, 1])
+
+ x = 0.543
+ ops = qml.transforms.make_tape(qfunc)(x).operations
+ assert len(ops) == 4
+ assert ops[0].name == "Hadamard"
+
+ assert ops[1].name == "RX"
+ assert ops[1].parameters == [x]
+
+ assert ops[2].name == "RY"
+ assert ops[2].parameters == [x / 2]
+
+ assert ops[3].name == "CZ"
+
+ def test_parametrized_transform(self):
+ """Test that a parametrized transform can be applied
+ to a quantum function"""
+
+ def my_transform(tape, a, b):
+ for op in tape.operations + tape.measurements:
+ if op.name == "CRX":
+ wires = op.wires
+ param = op.parameters[0]
+ qml.RX(a, wires=wires[1])
+ qml.RY(qml.math.sum(b) * param / 2, wires=wires[1])
+ qml.CZ(wires=[wires[1], wires[0]])
+ else:
+ op.queue()
+
+ my_transform = qml.qfunc_transform(my_transform)
+
+ def qfunc(x):
+ qml.Hadamard(wires=0)
+ qml.CRX(x, wires=[0, 1])
+
+ a = 0.1
+ b = np.array([0.2, 0.3])
+ x = 0.543
+ new_qfunc = my_transform(a, b)(qfunc)
+
+ ops = qml.transforms.make_tape(new_qfunc)(x).operations
+ assert len(ops) == 4
+ assert ops[0].name == "Hadamard"
+
+ assert ops[1].name == "RX"
+ assert ops[1].parameters == [a]
+
+ assert ops[2].name == "RY"
+ assert ops[2].parameters == [np.sum(b) * x / 2]
+
+ assert ops[3].name == "CZ"
+
+ def test_parametrized_transform_decorator(self):
+ """Test that a parametrized transform can be applied
+ to a quantum function via a decorator"""
+
+ @qml.qfunc_transform
+ def my_transform(tape, a, b):
+ for op in tape.operations + tape.measurements:
+ if op.name == "CRX":
+ wires = op.wires
+ param = op.parameters[0]
+ qml.RX(a, wires=wires[1])
+ qml.RY(qml.math.sum(b) * param / 2, wires=wires[1])
+ qml.CZ(wires=[wires[1], wires[0]])
+ else:
+ op.queue()
+
+ a = 0.1
+ b = np.array([0.2, 0.3])
+ x = 0.543
+
+ @my_transform(a, b)
+ def qfunc(x):
+ qml.Hadamard(wires=0)
+ qml.CRX(x, wires=[0, 1])
+
+ ops = qml.transforms.make_tape(qfunc)(x).operations
+ assert len(ops) == 4
+ assert ops[0].name == "Hadamard"
+
+ assert ops[1].name == "RX"
+ assert ops[1].parameters == [a]
+
+ assert ops[2].name == "RY"
+ assert ops[2].parameters == [np.sum(b) * x / 2]
+
+ assert ops[3].name == "CZ"
+
+ def test_nested_transforms(self):
+ """Test that nesting multiple transforms works as expected"""
+
+ @qml.qfunc_transform
+ def convert_cnots(tape):
+ for op in tape.operations + tape.measurements:
+ if op.name == "CNOT":
+ wires = op.wires
+ qml.Hadamard(wires=wires[0])
+ qml.CZ(wires=[wires[0], wires[1]])
+ else:
+ op.queue()
+
+ @qml.qfunc_transform
+ def expand_hadamards(tape, x):
+ for op in tape.operations + tape.measurements:
+ if op.name == "Hadamard":
+ qml.RZ(x, wires=op.wires)
+ else:
+ op.queue()
+
+ x = 0.5
+
+ @expand_hadamards(x)
+ @convert_cnots
+ def ansatz():
+ qml.CNOT(wires=[0, 1])
+
+ ops = qml.transforms.make_tape(ansatz)().operations
+ assert len(ops) == 2
+ assert ops[0].name == "RZ"
+ assert ops[0].parameters == [x]
+ assert ops[1].name == "CZ"
+
+
+############################################
+# Test transform, ansatz, and qfunc function
+
+
+@pytest.mark.parametrize("diff_method", ["parameter-shift", "backprop"])
+class TestQFuncTransformGradients:
+ """Tests for the qfunc_transform decorator differentiability"""
+
+ @staticmethod
+ @qml.qfunc_transform
+ def my_transform(tape, a, b):
+ """Test transform"""
+ for op in tape.operations + tape.measurements:
+ if op.name == "CRX":
+ wires = op.wires
+ param = op.parameters[0]
+ qml.RX(a * param, wires=wires[1])
+ qml.RY(qml.math.sum(b) * qml.math.sqrt(param), wires=wires[1])
+ qml.CZ(wires=[wires[1], wires[0]])
+ else:
+ op.queue()
+
+ @staticmethod
+ def ansatz(x):
+ """Test ansatz"""
+ qml.Hadamard(wires=0)
+ qml.CRX(x, wires=[0, 1])
+
+ @staticmethod
+ def circuit(param, *transform_weights):
+ """Test QFunc"""
+ qml.RX(0.1, wires=0)
+ TestQFuncTransformGradients.my_transform(*transform_weights)(
+ TestQFuncTransformGradients.ansatz
+ )(param)
+ return qml.expval(qml.PauliZ(1))
+
+ @staticmethod
+ def expval(x, a, b):
+ """Analytic expectation value of the above circuit qfunc"""
+ return np.cos(np.sum(b) * np.sqrt(x)) * np.cos(a * x)
+
+ def test_differentiable_qfunc_autograd(self, diff_method):
+ """Test that a qfunc transform is differentiable when using
+ autograd"""
+ dev = qml.device("default.qubit", wires=2)
+ qnode = qml.QNode(self.circuit, dev, interface="autograd", diff_method=diff_method)
+
+ a = np.array(0.5, requires_grad=True)
+ b = np.array([0.1, 0.2], requires_grad=True)
+ x = np.array(0.543, requires_grad=True)
+
+ res = qnode(x, a, b)
+ assert np.allclose(res, self.expval(x, a, b))
+
+ grad = qml.grad(qnode)(x, a, b)
+ expected = qml.grad(self.expval)(x, a, b)
+ assert all(np.allclose(g, e) for g, e in zip(grad, expected))
+
+ def test_differentiable_qfunc_tf(self, diff_method):
+ """Test that a qfunc transform is differentiable when using
+ TensorFlow"""
+ tf = pytest.importorskip("tensorflow")
+ dev = qml.device("default.qubit", wires=2)
+ qnode = qml.QNode(self.circuit, dev, interface="tf", diff_method=diff_method)
+
+ a = tf.Variable(0.5, dtype=tf.float64)
+ b = tf.Variable([0.1, 0.2], dtype=tf.float64)
+ x = tf.Variable(0.543, dtype=tf.float64)
+
+ with tf.GradientTape() as tape:
+ res = qnode(x, a, b)
+
+ assert np.allclose(res, self.expval(x, a, b))
+
+ grad = tape.gradient(res, [x, a, b])
+ expected = qml.grad(self.expval)(x.numpy(), a.numpy(), b.numpy())
+ assert all(np.allclose(g, e) for g, e in zip(grad, expected))
+
+ def test_differentiable_qfunc_torch(self, diff_method):
+ """Test that a qfunc transform is differentiable when using
+ PyTorch"""
+ if diff_method == "backprop":
+ pytest.skip("Does not support backprop mode")
+
+ torch = pytest.importorskip("torch")
+ dev = qml.device("default.qubit", wires=2)
+ qnode = qml.QNode(self.circuit, dev, interface="torch", diff_method=diff_method)
+
+ a = torch.tensor(0.5, requires_grad=True)
+ b = torch.tensor([0.1, 0.2], requires_grad=True)
+ x = torch.tensor(0.543, requires_grad=True)
+
+ res = qnode(x, a, b)
+ expected = self.expval(x.detach().numpy(), a.detach().numpy(), b.detach().numpy())
+ assert np.allclose(res.detach().numpy(), expected)
+
+ res.backward()
+ expected = qml.grad(self.expval)(x.detach().numpy(), a.detach().numpy(), b.detach().numpy())
+ assert np.allclose(x.grad, expected[0])
+ assert np.allclose(a.grad, expected[1])
+ assert np.allclose(b.grad, expected[2])
+
+ def test_differentiable_qfunc_jax(self, diff_method):
+ """Test that a qfunc transform is differentiable when using
+ jax"""
+ jax = pytest.importorskip("jax")
+ dev = qml.device("default.qubit", wires=2)
+ qnode = qml.QNode(self.circuit, dev, interface="jax", diff_method=diff_method)
+
+ a = jax.numpy.array(0.5)
+ b = jax.numpy.array([0.1, 0.2])
+ x = jax.numpy.array(0.543)
+
+ res = qnode(x, a, b)
+ assert np.allclose(res, self.expval(x, a, b))
+
+ grad = jax.grad(qnode, argnums=[0, 1, 2])(x, a, b)
+ expected = qml.grad(self.expval)(np.array(x), np.array(a), np.array(b))
+ print(grad, expected)
+ assert all(np.allclose(g, e) for g, e in zip(grad, expected))