Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

dynamic_one_shot uses tapes with shot-vectors and jitting takes advantage of it #5617

Merged
merged 103 commits into from
May 10, 2024
Merged
Show file tree
Hide file tree
Changes from 97 commits
Commits
Show all changes
103 commits
Select commit Hold shift + click to select a range
2d6d336
Added rng and prng_key to get_final_state, apply_operation
mudit2812 Mar 7, 2024
7560cdf
Use rng in apply_operation args; linting
mudit2812 Mar 7, 2024
11290a7
WIP
vincentmr Apr 4, 2024
1313af8
Fix measure_with_samples' handling of mid_measurements.
vincentmr Apr 9, 2024
412374e
Remove comments.
vincentmr Apr 9, 2024
bd32a35
Merge branch 'master' into feature/dynamic_samples
vincentmr Apr 9, 2024
19d7f73
update changelog
vincentmr Apr 9, 2024
8aea51e
Fix legacy node native mcm test.
vincentmr Apr 9, 2024
3fb1758
Fix old device API
vincentmr Apr 9, 2024
57386e6
Merge branch 'master' into feature/dynamic_samples
vincentmr Apr 9, 2024
b9d8996
Fill out mid_measurements in mock device.
vincentmr Apr 10, 2024
5f7a33f
Refactor using masks.
vincentmr Apr 10, 2024
146baef
Merge remote-tracking branch 'origin/master' into feature/dynamic_sam…
vincentmr Apr 10, 2024
c975493
Update rng use; fix docs
mudit2812 Apr 10, 2024
6899b2c
[skip ci] Skip CI
mudit2812 Apr 10, 2024
87f34b2
Always compute all results, even if mv = -1.
vincentmr Apr 11, 2024
721c284
[skip ci] testing changes to native MCM tests
mudit2812 Apr 11, 2024
4aabfe5
Make post-processing jax-ready. WIP
vincentmr Apr 11, 2024
2614d53
Update pennylane/devices/qubit/sampling.py
vincentmr Apr 12, 2024
c707d46
Implement Christina's suggestions.
vincentmr Apr 12, 2024
bd800bb
Merge branch 'master' into feature/dynamic_samples
vincentmr Apr 12, 2024
f89139a
Add rng and jit mid_measure. WARN: regression
vincentmr Apr 12, 2024
f9c7b10
Merge branch 'master' into feature/dynamic_samples
vincentmr Apr 12, 2024
f7d7fd2
WIP
vincentmr Apr 12, 2024
a836a51
Implement apply_mid_meas with norm of a branch @mudit.
vincentmr Apr 12, 2024
5b1b732
Merge remote-tracking branch 'origin/simulate-rng' into feature/dynam…
vincentmr Apr 12, 2024
34b513d
Add jax-jit support in apply_cond/mid_meas.
vincentmr Apr 12, 2024
2562b94
Introduce prng
vincentmr Apr 12, 2024
3db8767
WIP
vincentmr Apr 12, 2024
ceaa339
Merge branch 'master' into feature/dynamic_samples
vincentmr Apr 15, 2024
ccafa89
Merge branch 'master' into feature/dynamic_samples
vincentmr Apr 15, 2024
e1804a8
Remove jax branch in apply_mid_meas.
vincentmr Apr 15, 2024
f7c65bc
Fix single wire probs.
vincentmr Apr 15, 2024
5aec928
Bug fix MV lists.
vincentmr Apr 15, 2024
11d6d43
Add tests for math.all/any
vincentmr Apr 17, 2024
2753b26
Merge remote-tracking branch 'origin/master' into feature/dynamic_sam…
vincentmr Apr 17, 2024
62ef926
qml.math.all doesn't need implementation.
vincentmr Apr 17, 2024
6c15204
Merge remote-tracking branch 'origin/master' into feature/dynamic_sam…
vincentmr Apr 17, 2024
97cf86e
Fix error message @albi3ro
vincentmr Apr 18, 2024
ce9a4a1
Merge branch 'master' into feature/dynamic_samples
vincentmr Apr 18, 2024
027c65a
Test measure_final_state raises.
vincentmr Apr 18, 2024
85334e7
Add tests in tests/transforms/test_dynamic_one_shot.py
vincentmr Apr 19, 2024
de47630
Merge branch 'master' into feature/dynamic_samples
vincentmr Apr 19, 2024
0c295ac
Fix lint.
vincentmr Apr 19, 2024
eeed3d4
Add error test.
vincentmr Apr 19, 2024
387d898
Merge branch 'master' into feature/dynamic_samples
vincentmr Apr 19, 2024
88e9eca
Merge branch 'master' into simulate-rng
mudit2812 Apr 19, 2024
985775e
Merge branch 'master' into simulate-rng
mudit2812 Apr 19, 2024
e1d8e4b
Merge remote-tracking branch 'origin/master' into feature/dynamic_sam…
vincentmr Apr 22, 2024
7f7be2d
Merge remote-tracking branch 'origin/master' into feature/dynamic_sam…
vincentmr Apr 22, 2024
d7a91fc
Add test for batched dynamic_one_shot
vincentmr Apr 22, 2024
7125646
Merge branch 'master' into feature/dynamic_samples
vincentmr Apr 22, 2024
c46a46a
Merge remote-tracking branch 'origin/feature/dynamic_samples' into fe…
vincentmr Apr 22, 2024
ae5abe7
Refactor here and there.
vincentmr Apr 22, 2024
b6b5e68
Sort imports
vincentmr Apr 22, 2024
8162704
Revert isort changes.
vincentmr Apr 22, 2024
8f3e558
Merge branch 'master' into simulate-rng
mudit2812 Apr 22, 2024
e091e86
Merge remote-tracking branch 'origin/master' into feature/dynamic_sam…
vincentmr Apr 22, 2024
b0a1ec9
Correctly propagating PRNGKey and RNG to apply_operation
mudit2812 Apr 22, 2024
6ad8352
Fix _sample_state_jax cond
vincentmr Apr 23, 2024
c641af4
Fix test_parse_native_mid_circuit_measurements_unsupported_meas
vincentmr Apr 23, 2024
42a649d
DQ.execute distributes kwargs with multithreading
mudit2812 Apr 23, 2024
148d051
\Merge remote-tracking branch 'origin/simulate-rng' into feature/dyna…
vincentmr Apr 23, 2024
bec05b7
Add device as kwargs to dyn_one_shot
vincentmr Apr 24, 2024
b301554
Merge remote-tracking branch 'origin/master' into feature/dynamic_sam…
vincentmr Apr 29, 2024
0565c61
Filter postselected values in post-processing.
vincentmr Apr 29, 2024
bb49932
Add jax.jit tests.
vincentmr Apr 29, 2024
a0b6703
Merge remote-tracking branch 'origin/master' into feature/dynamic_sam…
vincentmr Apr 29, 2024
fabfe8f
Merge remote-tracking branch 'origin/master' into feature/dynamic_sam…
vincentmr Apr 30, 2024
3c05878
Fix flaky test's seed in test_broadcast_expand.
vincentmr Apr 30, 2024
9ea624e
jax.numpy.array(params)
vincentmr Apr 30, 2024
11f369f
Merge branch 'master' into feature/dynamic_samples_jit
vincentmr Apr 30, 2024
5398ce7
Update pennylane/devices/qubit/apply_operation.py
vincentmr Apr 30, 2024
f140b5c
Update pennylane/devices/qubit/apply_operation.py
vincentmr Apr 30, 2024
a858e0e
Update pennylane/devices/qubit/sampling.py
vincentmr Apr 30, 2024
16a1852
WIP
vincentmr Apr 30, 2024
303faaf
Implement Mudit's suggestions.
vincentmr Apr 30, 2024
e491b63
Merge remote-tracking branch 'origin/feature/dynamic_samples_jit' int…
vincentmr Apr 30, 2024
28ed449
Clean up ImageTape.
vincentmr May 1, 2024
d948334
Merge remote-tracking branch 'origin/master' into feature/dynamic_sam…
vincentmr May 1, 2024
91945be
Use more robust QubitUnitary for the time being.
vincentmr May 1, 2024
457bcae
Merge remote-tracking branch 'origin/master' into feature/dynamic_sam…
vincentmr May 1, 2024
ac42ef0
Make sure we're not using Python callbacks with jaxpr.
vincentmr May 1, 2024
7cc9472
Fix reset matrix.
vincentmr May 1, 2024
b13961e
Remove unused where.
vincentmr May 1, 2024
268e4e6
Merge branch 'master' into feature/dynamic_samples_jit
vincentmr May 1, 2024
24d4c59
Merge remote-tracking branch 'origin/feature/dynamic_samples_jit' int…
vincentmr May 2, 2024
bbe4f7e
Fix lint/tests.
vincentmr May 2, 2024
969c630
Fix legacy.
vincentmr May 2, 2024
0af62d2
Remove obsolete change [skip ci]
vincentmr May 2, 2024
2b1db5f
Add dev notes and remove useless importorskip.
vincentmr May 2, 2024
6c3bd0b
Merge branch 'master' into feature/dynamic_samples_jit
vincentmr May 2, 2024
0c3d5d3
Update pennylane/devices/qubit/apply_operation.py
vincentmr May 2, 2024
77bdf4c
Update pennylane/devices/qubit/apply_operation.py
vincentmr May 2, 2024
13fc37b
Merge branch 'master' into feature/dynamic_samples_jit
vincentmr May 6, 2024
c08a672
Enabe other class of MCMs with is_mcm.
vincentmr May 6, 2024
7f56309
Merge remote-tracking branch 'origin/feature/dynamic_samples_jit' int…
vincentmr May 6, 2024
d4e6940
Merge remote-tracking branch 'origin/master' into feature/batched_tape
vincentmr May 9, 2024
da197a9
Update pennylane/devices/qubit/simulate.py
vincentmr May 10, 2024
5a57363
Christina's suggestions.
vincentmr May 10, 2024
ec701bd
Revert prng_key=keys
vincentmr May 10, 2024
a18e8cf
vv => simulate_partial
vincentmr May 10, 2024
d01f453
Merge branch 'master' into feature/batched_tape
vincentmr May 10, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion doc/releases/changelog-dev.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,11 @@

