Skip to content

Commit

Permalink
Fix primitives mypy errors
Browse files Browse the repository at this point in the history
  • Loading branch information
Randl committed Mar 21, 2023
1 parent eb4c5cb commit 788f69d
Show file tree
Hide file tree
Showing 3 changed files with 14 additions and 8 deletions.
6 changes: 3 additions & 3 deletions qiskit/primitives/backend_estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ def _run_circuits(
circuits: QuantumCircuit | list[QuantumCircuit],
backend: BackendV1 | BackendV2,
**run_options,
) -> tuple[Result, list[dict]]:
) -> tuple[list[Result], list[dict]]:
"""Remove metadata of circuits and run the circuits on a backend.
Args:
circuits: The circuits
Expand Down Expand Up @@ -70,7 +70,7 @@ def _run_circuits(
return result, metadata


def _prepare_counts(results):
def _prepare_counts(results: Sequence[Result]):
counts = []
for res in results:
count = res.get_counts()
Expand Down Expand Up @@ -351,7 +351,7 @@ def _preprocessing(self) -> list[tuple[QuantumCircuit, list[QuantumCircuit]]]:
return preprocessed_circuits

def _postprocessing(
self, result: Result, accum: list[int], metadata: list[dict]
self, result: Sequence[Result], accum: list[int], metadata: list[dict]
) -> EstimatorResult:
"""
Postprocessing for evaluation of expectation value using pauli rotation gates.
Expand Down
4 changes: 3 additions & 1 deletion qiskit/primitives/backend_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,7 +158,9 @@ def _call(
result, _metadata = _run_circuits(bound_circuits, self._backend, **run_options)
return self._postprocessing(result, bound_circuits)

def _postprocessing(self, result: Result, circuits: list[QuantumCircuit]) -> SamplerResult:
def _postprocessing(
self, result: Sequence[Result], circuits: list[QuantumCircuit]
) -> SamplerResult:
counts = _prepare_counts(result)
shots = sum(counts[0].values())

Expand Down
12 changes: 8 additions & 4 deletions qiskit/primitives/base/base_primitive.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,7 @@
from __future__ import annotations

from abc import ABC
from collections.abc import Sequence
from typing import Any
from typing import Any, cast, Union, Sequence

import numpy as np

Expand Down Expand Up @@ -84,22 +83,27 @@ def _validate_parameter_values(

# Support numpy ndarray
if isinstance(parameter_values, np.ndarray):
parameter_values = cast(np.ndarray,parameter_values)
parameter_values = parameter_values.tolist()
elif isinstance(parameter_values, Sequence):
parameter_values = cast(Union[Sequence[Sequence[float]],Sequence[float]],parameter_values)
parameter_values = tuple(
vector.tolist() if isinstance(vector, np.ndarray) else vector
cast(np.ndarray, vector).tolist() if isinstance(vector, np.ndarray) else vector
for vector in parameter_values
)

# Allow single value
if _isreal(parameter_values):
parameter_values = cast(float,parameter_values)
parameter_values = ((parameter_values,),)
elif isinstance(parameter_values, Sequence) and not any(
isinstance(vector, Sequence) for vector in parameter_values
):
parameter_values = cast(Sequence[float],parameter_values)
parameter_values = (parameter_values,)

# Validation
parameter_values = cast(Sequence[Sequence[float]], parameter_values)
if (
not isinstance(parameter_values, Sequence)
or not all(isinstance(vector, Sequence) for vector in parameter_values)
Expand Down Expand Up @@ -132,7 +136,7 @@ def _isint(obj: Any) -> bool:
return isinstance(obj, int_types) and not isinstance(obj, bool)


def _isreal(obj: Any) -> bool:
def _isreal(obj: Any) -> bool | np.bool_:
"""Check if object is a real number: int or float except ``±Inf`` and ``NaN``."""
float_types = (float, np.floating)
return _isint(obj) or isinstance(obj, float_types) and float("-Inf") < obj < float("Inf")

0 comments on commit 788f69d

Please sign in to comment.