Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

dynamic_one_shot uses tapes with shot-vectors and jitting takes advantage of it #5617

Merged
merged 103 commits into from
May 10, 2024
Merged
Show file tree
Hide file tree
Changes from 98 commits
Commits
Show all changes
103 commits
Select commit Hold shift + click to select a range
2d6d336
Added rng and prng_key to get_final_state, apply_operation
mudit2812 Mar 7, 2024
7560cdf
Use rng in apply_operation args; linting
mudit2812 Mar 7, 2024
11290a7
WIP
vincentmr Apr 4, 2024
1313af8
Fix measure_with_samples' handling of mid_measurements.
vincentmr Apr 9, 2024
412374e
Remove comments.
vincentmr Apr 9, 2024
bd32a35
Merge branch 'master' into feature/dynamic_samples
vincentmr Apr 9, 2024
19d7f73
update changelog
vincentmr Apr 9, 2024
8aea51e
Fix legacy node native mcm test.
vincentmr Apr 9, 2024
3fb1758
Fix old device API
vincentmr Apr 9, 2024
57386e6
Merge branch 'master' into feature/dynamic_samples
vincentmr Apr 9, 2024
b9d8996
Fill out mid_measurements in mock device.
vincentmr Apr 10, 2024
5f7a33f
Refactor using masks.
vincentmr Apr 10, 2024
146baef
Merge remote-tracking branch 'origin/master' into feature/dynamic_sam…
vincentmr Apr 10, 2024
c975493
Update rng use; fix docs
mudit2812 Apr 10, 2024
6899b2c
[skip ci] Skip CI
mudit2812 Apr 10, 2024
87f34b2
Always compute all results, even if mv = -1.
vincentmr Apr 11, 2024
721c284
[skip ci] testing changes to native MCM tests
mudit2812 Apr 11, 2024
4aabfe5
Make post-processing jax-ready. WIP
vincentmr Apr 11, 2024
2614d53
Update pennylane/devices/qubit/sampling.py
vincentmr Apr 12, 2024
c707d46
Implement Christina's suggestions.
vincentmr Apr 12, 2024
bd800bb
Merge branch 'master' into feature/dynamic_samples
vincentmr Apr 12, 2024
f89139a
Add rng and jit mid_measure. WARN: regression
vincentmr Apr 12, 2024
f9c7b10
Merge branch 'master' into feature/dynamic_samples
vincentmr Apr 12, 2024
f7d7fd2
WIP
vincentmr Apr 12, 2024
a836a51
Implement apply_mid_meas with norm of a branch @mudit.
vincentmr Apr 12, 2024
5b1b732
Merge remote-tracking branch 'origin/simulate-rng' into feature/dynam…
vincentmr Apr 12, 2024
34b513d
Add jax-jit support in apply_cond/mid_meas.
vincentmr Apr 12, 2024
2562b94
Introduce prng
vincentmr Apr 12, 2024
3db8767
WIP
vincentmr Apr 12, 2024
ceaa339
Merge branch 'master' into feature/dynamic_samples
vincentmr Apr 15, 2024
ccafa89
Merge branch 'master' into feature/dynamic_samples
vincentmr Apr 15, 2024
e1804a8
Remove jax branch in apply_mid_meas.
vincentmr Apr 15, 2024
f7c65bc
Fix single wire probs.
vincentmr Apr 15, 2024
5aec928
Bug fix MV lists.
vincentmr Apr 15, 2024
11d6d43
Add tests for math.all/any
vincentmr Apr 17, 2024
2753b26
Merge remote-tracking branch 'origin/master' into feature/dynamic_sam…
vincentmr Apr 17, 2024
62ef926
qml.math.all doesn't need implementation.
vincentmr Apr 17, 2024
6c15204
Merge remote-tracking branch 'origin/master' into feature/dynamic_sam…
vincentmr Apr 17, 2024
97cf86e
Fix error message @albi3ro
vincentmr Apr 18, 2024
ce9a4a1
Merge branch 'master' into feature/dynamic_samples
vincentmr Apr 18, 2024
027c65a
Test measure_final_state raises.
vincentmr Apr 18, 2024
85334e7
Add tests in tests/transforms/test_dynamic_one_shot.py
vincentmr Apr 19, 2024
de47630
Merge branch 'master' into feature/dynamic_samples
vincentmr Apr 19, 2024
0c295ac
Fix lint.
vincentmr Apr 19, 2024
eeed3d4
Add error test.
vincentmr Apr 19, 2024
387d898
Merge branch 'master' into feature/dynamic_samples
vincentmr Apr 19, 2024
88e9eca
Merge branch 'master' into simulate-rng
mudit2812 Apr 19, 2024
985775e
Merge branch 'master' into simulate-rng
mudit2812 Apr 19, 2024
e1d8e4b
Merge remote-tracking branch 'origin/master' into feature/dynamic_sam…
vincentmr Apr 22, 2024
7f7be2d
Merge remote-tracking branch 'origin/master' into feature/dynamic_sam…
vincentmr Apr 22, 2024
d7a91fc
Add test for batched dynamic_one_shot
vincentmr Apr 22, 2024
7125646
Merge branch 'master' into feature/dynamic_samples
vincentmr Apr 22, 2024
c46a46a
Merge remote-tracking branch 'origin/feature/dynamic_samples' into fe…
vincentmr Apr 22, 2024
ae5abe7
Refactor here and there.
vincentmr Apr 22, 2024
b6b5e68
Sort imports
vincentmr Apr 22, 2024
8162704
Revert isort changes.
vincentmr Apr 22, 2024
8f3e558
Merge branch 'master' into simulate-rng
mudit2812 Apr 22, 2024
e091e86
Merge remote-tracking branch 'origin/master' into feature/dynamic_sam…
vincentmr Apr 22, 2024
b0a1ec9
Correctly propagating PRNGKey and RNG to apply_operation
mudit2812 Apr 22, 2024
6ad8352
Fix _sample_state_jax cond
vincentmr Apr 23, 2024
c641af4
Fix test_parse_native_mid_circuit_measurements_unsupported_meas
vincentmr Apr 23, 2024
42a649d
DQ.execute distributes kwargs with multithreading
mudit2812 Apr 23, 2024
148d051
\Merge remote-tracking branch 'origin/simulate-rng' into feature/dyna…
vincentmr Apr 23, 2024
bec05b7
Add device as kwargs to dyn_one_shot
vincentmr Apr 24, 2024
b301554
Merge remote-tracking branch 'origin/master' into feature/dynamic_sam…
vincentmr Apr 29, 2024
0565c61
Filter postselected values in post-processing.
vincentmr Apr 29, 2024
bb49932
Add jax.jit tests.
vincentmr Apr 29, 2024
a0b6703
Merge remote-tracking branch 'origin/master' into feature/dynamic_sam…
vincentmr Apr 29, 2024
fabfe8f
Merge remote-tracking branch 'origin/master' into feature/dynamic_sam…
vincentmr Apr 30, 2024
3c05878
Fix flaky test's seed in test_broadcast_expand.
vincentmr Apr 30, 2024
9ea624e
jax.numpy.array(params)
vincentmr Apr 30, 2024
11f369f
Merge branch 'master' into feature/dynamic_samples_jit
vincentmr Apr 30, 2024
5398ce7
Update pennylane/devices/qubit/apply_operation.py
vincentmr Apr 30, 2024
f140b5c
Update pennylane/devices/qubit/apply_operation.py
vincentmr Apr 30, 2024
a858e0e
Update pennylane/devices/qubit/sampling.py
vincentmr Apr 30, 2024
16a1852
WIP
vincentmr Apr 30, 2024
303faaf
Implement Mudit's suggestions.
vincentmr Apr 30, 2024
e491b63
Merge remote-tracking branch 'origin/feature/dynamic_samples_jit' int…
vincentmr Apr 30, 2024
28ed449
Clean up ImageTape.
vincentmr May 1, 2024
d948334
Merge remote-tracking branch 'origin/master' into feature/dynamic_sam…
vincentmr May 1, 2024
91945be
Use more robust QubitUnitary for the time being.
vincentmr May 1, 2024
457bcae
Merge remote-tracking branch 'origin/master' into feature/dynamic_sam…
vincentmr May 1, 2024
ac42ef0
Make sure we're not using Python callbacks with jaxpr.
vincentmr May 1, 2024
7cc9472
Fix reset matrix.
vincentmr May 1, 2024
b13961e
Remove unused where.
vincentmr May 1, 2024
268e4e6
Merge branch 'master' into feature/dynamic_samples_jit
vincentmr May 1, 2024
24d4c59
Merge remote-tracking branch 'origin/feature/dynamic_samples_jit' int…
vincentmr May 2, 2024
bbe4f7e
Fix lint/tests.
vincentmr May 2, 2024
969c630
Fix legacy.
vincentmr May 2, 2024
0af62d2
Remove obsolete change [skip ci]
vincentmr May 2, 2024
2b1db5f
Add dev notes and remove useless importorskip.
vincentmr May 2, 2024
6c3bd0b
Merge branch 'master' into feature/dynamic_samples_jit
vincentmr May 2, 2024
0c3d5d3
Update pennylane/devices/qubit/apply_operation.py
vincentmr May 2, 2024
77bdf4c
Update pennylane/devices/qubit/apply_operation.py
vincentmr May 2, 2024
13fc37b
Merge branch 'master' into feature/dynamic_samples_jit
vincentmr May 6, 2024
c08a672
Enabe other class of MCMs with is_mcm.
vincentmr May 6, 2024
7f56309
Merge remote-tracking branch 'origin/feature/dynamic_samples_jit' int…
vincentmr May 6, 2024
d4e6940
Merge remote-tracking branch 'origin/master' into feature/batched_tape
vincentmr May 9, 2024
da197a9
Update pennylane/devices/qubit/simulate.py
vincentmr May 10, 2024
5a57363
Christina's suggestions.
vincentmr May 10, 2024
ec701bd
Revert prng_key=keys
vincentmr May 10, 2024
a18e8cf
vv => simulate_partial
vincentmr May 10, 2024
d01f453
Merge branch 'master' into feature/batched_tape
vincentmr May 10, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 13 additions & 2 deletions pennylane/_qubit_device.py
Original file line number Diff line number Diff line change
Expand Up @@ -275,8 +275,19 @@ def execute(self, circuit, **kwargs):
self.check_validity(circuit.operations, circuit.observables)

