diff --git a/doc/releases/changelog-dev.md b/doc/releases/changelog-dev.md index 812e35faafd..71c11cbab8c 100644 --- a/doc/releases/changelog-dev.md +++ b/doc/releases/changelog-dev.md @@ -382,9 +382,12 @@ decomposition. [(#4675)](https://github.com/PennyLaneAI/pennylane/pull/4675) -

Breaking changes 💔

+* ``qml.defer_measurements`` now raises an error if a transformed circuit measures ``qml.probs``, + ``qml.sample``, or ``qml.counts`` without any wires or obsrvable, or if it measures ``qml.state``. + [(#4701)](https://github.com/PennyLaneAI/pennylane/pull/4701) + * The device test suite now converts device kwargs to integers or floats if they can be converted to integers or floats. [(#4640)](https://github.com/PennyLaneAI/pennylane/pull/4640) diff --git a/pennylane/devices/default_qubit.py b/pennylane/devices/default_qubit.py index 52bcd90470b..c1a21cbb4fe 100644 --- a/pennylane/devices/default_qubit.py +++ b/pennylane/devices/default_qubit.py @@ -387,8 +387,8 @@ def preprocess( config = self._setup_execution_config(execution_config) transform_program = TransformProgram() - transform_program.add_transform(qml.defer_measurements, device=self) transform_program.add_transform(validate_device_wires, self.wires, name=self.name) + transform_program.add_transform(qml.defer_measurements, device=self) transform_program.add_transform( decompose, stopping_condition=stopping_condition, name=self.name ) diff --git a/pennylane/transforms/defer_measurements.py b/pennylane/transforms/defer_measurements.py index 1acf9ec29ec..c39bb081957 100644 --- a/pennylane/transforms/defer_measurements.py +++ b/pennylane/transforms/defer_measurements.py @@ -14,7 +14,7 @@ """Code for the tape transform implementing the deferred measurement principle.""" from typing import Sequence, Callable import pennylane as qml -from pennylane.measurements import MidMeasureMP +from pennylane.measurements import MidMeasureMP, ProbabilityMP, SampleMP, CountsMP from pennylane.ops.op_math import ctrl from pennylane.tape import QuantumTape @@ -23,7 +23,66 @@ from pennylane.wires import Wires from pennylane.queuing import QueuingManager -# pylint: disable=too-many-branches, too-many-statements +# pylint: disable=too-many-branches, protected-access + + +def _check_tape_validity(tape: QuantumTape): + """Helper function to check that the tape is valid.""" + cv_types = (qml.operation.CVOperation, qml.operation.CVObservable) + ops_cv = any(isinstance(op, cv_types) and op.name != "Identity" for op in tape.operations) + obs_cv = any( + isinstance(getattr(op, "obs", None), cv_types) + and not isinstance(getattr(op, "obs", None), qml.Identity) + for op in tape.measurements + ) + if ops_cv or obs_cv: + raise ValueError("Continuous variable operations and observables are not supported.") + + for mp in tape.measurements: + if isinstance(mp, (CountsMP, ProbabilityMP, SampleMP)) and not ( + mp.obs or mp._wires or mp.mv + ): + raise ValueError( + f"Cannot use {mp.__class__.__name__} as a measurement without specifying wires " + "when using qml.defer_measurements. Deferred measurements can occur " + "automatically when using mid-circuit measurements on a device that does not " + "support them." + ) + + if mp.__class__.__name__ == "StateMP": + raise ValueError( + "Cannot use StateMP as a measurement when using qml.defer_measurements. " + "Deferred measurements can occur automatically when using mid-circuit " + "measurements on a device that does not support them." + ) + + +def _collect_mid_measure_info(tape: QuantumTape): + """Helper function to collect information related to mid-circuit measurements in the tape.""" + + # Find wires that are reused after measurement + measured_wires = [] + reused_measurement_wires = set() + any_repeated_measurements = False + is_postselecting = False + + for op in tape.operations: + if isinstance(op, MidMeasureMP): + if op.postselect is not None: + is_postselecting = True + if op.reset: + reused_measurement_wires.add(op.wires[0]) + + if op.wires[0] in measured_wires: + any_repeated_measurements = True + measured_wires.append(op.wires[0]) + + else: + reused_measurement_wires = reused_measurement_wires.union( + set(measured_wires).intersection(op.wires.toset()) + ) + + return measured_wires, reused_measurement_wires, any_repeated_measurements, is_postselecting def null_postprocessing(results): @@ -68,12 +127,18 @@ def defer_measurements(tape: QuantumTape, **kwargs) -> (Sequence[QuantumTape], C .. note:: - When applying the transform on a quantum function that returns - :func:`~pennylane.state` as the terminal measurement or contains the + When applying the transform on a quantum function that contains the :class:`~.Snapshot` instruction, state information corresponding to simulating the transformed circuit will be obtained. No post-measurement states are considered. + .. warning:: + + :func:`~.pennylane.state` is not supported with the ``defer_measurements`` transform. + Additionally, :func:`~.pennylane.probs`, :func:`~.pennylane.sample` and + :func:`~.pennylane.counts` can only be used with ``defer_measurements`` if wires + or an observable are explicitly specified. + Args: tape (.QuantumTape): a quantum tape @@ -136,45 +201,25 @@ def func(x, y): >>> func(*pars) tensor([0.76960924, 0.13204407, 0.08394415, 0.01440254], requires_grad=True) """ - # pylint: disable=protected-access + if not any(isinstance(o, MidMeasureMP) for o in tape.operations): return (tape,), null_postprocessing - cv_types = (qml.operation.CVOperation, qml.operation.CVObservable) - ops_cv = any(isinstance(op, cv_types) and op.name != "Identity" for op in tape.operations) - obs_cv = any( - isinstance(getattr(op, "obs", None), cv_types) - and not isinstance(getattr(op, "obs", None), qml.Identity) - for op in tape.measurements - ) - if ops_cv or obs_cv: - raise ValueError("Continuous variable operations and observables are not supported.") + _check_tape_validity(tape) + + device = kwargs.get("device", None) device = kwargs.get("device", None) new_operations = [] # Find wires that are reused after measurement - measured_wires = [] - reused_measurement_wires = set() - repeated_measurement_wire = False - is_postselecting = False - - for op in tape.operations: - if isinstance(op, MidMeasureMP): - if op.postselect is not None: - is_postselecting = True - if op.reset: - reused_measurement_wires.add(op.wires[0]) - - if op.wires[0] in measured_wires: - repeated_measurement_wire = True - measured_wires.append(op.wires[0]) - - else: - reused_measurement_wires = reused_measurement_wires.union( - set(measured_wires).intersection(op.wires.toset()) - ) + ( + measured_wires, + reused_measurement_wires, + any_repeated_measurements, + is_postselecting, + ) = _collect_mid_measure_info(tape) if is_postselecting and device is not None and not isinstance(device, qml.devices.DefaultQubit): raise ValueError(f"Postselection is not supported on the {device} device.") @@ -183,7 +228,7 @@ def func(x, y): # classically controlled operations control_wires = {} cur_wire = ( - max(tape.wires) + 1 if reused_measurement_wires or repeated_measurement_wire else None + max(tape.wires) + 1 if reused_measurement_wires or any_repeated_measurements else None ) for op in tape.operations: @@ -257,7 +302,7 @@ def _add_control_gate(op, control_wires): control = [control_wires[m.id] for m in op.meas_val.measurements] new_ops = [] - for branch, value in op.meas_val._items(): # pylint: disable=protected-access + for branch, value in op.meas_val._items(): if value: qscript = qml.tape.make_qscript( ctrl( diff --git a/tests/devices/default_qubit/test_default_qubit.py b/tests/devices/default_qubit/test_default_qubit.py index c56966d4942..257d5708984 100644 --- a/tests/devices/default_qubit/test_default_qubit.py +++ b/tests/devices/default_qubit/test_default_qubit.py @@ -1650,7 +1650,6 @@ class TestPostselection: qml.expval(qml.PauliZ(0)), qml.var(qml.PauliZ(0)), qml.probs(wires=[0, 1]), - qml.state(), qml.density_matrix(wires=0), qml.purity(0), qml.vn_entropy(0), @@ -1800,7 +1799,6 @@ def circ_postselect(theta): (qml.expval(qml.PauliZ(0)), "autograd", True), (qml.var(qml.PauliZ(0)), "autograd", True), (qml.probs(wires=[0, 1]), "autograd", True), - (qml.state(), "autograd", True), (qml.density_matrix(wires=0), "autograd", True), (qml.purity(0), "numpy", True), (qml.vn_entropy(0), "numpy", False), diff --git a/tests/transforms/test_defer_measurements.py b/tests/transforms/test_defer_measurements.py index f779c5a94cf..ec81062f610 100644 --- a/tests/transforms/test_defer_measurements.py +++ b/tests/transforms/test_defer_measurements.py @@ -80,6 +80,23 @@ def circ(): _ = circ() +@pytest.mark.parametrize( + "mp, err_msg", + [ + (qml.state(), "Cannot use StateMP as a measurement when"), + (qml.probs(), "Cannot use ProbabilityMP as a measurement without"), + (qml.sample(), "Cannot use SampleMP as a measurement without"), + (qml.counts(), "Cannot use CountsMP as a measurement without"), + ], +) +def test_unsupported_measurements(mp, err_msg): + """Test that using unsupported measurements raises an error.""" + tape = qml.tape.QuantumScript([MidMeasureMP(0)], [mp]) + + with pytest.raises(ValueError, match=err_msg): + _, _ = qml.defer_measurements(tape) + + class TestQNode: """Test that the transform integrates well with QNodes.""" @@ -1467,7 +1484,7 @@ def test_custom_wire_labels_allowed_without_reset(): qml.Hadamard("a") ma = qml.measure("a", reset=False) qml.cond(ma, qml.PauliX)("b") - qml.state() + qml.probs(wires="a") tape = qml.tape.QuantumScript.from_queue(q) tapes, _ = qml.defer_measurements(tape) @@ -1476,7 +1493,7 @@ def test_custom_wire_labels_allowed_without_reset(): assert len(tape) == 3 assert qml.equal(tape[0], qml.Hadamard("a")) assert qml.equal(tape[1], qml.CNOT(["a", "b"])) - assert qml.equal(tape[2], qml.state()) + assert qml.equal(tape[2], qml.probs(wires="a")) def test_custom_wire_labels_fails_with_reset(): @@ -1485,7 +1502,7 @@ def test_custom_wire_labels_fails_with_reset(): qml.Hadamard("a") ma = qml.measure("a", reset=True) qml.cond(ma, qml.PauliX)("b") - qml.state() + qml.probs(wires="a") tape = qml.tape.QuantumScript.from_queue(q) with pytest.raises(TypeError, match="can only concatenate str"):