diff --git a/doc/releases/changelog-dev.md b/doc/releases/changelog-dev.md
index f62236c1f76..3c49a70f52d 100644
--- a/doc/releases/changelog-dev.md
+++ b/doc/releases/changelog-dev.md
@@ -232,6 +232,11 @@
through the `unitary_to_rot` optimization transform.
[(#2015)](https://github.com/PennyLaneAI/pennylane/pull/2015)
+* Fixes a bug which allows using `jax.jit` to be compatible with circuits
+ which return `qml.probs` when the `default.qubit.jax` is provided with a custom shot
+ vector.
+ [(#2028)](https://github.com/PennyLaneAI/pennylane/pull/2028)
+
Documentation
* Fixes an error in the signs of equations in the `DoubleExcitation` page.
@@ -245,4 +250,4 @@
This release contains contributions from (in alphabetical order):
-Juan Miguel Arrazola, Ali Asadi, Esther Cruz, Olivia Di Matteo, Diego Guala, Ankit Khandelwal, Antal Száva, David Wierichs, Shaoming Zhang
+Juan Miguel Arrazola, Ali Asadi, Esther Cruz, Olivia Di Matteo, Diego Guala, Ankit Khandelwal, Jay Soni, Antal Száva, David Wierichs, Shaoming Zhang
diff --git a/pennylane/_qubit_device.py b/pennylane/_qubit_device.py
index 0c2b7be4f43..f7f6a09e2b4 100644
--- a/pennylane/_qubit_device.py
+++ b/pennylane/_qubit_device.py
@@ -210,7 +210,11 @@ def execute(self, circuit, **kwargs):
r = self.statistics(
circuit.observables, shot_range=[s1, s2], bin_size=shot_tuple.shots
)
- r = qml.math.squeeze(r)
+
+ if qml.math._multi_dispatch(r) == "jax": # pylint: disable=protected-access
+ r = r[0]
+ else:
+ r = qml.math.squeeze(r)
if shot_tuple.copies > 1:
results.extend(r.T)
diff --git a/pennylane/devices/default_qubit_jax.py b/pennylane/devices/default_qubit_jax.py
index 1a218f077cf..933a5dadc4f 100644
--- a/pennylane/devices/default_qubit_jax.py
+++ b/pennylane/devices/default_qubit_jax.py
@@ -264,18 +264,40 @@ def estimate_probability(self, wires=None, shot_range=None, bin_size=None):
indices = samples @ powers_of_two
if bin_size is not None:
- raise ValueError(
- "The default.qubit.jax device doesn't support getting probabilities when using a shot vector."
+ bins = len(samples) // bin_size
+
+ indices = indices.reshape((bins, -1))
+ prob = np.zeros(
+ [2 ** num_wires + 1, bins], dtype=jnp.float64
+ ) # extend it to store 'filled values'
+ prob = qml.math.convert_like(prob, indices)
+
+ # count the basis state occurrences, and construct the probability vector
+ for b, idx in enumerate(indices):
+ idx = qml.math.convert_like(idx, indices)
+ basis_states, counts = qml.math.unique(
+ idx, return_counts=True, size=2 ** num_wires, fill_value=-1
+ )
+
+ for state, count in zip(basis_states, counts):
+ prob = prob.at[state, b].set(count / bin_size)
+
+ prob = jnp.resize(
+ prob, (2 ** num_wires, bins)
+ ) # resize prob which discards the 'filled values'
+
+ else:
+ basis_states, counts = qml.math.unique(
+ indices, return_counts=True, size=2 ** num_wires, fill_value=-1
)
+ prob = np.zeros([2 ** num_wires + 1], dtype=jnp.float64)
+ prob = qml.math.convert_like(prob, indices)
- basis_states, counts = qml.math.unique(
- indices, return_counts=True, size=2 ** num_wires, fill_value=-1
- )
- prob = np.zeros([2 ** num_wires + 1], dtype=np.float64)
- prob = qml.math.convert_like(prob, indices)
+ for state, count in zip(basis_states, counts):
+ prob = prob.at[state].set(count / len(samples))
- for state, count in zip(basis_states, counts):
- prob = prob.at[state].set(count / len(samples))
+ prob = jnp.resize(
+ prob, 2 ** num_wires
+ ) # resize prob which discards the 'filled values'
- prob = jnp.resize(prob, 2 ** num_wires) # resize prob which discards the 'filled values'
return self._asarray(prob, dtype=self.R_DTYPE)
diff --git a/tests/devices/test_default_qubit_jax.py b/tests/devices/test_default_qubit_jax.py
index 4c2221268d3..48c308d3e33 100644
--- a/tests/devices/test_default_qubit_jax.py
+++ b/tests/devices/test_default_qubit_jax.py
@@ -192,10 +192,10 @@ def circuit():
result = circuit()
assert jnp.allclose(result, expected, atol=tol)
- def test_probs_jax_jit(self, tol):
- """Test that returning probs works with jax and jit"""
+ def test_custom_shots_probs_jax_jit(self, tol):
+ """Test that returning probs works with jax and jit when using custom shot vector"""
dev = qml.device("default.qubit.jax", wires=1, shots=(2, 2))
- expected = jnp.array([0.0, 1.0])
+ expected = jnp.array([[0.0, 1.0], [0.0, 1.0]])
@jax.jit
@qml.qnode(dev, interface="jax")
@@ -203,10 +203,8 @@ def circuit():
qml.PauliX(wires=0)
return qml.probs()
- with pytest.raises(
- ValueError, match="doesn't support getting probabilities when using a shot vector"
- ):
- result = circuit()
+ result = circuit()
+ assert jnp.allclose(result, expected, atol=tol)
def test_sampling_with_jit(self):
"""Test that sampling works with a jax.jit"""