diff --git a/doc/releases/changelog-dev.md b/doc/releases/changelog-dev.md
index 291e91c1780..363d9493818 100644
--- a/doc/releases/changelog-dev.md
+++ b/doc/releases/changelog-dev.md
@@ -92,6 +92,9 @@
Bug fixes 🐛
+* Finite shot circuits with a `qml.probs` measurement, both with a `wires` or `op` argument, can now be compiled with `jax.jit`.
+ [(#5619)](https://github.com/PennyLaneAI/pennylane/pull/5619)
+
* `param_shift`, `finite_diff`, `compile`, `merge_rotations`, and `transpile` now all work
with circuits with non-commuting measurements.
[(#5424)](https://github.com/PennyLaneAI/pennylane/pull/5424)
diff --git a/pennylane/measurements/probs.py b/pennylane/measurements/probs.py
index 5163daabbc8..a0ceb468f80 100644
--- a/pennylane/measurements/probs.py
+++ b/pennylane/measurements/probs.py
@@ -259,23 +259,33 @@ def _count_samples(indices, batch_size, dim):
"""Count the occurrences of sampled indices and convert them to relative
counts in order to estimate their occurrence probability."""
num_bins, bin_size = indices.shape[-2:]
- if batch_size is None:
- prob = qml.math.zeros((dim, num_bins), dtype="float64")
- # count the basis state occurrences, and construct the probability vector for each bin
- for b, idx in enumerate(indices):
- basis_states, counts = qml.math.unique(idx, return_counts=True)
- prob[basis_states, b] = counts / bin_size
+ interface = qml.math.get_deep_interface(indices)
- return prob
+ if qml.math.is_abstract(indices):
- prob = qml.math.zeros((batch_size, dim, num_bins), dtype="float64")
- indices = indices.reshape((batch_size, num_bins, bin_size))
+ def _count_samples_core(indices, dim, interface):
+ return qml.math.array(
+ [[qml.math.sum(idx == p) for idx in indices] for p in range(dim)],
+ like=interface,
+ )
+
+ else:
+
+ def _count_samples_core(indices, dim, *_):
+ probabilities = qml.math.zeros((dim, num_bins), dtype="float64")
+ for b, idx in enumerate(indices):
+ basis_states, counts = qml.math.unique(idx, return_counts=True)
+ probabilities[basis_states, b] = counts
+ return probabilities
+
+ if batch_size is None:
+ return _count_samples_core(indices, dim, interface) / bin_size
# count the basis state occurrences, and construct the probability vector
# for each bin and broadcasting index
- for i, _indices in enumerate(indices): # First iterate over broadcasting dimension
- for b, idx in enumerate(_indices): # Then iterate over bins dimension
- basis_states, counts = qml.math.unique(idx, return_counts=True)
- prob[i, basis_states, b] = counts / bin_size
-
- return prob
+ indices = indices.reshape((batch_size, num_bins, bin_size))
+ probabilities = qml.math.array(
+ [_count_samples_core(_indices, dim, interface) for _indices in indices],
+ like=interface,
+ )
+ return probabilities / bin_size
diff --git a/tests/devices/default_qubit/test_default_qubit_native_mcm.py b/tests/devices/default_qubit/test_default_qubit_native_mcm.py
index 1aff8c1f0b6..691dcf3f69b 100644
--- a/tests/devices/default_qubit/test_default_qubit_native_mcm.py
+++ b/tests/devices/default_qubit/test_default_qubit_native_mcm.py
@@ -667,11 +667,12 @@ def test_jax_jit(diff_method, postselect, reset):
def func(x, y, z):
m0, m1 = obs_tape(x, y, z, reset=reset, postselect=postselect)
return (
- # qml.probs(wires=[1]), # JAX cannot compile code calling qml.math.unique
+ qml.probs(wires=[1]),
+ qml.probs(wires=[0, 1]),
qml.sample(wires=[1]),
qml.sample(wires=[0, 1]),
qml.expval(obs),
- # qml.probs(obs), # JAX cannot compile code calling qml.math.unique
+ qml.probs(obs),
qml.sample(obs),
qml.var(obs),
qml.expval(op=m0 + 2 * m1),
@@ -695,11 +696,12 @@ def func(x, y, z):
results2 = func2(*jax.numpy.array(params))
measures = [
- # qml.probs,
+ qml.probs,
+ qml.probs,
qml.sample,
qml.sample,
qml.expval,
- # qml.probs,
+ qml.probs,
qml.sample,
qml.var,
qml.expval,
diff --git a/tests/measurements/test_probs.py b/tests/measurements/test_probs.py
index d666bf469fb..014465652dc 100644
--- a/tests/measurements/test_probs.py
+++ b/tests/measurements/test_probs.py
@@ -12,6 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""Unit tests for the probs module"""
+from typing import Sequence
+
import numpy as np
import pytest
@@ -226,6 +228,36 @@ def circuit():
expected = np.array([0.5, 0.5, 0, 0])
assert np.allclose(res, expected, atol=tol, rtol=0)
+ @pytest.mark.jax
+ @pytest.mark.parametrize("shots", (None, 500))
+ @pytest.mark.parametrize("obs", ([0, 1], qml.PauliZ(0) @ qml.PauliZ(1)))
+ @pytest.mark.parametrize("params", ([np.pi / 2], [np.pi / 2, np.pi / 2, np.pi / 2]))
+ def test_integration_jax(self, tol_stochastic, shots, obs, params):
+ """Test the probability is correct for a known state preparation when jitted with JAX."""
+ jax = pytest.importorskip("jax")
+
+ dev = qml.device("default.qubit", wires=2, shots=shots, seed=jax.random.PRNGKey(0))
+ params = jax.numpy.array(params)
+
+ @qml.qnode(dev, diff_method=None)
+ def circuit(x):
+ qml.PhaseShift(x, wires=1)
+ qml.RX(x, wires=1)
+ qml.PhaseShift(x, wires=1)
+ qml.CNOT(wires=[0, 1])
+ if isinstance(obs, Sequence):
+ return qml.probs(wires=obs)
+ return qml.probs(op=obs)
+
+ # expected probability, using [00, 01, 10, 11]
+ # ordering, is [0.5, 0.5, 0, 0]
+
+ assert "pure_callback" not in str(jax.make_jaxpr(circuit)(params))
+
+ res = jax.jit(circuit)(params)
+ expected = np.array([0.5, 0.5, 0, 0])
+ assert np.allclose(res, expected, atol=tol_stochastic, rtol=0)
+
@pytest.mark.parametrize("shots", [100, [1, 10, 100]])
def test_integration_analytic_false(self, tol, shots):
"""Test the probability is correct for a known state preparation when the