Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix cases for qml.sample(..., counts=True) #2839

Merged
merged 21 commits into from Jul 26, 2022
Merged
Show file tree
Hide file tree
Changes from 14 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
4 changes: 2 additions & 2 deletions doc/releases/changelog-dev.md
Expand Up @@ -180,6 +180,7 @@

* Samples can be grouped into counts by passing the `counts=True` flag to `qml.sample`.
[(#2686)](https://github.com/PennyLaneAI/pennylane/pull/2686)
[(#2839)](https://github.com/PennyLaneAI/pennylane/pull/2839)

Note that the change included creating a new `Counts` measurement type in `measurements.py`.

Expand Down Expand Up @@ -211,8 +212,7 @@
... return qml.sample(qml.PauliZ(0), counts=True), qml.sample(qml.PauliZ(1), counts=True)
>>> result = circuit()
>>> print(result)
[tensor({-1: 526, 1: 474}, dtype=object, requires_grad=True)
tensor({-1: 526, 1: 474}, dtype=object, requires_grad=True)]
({-1: 470, 1: 530}, {-1: 470, 1: 530})
rmoyard marked this conversation as resolved.
Show resolved Hide resolved
```

* The `qml.state` and `qml.density_matrix` measurements now support custom wire
Expand Down
40 changes: 28 additions & 12 deletions pennylane/_qubit_device.py
Expand Up @@ -264,6 +264,8 @@ def execute(self, circuit, **kwargs):
self._samples = self.generate_samples()

multiple_sampled_jobs = circuit.is_sampled and self._has_partitioned_shots()
ret_types = [m.return_type for m in circuit.measurements]
albi3ro marked this conversation as resolved.
Show resolved Hide resolved
no_counts = all(ret is not qml.measurements.Counts for ret in ret_types)
antalszava marked this conversation as resolved.
Show resolved Hide resolved

# compute the required statistics
if not self.analytic and self._shot_vector is not None:
Expand All @@ -279,20 +281,29 @@ def execute(self, circuit, **kwargs):

if qml.math._multi_dispatch(r) == "jax": # pylint: disable=protected-access
r = r[0]
elif not isinstance(r[0], dict):
antalszava marked this conversation as resolved.
Show resolved Hide resolved
elif no_counts:
# Measurement types except for Counts
r = qml.math.squeeze(r)
if isinstance(r, (np.ndarray, list)) and r.shape and isinstance(r[0], dict):
# This happens when measurement type is Counts
results.append(r)

if not no_counts:

# This happens when at least one measurement type is Counts
for result_group in r:
if isinstance(result_group, list):
albi3ro marked this conversation as resolved.
Show resolved Hide resolved
# List that contains one or more dictionaries
results.extend(result_group)
else:
# Other measurement results
results.append(result_group.T)

elif shot_tuple.copies > 1:
results.extend(r.T)
else:
results.append(r.T)

s1 = s2

if not multiple_sampled_jobs:
if not multiple_sampled_jobs and no_counts:
# Can only stack single element outputs
results = qml.math.stack(results)

Expand All @@ -301,8 +312,6 @@ def execute(self, circuit, **kwargs):

if not circuit.is_sampled:

ret_types = [m.return_type for m in circuit.measurements]

if len(circuit.measurements) == 1:
if circuit.measurements[0].return_type is qml.measurements.State:
# State: assumed to only be allowed if it's the only measurement
Expand All @@ -317,15 +326,18 @@ def execute(self, circuit, **kwargs):
):
# Measurements with expval or var
results = self._asarray(results, dtype=self.R_DTYPE)
elif any(ret is not qml.measurements.Counts for ret in ret_types):
# all the other cases except all counts
elif no_counts:
antalszava marked this conversation as resolved.
Show resolved Hide resolved
# all the other cases except any counts
results = self._asarray(results)

elif circuit.all_sampled and not self._has_partitioned_shots():

results = self._asarray(results)
else:
results = tuple(self._asarray(r) for r in results)
results = tuple(
rmoyard marked this conversation as resolved.
Show resolved Hide resolved
qml.math.squeeze(self._asarray(r)) if not isinstance(r, dict) else r
for r in results
)

# increment counter for number of executions of qubit device
self._num_executions += 1
Expand Down Expand Up @@ -1011,6 +1023,7 @@ def _samples_to_counts(samples, no_observable_provided):
# Before converting to str, we need to extract elements from arrays
# to satisfy the case of jax interface, as jax arrays do not support str.
samples = ["".join([str(s.item()) for s in sample]) for sample in samples]

states, counts = np.unique(samples, return_counts=True)
return dict(zip(states, counts))

Expand Down Expand Up @@ -1054,14 +1067,17 @@ def _samples_to_counts(samples, no_observable_provided):
if counts:
return _samples_to_counts(samples, no_observable_provided)
return samples

num_wires = len(device_wires) if len(device_wires) > 0 else self.num_wires
if counts:
shape = (-1, bin_size, 3) if no_observable_provided else (-1, bin_size)
shape = (-1, bin_size, num_wires) if no_observable_provided else (-1, bin_size)
rmoyard marked this conversation as resolved.
Show resolved Hide resolved
return [
_samples_to_counts(bin_sample, no_observable_provided)
for bin_sample in samples.reshape(shape)
]

return (
samples.reshape((3, bin_size, -1))
samples.reshape((num_wires, bin_size, -1))
if no_observable_provided
else samples.reshape((bin_size, -1))
)
Expand Down
9 changes: 3 additions & 6 deletions pennylane/interfaces/autograd.py
Expand Up @@ -111,15 +111,12 @@ def _execute(

for i, r in enumerate(res):

if any(m.return_type is qml.measurements.Counts for m in tapes[i].measurements):
antalszava marked this conversation as resolved.
Show resolved Hide resolved
continue

if isinstance(r, np.ndarray):
# For backwards compatibility, we flatten ragged tape outputs
# when there is no sampling
try:
if isinstance(r[0][0], dict):
# This happens when measurement type is Counts and shot vector is passed
continue
except (IndexError, KeyError):
antalszava marked this conversation as resolved.
Show resolved Hide resolved
pass
r = np.hstack(r) if r.dtype == np.dtype("object") else r
res[i] = np.tensor(r)

Expand Down
14 changes: 12 additions & 2 deletions pennylane/interfaces/jax.py
Expand Up @@ -169,16 +169,26 @@ def cp_tape(t, a):
tc.set_parameters(a)
return tc

def array_if_not_counts(tape, r):
"""Auxiliary function to convert the result of a tape to an array,
unless the tape had Counts measurements that are represented with
dictionaries. JAX NumPy arrays don't support dictionaries."""
return (
jnp.array(r)
if not any(m.return_type is qml.measurements.Counts for m in tape.measurements)
else r
)

@jax.custom_vjp
def wrapped_exec(params):
new_tapes = [cp_tape(t, a) for t, a in zip(tapes, params)]
with qml.tape.Unwrap(*new_tapes):
res, _ = execute_fn(new_tapes, **gradient_kwargs)

if len(tapes) > 1:
res = [jnp.array(r) for r in res]
res = [array_if_not_counts(tape, r) for tape, r in zip(tapes, res)]
else:
res = jnp.array(res)
res = array_if_not_counts(tapes[0], res)

return res

Expand Down
3 changes: 3 additions & 0 deletions pennylane/interfaces/tensorflow.py
Expand Up @@ -91,6 +91,9 @@ def execute(tapes, device, execute_fn, gradient_fn, gradient_kwargs, _n=1, max_d
for i, tape in enumerate(tapes):
# convert output to TensorFlow tensors

if any(m.return_type is qml.measurements.Counts for m in tape.measurements):
continue

if isinstance(res[i], np.ndarray):
# For backwards compatibility, we flatten ragged tape outputs
# when there is no sampling
Expand Down
3 changes: 3 additions & 0 deletions pennylane/interfaces/torch.py
Expand Up @@ -97,6 +97,9 @@ def forward(ctx, kwargs, *parameters): # pylint: disable=arguments-differ
# For backwards compatibility, we flatten ragged tape outputs
r = np.hstack(r)

if any(m.return_type is qml.measurements.Counts for m in ctx.tapes[i].measurements):
continue

if isinstance(r, (list, tuple)):
res[i] = [torch.as_tensor(t) for t in r]

Expand Down
11 changes: 5 additions & 6 deletions pennylane/measurements.py
Expand Up @@ -654,21 +654,20 @@ def circuit(x):

.. code-block:: python3

dev = qml.device('default.qubit', wires=3, shots=10)
dev = qml.device("default.qubit", wires=3, shots=10)


@qml.qnode(dev)
def my_circ():
qml.Hadamard(wires=0)
qml.CNOT(wires=[0,1])
qml.CNOT(wires=[0, 1])
qml.PauliX(wires=2)
return qml.sample(qml.PauliZ(0), counts = True), qml.sample(counts=True)
return qml.sample(qml.PauliZ(0), counts=True), qml.sample(counts=True)
rmoyard marked this conversation as resolved.
Show resolved Hide resolved

Executing this QNode:

>>> my_circ()
tensor([tensor({-1: 5, 1: 5}, dtype=object, requires_grad=True),
tensor({'001': 5, '111': 5}, dtype=object, requires_grad=True)],
dtype=object, requires_grad=True)
({-1: 3, 1: 7}, {'001': 7, '111': 3})

.. note::

Expand Down
22 changes: 19 additions & 3 deletions pennylane/qnode.py
Expand Up @@ -640,6 +640,25 @@ def __call__(self, *args, **kwargs):

res = res[0]

if (
not isinstance(self._qfunc_output, Sequence)
and self._qfunc_output.return_type is qml.measurements.Counts
):

if not self.device._has_partitioned_shots():
# return a dictionary with counts not as a single-element array
return res[0]

return tuple(res)

if isinstance(self._qfunc_output, Sequence) and qml.measurements.Counts in set(
out.return_type for out in self._qfunc_output
):
antalszava marked this conversation as resolved.
Show resolved Hide resolved
# If Counts was returned with other measurements, then apply the
# data structure used in the qfunc
qfunc_output_type = type(self._qfunc_output)
return qfunc_output_type(res)

if override_shots is not False:
# restore the initialization gradient function
self.gradient_fn, self.gradient_kwargs, self.device = original_grad_fn
Expand All @@ -650,9 +669,6 @@ def __call__(self, *args, **kwargs):
self.tape.is_sampled and self.device._has_partitioned_shots()
):
return res
if self._qfunc_output.return_type is qml.measurements.Counts:
# return a dictionary with counts not as a single-element array
return res[0]

return qml.math.squeeze(res)

Expand Down