Skip to content

Commit

Permalink
Make qml.simplify a transform (#4949)
Browse files Browse the repository at this point in the history
**Context:**
`qml.simplify` can be applied to operations, measurements, tapes,
qnodes, or qfuncs. It seems appropriate for it to be a transform. So I
made it one.

**Description of the Change:**
* Added `_simplify_transform` with a `transform` decorator to handle
transforming tapes, qnodes, and callables. The public `simplify`
function directly handles operators and measurements, and dispatches to
`_simplify_transform` for other valid inputs.

**Benefits:**
`qml.simplify` is consistent with other op functions/transforms.

**Possible Drawbacks:**
When given a tape, `qml.simplify` used to return a tape, but now it
returns a list with one tape and a processing function. Small breaking
change, but this was also a breaking change for the other op transforms,
so I don't think this is a major deal.

**Related GitHub Issues:**

[sc-52122]

---------

Co-authored-by: Romain Moyard <rmoyard@gmail.com>
  • Loading branch information
mudit2812 and rmoyard committed Jan 19, 2024
1 parent a6bc314 commit def8f86
Show file tree
Hide file tree
Showing 3 changed files with 40 additions and 44 deletions.
3 changes: 3 additions & 0 deletions doc/releases/changelog-dev.md
Original file line number Diff line number Diff line change
Expand Up @@ -383,6 +383,9 @@

<h4>Other improvements</h4>

* `qml.simplify` now uses the new transforms API.
[(#4949)](https://github.com/PennyLaneAI/pennylane/pull/4949)

* The formal requirement that type hinting be providing when using
the `qml.transform` decorator has been removed. Type hinting can still
be used, but is now optional. Please use a type checker such as
Expand Down
68 changes: 29 additions & 39 deletions pennylane/ops/functions/simplify.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,27 +15,29 @@
This module contains the qml.simplify function.
"""
from copy import copy
from functools import wraps
from typing import Callable, Union
from typing import Callable, Union, Sequence

import pennylane as qml
from pennylane.measurements import MeasurementProcess
from pennylane.operation import Operator
from pennylane.qnode import QNode
from pennylane.queuing import QueuingManager
from pennylane.tape import QuantumScript, make_qscript, QuantumTape
from pennylane.tape import QuantumScript, QuantumTape


def simplify(input: Union[Operator, MeasurementProcess, QuantumTape, QNode, Callable]):
"""Simplifies an operator, tape, qnode or quantum function by reducing its arithmetic depth
or number of rotation parameters.
Args:
input (.Operator, pennylane.QNode, .QuantumTape, or Callable): an operator, quantum node,
tape or function that applies quantum operations
input (.Operator, .MeasurementProcess, pennylane.QNode, .QuantumTape, or Callable): an
operator, quantum node, tape or function that applies quantum operations
Returns:
(.Operator, pennylane.QNode, .QuantumTape, or Callable): Simplified input.
(Operator or MeasurementProcess or qnode (QNode) or quantum function (Callable)
or tuple[List[QuantumTape], function]): Simplified input. If an operator or measurement
process is provided as input, the simplified input is returned directly. Otherwise, the
transformed circuit is returned as described in :func:`qml.transform <pennylane.transform>`.
**Example**
Expand Down Expand Up @@ -70,11 +72,11 @@ def simplify(input: Union[Operator, MeasurementProcess, QuantumTape, QNode, Call
Moreover, ``qml.simplify`` can be used to simplify QNodes or quantum functions:
>>> dev = qml.device("default.qubit", wires=2)
>>> @qml.simplify
@qml.qnode(dev)
def circuit():
qml.adjoint(qml.prod(qml.RX(1, 0) ** 1, qml.RY(1, 0), qml.RZ(1, 0)))
return qml.probs(wires=0)
>>> @qml.qnode(dev)
... @qml.simplify
... def circuit():
... qml.adjoint(qml.prod(qml.RX(1, 0) ** 1, qml.RY(1, 0), qml.RZ(1, 0)))
... return qml.probs(wires=0)
>>> circuit()
tensor([0.64596329, 0.35403671], requires_grad=True)
>>> list(circuit.tape)
Expand All @@ -89,33 +91,21 @@ def circuit():
return qml.apply(new_op)
return input.simplify()

if isinstance(input, QuantumScript):
return input.__class__(
[op.simplify() for op in input.operations],
[m.simplify() for m in input.measurements],
shots=input.shots,
)

if callable(input):
old_qfunc = input.func if isinstance(input, QNode) else input

@wraps(old_qfunc)
def qfunc(*args, **kwargs):
qs = make_qscript(old_qfunc)(*args, **kwargs)
_ = [qml.simplify(op) for op in qs.operations]
m = tuple(qml.simplify(m) for m in qs.measurements)
return m[0] if len(m) == 1 else m

if isinstance(input, QNode):
return QNode(
func=qfunc,
device=input.device,
interface=input.interface,
diff_method=input.diff_method,
expansion_strategy=input.expansion_strategy,
**input.execute_kwargs,
**input.gradient_kwargs,
)
return qfunc
if isinstance(input, QuantumScript) or callable(input):
return _simplify_transform(input)

raise ValueError(f"Cannot simplify the object {input} of type {type(input)}.")


@qml.transform
def _simplify_transform(tape: QuantumTape) -> (Sequence[QuantumTape], Callable):
with qml.QueuingManager.stop_recording():
new_operations = [op.simplify() for op in tape.operations]
new_measurements = [m.simplify() for m in tape.measurements]

new_tape = type(tape)(new_operations, new_measurements, shots=tape.shots)

def null_processing_fn(res):
return res[0]

return [new_tape], null_processing_fn
13 changes: 8 additions & 5 deletions tests/ops/functions/test_simplify.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ def test_simplify_tape(self, shots):
build_op()

tape = QuantumScript.from_queue(q_tape, shots=shots)
s_tape = qml.simplify(tape)
[s_tape], _ = qml.simplify(tape)
assert len(s_tape) == 1
s_op = s_tape[0]
assert isinstance(s_op, qml.ops.Prod) # pylint: disable=no-member
Expand All @@ -111,7 +111,7 @@ def test_execute_simplified_tape(self):

tape = QuantumScript.from_queue(q_tape)
simplified_tape_op = qml.PauliZ(1)
s_tape = qml.simplify(tape)
[s_tape], _ = qml.simplify(tape)
s_op = s_tape.operations[0]
assert isinstance(s_op, qml.PauliZ)
assert s_op.data == simplified_tape_op.data
Expand All @@ -138,9 +138,12 @@ def qnode():

s_qnode = qml.simplify(qnode)
assert s_qnode() == qnode()
assert len(s_qnode.tape) == 2
s_op = s_qnode.tape.operations[0]
s_obs = s_qnode.tape.observables[0]

[s_tape], _ = s_qnode.transform_program([s_qnode.tape])
assert len(s_tape) == 2

s_op = s_tape.operations[0]
s_obs = s_tape.observables[0]
assert isinstance(s_op, qml.PauliZ)
assert s_op.data == simplified_tape_op.data
assert s_op.wires == simplified_tape_op.wires
Expand Down

0 comments on commit def8f86

Please sign in to comment.