Skip to content

Commit

Permalink
dynamic_one_shot uses tapes with shot-vectors and jitting takes adv…
Browse files Browse the repository at this point in the history
…antage of it (#5617)

### Before submitting

Please complete the following checklist when submitting a PR:

- [x] All new features must include a unit test.
If you've fixed a bug or added code that should be tested, add a test to
the
      test directory!

- [x] All new functions and code must be clearly commented and
documented.
If you do make documentation changes, make sure that the docs build and
      render correctly by running `make docs`.

- [ ] Ensure that the test suite passes, by running `make test`.

- [x] Add a new entry to the `doc/releases/changelog-dev.md` file,
summarizing the
      change, and including a link back to the PR.

- [x] The PennyLane source code conforms to
      [PEP8 standards](https://www.python.org/dev/peps/pep-0008/).
We check all of our code against [Pylint](https://www.pylint.org/).
      To lint modified files, simply `pip install pylint`, and then
      run `pylint pennylane/path/to/file.py`.

When all the above are checked, delete everything above the dashed
line and fill in the pull request template.


------------------------------------------------------------------------------------------------------------

**Context:**
`dynamic_one_shot` creates n-shots tapes which is wasteful.

**Description of the Change:**
Create a single tape with a shot-vector which indicates to the device
how many times to repeat the tape execution.

**Benefits:**
For a tape like
```
dev = qml.device("default.qubit", shots=2000, seed=jax.random.PRNGKey(123))

@qml.qnode(dev, diff_method=None)
def func(x, y):
    qml.RX(x, wires=0)
    m0 = qml.measure(0, reset=False, postselect=1)
    qml.cond(m0, qml.RY)(y, wires=1)
    return qml.expval(qml.PauliZ(0))


params = np.pi / 4 * np.ones(2)
```
The execution times are as follows (Latitude laptop):
- 12.7 s : vanilla Python
- 8.2 s : jax.vmap
- 11.1 s : jax.jit + jax.vmap + compilation
- 6.89 ms : jax.jit + jax.vmap

**Possible Drawbacks:**

**Related GitHub Issues:**
[sc-62097]

---------

Co-authored-by: Mudit Pandey <mudit.pandey@xanadu.ai>
Co-authored-by: Christina Lee <christina@xanadu.ai>
  • Loading branch information
3 people committed May 10, 2024
1 parent 1d34de9 commit 9c9b6ba
Show file tree
Hide file tree
Showing 6 changed files with 68 additions and 29 deletions.
3 changes: 3 additions & 0 deletions doc/releases/changelog-dev.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,9 @@

<h4>Mid-circuit measurements and dynamic circuits</h4>

* The `dynamic_one_shot` transform uses a single auxiliary tape with a shot vector and `default.qubit` implements the loop over shots with `jax.vmap`.
[(#5617)](https://github.com/PennyLaneAI/pennylane/pull/5617)

* The `dynamic_one_shot` transform can be compiled with `jax.jit`.
[(#5557)](https://github.com/PennyLaneAI/pennylane/pull/5557)

Expand Down
15 changes: 13 additions & 2 deletions pennylane/_qubit_device.py
Original file line number Diff line number Diff line change
Expand Up @@ -275,8 +275,19 @@ def execute(self, circuit, **kwargs):
self.check_validity(circuit.operations, circuit.observables)

has_mcm = any(isinstance(op, MidMeasureMP) for op in circuit.operations)
if has_mcm:
kwargs["mid_measurements"] = {}
if has_mcm and "mid_measurements" not in kwargs:
results = []
aux_circ = qml.tape.QuantumScript(
circuit.operations,
circuit.measurements,
shots=[1],
trainable_params=circuit.trainable_params,
)
for _ in circuit.shots:
kwargs["mid_measurements"] = {}
self.reset()
results.append(self.execute(aux_circ, **kwargs))
return tuple(results)
# apply all circuit operations
self.apply(
circuit.operations,
Expand Down
28 changes: 26 additions & 2 deletions pennylane/devices/qubit/simulate.py
Original file line number Diff line number Diff line change
Expand Up @@ -279,9 +279,33 @@ def simulate(

has_mcm = any(isinstance(op, MidMeasureMP) for op in circuit.operations)
if circuit.shots and has_mcm:
return simulate_one_shot_native_mcm(
circuit, debugger=debugger, rng=rng, prng_key=prng_key, interface=interface
results = []
aux_circ = qml.tape.QuantumScript(
circuit.operations,
circuit.measurements,
shots=[1],
trainable_params=circuit.trainable_params,
)
keys = jax_random_split(prng_key, num=circuit.shots.total_shots)
if qml.math.get_deep_interface(circuit.data) == "jax":
# pylint: disable=import-outside-toplevel
import jax

def simulate_partial(k):
return simulate_one_shot_native_mcm(
aux_circ, debugger=debugger, rng=rng, prng_key=k, interface=interface
)

results = jax.vmap(simulate_partial, in_axes=(0,))(keys)
results = tuple(zip(*results))
else:
for i in range(circuit.shots.total_shots):
results.append(
simulate_one_shot_native_mcm(
aux_circ, debugger=debugger, rng=rng, prng_key=keys[i], interface=interface
)
)
return tuple(results)

ops_key, meas_key = jax_random_split(prng_key)
state, is_state_batched = get_final_state(
Expand Down
41 changes: 21 additions & 20 deletions pennylane/transforms/dynamic_one_shot.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,8 +87,7 @@ def func(x, y):
few-shots several-mid-circuit-measurement limit, whereas ``qml.defer_measurements`` is favorable
in the opposite limit.
"""

if not any(isinstance(o, MidMeasureMP) for o in tape.operations):
if not any(is_mcm(o) for o in tape.operations):
return (tape,), null_postprocessing

for m in tape.measurements:
Expand All @@ -103,9 +102,7 @@ def func(x, y):
raise qml.QuantumFunctionError("dynamic_one_shot is only supported with finite shots.")

samples_present = any(isinstance(mp, SampleMP) for mp in tape.measurements)
postselect_present = any(
op.postselect is not None for op in tape.operations if isinstance(op, MidMeasureMP)
)
postselect_present = any(op.postselect is not None for op in tape.operations if is_mcm(op))
if postselect_present and samples_present and tape.batch_size is not None:
raise ValueError(
"Returning qml.sample is not supported when postselecting mid-circuit "
Expand All @@ -119,38 +116,33 @@ def func(x, y):
broadcast_fn = None

aux_tapes = [init_auxiliary_tape(t) for t in tapes]
# Shape of output_tapes is (batch_size * total_shots,) with broadcasting,
# and (total_shots,) otherwise
output_tapes = [at for at in aux_tapes for _ in range(tape.shots.total_shots)]

def processing_fn(results, has_partitioned_shots=None, batched_results=None):
if batched_results is None and batch_size is not None:
# If broadcasting, recursively process the results for each batch. For each batch
# there are tape.shots.total_shots results. The length of the first axis of final_results
# will be batch_size.
results = list(results)
final_results = []
for _ in range(batch_size):
final_results.append(
processing_fn(results[0 : tape.shots.total_shots], batched_results=False)
)
del results[0 : tape.shots.total_shots]
for result in results:
final_results.append(processing_fn((result,), batched_results=False))
return broadcast_fn(final_results)

if has_partitioned_shots is None and tape.shots.has_partitioned_shots:
# If using shot vectors, recursively process the results for each shot bin. The length
# of the first axis of final_results will be the length of the shot vector.
results = list(results)
results = list(results[0])
final_results = []
for s in tape.shots:
final_results.append(
processing_fn(results[0:s], has_partitioned_shots=False, batched_results=False)
)
del results[0:s]
return tuple(final_results)
if not tape.shots.has_partitioned_shots:
results = results[0]
return parse_native_mid_circuit_measurements(tape, aux_tapes, results)

return output_tapes, processing_fn
return aux_tapes, processing_fn


@dynamic_one_shot.custom_qnode_transform
Expand All @@ -177,6 +169,12 @@ def _dynamic_one_shot_qnode(self, qnode, targs, tkwargs):
return self.default_qnode_transform(qnode, targs, tkwargs)


def is_mcm(operation):
"""Returns True if the operation is a mid-circuit measurement and False otherwise."""
mcm = isinstance(operation, MidMeasureMP)
return mcm or "MidCircuitMeasure" in str(type(operation))


def init_auxiliary_tape(circuit: qml.tape.QuantumScript):
"""Creates an auxiliary circuit to perform one-shot mid-circuit measurement calculations.
Expand All @@ -197,11 +195,14 @@ def init_auxiliary_tape(circuit: qml.tape.QuantumScript):
else:
new_measurements.append(m)
for op in circuit:
if isinstance(op, MidMeasureMP):
if is_mcm(op):
new_measurements.append(qml.sample(MeasurementValue([op], lambda res: res)))

return qml.tape.QuantumScript(
circuit.operations, new_measurements, shots=1, trainable_params=circuit.trainable_params
circuit.operations,
new_measurements,
shots=[1] * circuit.shots.total_shots,
trainable_params=circuit.trainable_params,
)


Expand All @@ -228,7 +229,7 @@ def measurement_with_no_shots(measurement):

interface = qml.math.get_deep_interface(circuit.data)

all_mcms = [op for op in aux_tapes[0].operations if isinstance(op, MidMeasureMP)]
all_mcms = [op for op in aux_tapes[0].operations if is_mcm(op)]
n_mcms = len(all_mcms)
post_process_tape = qml.tape.QuantumScript(
aux_tapes[0].operations,
Expand All @@ -248,7 +249,7 @@ def measurement_with_no_shots(measurement):
).reshape((1, -1))
is_valid = qml.math.all(mcm_samples * has_postselect == postselect, axis=1)
has_valid = qml.math.any(is_valid)
mid_meas = [op for op in circuit.operations if isinstance(op, MidMeasureMP)]
mid_meas = [op for op in circuit.operations if is_mcm(op)]
mcm_samples = [mcm_samples[:, i : i + 1] for i in range(n_mcms)]
mcm_samples = dict((k, v) for k, v in zip(mid_meas, mcm_samples))

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -693,7 +693,7 @@ def func(x, y, z):
assert "pure_callback" not in jaxpr

func2 = jax.jit(func)
results2 = func2(*jax.numpy.array(params))
results2 = func2(*params)

measures = [
qml.probs,
Expand Down
8 changes: 4 additions & 4 deletions tests/transforms/test_dynamic_one_shot.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ def test_len_tapes(n_shots):
"""Test that the transform produces the correct number of tapes."""
tape = qml.tape.QuantumScript([MidMeasureMP(0)], [qml.expval(qml.PauliZ(0))], shots=n_shots)
tapes, _ = qml.dynamic_one_shot(tape)
assert len(tapes) == n_shots
assert len(tapes) == 1


@pytest.mark.parametrize("n_batch", range(1, 4))
Expand All @@ -102,7 +102,7 @@ def test_len_tape_batched(n_batch, n_shots):
shots=n_shots,
)
tapes, _ = qml.dynamic_one_shot(tape)
assert len(tapes) == n_shots * n_batch
assert len(tapes) == n_batch


@pytest.mark.parametrize(
Expand All @@ -123,7 +123,7 @@ def test_len_measurements_obs(measure, aux_measure, n_meas):
[qml.Hadamard(0)] + [MidMeasureMP(0)] * n_mcms, [measure(op=qml.PauliZ(0))], shots=n_shots
)
tapes, _ = qml.dynamic_one_shot(tape)
assert len(tapes) == n_shots
assert len(tapes) == 1
aux_tape = tapes[0]
assert len(aux_tape.measurements) == n_meas + n_mcms
assert isinstance(aux_tape.measurements[0], aux_measure)
Expand All @@ -150,7 +150,7 @@ def test_len_measurements_mcms(measure, aux_measure, n_meas):
shots=n_shots,
)
tapes, _ = qml.dynamic_one_shot(tape)
assert len(tapes) == n_shots
assert len(tapes) == 1
aux_tape = tapes[0]
assert len(aux_tape.measurements) == n_meas + n_mcms
assert isinstance(aux_tape.measurements[0], aux_measure)
Expand Down

0 comments on commit 9c9b6ba

Please sign in to comment.