diff --git a/doc/releases/changelog-dev.md b/doc/releases/changelog-dev.md index 1a240756f3b..320196f0a74 100644 --- a/doc/releases/changelog-dev.md +++ b/doc/releases/changelog-dev.md @@ -115,6 +115,7 @@ * `qml.transforms.split_non_commuting` can now handle circuits containing measurements of multi-term observables. [(#5729)](https://github.com/PennyLaneAI/pennylane/pull/5729) [(#5853)](https://github.com/PennyLaneAI/pennylane/pull/5838) + [(#5869)](https://github.com/PennyLaneAI/pennylane/pull/5869) * The qchem module has dedicated functions for calling `pyscf` and `openfermion` backends. [(#5553)](https://github.com/PennyLaneAI/pennylane/pull/5553) diff --git a/pennylane/transforms/split_non_commuting.py b/pennylane/transforms/split_non_commuting.py index 8c23061e852..71d7798f309 100644 --- a/pennylane/transforms/split_non_commuting.py +++ b/pennylane/transforms/split_non_commuting.py @@ -28,6 +28,13 @@ from pennylane.typing import Result, ResultBatch +def null_postprocessing(results): + """A postprocessing function returned by a transform that only converts the batch of results + into a result for a single ``QuantumTape``. + """ + return results[0] + + @transform def split_non_commuting( tape: qml.tape.QuantumScript, @@ -243,6 +250,8 @@ def circuit(x): [[expval(X(0)), probs(wires=[1])], [probs(wires=[0, 1])]] """ + if len(tape.measurements) == 0: + return [tape], null_postprocessing # Special case for a single measurement of a Sum or Hamiltonian, in which case # the grouping information can be computed and cached in the observable. diff --git a/tests/transforms/test_split_non_commuting.py b/tests/transforms/test_split_non_commuting.py index 709a0ed8e2a..5af275bc17a 100644 --- a/tests/transforms/test_split_non_commuting.py +++ b/tests/transforms/test_split_non_commuting.py @@ -568,6 +568,17 @@ def test_tape_with_non_pauli_obs(self, non_pauli_obs): fn([[0.1, 0.2], [0.3, 0.6], 0.4, 0.5, 0.7]), [0.01, 0.06, 0.06, 0.16, 0.25, 0.36, 0.49] ) + @pytest.mark.parametrize("grouping_strategy", [None, "default", "qwc", "wires"]) + def test_no_measurements(self, grouping_strategy): + """Test that if the tape contains no measurements, the transform doesn't + modify it""" + + tape = qml.tape.QuantumScript([qml.X(0)]) + tapes, post_processing_fn = split_non_commuting(tape, grouping_strategy=grouping_strategy) + assert len(tapes) == 1 + assert tapes[0] == tape + assert post_processing_fn(tapes) == tape + class TestIntegration: """Tests the ``split_non_commuting`` transform performed on a QNode"""