diff --git a/doc/releases/changelog-dev.md b/doc/releases/changelog-dev.md index 363d9493818..73aa52d3d6e 100644 --- a/doc/releases/changelog-dev.md +++ b/doc/releases/changelog-dev.md @@ -11,6 +11,9 @@

Mid-circuit measurements and dynamic circuits

+* The `dynamic_one_shot` transform uses a single auxiliary tape with a shot vector and `default.qubit` implements the loop over shots with `jax.vmap`. + [(#5617)](https://github.com/PennyLaneAI/pennylane/pull/5617) + * The `dynamic_one_shot` transform can be compiled with `jax.jit`. [(#5557)](https://github.com/PennyLaneAI/pennylane/pull/5557) diff --git a/pennylane/_qubit_device.py b/pennylane/_qubit_device.py index 532d406f956..897863053c1 100644 --- a/pennylane/_qubit_device.py +++ b/pennylane/_qubit_device.py @@ -275,8 +275,19 @@ def execute(self, circuit, **kwargs): self.check_validity(circuit.operations, circuit.observables) has_mcm = any(isinstance(op, MidMeasureMP) for op in circuit.operations) - if has_mcm: - kwargs["mid_measurements"] = {} + if has_mcm and "mid_measurements" not in kwargs: + results = [] + aux_circ = qml.tape.QuantumScript( + circuit.operations, + circuit.measurements, + shots=[1], + trainable_params=circuit.trainable_params, + ) + for _ in circuit.shots: + kwargs["mid_measurements"] = {} + self.reset() + results.append(self.execute(aux_circ, **kwargs)) + return tuple(results) # apply all circuit operations self.apply( circuit.operations, diff --git a/pennylane/devices/qubit/simulate.py b/pennylane/devices/qubit/simulate.py index 2095aff3d6d..f1929cd83db 100644 --- a/pennylane/devices/qubit/simulate.py +++ b/pennylane/devices/qubit/simulate.py @@ -279,9 +279,33 @@ def simulate( has_mcm = any(isinstance(op, MidMeasureMP) for op in circuit.operations) if circuit.shots and has_mcm: - return simulate_one_shot_native_mcm( - circuit, debugger=debugger, rng=rng, prng_key=prng_key, interface=interface + results = [] + aux_circ = qml.tape.QuantumScript( + circuit.operations, + circuit.measurements, + shots=[1], + trainable_params=circuit.trainable_params, ) + keys = jax_random_split(prng_key, num=circuit.shots.total_shots) + if qml.math.get_deep_interface(circuit.data) == "jax": + # pylint: disable=import-outside-toplevel + import jax + + def simulate_partial(k): + return simulate_one_shot_native_mcm( + aux_circ, debugger=debugger, rng=rng, prng_key=k, interface=interface + ) + + results = jax.vmap(simulate_partial, in_axes=(0,))(keys) + results = tuple(zip(*results)) + else: + for i in range(circuit.shots.total_shots): + results.append( + simulate_one_shot_native_mcm( + aux_circ, debugger=debugger, rng=rng, prng_key=keys[i], interface=interface + ) + ) + return tuple(results) ops_key, meas_key = jax_random_split(prng_key) state, is_state_batched = get_final_state( diff --git a/pennylane/transforms/dynamic_one_shot.py b/pennylane/transforms/dynamic_one_shot.py index 86b57ca27a2..109f43db30b 100644 --- a/pennylane/transforms/dynamic_one_shot.py +++ b/pennylane/transforms/dynamic_one_shot.py @@ -87,8 +87,7 @@ def func(x, y): few-shots several-mid-circuit-measurement limit, whereas ``qml.defer_measurements`` is favorable in the opposite limit. """ - - if not any(isinstance(o, MidMeasureMP) for o in tape.operations): + if not any(is_mcm(o) for o in tape.operations): return (tape,), null_postprocessing for m in tape.measurements: @@ -103,9 +102,7 @@ def func(x, y): raise qml.QuantumFunctionError("dynamic_one_shot is only supported with finite shots.") samples_present = any(isinstance(mp, SampleMP) for mp in tape.measurements) - postselect_present = any( - op.postselect is not None for op in tape.operations if isinstance(op, MidMeasureMP) - ) + postselect_present = any(op.postselect is not None for op in tape.operations if is_mcm(op)) if postselect_present and samples_present and tape.batch_size is not None: raise ValueError( "Returning qml.sample is not supported when postselecting mid-circuit " @@ -119,28 +116,21 @@ def func(x, y): broadcast_fn = None aux_tapes = [init_auxiliary_tape(t) for t in tapes] - # Shape of output_tapes is (batch_size * total_shots,) with broadcasting, - # and (total_shots,) otherwise - output_tapes = [at for at in aux_tapes for _ in range(tape.shots.total_shots)] def processing_fn(results, has_partitioned_shots=None, batched_results=None): if batched_results is None and batch_size is not None: # If broadcasting, recursively process the results for each batch. For each batch # there are tape.shots.total_shots results. The length of the first axis of final_results # will be batch_size. - results = list(results) final_results = [] - for _ in range(batch_size): - final_results.append( - processing_fn(results[0 : tape.shots.total_shots], batched_results=False) - ) - del results[0 : tape.shots.total_shots] + for result in results: + final_results.append(processing_fn((result,), batched_results=False)) return broadcast_fn(final_results) if has_partitioned_shots is None and tape.shots.has_partitioned_shots: # If using shot vectors, recursively process the results for each shot bin. The length # of the first axis of final_results will be the length of the shot vector. - results = list(results) + results = list(results[0]) final_results = [] for s in tape.shots: final_results.append( @@ -148,9 +138,11 @@ def processing_fn(results, has_partitioned_shots=None, batched_results=None): ) del results[0:s] return tuple(final_results) + if not tape.shots.has_partitioned_shots: + results = results[0] return parse_native_mid_circuit_measurements(tape, aux_tapes, results) - return output_tapes, processing_fn + return aux_tapes, processing_fn @dynamic_one_shot.custom_qnode_transform @@ -177,6 +169,12 @@ def _dynamic_one_shot_qnode(self, qnode, targs, tkwargs): return self.default_qnode_transform(qnode, targs, tkwargs) +def is_mcm(operation): + """Returns True if the operation is a mid-circuit measurement and False otherwise.""" + mcm = isinstance(operation, MidMeasureMP) + return mcm or "MidCircuitMeasure" in str(type(operation)) + + def init_auxiliary_tape(circuit: qml.tape.QuantumScript): """Creates an auxiliary circuit to perform one-shot mid-circuit measurement calculations. @@ -197,11 +195,14 @@ def init_auxiliary_tape(circuit: qml.tape.QuantumScript): else: new_measurements.append(m) for op in circuit: - if isinstance(op, MidMeasureMP): + if is_mcm(op): new_measurements.append(qml.sample(MeasurementValue([op], lambda res: res))) return qml.tape.QuantumScript( - circuit.operations, new_measurements, shots=1, trainable_params=circuit.trainable_params + circuit.operations, + new_measurements, + shots=[1] * circuit.shots.total_shots, + trainable_params=circuit.trainable_params, ) @@ -228,7 +229,7 @@ def measurement_with_no_shots(measurement): interface = qml.math.get_deep_interface(circuit.data) - all_mcms = [op for op in aux_tapes[0].operations if isinstance(op, MidMeasureMP)] + all_mcms = [op for op in aux_tapes[0].operations if is_mcm(op)] n_mcms = len(all_mcms) post_process_tape = qml.tape.QuantumScript( aux_tapes[0].operations, @@ -248,7 +249,7 @@ def measurement_with_no_shots(measurement): ).reshape((1, -1)) is_valid = qml.math.all(mcm_samples * has_postselect == postselect, axis=1) has_valid = qml.math.any(is_valid) - mid_meas = [op for op in circuit.operations if isinstance(op, MidMeasureMP)] + mid_meas = [op for op in circuit.operations if is_mcm(op)] mcm_samples = [mcm_samples[:, i : i + 1] for i in range(n_mcms)] mcm_samples = dict((k, v) for k, v in zip(mid_meas, mcm_samples)) diff --git a/tests/devices/default_qubit/test_default_qubit_native_mcm.py b/tests/devices/default_qubit/test_default_qubit_native_mcm.py index 691dcf3f69b..4848ea4f6e3 100644 --- a/tests/devices/default_qubit/test_default_qubit_native_mcm.py +++ b/tests/devices/default_qubit/test_default_qubit_native_mcm.py @@ -693,7 +693,7 @@ def func(x, y, z): assert "pure_callback" not in jaxpr func2 = jax.jit(func) - results2 = func2(*jax.numpy.array(params)) + results2 = func2(*params) measures = [ qml.probs, diff --git a/tests/transforms/test_dynamic_one_shot.py b/tests/transforms/test_dynamic_one_shot.py index f65f2b805a5..3ca186171fe 100644 --- a/tests/transforms/test_dynamic_one_shot.py +++ b/tests/transforms/test_dynamic_one_shot.py @@ -88,7 +88,7 @@ def test_len_tapes(n_shots): """Test that the transform produces the correct number of tapes.""" tape = qml.tape.QuantumScript([MidMeasureMP(0)], [qml.expval(qml.PauliZ(0))], shots=n_shots) tapes, _ = qml.dynamic_one_shot(tape) - assert len(tapes) == n_shots + assert len(tapes) == 1 @pytest.mark.parametrize("n_batch", range(1, 4)) @@ -102,7 +102,7 @@ def test_len_tape_batched(n_batch, n_shots): shots=n_shots, ) tapes, _ = qml.dynamic_one_shot(tape) - assert len(tapes) == n_shots * n_batch + assert len(tapes) == n_batch @pytest.mark.parametrize( @@ -123,7 +123,7 @@ def test_len_measurements_obs(measure, aux_measure, n_meas): [qml.Hadamard(0)] + [MidMeasureMP(0)] * n_mcms, [measure(op=qml.PauliZ(0))], shots=n_shots ) tapes, _ = qml.dynamic_one_shot(tape) - assert len(tapes) == n_shots + assert len(tapes) == 1 aux_tape = tapes[0] assert len(aux_tape.measurements) == n_meas + n_mcms assert isinstance(aux_tape.measurements[0], aux_measure) @@ -150,7 +150,7 @@ def test_len_measurements_mcms(measure, aux_measure, n_meas): shots=n_shots, ) tapes, _ = qml.dynamic_one_shot(tape) - assert len(tapes) == n_shots + assert len(tapes) == 1 aux_tape = tapes[0] assert len(aux_tape.measurements) == n_meas + n_mcms assert isinstance(aux_tape.measurements[0], aux_measure)