has_mcm = any(isinstance(op, MidMeasureMP) for op in circuit.operations)
if has_mcm:
kwargs["mid_measurements"] = {}
if has_mcm and "mid_measurements" not in kwargs:
vincentmr marked this conversation as resolved.
Show resolved Hide resolved
results = []
aux_circ = qml.tape.QuantumScript(
circuit.operations,
circuit.measurements,
shots=[1],
trainable_params=circuit.trainable_params,
)
for _ in circuit.shots:
kwargs["mid_measurements"] = {}
self.reset()
results.append(self.execute(aux_circ, **kwargs))
return tuple(results)
# apply all circuit operations
self.apply(
circuit.operations,
Expand Down
28 changes: 26 additions & 2 deletions pennylane/devices/qubit/simulate.py
Original file line number Diff line number Diff line change
Expand Up @@ -279,9 +279,33 @@ def simulate(

has_mcm = any(isinstance(op, MidMeasureMP) for op in circuit.operations)
if circuit.shots and has_mcm:
return simulate_one_shot_native_mcm(
circuit, debugger=debugger, rng=rng, prng_key=prng_key, interface=interface
results = []
aux_circ = qml.tape.QuantumScript(
circuit.operations,
circuit.measurements,
shots=[1],
trainable_params=circuit.trainable_params,
)
keys = jax_random_split(prng_key, num=circuit.shots.total_shots)
if qml.math.get_deep_interface(circuit.data) == "jax":
# pylint: disable=import-outside-toplevel
import jax

def vv(k):
vincentmr marked this conversation as resolved.
Show resolved Hide resolved
return simulate_one_shot_native_mcm(
aux_circ, debugger=debugger, rng=rng, prng_key=k, interface=interface
)

results = jax.vmap(vv, in_axes=(0,))(keys)
vincentmr marked this conversation as resolved.
Show resolved Hide resolved
vincentmr marked this conversation as resolved.
Show resolved Hide resolved
results = list(zip(*results))
vincentmr marked this conversation as resolved.
Show resolved Hide resolved
else:
for i, _ in enumerate(circuit.shots):
vincentmr marked this conversation as resolved.
Show resolved Hide resolved
results.append(
simulate_one_shot_native_mcm(
aux_circ, debugger=debugger, rng=rng, prng_key=keys[i], interface=interface
vincentmr marked this conversation as resolved.
Show resolved Hide resolved
)
)
return tuple(results)

ops_key, meas_key = jax_random_split(prng_key)
state, is_state_batched = get_final_state(
Expand Down
41 changes: 21 additions & 20 deletions pennylane/transforms/dynamic_one_shot.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,8 +87,7 @@ def func(x, y):
few-shots several-mid-circuit-measurement limit, whereas ``qml.defer_measurements`` is favorable
in the opposite limit.
"""

if not any(isinstance(o, MidMeasureMP) for o in tape.operations):
if not any(is_mcm(o) for o in tape.operations):
return (tape,), null_postprocessing

for m in tape.measurements:
Expand All @@ -103,9 +102,7 @@ def func(x, y):
raise qml.QuantumFunctionError("dynamic_one_shot is only supported with finite shots.")

samples_present = any(isinstance(mp, SampleMP) for mp in tape.measurements)
postselect_present = any(
op.postselect is not None for op in tape.operations if isinstance(op, MidMeasureMP)
)
postselect_present = any(op.postselect is not None for op in tape.operations if is_mcm(op))
if postselect_present and samples_present and tape.batch_size is not None:
raise ValueError(
"Returning qml.sample is not supported when postselecting mid-circuit "
Expand All @@ -119,38 +116,33 @@ def func(x, y):
broadcast_fn = None

aux_tapes = [init_auxiliary_tape(t) for t in tapes]
# Shape of output_tapes is (batch_size * total_shots,) with broadcasting,
# and (total_shots,) otherwise
output_tapes = [at for at in aux_tapes for _ in range(tape.shots.total_shots)]

def processing_fn(results, has_partitioned_shots=None, batched_results=None):
if batched_results is None and batch_size is not None:
# If broadcasting, recursively process the results for each batch. For each batch
# there are tape.shots.total_shots results. The length of the first axis of final_results
# will be batch_size.
results = list(results)
final_results = []
for _ in range(batch_size):
final_results.append(
processing_fn(results[0 : tape.shots.total_shots], batched_results=False)
)
del results[0 : tape.shots.total_shots]
for result in results:
final_results.append(processing_fn((result,), batched_results=False))
return broadcast_fn(final_results)

if has_partitioned_shots is None and tape.shots.has_partitioned_shots:
# If using shot vectors, recursively process the results for each shot bin. The length
# of the first axis of final_results will be the length of the shot vector.
results = list(results)
results = list(results[0])
final_results = []
for s in tape.shots:
final_results.append(
processing_fn(results[0:s], has_partitioned_shots=False, batched_results=False)
)
del results[0:s]
return tuple(final_results)
if not tape.shots.has_partitioned_shots:
results = results[0]
return parse_native_mid_circuit_measurements(tape, aux_tapes, results)

return output_tapes, processing_fn
return aux_tapes, processing_fn


@dynamic_one_shot.custom_qnode_transform
Expand All @@ -177,6 +169,12 @@ def _dynamic_one_shot_qnode(self, qnode, targs, tkwargs):
return self.default_qnode_transform(qnode, targs, tkwargs)


def is_mcm(operation):
"""Returns True if the operation is a mid-circuit measurement and False otherwise."""
mcm = isinstance(operation, MidMeasureMP)
return mcm or "MidCircuitMeasure" in str(type(operation))


def init_auxiliary_tape(circuit: qml.tape.QuantumScript):
"""Creates an auxiliary circuit to perform one-shot mid-circuit measurement calculations.

Expand All @@ -197,11 +195,14 @@ def init_auxiliary_tape(circuit: qml.tape.QuantumScript):
else:
new_measurements.append(m)
for op in circuit:
if isinstance(op, MidMeasureMP):
if is_mcm(op):
new_measurements.append(qml.sample(MeasurementValue([op], lambda res: res)))

return qml.tape.QuantumScript(
circuit.operations, new_measurements, shots=1, trainable_params=circuit.trainable_params
circuit.operations,
new_measurements,
shots=[1] * circuit.shots.total_shots,
trainable_params=circuit.trainable_params,
)


Expand All @@ -228,7 +229,7 @@ def measurement_with_no_shots(measurement):

interface = qml.math.get_deep_interface(circuit.data)

all_mcms = [op for op in aux_tapes[0].operations if isinstance(op, MidMeasureMP)]
all_mcms = [op for op in aux_tapes[0].operations if is_mcm(op)]
n_mcms = len(all_mcms)
post_process_tape = qml.tape.QuantumScript(
aux_tapes[0].operations,
Expand All @@ -248,7 +249,7 @@ def measurement_with_no_shots(measurement):
).reshape((1, -1))
is_valid = qml.math.all(mcm_samples * has_postselect == postselect, axis=1)
has_valid = qml.math.any(is_valid)
mid_meas = [op for op in circuit.operations if isinstance(op, MidMeasureMP)]
mid_meas = [op for op in circuit.operations if is_mcm(op)]
mcm_samples = [mcm_samples[:, i : i + 1] for i in range(n_mcms)]
mcm_samples = dict((k, v) for k, v in zip(mid_meas, mcm_samples))

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -692,7 +692,7 @@ def func(x, y, z):
assert "pure_callback" not in jaxpr

func2 = jax.jit(func)
results2 = func2(*jax.numpy.array(params))
results2 = func2(*params)

measures = [
# qml.probs,
Expand Down
8 changes: 4 additions & 4 deletions tests/transforms/test_dynamic_one_shot.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ def test_len_tapes(n_shots):
"""Test that the transform produces the correct number of tapes."""
tape = qml.tape.QuantumScript([MidMeasureMP(0)], [qml.expval(qml.PauliZ(0))], shots=n_shots)
tapes, _ = qml.dynamic_one_shot(tape)
assert len(tapes) == n_shots
assert len(tapes) == 1


@pytest.mark.parametrize("n_batch", range(1, 4))
Expand All @@ -102,7 +102,7 @@ def test_len_tape_batched(n_batch, n_shots):
shots=n_shots,
)
tapes, _ = qml.dynamic_one_shot(tape)
assert len(tapes) == n_shots * n_batch
assert len(tapes) == n_batch


@pytest.mark.parametrize(
Expand All @@ -123,7 +123,7 @@ def test_len_measurements_obs(measure, aux_measure, n_meas):
[qml.Hadamard(0)] + [MidMeasureMP(0)] * n_mcms, [measure(op=qml.PauliZ(0))], shots=n_shots
)
tapes, _ = qml.dynamic_one_shot(tape)
assert len(tapes) == n_shots
assert len(tapes) == 1
aux_tape = tapes[0]
assert len(aux_tape.measurements) == n_meas + n_mcms
assert isinstance(aux_tape.measurements[0], aux_measure)
Expand All @@ -150,7 +150,7 @@ def test_len_measurements_mcms(measure, aux_measure, n_meas):
shots=n_shots,
)
tapes, _ = qml.dynamic_one_shot(tape)
assert len(tapes) == n_shots
assert len(tapes) == 1
aux_tape = tapes[0]
assert len(aux_tape.measurements) == n_meas + n_mcms
assert isinstance(aux_tape.measurements[0], aux_measure)
Expand Down
Loading