diff --git a/doc/releases/changelog-dev.md b/doc/releases/changelog-dev.md
index 812e35faafd..71c11cbab8c 100644
--- a/doc/releases/changelog-dev.md
+++ b/doc/releases/changelog-dev.md
@@ -382,9 +382,12 @@
decomposition.
[(#4675)](https://github.com/PennyLaneAI/pennylane/pull/4675)
-
Breaking changes 💔
+* ``qml.defer_measurements`` now raises an error if a transformed circuit measures ``qml.probs``,
+ ``qml.sample``, or ``qml.counts`` without any wires or obsrvable, or if it measures ``qml.state``.
+ [(#4701)](https://github.com/PennyLaneAI/pennylane/pull/4701)
+
* The device test suite now converts device kwargs to integers or floats if they can be converted to integers or floats.
[(#4640)](https://github.com/PennyLaneAI/pennylane/pull/4640)
diff --git a/pennylane/devices/default_qubit.py b/pennylane/devices/default_qubit.py
index 52bcd90470b..c1a21cbb4fe 100644
--- a/pennylane/devices/default_qubit.py
+++ b/pennylane/devices/default_qubit.py
@@ -387,8 +387,8 @@ def preprocess(
config = self._setup_execution_config(execution_config)
transform_program = TransformProgram()
- transform_program.add_transform(qml.defer_measurements, device=self)
transform_program.add_transform(validate_device_wires, self.wires, name=self.name)
+ transform_program.add_transform(qml.defer_measurements, device=self)
transform_program.add_transform(
decompose, stopping_condition=stopping_condition, name=self.name
)
diff --git a/pennylane/transforms/defer_measurements.py b/pennylane/transforms/defer_measurements.py
index 1acf9ec29ec..c39bb081957 100644
--- a/pennylane/transforms/defer_measurements.py
+++ b/pennylane/transforms/defer_measurements.py
@@ -14,7 +14,7 @@
"""Code for the tape transform implementing the deferred measurement principle."""
from typing import Sequence, Callable
import pennylane as qml
-from pennylane.measurements import MidMeasureMP
+from pennylane.measurements import MidMeasureMP, ProbabilityMP, SampleMP, CountsMP
from pennylane.ops.op_math import ctrl
from pennylane.tape import QuantumTape
@@ -23,7 +23,66 @@
from pennylane.wires import Wires
from pennylane.queuing import QueuingManager
-# pylint: disable=too-many-branches, too-many-statements
+# pylint: disable=too-many-branches, protected-access
+
+
+def _check_tape_validity(tape: QuantumTape):
+ """Helper function to check that the tape is valid."""
+ cv_types = (qml.operation.CVOperation, qml.operation.CVObservable)
+ ops_cv = any(isinstance(op, cv_types) and op.name != "Identity" for op in tape.operations)
+ obs_cv = any(
+ isinstance(getattr(op, "obs", None), cv_types)
+ and not isinstance(getattr(op, "obs", None), qml.Identity)
+ for op in tape.measurements
+ )
+ if ops_cv or obs_cv:
+ raise ValueError("Continuous variable operations and observables are not supported.")
+
+ for mp in tape.measurements:
+ if isinstance(mp, (CountsMP, ProbabilityMP, SampleMP)) and not (
+ mp.obs or mp._wires or mp.mv
+ ):
+ raise ValueError(
+ f"Cannot use {mp.__class__.__name__} as a measurement without specifying wires "
+ "when using qml.defer_measurements. Deferred measurements can occur "
+ "automatically when using mid-circuit measurements on a device that does not "
+ "support them."
+ )
+
+ if mp.__class__.__name__ == "StateMP":
+ raise ValueError(
+ "Cannot use StateMP as a measurement when using qml.defer_measurements. "
+ "Deferred measurements can occur automatically when using mid-circuit "
+ "measurements on a device that does not support them."
+ )
+
+
+def _collect_mid_measure_info(tape: QuantumTape):
+ """Helper function to collect information related to mid-circuit measurements in the tape."""
+
+ # Find wires that are reused after measurement
+ measured_wires = []
+ reused_measurement_wires = set()
+ any_repeated_measurements = False
+ is_postselecting = False
+
+ for op in tape.operations:
+ if isinstance(op, MidMeasureMP):
+ if op.postselect is not None:
+ is_postselecting = True
+ if op.reset:
+ reused_measurement_wires.add(op.wires[0])
+
+ if op.wires[0] in measured_wires:
+ any_repeated_measurements = True
+ measured_wires.append(op.wires[0])
+
+ else:
+ reused_measurement_wires = reused_measurement_wires.union(
+ set(measured_wires).intersection(op.wires.toset())
+ )
+
+ return measured_wires, reused_measurement_wires, any_repeated_measurements, is_postselecting
def null_postprocessing(results):
@@ -68,12 +127,18 @@ def defer_measurements(tape: QuantumTape, **kwargs) -> (Sequence[QuantumTape], C
.. note::
- When applying the transform on a quantum function that returns
- :func:`~pennylane.state` as the terminal measurement or contains the
+ When applying the transform on a quantum function that contains the
:class:`~.Snapshot` instruction, state information corresponding to
simulating the transformed circuit will be obtained. No
post-measurement states are considered.
+ .. warning::
+
+ :func:`~.pennylane.state` is not supported with the ``defer_measurements`` transform.
+ Additionally, :func:`~.pennylane.probs`, :func:`~.pennylane.sample` and
+ :func:`~.pennylane.counts` can only be used with ``defer_measurements`` if wires
+ or an observable are explicitly specified.
+
Args:
tape (.QuantumTape): a quantum tape
@@ -136,45 +201,25 @@ def func(x, y):
>>> func(*pars)
tensor([0.76960924, 0.13204407, 0.08394415, 0.01440254], requires_grad=True)
"""
- # pylint: disable=protected-access
+
if not any(isinstance(o, MidMeasureMP) for o in tape.operations):
return (tape,), null_postprocessing
- cv_types = (qml.operation.CVOperation, qml.operation.CVObservable)
- ops_cv = any(isinstance(op, cv_types) and op.name != "Identity" for op in tape.operations)
- obs_cv = any(
- isinstance(getattr(op, "obs", None), cv_types)
- and not isinstance(getattr(op, "obs", None), qml.Identity)
- for op in tape.measurements
- )
- if ops_cv or obs_cv:
- raise ValueError("Continuous variable operations and observables are not supported.")
+ _check_tape_validity(tape)
+
+ device = kwargs.get("device", None)
device = kwargs.get("device", None)
new_operations = []
# Find wires that are reused after measurement
- measured_wires = []
- reused_measurement_wires = set()
- repeated_measurement_wire = False
- is_postselecting = False
-
- for op in tape.operations:
- if isinstance(op, MidMeasureMP):
- if op.postselect is not None:
- is_postselecting = True
- if op.reset:
- reused_measurement_wires.add(op.wires[0])
-
- if op.wires[0] in measured_wires:
- repeated_measurement_wire = True
- measured_wires.append(op.wires[0])
-
- else:
- reused_measurement_wires = reused_measurement_wires.union(
- set(measured_wires).intersection(op.wires.toset())
- )
+ (
+ measured_wires,
+ reused_measurement_wires,
+ any_repeated_measurements,
+ is_postselecting,
+ ) = _collect_mid_measure_info(tape)
if is_postselecting and device is not None and not isinstance(device, qml.devices.DefaultQubit):
raise ValueError(f"Postselection is not supported on the {device} device.")
@@ -183,7 +228,7 @@ def func(x, y):
# classically controlled operations
control_wires = {}
cur_wire = (
- max(tape.wires) + 1 if reused_measurement_wires or repeated_measurement_wire else None
+ max(tape.wires) + 1 if reused_measurement_wires or any_repeated_measurements else None
)
for op in tape.operations:
@@ -257,7 +302,7 @@ def _add_control_gate(op, control_wires):
control = [control_wires[m.id] for m in op.meas_val.measurements]
new_ops = []
- for branch, value in op.meas_val._items(): # pylint: disable=protected-access
+ for branch, value in op.meas_val._items():
if value:
qscript = qml.tape.make_qscript(
ctrl(
diff --git a/tests/devices/default_qubit/test_default_qubit.py b/tests/devices/default_qubit/test_default_qubit.py
index c56966d4942..257d5708984 100644
--- a/tests/devices/default_qubit/test_default_qubit.py
+++ b/tests/devices/default_qubit/test_default_qubit.py
@@ -1650,7 +1650,6 @@ class TestPostselection:
qml.expval(qml.PauliZ(0)),
qml.var(qml.PauliZ(0)),
qml.probs(wires=[0, 1]),
- qml.state(),
qml.density_matrix(wires=0),
qml.purity(0),
qml.vn_entropy(0),
@@ -1800,7 +1799,6 @@ def circ_postselect(theta):
(qml.expval(qml.PauliZ(0)), "autograd", True),
(qml.var(qml.PauliZ(0)), "autograd", True),
(qml.probs(wires=[0, 1]), "autograd", True),
- (qml.state(), "autograd", True),
(qml.density_matrix(wires=0), "autograd", True),
(qml.purity(0), "numpy", True),
(qml.vn_entropy(0), "numpy", False),
diff --git a/tests/transforms/test_defer_measurements.py b/tests/transforms/test_defer_measurements.py
index f779c5a94cf..ec81062f610 100644
--- a/tests/transforms/test_defer_measurements.py
+++ b/tests/transforms/test_defer_measurements.py
@@ -80,6 +80,23 @@ def circ():
_ = circ()
+@pytest.mark.parametrize(
+ "mp, err_msg",
+ [
+ (qml.state(), "Cannot use StateMP as a measurement when"),
+ (qml.probs(), "Cannot use ProbabilityMP as a measurement without"),
+ (qml.sample(), "Cannot use SampleMP as a measurement without"),
+ (qml.counts(), "Cannot use CountsMP as a measurement without"),
+ ],
+)
+def test_unsupported_measurements(mp, err_msg):
+ """Test that using unsupported measurements raises an error."""
+ tape = qml.tape.QuantumScript([MidMeasureMP(0)], [mp])
+
+ with pytest.raises(ValueError, match=err_msg):
+ _, _ = qml.defer_measurements(tape)
+
+
class TestQNode:
"""Test that the transform integrates well with QNodes."""
@@ -1467,7 +1484,7 @@ def test_custom_wire_labels_allowed_without_reset():
qml.Hadamard("a")
ma = qml.measure("a", reset=False)
qml.cond(ma, qml.PauliX)("b")
- qml.state()
+ qml.probs(wires="a")
tape = qml.tape.QuantumScript.from_queue(q)
tapes, _ = qml.defer_measurements(tape)
@@ -1476,7 +1493,7 @@ def test_custom_wire_labels_allowed_without_reset():
assert len(tape) == 3
assert qml.equal(tape[0], qml.Hadamard("a"))
assert qml.equal(tape[1], qml.CNOT(["a", "b"]))
- assert qml.equal(tape[2], qml.state())
+ assert qml.equal(tape[2], qml.probs(wires="a"))
def test_custom_wire_labels_fails_with_reset():
@@ -1485,7 +1502,7 @@ def test_custom_wire_labels_fails_with_reset():
qml.Hadamard("a")
ma = qml.measure("a", reset=True)
qml.cond(ma, qml.PauliX)("b")
- qml.state()
+ qml.probs(wires="a")
tape = qml.tape.QuantumScript.from_queue(q)
with pytest.raises(TypeError, match="can only concatenate str"):