Skip to content

Commit

Permalink
Improve validation in the cutting reconstruction function (#581)
Browse files Browse the repository at this point in the history
* Improve validation in the cutting reconstruction function

* Fix lint by adding a type annotation

* Bring coverage to 100%
  • Loading branch information
garrison committed Jun 3, 2024
1 parent 3162cfd commit 67f24cb
Show file tree
Hide file tree
Showing 2 changed files with 79 additions and 19 deletions.
56 changes: 39 additions & 17 deletions circuit_knitting/cutting/cutting_reconstruction.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@

from __future__ import annotations

from collections.abc import Sequence, Hashable
from collections.abc import Sequence, Hashable, Mapping

import numpy as np
from qiskit.quantum_info import PauliList
Expand Down Expand Up @@ -75,40 +75,62 @@ def reconstruct_expectation_values(
ValueError: ``observables`` and ``results`` are of incompatible types.
ValueError: An input observable has a phase not equal to 1.
"""
if isinstance(observables, PauliList) and not isinstance(
results, (SamplerResult, PrimitiveResult)
):
raise ValueError(
"If observables is a PauliList, results must be a SamplerResult or PrimitiveResult instance."
)
if isinstance(observables, dict) and not isinstance(results, dict):
raise ValueError(
"If observables is a dictionary, results must also be a dictionary."
)

# If circuit was not separated, transform input data structures to dictionary format
# If circuit was not separated, transform input data structures to
# dictionary format. Perform some input validation in either case.
if isinstance(observables, PauliList):
if not isinstance(results, (SamplerResult, PrimitiveResult)):
raise ValueError(
"If observables is a PauliList, results must be a SamplerResult or PrimitiveResult instance."
)
if any(obs.phase != 0 for obs in observables):
raise ValueError("An input observable has a phase not equal to 1.")
subobservables_by_subsystem = decompose_observables(
observables, "A" * len(observables[0])
subobservables_by_subsystem: Mapping[Hashable, PauliList] = (
decompose_observables(observables, "A" * len(observables[0]))
)
results_dict: dict[Hashable, SamplerResult | PrimitiveResult] = {"A": results}
results_dict: Mapping[Hashable, SamplerResult | PrimitiveResult] = {
"A": results
}
expvals = np.zeros(len(observables))

else:
elif isinstance(observables, Mapping):
if not isinstance(results, Mapping):
raise ValueError(
"If observables is a dictionary, results must also be a dictionary."
)
if observables.keys() != results.keys():
raise ValueError(
"The subsystem labels of the observables and results do not match."
)
results_dict = results
for label, subobservable in observables.items():
if any(obs.phase != 0 for obs in subobservable):
raise ValueError("An input observable has a phase not equal to 1.")
subobservables_by_subsystem = observables
expvals = np.zeros(len(list(observables.values())[0]))

else:
raise ValueError("observables must be either a PauliList or dict.")

subsystem_observables = {
label: ObservableCollection(subobservables)
for label, subobservables in subobservables_by_subsystem.items()
}

# Validate that the number of subexperiments executed is consistent with
# the number of coefficients and observable groups.
for label, so in subsystem_observables.items():
current_result = results_dict[label]
if isinstance(current_result, SamplerResult):
# SamplerV1 provides a SamplerResult
current_result = current_result.quasi_dists
if len(current_result) != len(coefficients) * len(so.groups):
raise ValueError(
f"The number of subexperiments performed in subsystem '{label}' "
f"({len(current_result)}) should equal the number of coefficients "
f"({len(coefficients)}) times the number of mutually commuting "
f"subobservable groups ({len(so.groups)}), but it does not."
)

# Reconstruct the expectation values
for i, coeff in enumerate(coefficients):
current_expvals = np.ones((len(expvals),))
Expand Down
42 changes: 40 additions & 2 deletions test/cutting/test_cutting_reconstruction.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
BitArray,
)
from qiskit.primitives.containers import make_data_bin
from qiskit.quantum_info import Pauli, PauliList
from qiskit.quantum_info import Pauli, PauliList, SparsePauliOp

from circuit_knitting.utils.observable_grouping import CommutingObservableGroup
from circuit_knitting.cutting.qpd import WeightType
Expand All @@ -48,7 +48,7 @@ def test_cutting_reconstruction(self):
observables = PauliList(["ZZ"])
expvals = reconstruct_expectation_values(results, weights, observables)
self.assertEqual([1.0], expvals)
with self.subTest("Test mismatching inputs"):
with self.subTest("Test mismatching input types"):
results = SamplerResult(
quasi_dists=[QuasiDistribution({"0": 1.0})], metadata=[{}]
)
Expand All @@ -68,6 +68,32 @@ def test_cutting_reconstruction(self):
e_info.value.args[0]
== "If observables is a PauliList, results must be a SamplerResult or PrimitiveResult instance."
)
with self.subTest("Test invalid observables type"):
results = SamplerResult(
quasi_dists=[QuasiDistribution({"0": 1.0})], metadata=[{}]
)
weights = [(1.0, WeightType.EXACT)]
observables = [SparsePauliOp(["ZZ"])]
with pytest.raises(ValueError) as e_info:
reconstruct_expectation_values(results, weights, observables)
assert (
e_info.value.args[0]
== "observables must be either a PauliList or dict."
)
with self.subTest("Test mismatching subsystem labels"):
results = {
"A": SamplerResult(
quasi_dists=[QuasiDistribution({"0": 1.0})], metadata=[{}]
)
}
weights = [(1.0, WeightType.EXACT)]
observables = {"B": [PauliList("ZZ")]}
with pytest.raises(ValueError) as e_info:
reconstruct_expectation_values(results, weights, observables)
assert (
e_info.value.args[0]
== "The subsystem labels of the observables and results do not match."
)
with self.subTest("Test unsupported phase"):
results = SamplerResult(
quasi_dists=[QuasiDistribution({"0": 1.0})], metadata=[{}]
Expand Down Expand Up @@ -110,6 +136,18 @@ def test_cutting_reconstruction(self):
observables = PauliList(["II", "IZ", "ZI", "ZZ"])
expvals = reconstruct_expectation_values(results, weights, observables)
assert expvals == pytest.approx([0.0, -0.6, 0.0, -0.2])
with self.subTest("Test inconsistent number of subexperiment results provided"):
results = SamplerResult(
quasi_dists=[QuasiDistribution({"0": 1.0})], metadata=[{}]
)
weights = [(1.0, WeightType.EXACT)]
observables = PauliList(["ZZ", "XX"])
with pytest.raises(ValueError) as e_info:
reconstruct_expectation_values(results, weights, observables)
assert (
e_info.value.args[0]
== "The number of subexperiments performed in subsystem 'A' (1) should equal the number of coefficients (1) times the number of mutually commuting subobservable groups (2), but it does not."
)

@data(
("000", [1, 1, 1]),
Expand Down

0 comments on commit 67f24cb

Please sign in to comment.