Skip to content

Commit

Permalink
jax.jit works with qml.probs and finite shots (#5619)
Browse files Browse the repository at this point in the history
### Before submitting

Please complete the following checklist when submitting a PR:

- [x] All new features must include a unit test.
If you've fixed a bug or added code that should be tested, add a test to
the
      test directory!

- [x] All new functions and code must be clearly commented and
documented.
If you do make documentation changes, make sure that the docs build and
      render correctly by running `make docs`.

- [x] Ensure that the test suite passes, by running `make test`.

- [x] Add a new entry to the `doc/releases/changelog-dev.md` file,
summarizing the
      change, and including a link back to the PR.

- [x] The PennyLane source code conforms to
      [PEP8 standards](https://www.python.org/dev/peps/pep-0008/).
We check all of our code against [Pylint](https://www.pylint.org/).
      To lint modified files, simply `pip install pylint`, and then
      run `pylint pennylane/path/to/file.py`.

When all the above are checked, delete everything above the dashed
line and fill in the pull request template.


------------------------------------------------------------------------------------------------------------

**Context:**
`qml.math.unique` cannot be jitted with JAX because the output shape of
`unique` depends on the input values.

**Description of the Change:**
Use an input agnostic way to update the probs.

**Benefits:**
`qml.probs` can be jitted (without Python callbacks).

**Possible Drawbacks:**

**Related GitHub Issues:**

---------

Co-authored-by: Mudit Pandey <mudit.pandey@xanadu.ai>
Co-authored-by: Christina Lee <christina@xanadu.ai>
Co-authored-by: David Wierichs <david.wierichs@xanadu.ai>
  • Loading branch information
4 people committed May 10, 2024
1 parent ebcb29f commit 1d34de9
Show file tree
Hide file tree
Showing 4 changed files with 66 additions and 19 deletions.
3 changes: 3 additions & 0 deletions doc/releases/changelog-dev.md
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,9 @@

<h3>Bug fixes 🐛</h3>

* Finite shot circuits with a `qml.probs` measurement, both with a `wires` or `op` argument, can now be compiled with `jax.jit`.
[(#5619)](https://github.com/PennyLaneAI/pennylane/pull/5619)

* `param_shift`, `finite_diff`, `compile`, `merge_rotations`, and `transpile` now all work
with circuits with non-commuting measurements.
[(#5424)](https://github.com/PennyLaneAI/pennylane/pull/5424)
Expand Down
40 changes: 25 additions & 15 deletions pennylane/measurements/probs.py
Original file line number Diff line number Diff line change
Expand Up @@ -259,23 +259,33 @@ def _count_samples(indices, batch_size, dim):
"""Count the occurrences of sampled indices and convert them to relative
counts in order to estimate their occurrence probability."""
num_bins, bin_size = indices.shape[-2:]
if batch_size is None:
prob = qml.math.zeros((dim, num_bins), dtype="float64")
# count the basis state occurrences, and construct the probability vector for each bin
for b, idx in enumerate(indices):
basis_states, counts = qml.math.unique(idx, return_counts=True)
prob[basis_states, b] = counts / bin_size
interface = qml.math.get_deep_interface(indices)

return prob
if qml.math.is_abstract(indices):

prob = qml.math.zeros((batch_size, dim, num_bins), dtype="float64")
indices = indices.reshape((batch_size, num_bins, bin_size))
def _count_samples_core(indices, dim, interface):
return qml.math.array(
[[qml.math.sum(idx == p) for idx in indices] for p in range(dim)],
like=interface,
)

else:

def _count_samples_core(indices, dim, *_):
probabilities = qml.math.zeros((dim, num_bins), dtype="float64")
for b, idx in enumerate(indices):
basis_states, counts = qml.math.unique(idx, return_counts=True)
probabilities[basis_states, b] = counts
return probabilities

if batch_size is None:
return _count_samples_core(indices, dim, interface) / bin_size

# count the basis state occurrences, and construct the probability vector
# for each bin and broadcasting index
for i, _indices in enumerate(indices): # First iterate over broadcasting dimension
for b, idx in enumerate(_indices): # Then iterate over bins dimension
basis_states, counts = qml.math.unique(idx, return_counts=True)
prob[i, basis_states, b] = counts / bin_size

return prob
indices = indices.reshape((batch_size, num_bins, bin_size))
probabilities = qml.math.array(
[_count_samples_core(_indices, dim, interface) for _indices in indices],
like=interface,
)
return probabilities / bin_size
10 changes: 6 additions & 4 deletions tests/devices/default_qubit/test_default_qubit_native_mcm.py
Original file line number Diff line number Diff line change
Expand Up @@ -667,11 +667,12 @@ def test_jax_jit(diff_method, postselect, reset):
def func(x, y, z):
m0, m1 = obs_tape(x, y, z, reset=reset, postselect=postselect)
return (
# qml.probs(wires=[1]), # JAX cannot compile code calling qml.math.unique
qml.probs(wires=[1]),
qml.probs(wires=[0, 1]),
qml.sample(wires=[1]),
qml.sample(wires=[0, 1]),
qml.expval(obs),
# qml.probs(obs), # JAX cannot compile code calling qml.math.unique
qml.probs(obs),
qml.sample(obs),
qml.var(obs),
qml.expval(op=m0 + 2 * m1),
Expand All @@ -695,11 +696,12 @@ def func(x, y, z):
results2 = func2(*jax.numpy.array(params))

measures = [
# qml.probs,
qml.probs,
qml.probs,
qml.sample,
qml.sample,
qml.expval,
# qml.probs,
qml.probs,
qml.sample,
qml.var,
qml.expval,
Expand Down
32 changes: 32 additions & 0 deletions tests/measurements/test_probs.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""Unit tests for the probs module"""
from typing import Sequence

import numpy as np
import pytest

Expand Down Expand Up @@ -226,6 +228,36 @@ def circuit():
expected = np.array([0.5, 0.5, 0, 0])
assert np.allclose(res, expected, atol=tol, rtol=0)

@pytest.mark.jax
@pytest.mark.parametrize("shots", (None, 500))
@pytest.mark.parametrize("obs", ([0, 1], qml.PauliZ(0) @ qml.PauliZ(1)))
@pytest.mark.parametrize("params", ([np.pi / 2], [np.pi / 2, np.pi / 2, np.pi / 2]))
def test_integration_jax(self, tol_stochastic, shots, obs, params):
"""Test the probability is correct for a known state preparation when jitted with JAX."""
jax = pytest.importorskip("jax")

dev = qml.device("default.qubit", wires=2, shots=shots, seed=jax.random.PRNGKey(0))
params = jax.numpy.array(params)

@qml.qnode(dev, diff_method=None)
def circuit(x):
qml.PhaseShift(x, wires=1)
qml.RX(x, wires=1)
qml.PhaseShift(x, wires=1)
qml.CNOT(wires=[0, 1])
if isinstance(obs, Sequence):
return qml.probs(wires=obs)
return qml.probs(op=obs)

# expected probability, using [00, 01, 10, 11]
# ordering, is [0.5, 0.5, 0, 0]

assert "pure_callback" not in str(jax.make_jaxpr(circuit)(params))

res = jax.jit(circuit)(params)
expected = np.array([0.5, 0.5, 0, 0])
assert np.allclose(res, expected, atol=tol_stochastic, rtol=0)

@pytest.mark.parametrize("shots", [100, [1, 10, 100]])
def test_integration_analytic_false(self, tol, shots):
"""Test the probability is correct for a known state preparation when the
Expand Down

0 comments on commit 1d34de9

Please sign in to comment.