Skip to content

Commit

Permalink
Merge pull request #321 from QunaSys/general_sampling_estimator
Browse files Browse the repository at this point in the history
General sampling estimator
  • Loading branch information
toru4838 committed Mar 18, 2024
2 parents 6180782 + 0d75127 commit 776ec7c
Show file tree
Hide file tree
Showing 5 changed files with 154 additions and 9 deletions.
11 changes: 6 additions & 5 deletions packages/core/quri_parts/core/estimator/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -480,8 +480,8 @@ class GeneralQuantumEstimator(Generic[_StateT, _ParametricStateT]):
- Act as :class:`ConcurrentParametricQuantumEstimator`:
- Estimatable, _ParametricStateT, [[float, ...], ...] -> [Estimate, ...]
When a :class:`GeneralEstimator` is called directly with one of the combinations
above, it needs to parse the input arguments to figure out which of
When a :class:`GeneralQuantumEstimator` is called directly with one of the
combinations above, it needs to parse the input arguments to figure out which of
:class:`QuantumEstimator`, :class:`ConcurrentQuantumEstimator`,
:class:`ParametricQuantumEstimator`, or :class:`ConcurrentParametricEstimator`
is required to perform the estimation. To avoid such performance penalty, you may
Expand Down Expand Up @@ -597,12 +597,13 @@ def create_general_estimator_from_estimator(
def create_general_estimator_from_estimator(
estimator: QuantumEstimator[_StateT],
) -> GeneralQuantumEstimator[_StateT, _ParametricStateT]:
"""Creates a :class:`GeneralEstimator` from a :class:`QuantumEstimator`.
"""Creates a :class:`GeneralQuantumEstimator` from a
:class:`QuantumEstimator`.
Note:
- The concurrencies of the :class:`ConcurrentQuantumEstimaror` and
`ConcurrentParametricQuantumEstimaror` will be set to 1 when a
:class:`GeneralEstimator` is created with this function.
:class:`GeneralQuantumEstimator` is created with this function.
- When circuit conversion is involved in the estimator execution, the
parametric estimator created from this function will bind the parameter
first, and then convert the bound circuit every time the patametric estimator
Expand Down Expand Up @@ -645,7 +646,7 @@ def create_general_estimator_from_concurrent_estimator(
def create_general_estimator_from_concurrent_estimator(
concurrent_estimator: ConcurrentQuantumEstimator[_StateT],
) -> GeneralQuantumEstimator[_StateT, _ParametricStateT]:
"""Creates a :class:`GeneralEstimator` from a
"""Creates a :class:`GeneralQuantumEstimator` from a
:class:`ConcurrentQuantumEstimator`.
Note:
Expand Down
2 changes: 2 additions & 0 deletions packages/core/quri_parts/core/estimator/sampling/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

from .estimator import (
concurrent_sampling_estimate,
create_general_sampling_estimator,
create_sampling_concurrent_estimator,
create_sampling_estimator,
get_estimate_from_sampling_result,
Expand Down Expand Up @@ -81,4 +82,5 @@
"CircuitShotPairPreparationFunction",
"get_sampling_circuits_and_shots",
"distribute_shots_among_pauli_sets",
"create_general_sampling_estimator",
]
27 changes: 26 additions & 1 deletion packages/core/quri_parts/core/estimator/sampling/estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,9 @@
ConcurrentQuantumEstimator,
Estimatable,
Estimate,
GeneralQuantumEstimator,
QuantumEstimator,
create_general_estimator_from_estimator,
)
from quri_parts.core.estimator.sampling.pauli import (
general_pauli_sum_expectation_estimator,
Expand All @@ -35,7 +37,7 @@
MeasurementCounts,
PauliSamplingShotsAllocator,
)
from quri_parts.core.state import CircuitQuantumState
from quri_parts.core.state import CircuitQuantumState, ParametricCircuitQuantumState

from .estimator_helpers import (
CircuitShotPairPreparationFunction,
Expand Down Expand Up @@ -286,3 +288,26 @@ def estimator(
)

return estimator


def create_general_sampling_estimator(
total_shots: int,
sampler: ConcurrentSampler,
measurement_factory: CommutablePauliSetMeasurementFactory,
shots_allocator: PauliSamplingShotsAllocator,
) -> GeneralQuantumEstimator[CircuitQuantumState, ParametricCircuitQuantumState]:
"""Creates a :class:`GeneralQuantumEstimator` that performs sampling
estimation.
Args:
total_shots: Total number of shots available for sampling measurements.
sampler: A Sampler that actually performs the sampling measurements.
measurement_factory: A function that performs Pauli grouping and returns
a measurement scheme for Pauli operators constituting the original operator.
shots_allocator: A function that allocates the total shots to Pauli groups to
be measured.
"""
sampling_estimator = create_sampling_estimator(
total_shots, sampler, measurement_factory, shots_allocator
)
return create_general_estimator_from_estimator(sampling_estimator)
121 changes: 119 additions & 2 deletions packages/core/tests/core/estimator/sampling/test_sampling_estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,17 +8,26 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import unittest
from collections.abc import Collection, Iterable
from math import sqrt
from typing import Any, Union, cast
from unittest.mock import Mock

import numpy as np
import pytest

from quri_parts.circuit import H, NonParametricQuantumCircuit, QuantumCircuit, X
from quri_parts.circuit import (
H,
NonParametricQuantumCircuit,
QuantumCircuit,
UnboundParametricQuantumCircuit,
X,
)
from quri_parts.core.estimator import Estimate
from quri_parts.core.estimator.sampling import (
concurrent_sampling_estimate,
create_general_sampling_estimator,
create_sampling_concurrent_estimator,
create_sampling_estimator,
get_estimate_from_sampling_result,
Expand All @@ -44,7 +53,11 @@
from quri_parts.core.sampling.shots_allocator import (
create_equipartition_shots_allocator,
)
from quri_parts.core.state import CircuitQuantumState, ComputationalBasisState
from quri_parts.core.state import (
CircuitQuantumState,
ComputationalBasisState,
ParametricCircuitQuantumState,
)

n_qubits = 3

Expand Down Expand Up @@ -510,3 +523,107 @@ def test_sampling_concurrent_estimator(self) -> None:
assert len(estimate_list) == 2
assert_sample(estimate_list[0])
assert estimate_list[1].value == (1 - 1 + 2 - 4) / 8


class GeneralSamplingEstimator(unittest.TestCase):
def setUp(self) -> None:
s = mock_sampler()
self.general_estimator = create_general_sampling_estimator(
total_shots(),
s,
bitwise_commuting_pauli_measurement,
allocator,
)

def test_general_quantum_estimator(self) -> None:
estimate = self.general_estimator(operator(), initial_state())
assert_sample(estimate)

def test_concurrent_estimate(self) -> None:
estimates = self.general_estimator(
operator(),
[initial_state(), ComputationalBasisState(3, bits=0b001)],
)

estimate_list = list(estimates)
assert len(estimate_list) == 2
assert_sample(estimate_list[0])
assert_sample(estimate_list[1])

estimates = self.general_estimator(
[operator()],
[initial_state(), ComputationalBasisState(3, bits=0b001)],
)

estimate_list = list(estimates)
assert len(estimate_list) == 2
assert_sample(estimate_list[0])
assert_sample(estimate_list[1])

estimates = self.general_estimator(
[operator(), pauli_label("Z0")],
ComputationalBasisState(3, bits=0b001),
)

estimate_list = list(estimates)
assert len(estimate_list) == 2
assert_sample(estimate_list[0])
assert estimate_list[1].value == (1 - 1 + 2 - 4) / 8

estimates = self.general_estimator(
[operator(), pauli_label("Z0")],
[ComputationalBasisState(3, bits=0b001)],
)

estimate_list = list(estimates)
assert len(estimate_list) == 2
assert_sample(estimate_list[0])
assert estimate_list[1].value == (1 - 1 + 2 - 4) / 8

estimates = self.general_estimator(
[operator(), pauli_label("Z0")],
[initial_state(), ComputationalBasisState(3, bits=0b001)],
)

estimate_list = list(estimates)
assert len(estimate_list) == 2
assert_sample(estimate_list[0])
assert estimate_list[1].value == (1 - 1 + 2 - 4) / 8

def test_parametric_estimate(self) -> None:
circuit = UnboundParametricQuantumCircuit(n_qubits)
circuit.add_X_gate(0)
circuit.add_ParametricRX_gate(0)
circuit.add_ParametricRY_gate(1)
circuit.add_ParametricRZ_gate(2)

state = ParametricCircuitQuantumState(n_qubits, circuit)

estimate = self.general_estimator(operator(), state, [0, 1, 2])
assert_sample(estimate)

estimate = self.general_estimator(operator(), state, np.array([0, 1, 2]))
assert_sample(estimate)

def test_concurrent_parametric_estimate(self) -> None:
circuit = UnboundParametricQuantumCircuit(n_qubits)
circuit.add_X_gate(0)
circuit.add_ParametricRX_gate(0)
circuit.add_ParametricRY_gate(1)
circuit.add_ParametricRZ_gate(2)

state = ParametricCircuitQuantumState(n_qubits, circuit)

estimates = self.general_estimator(operator(), state, [[0, 1, 2], [4, 5, 6]])
estimate_list = list(estimates)
assert len(estimate_list) == 2
assert_sample(estimate_list[0])
assert_sample(estimate_list[1])

estimates = self.general_estimator(
operator(), state, np.array([[0, 1, 2], [4, 5, 6]])
)
estimate_list = list(estimates)
assert len(estimate_list) == 2
assert_sample(estimate_list[0])
assert_sample(estimate_list[1])
Original file line number Diff line number Diff line change
Expand Up @@ -258,7 +258,7 @@ def test_with_vector(self) -> None:
assert np.isclose(estimate_list[1].value, 12 * np.sqrt(2))


class TestGeneralEstimator(unittest.TestCase):
class TestGeneralQuantumEstimator(unittest.TestCase):
def setUp(self) -> None:
self.op_0 = PAULI_IDENTITY
self.op_1 = Operator({pauli_label("X0"): 1, pauli_label("Y0"): 1})
Expand Down

1 comment on commit 776ec7c

@github-actions
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please sign in to comment.