diff --git a/doc/releases/changelog-dev.md b/doc/releases/changelog-dev.md index f62236c1f76..3c49a70f52d 100644 --- a/doc/releases/changelog-dev.md +++ b/doc/releases/changelog-dev.md @@ -232,6 +232,11 @@ through the `unitary_to_rot` optimization transform. [(#2015)](https://github.com/PennyLaneAI/pennylane/pull/2015) +* Fixes a bug which allows using `jax.jit` to be compatible with circuits + which return `qml.probs` when the `default.qubit.jax` is provided with a custom shot + vector. + [(#2028)](https://github.com/PennyLaneAI/pennylane/pull/2028) +

Documentation

* Fixes an error in the signs of equations in the `DoubleExcitation` page. @@ -245,4 +250,4 @@ This release contains contributions from (in alphabetical order): -Juan Miguel Arrazola, Ali Asadi, Esther Cruz, Olivia Di Matteo, Diego Guala, Ankit Khandelwal, Antal Száva, David Wierichs, Shaoming Zhang +Juan Miguel Arrazola, Ali Asadi, Esther Cruz, Olivia Di Matteo, Diego Guala, Ankit Khandelwal, Jay Soni, Antal Száva, David Wierichs, Shaoming Zhang diff --git a/pennylane/_qubit_device.py b/pennylane/_qubit_device.py index 0c2b7be4f43..f7f6a09e2b4 100644 --- a/pennylane/_qubit_device.py +++ b/pennylane/_qubit_device.py @@ -210,7 +210,11 @@ def execute(self, circuit, **kwargs): r = self.statistics( circuit.observables, shot_range=[s1, s2], bin_size=shot_tuple.shots ) - r = qml.math.squeeze(r) + + if qml.math._multi_dispatch(r) == "jax": # pylint: disable=protected-access + r = r[0] + else: + r = qml.math.squeeze(r) if shot_tuple.copies > 1: results.extend(r.T) diff --git a/pennylane/devices/default_qubit_jax.py b/pennylane/devices/default_qubit_jax.py index 1a218f077cf..933a5dadc4f 100644 --- a/pennylane/devices/default_qubit_jax.py +++ b/pennylane/devices/default_qubit_jax.py @@ -264,18 +264,40 @@ def estimate_probability(self, wires=None, shot_range=None, bin_size=None): indices = samples @ powers_of_two if bin_size is not None: - raise ValueError( - "The default.qubit.jax device doesn't support getting probabilities when using a shot vector." + bins = len(samples) // bin_size + + indices = indices.reshape((bins, -1)) + prob = np.zeros( + [2 ** num_wires + 1, bins], dtype=jnp.float64 + ) # extend it to store 'filled values' + prob = qml.math.convert_like(prob, indices) + + # count the basis state occurrences, and construct the probability vector + for b, idx in enumerate(indices): + idx = qml.math.convert_like(idx, indices) + basis_states, counts = qml.math.unique( + idx, return_counts=True, size=2 ** num_wires, fill_value=-1 + ) + + for state, count in zip(basis_states, counts): + prob = prob.at[state, b].set(count / bin_size) + + prob = jnp.resize( + prob, (2 ** num_wires, bins) + ) # resize prob which discards the 'filled values' + + else: + basis_states, counts = qml.math.unique( + indices, return_counts=True, size=2 ** num_wires, fill_value=-1 ) + prob = np.zeros([2 ** num_wires + 1], dtype=jnp.float64) + prob = qml.math.convert_like(prob, indices) - basis_states, counts = qml.math.unique( - indices, return_counts=True, size=2 ** num_wires, fill_value=-1 - ) - prob = np.zeros([2 ** num_wires + 1], dtype=np.float64) - prob = qml.math.convert_like(prob, indices) + for state, count in zip(basis_states, counts): + prob = prob.at[state].set(count / len(samples)) - for state, count in zip(basis_states, counts): - prob = prob.at[state].set(count / len(samples)) + prob = jnp.resize( + prob, 2 ** num_wires + ) # resize prob which discards the 'filled values' - prob = jnp.resize(prob, 2 ** num_wires) # resize prob which discards the 'filled values' return self._asarray(prob, dtype=self.R_DTYPE) diff --git a/tests/devices/test_default_qubit_jax.py b/tests/devices/test_default_qubit_jax.py index 4c2221268d3..48c308d3e33 100644 --- a/tests/devices/test_default_qubit_jax.py +++ b/tests/devices/test_default_qubit_jax.py @@ -192,10 +192,10 @@ def circuit(): result = circuit() assert jnp.allclose(result, expected, atol=tol) - def test_probs_jax_jit(self, tol): - """Test that returning probs works with jax and jit""" + def test_custom_shots_probs_jax_jit(self, tol): + """Test that returning probs works with jax and jit when using custom shot vector""" dev = qml.device("default.qubit.jax", wires=1, shots=(2, 2)) - expected = jnp.array([0.0, 1.0]) + expected = jnp.array([[0.0, 1.0], [0.0, 1.0]]) @jax.jit @qml.qnode(dev, interface="jax") @@ -203,10 +203,8 @@ def circuit(): qml.PauliX(wires=0) return qml.probs() - with pytest.raises( - ValueError, match="doesn't support getting probabilities when using a shot vector" - ): - result = circuit() + result = circuit() + assert jnp.allclose(result, expected, atol=tol) def test_sampling_with_jit(self): """Test that sampling works with a jax.jit"""