Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

defer_measurements raises errors for unsupported measurements #4701

Merged
merged 158 commits into from
Oct 23, 2023
Merged
Show file tree
Hide file tree
Changes from 154 commits
Commits
Show all changes
158 commits
Select commit Hold shift + click to select a range
64a3deb
Updated hash/eq
mudit2812 Aug 28, 2023
ed9d1f9
Updated pytest.ini
mudit2812 Aug 28, 2023
1c9a256
Updated changelog
mudit2812 Aug 28, 2023
593115d
Update doc/releases/changelog-dev.md
mudit2812 Aug 28, 2023
4380982
Removed warning imports
mudit2812 Aug 28, 2023
d7ab544
Merge branch 'master' into eq-hash
mudit2812 Aug 28, 2023
12e3fb4
Update tests
mudit2812 Aug 28, 2023
a052b7e
Fixed test
mudit2812 Aug 29, 2023
6647729
Added single MV support for MPs
mudit2812 Aug 29, 2023
b4ef716
Merge branch 'master' into eq-hash
mudit2812 Aug 29, 2023
31bef6e
Added docs; updated defer_measurements
mudit2812 Aug 29, 2023
855a194
Added mp tests
mudit2812 Aug 30, 2023
a003eb3
Added qnode test
mudit2812 Aug 30, 2023
29f93d9
Added deferred_measurements test
mudit2812 Aug 30, 2023
b8d1932
Merge branch 'master' into mcm-stats-1
mudit2812 Aug 30, 2023
e9f5200
Added changelog entry
mudit2812 Aug 30, 2023
096f90f
forgot to add name to changelog
mudit2812 Aug 30, 2023
6baf8ea
Fixed docs
mudit2812 Aug 30, 2023
362ac1c
Fixed test validation function
mudit2812 Aug 30, 2023
83fdd66
Fixed defer_measurementes
mudit2812 Aug 30, 2023
f0226df
Merge branch 'master' into mcm-stats-1
mudit2812 Aug 30, 2023
08d6dc2
Refactoring
mudit2812 Aug 30, 2023
e46e720
Reverted QNode changes
mudit2812 Aug 30, 2023
23d8fd3
Added example to changelog
mudit2812 Aug 30, 2023
dc418ab
Fixed example to changelog
mudit2812 Aug 30, 2023
10bde2e
Merge branch 'master' into eq-hash
mudit2812 Aug 30, 2023
655748a
Updated Hamiltonian.compare
mudit2812 Aug 31, 2023
37667e7
Merge branch 'master' into eq-hash
mudit2812 Aug 31, 2023
5e92495
Updated doc
mudit2812 Aug 31, 2023
b1f3896
Changed QNode back
mudit2812 Aug 31, 2023
a491511
Merge branch 'eq-hash' into mcm-stats-1
mudit2812 Aug 31, 2023
a4e8074
Fixed sample tests
mudit2812 Aug 31, 2023
802c7a0
Removed decimal
mudit2812 Aug 31, 2023
708baa5
Fixed counts
mudit2812 Sep 1, 2023
b82e49b
Merge branch 'master' into mcm-stats-1
mudit2812 Sep 1, 2023
6c2a232
Update changelog
mudit2812 Sep 5, 2023
5ebbc2e
separate DensityMatrixMP into its own measurement process (#4558)
timmysilv Sep 1, 2023
f10bfe4
add default.qubit.legacy and rename some test devices to prepare for …
timmysilv Sep 5, 2023
da51a83
StateMP accepts wires (#4570)
timmysilv Sep 5, 2023
361caf3
Deprecate fancy decorator syntax in batch transforms (#4457)
eddddddy Sep 6, 2023
126bd7b
Batch transforms are updated (#4440)
rmoyard Sep 6, 2023
ebb4d91
Update measurements tests for DQ2; create legacy tests folder (#4574)
timmysilv Sep 6, 2023
bcd09c4
Use fermi sentence in qchem dipole and number functions (#4546)
soranjh Sep 6, 2023
33ec1e6
fix np.random.seed usage (#4581)
timmysilv Sep 7, 2023
58fb45f
Remove qchem deprecated functionality for fermionic observables (#4556)
soranjh Sep 7, 2023
603074f
Update `DecompositionUndefinedError` for `Exp` class (#4571)
Jaybsoni Sep 8, 2023
8816391
Fix expand_tape_state_prep. (#4564)
vincentmr Sep 8, 2023
cc1cb0f
Merge branch 'master' into mcm-stats-1
mudit2812 Sep 8, 2023
2f48b93
Added counts legacy tests
mudit2812 Sep 8, 2023
f0d7ff8
Merge branch 'master' into mcm-stats-1
mudit2812 Sep 8, 2023
929b463
Fixing sample legacy tests
mudit2812 Sep 8, 2023
a6ce9aa
Reverted Hamiltonian.compare; linting
mudit2812 Sep 11, 2023
de5f3b1
Removed legacy tests
mudit2812 Sep 11, 2023
ffca318
Reverted sample legacy tests
mudit2812 Sep 11, 2023
8b22ca0
Update test to measure 3 wires
mudit2812 Sep 11, 2023
d86b085
Merge branch 'master' into mcm-stats-1
mudit2812 Sep 11, 2023
5bf5061
Updated MP eigvals
mudit2812 Sep 12, 2023
a7d0ebb
reformatting
mudit2812 Sep 12, 2023
d3b3057
Merge branch 'master' into mcm-stats-1
mudit2812 Sep 12, 2023
031b2d2
Reverted counts changes
mudit2812 Sep 12, 2023
74e896e
Removed error supression
mudit2812 Sep 13, 2023
150304d
Merging
mudit2812 Sep 13, 2023
53d7db5
Merge branch 'master' into mcm-stats-1
mudit2812 Sep 13, 2023
55c3a38
Updated measurements eigvals test
mudit2812 Sep 13, 2023
bd64ea9
AmplitudeEmbedding inherits from StatePrep (#4583)
timmysilv Sep 13, 2023
ba35f2a
Add support for offset in `qml.MPS` template (#4531)
obliviateandsurrender Sep 13, 2023
4c14d70
Fix copying the Select template (#4551)
DSGuala Sep 13, 2023
788fa67
Various changes for DQ2 to work (#4534)
timmysilv Sep 14, 2023
4ef3913
Update `op_transforms` (#4573)
mudit2812 Sep 14, 2023
ac42ff9
Make the device API not-experimental (#4594)
timmysilv Sep 15, 2023
c9c9fe3
get op batch size at runtime to match DQL (#4592)
timmysilv Sep 15, 2023
115b7bd
Removed unnecessary checks
mudit2812 Sep 15, 2023
74a8f75
Merge branch 'master' into mcm-stats-1
mudit2812 Sep 15, 2023
6a31c34
Adding postselection support
mudit2812 Sep 15, 2023
a63c802
Fixed qnode
mudit2812 Sep 15, 2023
34cb028
Added MeasurementValue to DQL supported observables
mudit2812 Sep 18, 2023
33bfaa5
update measure.py; add changelog for eigvals
mudit2812 Sep 18, 2023
0c32223
Fixed qnode
mudit2812 Sep 15, 2023
e7b332b
Added MeasurementValue to DQL supported observables
mudit2812 Sep 18, 2023
67d9ede
update measure.py; add changelog for eigvals
mudit2812 Sep 18, 2023
1e8c0de
Updated error message
mudit2812 Sep 18, 2023
c0ef20b
Added tests
mudit2812 Sep 18, 2023
b93563b
Added changelog
mudit2812 Sep 18, 2023
e8087a2
Fixed error
mudit2812 Sep 19, 2023
9117c8a
Updated simulate; added docs
mudit2812 Sep 19, 2023
129b563
Merge branch 'master' into mcm-stats-1
mudit2812 Sep 19, 2023
1d12308
Added docs to measurements.rst
mudit2812 Sep 19, 2023
b3b490d
fix torch take with axis=-1 (#4605)
timmysilv Sep 18, 2023
8321065
[CI] Bump jax version to v0.4.16 (#4612)
albi3ro Sep 19, 2023
8e421f5
TmpPauliRot still decomposes if the theta of zero is trainable (#4585)
timmysilv Sep 19, 2023
bf4dbb2
Update shot_adaptive to not mutate device shots (#4599)
timmysilv Sep 19, 2023
d8d269f
Tests always enable jax float64 (#4613)
albi3ro Sep 19, 2023
de83a59
`process_state` assumes flat, `state()` is always complex (#4602)
timmysilv Sep 19, 2023
6eb3d57
Register ParameterizedEvolution (#4598)
lillian542 Sep 19, 2023
a7c1294
Added docs to measurements.rst
mudit2812 Sep 19, 2023
4bfcac5
Added docs for nan values
mudit2812 Sep 19, 2023
01011d6
Merge branch 'mcm-stats-1' into mcm-post
mudit2812 Sep 19, 2023
25b9eff
Fixed doc error
mudit2812 Sep 19, 2023
cdf53d4
Updated measurement process; updated tests
mudit2812 Sep 20, 2023
1b4a7ae
Update tests
mudit2812 Sep 20, 2023
d1ad001
Update doc/introduction/measurements.rst
mudit2812 Sep 20, 2023
f0baf87
Update doc/releases/changelog-dev.md
mudit2812 Sep 20, 2023
227899b
Update pennylane/tape/qscript.py
mudit2812 Sep 20, 2023
4728585
Apply suggestions from code review
mudit2812 Sep 20, 2023
b857dc0
Addressing PR review
mudit2812 Sep 20, 2023
275b355
Add prng_key kwarg to new device API (#4596)
lillian542 Sep 19, 2023
c36f08c
Fix `qnn.TorchLayer` attributes (#4611)
albi3ro Sep 19, 2023
c782f78
bump jax version for docs (#4618)
albi3ro Sep 20, 2023
6c56a61
No qnode post processing on informative transforms (#4616)
rmoyard Sep 20, 2023
569e1aa
Removed unused imports
mudit2812 Sep 20, 2023
b5d9a5f
`default.qubit` returns the new `DefaultQubit` device (#4436)
timmysilv Sep 20, 2023
0a1b8c9
[skip ci] Added qml.math.is_nan
mudit2812 Sep 21, 2023
9cf3588
Update pennylane/transforms/defer_measurements.py
mudit2812 Sep 21, 2023
57d77bd
Make Measurements Pytrees (#4607)
albi3ro Sep 20, 2023
a4ae4c0
Converter for MPS DMRG wavefunctions (#4523)
Chiffafox Sep 20, 2023
d0240d8
Converter for SHCI wavefunction (#4524)
Chiffafox Sep 20, 2023
534601b
Updated measurements pytree stuff
mudit2812 Sep 21, 2023
e856749
Added tape splitting to defer_measurements
mudit2812 Sep 21, 2023
7ee28f7
Updated measure.py
mudit2812 Sep 21, 2023
240fbe2
Update returns.rst (#4617)
trbromley Sep 21, 2023
09baad0
fix and enable pulse gradient with broadcasting on new device (#4620)
timmysilv Sep 21, 2023
b0a2b16
Register `QuantumScript` and `QuantumTape` as pytrees (#4608)
albi3ro Sep 21, 2023
8996b4c
Merge branch 'master' into mcm-post
mudit2812 Sep 22, 2023
83613e0
Merge branch 'master' into mcm-post
mudit2812 Oct 2, 2023
c6f48e7
Added nan measurement
mudit2812 Oct 2, 2023
b5ade01
Fixed simulate; added tests
mudit2812 Oct 3, 2023
8579cce
Fixed mp errors
mudit2812 Oct 3, 2023
1ebc58d
Updated postselection measurement
mudit2812 Oct 4, 2023
4243dc3
Merge branch 'master' into mcm-post
mudit2812 Oct 4, 2023
e862d49
Updated defer_measurements tests
mudit2812 Oct 5, 2023
c1c763b
Reverted validation for zero probability. Need to update tests
mudit2812 Oct 6, 2023
beec97e
Merge branch 'master' into mcm-post
mudit2812 Oct 6, 2023
b6279f0
Lots of changes
mudit2812 Oct 16, 2023
05a32ea
Merge branch 'master' into mcm-post
mudit2812 Oct 16, 2023
adecaed
Merge branch 'master' into mcm-post
mudit2812 Oct 17, 2023
ddb00d0
Updated docs; added sampling tests
mudit2812 Oct 17, 2023
03807ac
[skip ci] adding tests; defer_measurements bug
mudit2812 Oct 17, 2023
160a771
Removed testing with shot vectors
mudit2812 Oct 17, 2023
d336ef0
Updated tests; fixed docs
mudit2812 Oct 17, 2023
94c1d70
Merge branch 'master' into mcm-post
mudit2812 Oct 17, 2023
54be0ca
Fixing/adding tests
mudit2812 Oct 18, 2023
dbd2a58
Fixed docs; linting
mudit2812 Oct 18, 2023
675a5f7
Fixed qnode tests; linting
mudit2812 Oct 18, 2023
6045241
Added coverage
mudit2812 Oct 18, 2023
9e149fe
More tests
mudit2812 Oct 18, 2023
d73a56d
[skip ci] Updated measurement docs
mudit2812 Oct 19, 2023
4f49f2b
Addressed PR review; tweaked docs
mudit2812 Oct 19, 2023
4fe2e25
Updated tests
mudit2812 Oct 19, 2023
e069ed4
Merge branch 'master' into mcm-post
mudit2812 Oct 19, 2023
f90a889
[skip ci] Fixing docs
mudit2812 Oct 20, 2023
8034037
Merge branch 'master' into mcm-post
mudit2812 Oct 20, 2023
d287798
[skip ci] Added errors
mudit2812 Oct 20, 2023
5066c28
[skip ci] Removed test case with `qml.state()`
mudit2812 Oct 20, 2023
1c5724c
[skip ci] Added link to changelog
mudit2812 Oct 20, 2023
0c53d59
[skip ci] Reordered preprocessing steps
mudit2812 Oct 20, 2023
48f56bf
[skip ci] Apply suggestions from code review
mudit2812 Oct 23, 2023
877eeca
Merge branch 'master' into mcm-state
mudit2812 Oct 23, 2023
a4cf441
Merge branch 'master' into mcm-state
mudit2812 Oct 23, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
105 changes: 76 additions & 29 deletions doc/introduction/measurements.rst
Original file line number Diff line number Diff line change
Expand Up @@ -260,6 +260,9 @@ outcome of such mid-circuit measurements:
qml.cond(m_0, qml.RY)(y, wires=0)
return qml.probs(wires=[0])

Deferred measurements
*********************

A quantum function with mid-circuit measurements (defined using
:func:`~.pennylane.measure`) and conditional operations (defined using
:func:`~.pennylane.cond`) can be executed by applying the `deferred measurement
Expand All @@ -269,8 +272,12 @@ measurement on qubit 1 yielded ``1`` as an outcome, otherwise doing nothing
for the ``0`` measurement outcome.

PennyLane implements the deferred measurement principle to transform
conditional operations with the :func:`~.defer_measurements` quantum
function transform.
conditional operations with the :func:`~.pennylane.defer_measurements` quantum
function transform. The deferred measurement principle provides a natural method
to simulate the application of mid-circuit measurements and conditional operations
in a differentiable and device-independent way. Performing true mid-circuit
measurements and conditional operations is dependent on the quantum hardware and
PennyLane device capabilities.

.. code-block:: python

Expand All @@ -290,7 +297,35 @@ The decorator syntax applies equally well:
def qnode(x, y):
(...)

Note that we can also specify an outcome when defining a conditional operation:
Resetting wires
***************

Wires can be reused as normal after making mid-circuit measurements. Moreover, a measured wire can also be
reset to the :math:`|0 \rangle` state by setting the ``reset`` keyword argument of :func:`~.pennylane.measure`
to ``True``.

.. code-block:: python3

dev = qml.device("default.qubit", wires=3)

@qml.qnode(dev)
def func():
qml.PauliX(1)
m_0 = qml.measure(1, reset=True)
qml.PauliX(1)
return qml.probs(wires=[1])

Executing this QNode:

>>> func()
tensor([0., 1.], requires_grad=True)

Conditional operators
*********************

Users can create conditional operators controlled on mid-circuit measurements using
:func:`~.pennylane.cond`. We can also specify an outcome when defining a conditional
operation:

.. code-block:: python

Expand All @@ -309,30 +344,50 @@ Note that we can also specify an outcome when defining a conditional operation:
>>> qnode_conditional_op_on_zero(*pars)
tensor([0.88660045, 0.11339955], requires_grad=True)

Wires can be reused as normal after making mid-circuit measurements. Moreover, a measured wire can also be
reset to the :math:`|0 \rangle` state by setting the ``reset`` keyword argument of :func:`~.pennylane.measure`
to ``True``.
For more examples on applying quantum functions conditionally, refer to the
:func:`~.pennylane.cond` documentation.

Postselecting mid-circuit measurements
**************************************

PennyLane also supports postselecting on mid-circuit measurement outcomes by specifying the ``postselect``
keyword argument of :func:`~.pennylane.measure`. Postselection discards outcomes that do not meet the
criteria provided by the ``postselect`` argument. For example, specifying ``postselect=1`` on wire 0 would
be equivalent to projecting the state vector onto the :math:`|1\rangle` state on wire 0:

.. code-block:: python3

dev = qml.device("default.qubit", wires=3)
dev = qml.device("default.qubit")

@qml.qnode(dev)
def func():
qml.PauliX(1)
m_0 = qml.measure(1, reset=True)
qml.PauliX(1)
return qml.probs(wires=[1])
def func(x):
qml.RX(x, wires=0)
m0 = qml.measure(0, postselect=1)
qml.cond(m0, qml.PauliX)(wires=1)
return qml.sample(wires=1)

Executing this QNode:
By postselecting on ``1``, we only consider the ``1`` measurement outcome on wire 0. So, the probability of
measuring ``1`` on wire 1 after postselection should also be 1. Executing this QNode with 10 shots:

>>> func()
tensor([0., 1.], requires_grad=True)
>>> func(np.pi / 2, shots=10)
array([1, 1, 1, 1, 1, 1, 1])

Note that only 7 samples are returned. This is because samples that do not meet the postselection criteria are
discarded.

.. note::

Statistics can also be collected on mid-circuit measurements along with terminal measurement statistics.
Currently, postselection support is only available on :class:`~.pennylane.devices.DefaultQubit`. Using
postselection on other devices will raise an error.

Mid-circuit measurement statistics
**********************************

Statistics can be collected on mid-circuit measurements along with terminal measurement statistics.
Currently, ``qml.probs``, ``qml.sample``, ``qml.expval``, ``qml.var``, and ``qml.counts`` are supported,
and can be requested along with other measurements. The devices that currently support collecting such
statistics are ``"default.qubit"``, ``"default.mixed"``, and ``"default.qubit.legacy"``.
statistics are :class:`~.pennylane.devices.DefaultQubit`, :class:`~.pennylane.devices.DefaultMixed`, and
:class:`~.pennylane.devices.DefaultQubitLegacy`.

.. code-block:: python3

Expand All @@ -351,19 +406,11 @@ Executing this QNode:
(tensor([0.9267767, 0.0732233], requires_grad=True),
tensor([0.5, 0.5], requires_grad=True))

Currently, statistics can only be collected for single mid-circuit measurement values. Moreover, any
measurement values manipulated using boolean or arithmetic operators cannot be used. These can lead to
unexpected/incorrect behaviour.

The deferred measurement principle provides a natural method to simulate the
application of mid-circuit measurements and conditional operations in a
differentiable and device-independent way. Performing true mid-circuit
measurements and conditional operations is dependent on the
quantum hardware and PennyLane device capabilities.

For more examples on applying quantum functions conditionally, refer to the
:func:`~.pennylane.cond` transform.
.. warning::

Currently, statistics can only be collected for single mid-circuit measurement values. Moreover, any
measurement values manipulated using boolean or arithmetic operators cannot be used. These can lead to
unexpected/incorrect behaviour.

Changing the number of shots
----------------------------
Expand Down
30 changes: 28 additions & 2 deletions doc/releases/changelog-dev.md
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,30 @@
(array(0.6), array([1, 1, 1, 0, 1]))
```

* Users can now request postselection after making mid-circuit measurements. They can do so
by specifying the `postselect` keyword argument for `qml.measure` as either `0` or `1`,
corresponding to the basis states.
[(#4604)](https://github.com/PennyLaneAI/pennylane/pull/4604)

```python
dev = qml.device("default.qubit", wires=3)

@qml.qnode(dev)
def circuit(phi):
qml.RX(phi, wires=0)
m = qml.measure(0, postselect=1)
qml.cond(m, qml.PauliX)(wires=1)
return qml.probs(wires=1)
```
```pycon
>>> circuit(np.pi)
tensor([0., 1.], requires_grad=True)
```

Here, we measure a probability of one on wire 1 as we postselect on the $|1\rangle$ state on wire
0, thus resulting in the circuit being projected onto the state corresponding to the measurement
outcome $|1\rangle$ on wire 0.

* Operator transforms `qml.matrix`, `qml.eigvals`, `qml.generator`, and `qml.transforms.to_zx` are updated
to the new transform program system.
[(#4573)](https://github.com/PennyLaneAI/pennylane/pull/4573)
Expand Down Expand Up @@ -272,10 +296,12 @@
decomposition.
[(#4675)](https://github.com/PennyLaneAI/pennylane/pull/4675)



<h3>Breaking changes 💔</h3>

* ``qml.defer_measurements`` now raises an error if a transformed circuit measures ``qml.probs``,
``qml.sample``, or ``qml.counts`` without any wires or obsrvable, or if it measures ``qml.state``.
[(#4701)](https://github.com/PennyLaneAI/pennylane/pull/4701)

* The device test suite now converts device kwargs to integers or floats if they can be converted to integers or floats.
[(#4640)](https://github.com/PennyLaneAI/pennylane/pull/4640)

Expand Down
3 changes: 3 additions & 0 deletions pennylane/_device.py
Original file line number Diff line number Diff line change
Expand Up @@ -977,6 +977,9 @@ def check_validity(self, queue, observables):
"simulate the application of mid-circuit measurements on this device."
)

if isinstance(o, qml.Projector):
raise ValueError(f"Postselection is not supported on the {self.name} device.")

if not self.stopping_condition(o):
raise DeviceError(
f"Gate {operation_name} not supported on device {self.short_name}"
Expand Down
2 changes: 1 addition & 1 deletion pennylane/devices/default_qubit.py
Original file line number Diff line number Diff line change
Expand Up @@ -387,7 +387,7 @@ def preprocess(
config = self._setup_execution_config(execution_config)
transform_program = TransformProgram()

transform_program.add_transform(qml.defer_measurements)
transform_program.add_transform(qml.defer_measurements, device=self)
transform_program.add_transform(validate_device_wires, self.wires, name=self.name)
transform_program.add_transform(
decompose, stopping_condition=stopping_condition, name=self.name
Expand Down
47 changes: 29 additions & 18 deletions pennylane/devices/qubit/sampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,6 +179,7 @@ def measure_with_samples(
Returns:
List[TensorLike[Any]]: Sample measurement results
"""

groups, indices = _group_measurements(mps)

all_res = []
Expand Down Expand Up @@ -264,27 +265,37 @@ def _process_single_shot(samples):
# currently we call sample_state for each shot entry, but it may be
# better to call sample_state just once with total_shots, then use
# the shot_range keyword argument
samples = sample_state(
state,
shots=s,
is_state_batched=is_state_batched,
wires=wires,
rng=rng,
prng_key=prng_key,
)
try:
samples = sample_state(
state,
shots=s,
is_state_batched=is_state_batched,
wires=wires,
rng=rng,
prng_key=prng_key,
)
except ValueError as e:
if str(e) != "probabilities contain NaN":
raise e
samples = qml.math.full((s, len(wires)), 0)

processed_samples.append(_process_single_shot(samples))

return tuple(zip(*processed_samples))

samples = sample_state(
state,
shots=shots.total_shots,
is_state_batched=is_state_batched,
wires=wires,
rng=rng,
prng_key=prng_key,
)
try:
samples = sample_state(
state,
shots=shots.total_shots,
is_state_batched=is_state_batched,
wires=wires,
rng=rng,
prng_key=prng_key,
)
except ValueError as e:
if str(e) != "probabilities contain NaN":
raise e
samples = qml.math.full((shots.total_shots, len(wires)), 0)

return _process_single_shot(samples)

Expand Down Expand Up @@ -352,7 +363,7 @@ def _sum_for_single_shot(s):
)
return sum(c * res for c, res in zip(mp.obs.terms()[0], results))

unsqueezed_results = tuple(_sum_for_single_shot(Shots(s)) for s in shots)
unsqueezed_results = tuple(_sum_for_single_shot(type(shots)(s)) for s in shots)
return [unsqueezed_results] if shots.has_partitioned_shots else [unsqueezed_results[0]]


Expand Down Expand Up @@ -380,7 +391,7 @@ def _sum_for_single_shot(s):
)
return sum(results)

unsqueezed_results = tuple(_sum_for_single_shot(Shots(s)) for s in shots)
unsqueezed_results = tuple(_sum_for_single_shot(type(shots)(s)) for s in shots)
return [unsqueezed_results] if shots.has_partitioned_shots else [unsqueezed_results[0]]


Expand Down
55 changes: 55 additions & 0 deletions pennylane/devices/qubit/simulate.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
"""Simulate a quantum script."""
# pylint: disable=protected-access
from numpy.random import default_rng
import numpy as np

import pennylane as qml
from pennylane.typing import Result
Expand Down Expand Up @@ -46,6 +47,20 @@
}


class _FlexShots(qml.measurements.Shots):
"""Shots class that allows zero shots."""

# pylint: disable=super-init-not-called
def __init__(self, shots=None):
if isinstance(shots, int):
self.total_shots = shots
self.shot_vector = (qml.measurements.ShotCopies(shots, 1),)
else:
self.__all_tuple_init__([s if isinstance(s, tuple) else (s, 1) for s in shots])

self._frozen = True


def expand_state_over_wires(state, state_wires, all_wires, is_state_batched):
"""
Expand and re-order a state given some initial and target wire orders, setting
Expand Down Expand Up @@ -83,6 +98,39 @@ def expand_state_over_wires(state, state_wires, all_wires, is_state_batched):
return qml.math.transpose(state, desired_axes)


def _postselection_postprocess(state, is_state_batched, shots):
"""Update state after projector is applied."""
if is_state_batched:
raise ValueError(
"Cannot postselect on circuits with broadcasting. Use the "
"qml.transforms.broadcast_expand transform to split a broadcasted "
"tape into multiple non-broadcasted tapes before executing if "
"postselection is used."
)

# The floor function is being used here so that a norm very close to zero becomes exactly
# equal to zero so that the state can become invalid. This way, execution can continue, and
# bad postselection gives results that are invalid rather than results that look valid but
# are incorrect.
norm = qml.math.floor(qml.math.real(qml.math.norm(state)) * 1e15) * 1e-15

if shots:
# Clip the number of shots using a binomial distribution using the probability of
# measuring the postselected state.
postselected_shots = (
[np.random.binomial(s, float(norm)) for s in shots]
if not qml.math.is_abstract(norm)
else shots
)

# _FlexShots is used here since the binomial distribution could result in zero
# valid samples
shots = _FlexShots(postselected_shots)

state = state / qml.math.cast_like(norm, state)
return state, shots


def get_final_state(circuit, debugger=None, interface=None):
"""
Get the final state that results from executing the given quantum script.
Expand Down Expand Up @@ -112,6 +160,12 @@ def get_final_state(circuit, debugger=None, interface=None):
for op in circuit.operations[bool(prep) :]:
state = apply_operation(op, state, is_state_batched=is_state_batched, debugger=debugger)

# Handle postselection on mid-circuit measurements
if isinstance(op, qml.Projector):
state, circuit._shots = _postselection_postprocess(
state, is_state_batched, circuit.shots
)

# new state is batched if i) the old state is batched, or ii) the new op adds a batch dim
is_state_batched = is_state_batched or op.batch_size is not None

Expand Down Expand Up @@ -147,6 +201,7 @@ def measure_final_state(circuit, state, is_state_batched, rng=None, prng_key=Non
Returns:
Tuple[TensorLike]: The measurement results
"""

circuit = circuit.map_to_standard_wires()

if not circuit.shots:
Expand Down
Loading