Skip to content

Commit

Permalink
Merge pull request #272 from QunaSys/check_operator_estimatable
Browse files Browse the repository at this point in the history
Check operator estimatable
  • Loading branch information
rykojima committed Nov 30, 2023
2 parents 5744451 + af19e69 commit 5e97c2e
Show file tree
Hide file tree
Showing 4 changed files with 177 additions and 0 deletions.
5 changes: 5 additions & 0 deletions packages/core/quri_parts/core/estimator/sampling/estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
general_pauli_sum_expectation_estimator,
general_pauli_sum_sample_variance,
)
from quri_parts.core.estimator.utils import is_estimatable
from quri_parts.core.measurement import (
CommutablePauliSetMeasurementFactory,
PauliReconstructorFactory,
Expand Down Expand Up @@ -109,6 +110,10 @@ def sampling_estimate(
The estimated value (can be accessed with :attr:`.value`) with standard error
of estimation (can be accessed with :attr:`.error`).
"""
assert is_estimatable(
op, state
), "Number of qubits of the operator is too large to estimate."

if not isinstance(op, Operator):
op = Operator({op: 1.0})

Expand Down
33 changes: 33 additions & 0 deletions packages/core/quri_parts/core/estimator/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from quri_parts.core.estimator import Estimatable
from quri_parts.core.operator import PAULI_IDENTITY, Operator, PauliLabel, zero
from quri_parts.core.state import QuantumState


def is_estimatable(observable: Estimatable, state: QuantumState) -> bool:
"""Check if the qubit count of the observable is larger than that of the
state."""
if observable == PAULI_IDENTITY or observable == zero():
return True

elif isinstance(observable, PauliLabel):
min_state_qubit_required = max(observable.qubit_indices()) + 1
return min_state_qubit_required <= state.qubit_count

elif isinstance(observable, Operator):
for op in observable:
if not is_estimatable(op, state):
return False
return True

else:
assert False, "Observable should be either a PauliLabel or an Operator."
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,32 @@ def test_zero_op(self) -> None:
assert estimate.value == 0.0
assert estimate.error == 0.0

def test_raise_when_not_estimatable(self) -> None:
with pytest.raises(
AssertionError,
match="Number of qubits of the operator is too large to estimate.",
):
sampling_estimate(
pauli_label("Z3"),
initial_state(),
1000,
sampler,
measurement_factory,
allocator,
)
with pytest.raises(
AssertionError,
match="Number of qubits of the operator is too large to estimate.",
):
sampling_estimate(
Operator({pauli_label("Z3"): 3, pauli_label("X0 X1"): 3}),
initial_state(),
1000,
sampler,
measurement_factory,
allocator,
)

def test_const_op(self) -> None:
op = Operator({PAULI_IDENTITY: 3.0})
estimate = sampling_estimate(
Expand Down Expand Up @@ -213,6 +239,26 @@ def test_sampling_estimator(self) -> None:
assert_sampler_args(s)
assert_sample(estimate)

def test_raises_when_not_estimatable(self) -> None:
s = mock_sampler()
estimator = create_sampling_estimator(
total_shots(), s, measurement_factory, allocator
)
with pytest.raises(
AssertionError,
match="Number of qubits of the operator is too large to estimate.",
):
estimator(pauli_label("Z3"), initial_state())

with pytest.raises(
AssertionError,
match="Number of qubits of the operator is too large to estimate.",
):
estimator(
Operator({pauli_label("Z3"): 3, pauli_label("X0 X1"): 3}),
initial_state(),
)


class TestConcurrentSamplingEstimate:
def test_invalid_arguments(self) -> None:
Expand All @@ -238,6 +284,21 @@ def test_invalid_arguments(self) -> None:
allocator,
)

with pytest.raises(
AssertionError,
match="Number of qubits of the operator is too large to estimate.",
):
obs1 = pauli_label("Z3")
obs2 = Operator({pauli_label("Z3"): 3, pauli_label("X0 X1"): 3})
concurrent_sampling_estimate(
[obs1, obs2],
[initial_state()] * 2,
total_shots(),
s,
measurement_factory,
allocator,
)

def test_concurrent_estimate(self) -> None:
s = mock_sampler()

Expand Down
78 changes: 78 additions & 0 deletions packages/core/tests/core/estimator/test_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from quri_parts.circuit import UnboundParametricQuantumCircuit
from quri_parts.core.estimator import Estimatable
from quri_parts.core.estimator.utils import is_estimatable
from quri_parts.core.operator import PAULI_IDENTITY, Operator, pauli_label, zero
from quri_parts.core.state import (
ComputationalBasisState,
GeneralCircuitQuantumState,
ParametricCircuitQuantumState,
ParametricQuantumStateVector,
QuantumStateVector,
)


def test_is_estimatable() -> None:
states = [
GeneralCircuitQuantumState(3),
ComputationalBasisState(3),
QuantumStateVector(3),
ParametricCircuitQuantumState(3, circuit=UnboundParametricQuantumCircuit(3)),
ParametricQuantumStateVector(3, circuit=UnboundParametricQuantumCircuit(3)),
]

valid_pauli_labels: list[Estimatable] = [
zero(),
PAULI_IDENTITY,
pauli_label("X0"),
pauli_label("Y1"),
pauli_label("Z2"),
pauli_label("X0 Y2"),
pauli_label("Y1 Z2"),
pauli_label("Z2"),
]

for op in valid_pauli_labels:
for state in states:
assert is_estimatable(op, state)

invalid_pauli_labels = [pauli_label("X3"), pauli_label("Z50000")]
for op in invalid_pauli_labels:
for state in states:
assert not is_estimatable(op, state)

valid_operator = Operator(
{
pauli_label("Z0 Z1 Z2"): 1,
pauli_label("X0 Z1 Y2"): 2,
pauli_label("X0 Z1 X2"): 2,
pauli_label("Z1"): 2,
PAULI_IDENTITY: 3,
}
)

for state in states:
assert is_estimatable(valid_operator, state)

invalid_operator = Operator(
{
pauli_label("Z50000"): 1,
pauli_label("Z3"): 1,
pauli_label("Z0 Z1 Z2"): 1,
pauli_label("X0 Z1 Y2"): 2,
pauli_label("X0 Z1 X2"): 2,
pauli_label("Z1"): 2,
PAULI_IDENTITY: 3,
}
)
for state in states:
assert not is_estimatable(invalid_operator, state)

1 comment on commit 5e97c2e

@github-actions
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please sign in to comment.