diff --git a/doc/releases/changelog-dev.md b/doc/releases/changelog-dev.md
index 363d9493818..73aa52d3d6e 100644
--- a/doc/releases/changelog-dev.md
+++ b/doc/releases/changelog-dev.md
@@ -11,6 +11,9 @@
Mid-circuit measurements and dynamic circuits
+* The `dynamic_one_shot` transform uses a single auxiliary tape with a shot vector and `default.qubit` implements the loop over shots with `jax.vmap`.
+ [(#5617)](https://github.com/PennyLaneAI/pennylane/pull/5617)
+
* The `dynamic_one_shot` transform can be compiled with `jax.jit`.
[(#5557)](https://github.com/PennyLaneAI/pennylane/pull/5557)
diff --git a/pennylane/_qubit_device.py b/pennylane/_qubit_device.py
index 532d406f956..897863053c1 100644
--- a/pennylane/_qubit_device.py
+++ b/pennylane/_qubit_device.py
@@ -275,8 +275,19 @@ def execute(self, circuit, **kwargs):
self.check_validity(circuit.operations, circuit.observables)
has_mcm = any(isinstance(op, MidMeasureMP) for op in circuit.operations)
- if has_mcm:
- kwargs["mid_measurements"] = {}
+ if has_mcm and "mid_measurements" not in kwargs:
+ results = []
+ aux_circ = qml.tape.QuantumScript(
+ circuit.operations,
+ circuit.measurements,
+ shots=[1],
+ trainable_params=circuit.trainable_params,
+ )
+ for _ in circuit.shots:
+ kwargs["mid_measurements"] = {}
+ self.reset()
+ results.append(self.execute(aux_circ, **kwargs))
+ return tuple(results)
# apply all circuit operations
self.apply(
circuit.operations,
diff --git a/pennylane/devices/qubit/simulate.py b/pennylane/devices/qubit/simulate.py
index 2095aff3d6d..f1929cd83db 100644
--- a/pennylane/devices/qubit/simulate.py
+++ b/pennylane/devices/qubit/simulate.py
@@ -279,9 +279,33 @@ def simulate(
has_mcm = any(isinstance(op, MidMeasureMP) for op in circuit.operations)
if circuit.shots and has_mcm:
- return simulate_one_shot_native_mcm(
- circuit, debugger=debugger, rng=rng, prng_key=prng_key, interface=interface
+ results = []
+ aux_circ = qml.tape.QuantumScript(
+ circuit.operations,
+ circuit.measurements,
+ shots=[1],
+ trainable_params=circuit.trainable_params,
)
+ keys = jax_random_split(prng_key, num=circuit.shots.total_shots)
+ if qml.math.get_deep_interface(circuit.data) == "jax":
+ # pylint: disable=import-outside-toplevel
+ import jax
+
+ def simulate_partial(k):
+ return simulate_one_shot_native_mcm(
+ aux_circ, debugger=debugger, rng=rng, prng_key=k, interface=interface
+ )
+
+ results = jax.vmap(simulate_partial, in_axes=(0,))(keys)
+ results = tuple(zip(*results))
+ else:
+ for i in range(circuit.shots.total_shots):
+ results.append(
+ simulate_one_shot_native_mcm(
+ aux_circ, debugger=debugger, rng=rng, prng_key=keys[i], interface=interface
+ )
+ )
+ return tuple(results)
ops_key, meas_key = jax_random_split(prng_key)
state, is_state_batched = get_final_state(
diff --git a/pennylane/transforms/dynamic_one_shot.py b/pennylane/transforms/dynamic_one_shot.py
index 86b57ca27a2..109f43db30b 100644
--- a/pennylane/transforms/dynamic_one_shot.py
+++ b/pennylane/transforms/dynamic_one_shot.py
@@ -87,8 +87,7 @@ def func(x, y):
few-shots several-mid-circuit-measurement limit, whereas ``qml.defer_measurements`` is favorable
in the opposite limit.
"""
-
- if not any(isinstance(o, MidMeasureMP) for o in tape.operations):
+ if not any(is_mcm(o) for o in tape.operations):
return (tape,), null_postprocessing
for m in tape.measurements:
@@ -103,9 +102,7 @@ def func(x, y):
raise qml.QuantumFunctionError("dynamic_one_shot is only supported with finite shots.")
samples_present = any(isinstance(mp, SampleMP) for mp in tape.measurements)
- postselect_present = any(
- op.postselect is not None for op in tape.operations if isinstance(op, MidMeasureMP)
- )
+ postselect_present = any(op.postselect is not None for op in tape.operations if is_mcm(op))
if postselect_present and samples_present and tape.batch_size is not None:
raise ValueError(
"Returning qml.sample is not supported when postselecting mid-circuit "
@@ -119,28 +116,21 @@ def func(x, y):
broadcast_fn = None
aux_tapes = [init_auxiliary_tape(t) for t in tapes]
- # Shape of output_tapes is (batch_size * total_shots,) with broadcasting,
- # and (total_shots,) otherwise
- output_tapes = [at for at in aux_tapes for _ in range(tape.shots.total_shots)]
def processing_fn(results, has_partitioned_shots=None, batched_results=None):
if batched_results is None and batch_size is not None:
# If broadcasting, recursively process the results for each batch. For each batch
# there are tape.shots.total_shots results. The length of the first axis of final_results
# will be batch_size.
- results = list(results)
final_results = []
- for _ in range(batch_size):
- final_results.append(
- processing_fn(results[0 : tape.shots.total_shots], batched_results=False)
- )
- del results[0 : tape.shots.total_shots]
+ for result in results:
+ final_results.append(processing_fn((result,), batched_results=False))
return broadcast_fn(final_results)
if has_partitioned_shots is None and tape.shots.has_partitioned_shots:
# If using shot vectors, recursively process the results for each shot bin. The length
# of the first axis of final_results will be the length of the shot vector.
- results = list(results)
+ results = list(results[0])
final_results = []
for s in tape.shots:
final_results.append(
@@ -148,9 +138,11 @@ def processing_fn(results, has_partitioned_shots=None, batched_results=None):
)
del results[0:s]
return tuple(final_results)
+ if not tape.shots.has_partitioned_shots:
+ results = results[0]
return parse_native_mid_circuit_measurements(tape, aux_tapes, results)
- return output_tapes, processing_fn
+ return aux_tapes, processing_fn
@dynamic_one_shot.custom_qnode_transform
@@ -177,6 +169,12 @@ def _dynamic_one_shot_qnode(self, qnode, targs, tkwargs):
return self.default_qnode_transform(qnode, targs, tkwargs)
+def is_mcm(operation):
+ """Returns True if the operation is a mid-circuit measurement and False otherwise."""
+ mcm = isinstance(operation, MidMeasureMP)
+ return mcm or "MidCircuitMeasure" in str(type(operation))
+
+
def init_auxiliary_tape(circuit: qml.tape.QuantumScript):
"""Creates an auxiliary circuit to perform one-shot mid-circuit measurement calculations.
@@ -197,11 +195,14 @@ def init_auxiliary_tape(circuit: qml.tape.QuantumScript):
else:
new_measurements.append(m)
for op in circuit:
- if isinstance(op, MidMeasureMP):
+ if is_mcm(op):
new_measurements.append(qml.sample(MeasurementValue([op], lambda res: res)))
return qml.tape.QuantumScript(
- circuit.operations, new_measurements, shots=1, trainable_params=circuit.trainable_params
+ circuit.operations,
+ new_measurements,
+ shots=[1] * circuit.shots.total_shots,
+ trainable_params=circuit.trainable_params,
)
@@ -228,7 +229,7 @@ def measurement_with_no_shots(measurement):
interface = qml.math.get_deep_interface(circuit.data)
- all_mcms = [op for op in aux_tapes[0].operations if isinstance(op, MidMeasureMP)]
+ all_mcms = [op for op in aux_tapes[0].operations if is_mcm(op)]
n_mcms = len(all_mcms)
post_process_tape = qml.tape.QuantumScript(
aux_tapes[0].operations,
@@ -248,7 +249,7 @@ def measurement_with_no_shots(measurement):
).reshape((1, -1))
is_valid = qml.math.all(mcm_samples * has_postselect == postselect, axis=1)
has_valid = qml.math.any(is_valid)
- mid_meas = [op for op in circuit.operations if isinstance(op, MidMeasureMP)]
+ mid_meas = [op for op in circuit.operations if is_mcm(op)]
mcm_samples = [mcm_samples[:, i : i + 1] for i in range(n_mcms)]
mcm_samples = dict((k, v) for k, v in zip(mid_meas, mcm_samples))
diff --git a/tests/devices/default_qubit/test_default_qubit_native_mcm.py b/tests/devices/default_qubit/test_default_qubit_native_mcm.py
index 691dcf3f69b..4848ea4f6e3 100644
--- a/tests/devices/default_qubit/test_default_qubit_native_mcm.py
+++ b/tests/devices/default_qubit/test_default_qubit_native_mcm.py
@@ -693,7 +693,7 @@ def func(x, y, z):
assert "pure_callback" not in jaxpr
func2 = jax.jit(func)
- results2 = func2(*jax.numpy.array(params))
+ results2 = func2(*params)
measures = [
qml.probs,
diff --git a/tests/transforms/test_dynamic_one_shot.py b/tests/transforms/test_dynamic_one_shot.py
index f65f2b805a5..3ca186171fe 100644
--- a/tests/transforms/test_dynamic_one_shot.py
+++ b/tests/transforms/test_dynamic_one_shot.py
@@ -88,7 +88,7 @@ def test_len_tapes(n_shots):
"""Test that the transform produces the correct number of tapes."""
tape = qml.tape.QuantumScript([MidMeasureMP(0)], [qml.expval(qml.PauliZ(0))], shots=n_shots)
tapes, _ = qml.dynamic_one_shot(tape)
- assert len(tapes) == n_shots
+ assert len(tapes) == 1
@pytest.mark.parametrize("n_batch", range(1, 4))
@@ -102,7 +102,7 @@ def test_len_tape_batched(n_batch, n_shots):
shots=n_shots,
)
tapes, _ = qml.dynamic_one_shot(tape)
- assert len(tapes) == n_shots * n_batch
+ assert len(tapes) == n_batch
@pytest.mark.parametrize(
@@ -123,7 +123,7 @@ def test_len_measurements_obs(measure, aux_measure, n_meas):
[qml.Hadamard(0)] + [MidMeasureMP(0)] * n_mcms, [measure(op=qml.PauliZ(0))], shots=n_shots
)
tapes, _ = qml.dynamic_one_shot(tape)
- assert len(tapes) == n_shots
+ assert len(tapes) == 1
aux_tape = tapes[0]
assert len(aux_tape.measurements) == n_meas + n_mcms
assert isinstance(aux_tape.measurements[0], aux_measure)
@@ -150,7 +150,7 @@ def test_len_measurements_mcms(measure, aux_measure, n_meas):
shots=n_shots,
)
tapes, _ = qml.dynamic_one_shot(tape)
- assert len(tapes) == n_shots
+ assert len(tapes) == 1
aux_tape = tapes[0]
assert len(aux_tape.measurements) == n_meas + n_mcms
assert isinstance(aux_tape.measurements[0], aux_measure)