Skip to content

Commit

Permalink
Fixed bug with adding a shot vector (#2028)
Browse files Browse the repository at this point in the history
* Fixed bug with adding a shot vector

* changelog

* temp commit, while I work on something else

* fixed?

* lint

* added check for list type

* minor type

* using try-except instead to deal with lst wrapper issue

* fixed jax import

* Update doc/releases/changelog-dev.md

Co-authored-by: antalszava <antalszava@gmail.com>

* moved logic to qml.math

* New fix

* clean up

* try this?

* Apply suggestions from code review

* Update pennylane/_qubit_device.py

* lint to fix formatting

Co-authored-by: antalszava <antalszava@gmail.com>
  • Loading branch information
Jaybsoni and antalszava committed Jan 10, 2022
1 parent 208095a commit c9211f8
Show file tree
Hide file tree
Showing 4 changed files with 48 additions and 19 deletions.
7 changes: 6 additions & 1 deletion doc/releases/changelog-dev.md
Original file line number Diff line number Diff line change
Expand Up @@ -232,6 +232,11 @@
through the `unitary_to_rot` optimization transform.
[(#2015)](https://github.com/PennyLaneAI/pennylane/pull/2015)

* Fixes a bug which allows using `jax.jit` to be compatible with circuits
which return `qml.probs` when the `default.qubit.jax` is provided with a custom shot
vector.
[(#2028)](https://github.com/PennyLaneAI/pennylane/pull/2028)

<h3>Documentation</h3>

* Fixes an error in the signs of equations in the `DoubleExcitation` page.
Expand All @@ -245,4 +250,4 @@

This release contains contributions from (in alphabetical order):

Juan Miguel Arrazola, Ali Asadi, Esther Cruz, Olivia Di Matteo, Diego Guala, Ankit Khandelwal, Antal Száva, David Wierichs, Shaoming Zhang
Juan Miguel Arrazola, Ali Asadi, Esther Cruz, Olivia Di Matteo, Diego Guala, Ankit Khandelwal, Jay Soni, Antal Száva, David Wierichs, Shaoming Zhang
6 changes: 5 additions & 1 deletion pennylane/_qubit_device.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,7 +210,11 @@ def execute(self, circuit, **kwargs):
r = self.statistics(
circuit.observables, shot_range=[s1, s2], bin_size=shot_tuple.shots
)
r = qml.math.squeeze(r)

if qml.math._multi_dispatch(r) == "jax": # pylint: disable=protected-access
r = r[0]
else:
r = qml.math.squeeze(r)

if shot_tuple.copies > 1:
results.extend(r.T)
Expand Down
42 changes: 32 additions & 10 deletions pennylane/devices/default_qubit_jax.py
Original file line number Diff line number Diff line change
Expand Up @@ -264,18 +264,40 @@ def estimate_probability(self, wires=None, shot_range=None, bin_size=None):
indices = samples @ powers_of_two

if bin_size is not None:
raise ValueError(
"The default.qubit.jax device doesn't support getting probabilities when using a shot vector."
bins = len(samples) // bin_size

indices = indices.reshape((bins, -1))
prob = np.zeros(
[2 ** num_wires + 1, bins], dtype=jnp.float64
) # extend it to store 'filled values'
prob = qml.math.convert_like(prob, indices)

# count the basis state occurrences, and construct the probability vector
for b, idx in enumerate(indices):
idx = qml.math.convert_like(idx, indices)
basis_states, counts = qml.math.unique(
idx, return_counts=True, size=2 ** num_wires, fill_value=-1
)

for state, count in zip(basis_states, counts):
prob = prob.at[state, b].set(count / bin_size)

prob = jnp.resize(
prob, (2 ** num_wires, bins)
) # resize prob which discards the 'filled values'

else:
basis_states, counts = qml.math.unique(
indices, return_counts=True, size=2 ** num_wires, fill_value=-1
)
prob = np.zeros([2 ** num_wires + 1], dtype=jnp.float64)
prob = qml.math.convert_like(prob, indices)

basis_states, counts = qml.math.unique(
indices, return_counts=True, size=2 ** num_wires, fill_value=-1
)
prob = np.zeros([2 ** num_wires + 1], dtype=np.float64)
prob = qml.math.convert_like(prob, indices)
for state, count in zip(basis_states, counts):
prob = prob.at[state].set(count / len(samples))

for state, count in zip(basis_states, counts):
prob = prob.at[state].set(count / len(samples))
prob = jnp.resize(
prob, 2 ** num_wires
) # resize prob which discards the 'filled values'

prob = jnp.resize(prob, 2 ** num_wires) # resize prob which discards the 'filled values'
return self._asarray(prob, dtype=self.R_DTYPE)
12 changes: 5 additions & 7 deletions tests/devices/test_default_qubit_jax.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,21 +192,19 @@ def circuit():
result = circuit()
assert jnp.allclose(result, expected, atol=tol)

def test_probs_jax_jit(self, tol):
"""Test that returning probs works with jax and jit"""
def test_custom_shots_probs_jax_jit(self, tol):
"""Test that returning probs works with jax and jit when using custom shot vector"""
dev = qml.device("default.qubit.jax", wires=1, shots=(2, 2))
expected = jnp.array([0.0, 1.0])
expected = jnp.array([[0.0, 1.0], [0.0, 1.0]])

@jax.jit
@qml.qnode(dev, interface="jax")
def circuit():
qml.PauliX(wires=0)
return qml.probs()

with pytest.raises(
ValueError, match="doesn't support getting probabilities when using a shot vector"
):
result = circuit()
result = circuit()
assert jnp.allclose(result, expected, atol=tol)

def test_sampling_with_jit(self):
"""Test that sampling works with a jax.jit"""
Expand Down

0 comments on commit c9211f8

Please sign in to comment.