<h3>Improvements 🛠</h3>

<h4>Mid-circuit measurements and dynamic circuits</h4>

* The `dynamic_one_shot` transform can be compiled with `jax.jit`.
[(#5557)](https://github.com/PennyLaneAI/pennylane/pull/5557)

* When using `defer_measurements` with postselecting mid-circuit measurements, operations
that will never be active due to the postselected state are skipped in the transformed
quantum circuit. In addition, postselected controls are skipped, as they are evaluated
Expand Down Expand Up @@ -77,4 +82,5 @@ This release contains contributions from (in alphabetical order):

Pietropaolo Frisoni,
Christina Lee,
David Wierichs.
Vincent Michaud-Rioux,
David Wierichs.
15 changes: 13 additions & 2 deletions pennylane/_qubit_device.py
Original file line number Diff line number Diff line change
Expand Up @@ -275,8 +275,19 @@ def execute(self, circuit, **kwargs):
self.check_validity(circuit.operations, circuit.observables)

has_mcm = any(isinstance(op, MidMeasureMP) for op in circuit.operations)
if has_mcm:
kwargs["mid_measurements"] = {}
if has_mcm and "mid_measurements" not in kwargs:
vincentmr marked this conversation as resolved.
Show resolved Hide resolved
results = []
aux_circ = qml.tape.QuantumScript(
circuit.operations,
circuit.measurements,
shots=[1],
trainable_params=circuit.trainable_params,
)
for _ in circuit.shots:
kwargs["mid_measurements"] = {}
self.reset()
results.append(self.execute(aux_circ, **kwargs))
return tuple(results)
# apply all circuit operations
self.apply(
circuit.operations,
Expand Down
84 changes: 61 additions & 23 deletions pennylane/devices/qubit/apply_operation.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""Functions to apply an operation to a state vector."""
# pylint: disable=unused-argument
# pylint: disable=unused-argument, too-many-arguments

from functools import singledispatch
from string import ascii_letters as alphabet
Expand All @@ -21,7 +21,7 @@

import pennylane as qml
from pennylane import math
from pennylane.measurements import MidMeasureMP, Shots
from pennylane.measurements import MidMeasureMP
from pennylane.ops import Conditional

SQRT2INV = 1 / math.sqrt(2)
Expand Down Expand Up @@ -238,14 +238,36 @@ def apply_conditional(
ndarray: output state
"""
mid_measurements = execution_kwargs.get("mid_measurements", None)

rng = execution_kwargs.get("rng", None)
prng_key = execution_kwargs.get("prng_key", None)
interface = qml.math.get_deep_interface(state)
if interface == "jax":
# pylint: disable=import-outside-toplevel
from jax.lax import cond

return cond(
op.meas_val.concretize(mid_measurements),
lambda x: apply_operation(
op.then_op,
x,
is_state_batched=is_state_batched,
debugger=debugger,
mid_measurements=mid_measurements,
rng=rng,
prng_key=prng_key,
),
lambda x: x,
state,
)
if op.meas_val.concretize(mid_measurements):
return apply_operation(
op.then_op,
state,
is_state_batched=is_state_batched,
debugger=debugger,
mid_measurements=mid_measurements,
rng=rng,
prng_key=prng_key,
)
return state

Expand Down Expand Up @@ -273,31 +295,47 @@ def apply_mid_measure(
mid_measurements = execution_kwargs.get("mid_measurements", None)
rng = execution_kwargs.get("rng", None)
prng_key = execution_kwargs.get("prng_key", None)

if is_state_batched:
raise ValueError("MidMeasureMP cannot be applied to batched states.")
if not np.allclose(np.linalg.norm(state), 1.0):
mid_measurements[op] = -1
return np.zeros_like(state)
wire = op.wires
sample = qml.devices.qubit.sampling.measure_with_samples(
[qml.sample(wires=wire)], state, Shots(1), rng=rng, prng_key=prng_key
)
sample = int(sample[0])
mid_measurements[op] = sample
if op.postselect is not None and sample != op.postselect:
mid_measurements[op] = -1
return np.zeros_like(state)
axis = wire.toarray()[0]
slices = [slice(None)] * qml.math.ndim(state)
slices[axis] = int(not sample)
state[tuple(slices)] = 0.0
state_norm = np.linalg.norm(state)
state = state / state_norm
if op.reset and sample == 1:
state = apply_operation(
qml.X(wire), state, is_state_batched=is_state_batched, debugger=debugger
)
slices[axis] = 0
prob0 = qml.math.norm(state[tuple(slices)]) ** 2
interface = qml.math.get_deep_interface(state)
if prng_key is not None:
# pylint: disable=import-outside-toplevel
from jax.random import binomial

def binomial_fn(n, p):
return binomial(prng_key, n, p).astype(int)

else:
binomial_fn = np.random.binomial if rng is None else rng.binomial
sample = binomial_fn(1, 1 - prob0)
mid_measurements[op] = sample

# Using apply_operation(qml.QubitUnitary,...) instead of apply_operation(qml.Projector([sample], wire),...)
# to select the sample branch enables jax.jit and prevents it from using Python callbacks
matrix = qml.math.array([[(sample + 1) % 2, 0.0], [0.0, (sample) % 2]], like=interface)
state = apply_operation(
qml.QubitUnitary(matrix, wire),
state,
is_state_batched=is_state_batched,
debugger=debugger,
)
state = state / qml.math.norm(state)

# Using apply_operation(qml.QubitUnitary,...) instead of apply_operation(qml.X(wire), ...)
# to reset enables jax.jit and prevents it from using Python callbacks
element = op.reset and sample == 1
matrix = qml.math.array(
[[(element + 1) % 2, (element) % 2], [(element) % 2, (element + 1) % 2]], like=interface
).astype(float)
state = apply_operation(
qml.QubitUnitary(matrix, wire), state, is_state_batched=is_state_batched, debugger=debugger
)

return state


Expand Down
15 changes: 5 additions & 10 deletions pennylane/devices/qubit/sampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@
def jax_random_split(prng_key, num: int = 2):
"""Get a new key with ``jax.random.split``."""
if prng_key is None:
return [None] * num
return (None,) * num
# pylint: disable=import-outside-toplevel
from jax.random import split

Expand Down Expand Up @@ -213,15 +213,11 @@ def measure_with_samples(
"""
# last N measurements are sampling MCMs in ``dynamic_one_shot`` execution mode
mps = measurements[0 : -len(mid_measurements)] if mid_measurements else measurements
skip_measure = any(v == -1 for v in mid_measurements.values()) if mid_measurements else False

groups, indices = _group_measurements(mps)

all_res = []
for group in groups:
if skip_measure:
all_res.extend([None] * len(group))
continue
if isinstance(group[0], ExpectationMP) and isinstance(
group[0].obs, (Hamiltonian, LinearCombination)
):
Expand Down Expand Up @@ -477,11 +473,10 @@ def sample_state(
# probabilities must be renormalized as they may not sum to one
# see https://github.com/PennyLaneAI/pennylane/issues/5444
norm = qml.math.sum(probs, axis=-1)
abs_diff = np.abs(norm - 1.0)
abs_diff = qml.math.abs(norm - 1.0)
cutoff = 1e-07

if is_state_batched:

normalize_condition = False

for s in abs_diff:
Expand All @@ -497,9 +492,9 @@ def sample_state(
# rng.choice doesn't support broadcasting
samples = np.stack([rng.choice(basis_states, shots, p=p) for p in probs])
else:

if 0 < abs_diff < cutoff:
probs /= norm
if not 0 < abs_diff < cutoff:
norm = 1.0
probs = probs / norm

samples = rng.choice(basis_states, shots, p=probs)

Expand Down
28 changes: 26 additions & 2 deletions pennylane/devices/qubit/simulate.py
Original file line number Diff line number Diff line change
Expand Up @@ -279,9 +279,33 @@ def simulate(

has_mcm = any(isinstance(op, MidMeasureMP) for op in circuit.operations)
if circuit.shots and has_mcm:
return simulate_one_shot_native_mcm(
circuit, debugger=debugger, rng=rng, prng_key=prng_key, interface=interface
results = []
aux_circ = qml.tape.QuantumScript(
circuit.operations,
circuit.measurements,
shots=[1],
trainable_params=circuit.trainable_params,
)
keys = jax_random_split(prng_key, num=circuit.shots.total_shots)
if qml.math.get_deep_interface(circuit.data) == "jax":
# pylint: disable=import-outside-toplevel
import jax

def vv(k):
vincentmr marked this conversation as resolved.
Show resolved Hide resolved
return simulate_one_shot_native_mcm(
aux_circ, debugger=debugger, rng=rng, prng_key=k, interface=interface
)

results = jax.vmap(vv, in_axes=(0,))(keys)
vincentmr marked this conversation as resolved.
Show resolved Hide resolved
vincentmr marked this conversation as resolved.
Show resolved Hide resolved
results = list(zip(*results))
vincentmr marked this conversation as resolved.
Show resolved Hide resolved
else:
for i, _ in enumerate(circuit.shots):
vincentmr marked this conversation as resolved.
Show resolved Hide resolved
results.append(
simulate_one_shot_native_mcm(
aux_circ, debugger=debugger, rng=rng, prng_key=keys[i], interface=interface
vincentmr marked this conversation as resolved.
Show resolved Hide resolved
)
)
return tuple(results)

ops_key, meas_key = jax_random_split(prng_key)
state, is_state_batched = get_final_state(
Expand Down
4 changes: 2 additions & 2 deletions pennylane/measurements/mid_measure.py
Original file line number Diff line number Diff line change
Expand Up @@ -443,10 +443,10 @@ def __ge__(self, other):
return self._transform_bin_op(lambda a, b: a >= b, other)

def __and__(self, other):
return self._transform_bin_op(lambda a, b: a and b, other)
return self._transform_bin_op(qml.math.logical_and, other)

def __or__(self, other):
return self._transform_bin_op(lambda a, b: a or b, other)
return self._transform_bin_op(qml.math.logical_or, other)

def _apply(self, fn):
"""Apply a post computation to this measurement"""
Expand Down
Loading
Loading