Skip to content

Commit

Permalink
Fix cases for qml.sample(..., counts=True) (#2839)
Browse files Browse the repository at this point in the history
* draft

* reworked tests

* apply same fix as for qutrit dev PR; extend tests

* no prints

* linting; reverting len compute that caused issues

* changelog

* changes

* multi-measure cases with Counts

* no need for different sample handling case

* update object name

* Update pennylane/_qubit_device.py

Co-authored-by: Christina Lee <christina@xanadu.ai>

* Apply suggestions from code review

Co-authored-by: Christina Lee <christina@xanadu.ai>

* counts_exist; revert to making ret_types a list because the generator approach pops items on first use

* format

Co-authored-by: Romain Moyard <rmoyard@gmail.com>
Co-authored-by: Christina Lee <christina@xanadu.ai>
  • Loading branch information
3 people committed Jul 26, 2022
1 parent 52d656d commit cad746f
Show file tree
Hide file tree
Showing 9 changed files with 216 additions and 177 deletions.
4 changes: 2 additions & 2 deletions doc/releases/changelog-dev.md
Expand Up @@ -219,6 +219,7 @@ of operators. [(#2622)](https://github.com/PennyLaneAI/pennylane/pull/2622)

* Samples can be grouped into counts by passing the `counts=True` flag to `qml.sample`.
[(#2686)](https://github.com/PennyLaneAI/pennylane/pull/2686)
[(#2839)](https://github.com/PennyLaneAI/pennylane/pull/2839)

Note that the change included creating a new `Counts` measurement type in `measurements.py`.

Expand Down Expand Up @@ -250,8 +251,7 @@ of operators. [(#2622)](https://github.com/PennyLaneAI/pennylane/pull/2622)
... return qml.sample(qml.PauliZ(0), counts=True), qml.sample(qml.PauliZ(1), counts=True)
>>> result = circuit()
>>> print(result)
[tensor({-1: 526, 1: 474}, dtype=object, requires_grad=True)
tensor({-1: 526, 1: 474}, dtype=object, requires_grad=True)]
({-1: 470, 1: 530}, {-1: 470, 1: 530})
```

* The `qml.state` and `qml.density_matrix` measurements now support custom wire
Expand Down
40 changes: 28 additions & 12 deletions pennylane/_qubit_device.py
Expand Up @@ -265,6 +265,8 @@ def execute(self, circuit, **kwargs):
self._samples = self.generate_samples()

multiple_sampled_jobs = circuit.is_sampled and self._has_partitioned_shots()
ret_types = [m.return_type for m in circuit.measurements]
counts_exist = any(ret is qml.measurements.Counts for ret in ret_types)

# compute the required statistics
if not self.analytic and self._shot_vector is not None:
Expand All @@ -280,20 +282,29 @@ def execute(self, circuit, **kwargs):

if qml.math._multi_dispatch(r) == "jax": # pylint: disable=protected-access
r = r[0]
elif not isinstance(r[0], dict):
elif not counts_exist:
# Measurement types except for Counts
r = qml.math.squeeze(r)
if isinstance(r, (np.ndarray, list)) and r.shape and isinstance(r[0], dict):
# This happens when measurement type is Counts
results.append(r)

if counts_exist:

# This happens when at least one measurement type is Counts
for result_group in r:
if isinstance(result_group, list):
# List that contains one or more dictionaries
results.extend(result_group)
else:
# Other measurement results
results.append(result_group.T)

elif shot_tuple.copies > 1:
results.extend(r.T)
else:
results.append(r.T)

s1 = s2

if not multiple_sampled_jobs:
if not multiple_sampled_jobs and not counts_exist:
# Can only stack single element outputs
results = qml.math.stack(results)

Expand All @@ -302,8 +313,6 @@ def execute(self, circuit, **kwargs):

if not circuit.is_sampled:

ret_types = [m.return_type for m in circuit.measurements]

if len(circuit.measurements) == 1:
if circuit.measurements[0].return_type is qml.measurements.State:
# State: assumed to only be allowed if it's the only measurement
Expand All @@ -318,15 +327,18 @@ def execute(self, circuit, **kwargs):
):
# Measurements with expval or var
results = self._asarray(results, dtype=self.R_DTYPE)
elif any(ret is not qml.measurements.Counts for ret in ret_types):
# all the other cases except all counts
elif not counts_exist:
# all the other cases except any counts
results = self._asarray(results)

elif circuit.all_sampled and not self._has_partitioned_shots():

results = self._asarray(results)
else:
results = tuple(self._asarray(r) for r in results)
results = tuple(
qml.math.squeeze(self._asarray(r)) if not isinstance(r, dict) else r
for r in results
)

# increment counter for number of executions of qubit device
self._num_executions += 1
Expand Down Expand Up @@ -1012,6 +1024,7 @@ def _samples_to_counts(samples, no_observable_provided):
# Before converting to str, we need to extract elements from arrays
# to satisfy the case of jax interface, as jax arrays do not support str.
samples = ["".join([str(s.item()) for s in sample]) for sample in samples]

states, counts = np.unique(samples, return_counts=True)
return dict(zip(states, counts))

Expand Down Expand Up @@ -1055,14 +1068,17 @@ def _samples_to_counts(samples, no_observable_provided):
if counts:
return _samples_to_counts(samples, no_observable_provided)
return samples

num_wires = len(device_wires) if len(device_wires) > 0 else self.num_wires
if counts:
shape = (-1, bin_size, 3) if no_observable_provided else (-1, bin_size)
shape = (-1, bin_size, num_wires) if no_observable_provided else (-1, bin_size)
return [
_samples_to_counts(bin_sample, no_observable_provided)
for bin_sample in samples.reshape(shape)
]

return (
samples.reshape((3, bin_size, -1))
samples.reshape((num_wires, bin_size, -1))
if no_observable_provided
else samples.reshape((bin_size, -1))
)
Expand Down
9 changes: 3 additions & 6 deletions pennylane/interfaces/autograd.py
Expand Up @@ -111,15 +111,12 @@ def _execute(

for i, r in enumerate(res):

if any(m.return_type is qml.measurements.Counts for m in tapes[i].measurements):
continue

if isinstance(r, np.ndarray):
# For backwards compatibility, we flatten ragged tape outputs
# when there is no sampling
try:
if isinstance(r[0][0], dict):
# This happens when measurement type is Counts and shot vector is passed
continue
except (IndexError, KeyError):
pass
r = np.hstack(r) if r.dtype == np.dtype("object") else r
res[i] = np.tensor(r)

Expand Down
14 changes: 12 additions & 2 deletions pennylane/interfaces/jax.py
Expand Up @@ -169,16 +169,26 @@ def cp_tape(t, a):
tc.set_parameters(a)
return tc

def array_if_not_counts(tape, r):
"""Auxiliary function to convert the result of a tape to an array,
unless the tape had Counts measurements that are represented with
dictionaries. JAX NumPy arrays don't support dictionaries."""
return (
jnp.array(r)
if not any(m.return_type is qml.measurements.Counts for m in tape.measurements)
else r
)

@jax.custom_vjp
def wrapped_exec(params):
new_tapes = [cp_tape(t, a) for t, a in zip(tapes, params)]
with qml.tape.Unwrap(*new_tapes):
res, _ = execute_fn(new_tapes, **gradient_kwargs)

if len(tapes) > 1:
res = [jnp.array(r) for r in res]
res = [array_if_not_counts(tape, r) for tape, r in zip(tapes, res)]
else:
res = jnp.array(res)
res = array_if_not_counts(tapes[0], res)

return res

Expand Down
3 changes: 3 additions & 0 deletions pennylane/interfaces/tensorflow.py
Expand Up @@ -91,6 +91,9 @@ def execute(tapes, device, execute_fn, gradient_fn, gradient_kwargs, _n=1, max_d
for i, tape in enumerate(tapes):
# convert output to TensorFlow tensors

if any(m.return_type is qml.measurements.Counts for m in tape.measurements):
continue

if isinstance(res[i], np.ndarray):
# For backwards compatibility, we flatten ragged tape outputs
# when there is no sampling
Expand Down
3 changes: 3 additions & 0 deletions pennylane/interfaces/torch.py
Expand Up @@ -97,6 +97,9 @@ def forward(ctx, kwargs, *parameters): # pylint: disable=arguments-differ
# For backwards compatibility, we flatten ragged tape outputs
r = np.hstack(r)

if any(m.return_type is qml.measurements.Counts for m in ctx.tapes[i].measurements):
continue

if isinstance(r, (list, tuple)):
res[i] = [torch.as_tensor(t) for t in r]

Expand Down
11 changes: 5 additions & 6 deletions pennylane/measurements.py
Expand Up @@ -654,21 +654,20 @@ def circuit(x):
.. code-block:: python3
dev = qml.device('default.qubit', wires=3, shots=10)
dev = qml.device("default.qubit", wires=3, shots=10)
@qml.qnode(dev)
def my_circ():
qml.Hadamard(wires=0)
qml.CNOT(wires=[0,1])
qml.CNOT(wires=[0, 1])
qml.PauliX(wires=2)
return qml.sample(qml.PauliZ(0), counts = True), qml.sample(counts=True)
return qml.sample(qml.PauliZ(0), counts=True), qml.sample(counts=True)
Executing this QNode:
>>> my_circ()
tensor([tensor({-1: 5, 1: 5}, dtype=object, requires_grad=True),
tensor({'001': 5, '111': 5}, dtype=object, requires_grad=True)],
dtype=object, requires_grad=True)
({-1: 3, 1: 7}, {'001': 7, '111': 3})
.. note::
Expand Down
23 changes: 20 additions & 3 deletions pennylane/qnode.py
Expand Up @@ -640,6 +640,26 @@ def __call__(self, *args, **kwargs):

res = res[0]

if (
not isinstance(self._qfunc_output, Sequence)
and self._qfunc_output.return_type is qml.measurements.Counts
):

if not self.device._has_partitioned_shots():
# return a dictionary with counts not as a single-element array
return res[0]

return tuple(res)

if isinstance(self._qfunc_output, Sequence) and any(
m.return_type is qml.measurements.Counts for m in self._qfunc_output
):

# If Counts was returned with other measurements, then apply the
# data structure used in the qfunc
qfunc_output_type = type(self._qfunc_output)
return qfunc_output_type(res)

if override_shots is not False:
# restore the initialization gradient function
self.gradient_fn, self.gradient_kwargs, self.device = original_grad_fn
Expand All @@ -650,9 +670,6 @@ def __call__(self, *args, **kwargs):
self.tape.is_sampled and self.device._has_partitioned_shots()
):
return res
if self._qfunc_output.return_type is qml.measurements.Counts:
# return a dictionary with counts not as a single-element array
return res[0]

return qml.math.squeeze(res)

Expand Down

0 comments on commit cad746f

Please sign in to comment.