diff --git a/doc/releases/changelog-dev.md b/doc/releases/changelog-dev.md index 291e91c1780..363d9493818 100644 --- a/doc/releases/changelog-dev.md +++ b/doc/releases/changelog-dev.md @@ -92,6 +92,9 @@

Bug fixes 🐛

+* Finite shot circuits with a `qml.probs` measurement, both with a `wires` or `op` argument, can now be compiled with `jax.jit`. + [(#5619)](https://github.com/PennyLaneAI/pennylane/pull/5619) + * `param_shift`, `finite_diff`, `compile`, `merge_rotations`, and `transpile` now all work with circuits with non-commuting measurements. [(#5424)](https://github.com/PennyLaneAI/pennylane/pull/5424) diff --git a/pennylane/measurements/probs.py b/pennylane/measurements/probs.py index 5163daabbc8..a0ceb468f80 100644 --- a/pennylane/measurements/probs.py +++ b/pennylane/measurements/probs.py @@ -259,23 +259,33 @@ def _count_samples(indices, batch_size, dim): """Count the occurrences of sampled indices and convert them to relative counts in order to estimate their occurrence probability.""" num_bins, bin_size = indices.shape[-2:] - if batch_size is None: - prob = qml.math.zeros((dim, num_bins), dtype="float64") - # count the basis state occurrences, and construct the probability vector for each bin - for b, idx in enumerate(indices): - basis_states, counts = qml.math.unique(idx, return_counts=True) - prob[basis_states, b] = counts / bin_size + interface = qml.math.get_deep_interface(indices) - return prob + if qml.math.is_abstract(indices): - prob = qml.math.zeros((batch_size, dim, num_bins), dtype="float64") - indices = indices.reshape((batch_size, num_bins, bin_size)) + def _count_samples_core(indices, dim, interface): + return qml.math.array( + [[qml.math.sum(idx == p) for idx in indices] for p in range(dim)], + like=interface, + ) + + else: + + def _count_samples_core(indices, dim, *_): + probabilities = qml.math.zeros((dim, num_bins), dtype="float64") + for b, idx in enumerate(indices): + basis_states, counts = qml.math.unique(idx, return_counts=True) + probabilities[basis_states, b] = counts + return probabilities + + if batch_size is None: + return _count_samples_core(indices, dim, interface) / bin_size # count the basis state occurrences, and construct the probability vector # for each bin and broadcasting index - for i, _indices in enumerate(indices): # First iterate over broadcasting dimension - for b, idx in enumerate(_indices): # Then iterate over bins dimension - basis_states, counts = qml.math.unique(idx, return_counts=True) - prob[i, basis_states, b] = counts / bin_size - - return prob + indices = indices.reshape((batch_size, num_bins, bin_size)) + probabilities = qml.math.array( + [_count_samples_core(_indices, dim, interface) for _indices in indices], + like=interface, + ) + return probabilities / bin_size diff --git a/tests/devices/default_qubit/test_default_qubit_native_mcm.py b/tests/devices/default_qubit/test_default_qubit_native_mcm.py index 1aff8c1f0b6..691dcf3f69b 100644 --- a/tests/devices/default_qubit/test_default_qubit_native_mcm.py +++ b/tests/devices/default_qubit/test_default_qubit_native_mcm.py @@ -667,11 +667,12 @@ def test_jax_jit(diff_method, postselect, reset): def func(x, y, z): m0, m1 = obs_tape(x, y, z, reset=reset, postselect=postselect) return ( - # qml.probs(wires=[1]), # JAX cannot compile code calling qml.math.unique + qml.probs(wires=[1]), + qml.probs(wires=[0, 1]), qml.sample(wires=[1]), qml.sample(wires=[0, 1]), qml.expval(obs), - # qml.probs(obs), # JAX cannot compile code calling qml.math.unique + qml.probs(obs), qml.sample(obs), qml.var(obs), qml.expval(op=m0 + 2 * m1), @@ -695,11 +696,12 @@ def func(x, y, z): results2 = func2(*jax.numpy.array(params)) measures = [ - # qml.probs, + qml.probs, + qml.probs, qml.sample, qml.sample, qml.expval, - # qml.probs, + qml.probs, qml.sample, qml.var, qml.expval, diff --git a/tests/measurements/test_probs.py b/tests/measurements/test_probs.py index d666bf469fb..014465652dc 100644 --- a/tests/measurements/test_probs.py +++ b/tests/measurements/test_probs.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. """Unit tests for the probs module""" +from typing import Sequence + import numpy as np import pytest @@ -226,6 +228,36 @@ def circuit(): expected = np.array([0.5, 0.5, 0, 0]) assert np.allclose(res, expected, atol=tol, rtol=0) + @pytest.mark.jax + @pytest.mark.parametrize("shots", (None, 500)) + @pytest.mark.parametrize("obs", ([0, 1], qml.PauliZ(0) @ qml.PauliZ(1))) + @pytest.mark.parametrize("params", ([np.pi / 2], [np.pi / 2, np.pi / 2, np.pi / 2])) + def test_integration_jax(self, tol_stochastic, shots, obs, params): + """Test the probability is correct for a known state preparation when jitted with JAX.""" + jax = pytest.importorskip("jax") + + dev = qml.device("default.qubit", wires=2, shots=shots, seed=jax.random.PRNGKey(0)) + params = jax.numpy.array(params) + + @qml.qnode(dev, diff_method=None) + def circuit(x): + qml.PhaseShift(x, wires=1) + qml.RX(x, wires=1) + qml.PhaseShift(x, wires=1) + qml.CNOT(wires=[0, 1]) + if isinstance(obs, Sequence): + return qml.probs(wires=obs) + return qml.probs(op=obs) + + # expected probability, using [00, 01, 10, 11] + # ordering, is [0.5, 0.5, 0, 0] + + assert "pure_callback" not in str(jax.make_jaxpr(circuit)(params)) + + res = jax.jit(circuit)(params) + expected = np.array([0.5, 0.5, 0, 0]) + assert np.allclose(res, expected, atol=tol_stochastic, rtol=0) + @pytest.mark.parametrize("shots", [100, [1, 10, 100]]) def test_integration_analytic_false(self, tol, shots): """Test the probability is correct for a known state preparation when the