Skip to content

Commit

Permalink
Fix utils mypy errors
Browse files Browse the repository at this point in the history
  • Loading branch information
Randl committed May 23, 2023
1 parent e1056ee commit 2ed4b7d
Show file tree
Hide file tree
Showing 8 changed files with 128 additions and 116 deletions.
33 changes: 18 additions & 15 deletions qiskit/utils/deprecation.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,14 @@
import functools
import inspect
import warnings
from typing import Any, Callable, Dict, Optional, Type, Tuple, Union
from collections.abc import Callable
from typing import Any, Type


def deprecate_func(
*,
since: str,
additional_msg: Optional[str] = None,
additional_msg: str | None = None,
pending: bool = False,
package_name: str = "qiskit-terra",
removal_timeline: str = "no earlier than 3 months after the release date",
Expand Down Expand Up @@ -104,12 +105,12 @@ def deprecate_arg(
name: str,
*,
since: str,
additional_msg: Optional[str] = None,
deprecation_description: Optional[str] = None,
additional_msg: str | None = None,
deprecation_description: str | None = None,
pending: bool = False,
package_name: str = "qiskit-terra",
new_alias: Optional[str] = None,
predicate: Optional[Callable[[Any], bool]] = None,
new_alias: str | None = None,
predicate: Callable[[Any], bool] | None = None,
removal_timeline: str = "no earlier than 3 months after the release date",
):
"""Decorator to indicate an argument has been deprecated in some way.
Expand Down Expand Up @@ -204,10 +205,10 @@ def wrapper(*args, **kwargs):


def deprecate_arguments(
kwarg_map: Dict[str, Optional[str]],
kwarg_map: dict[str, str | None],
category: Type[Warning] = DeprecationWarning,
*,
since: Optional[str] = None,
since: str | None = None,
):
"""Deprecated. Instead, use `@deprecate_arg`.
Expand Down Expand Up @@ -280,7 +281,7 @@ def deprecate_function(
stacklevel: int = 2,
category: Type[Warning] = DeprecationWarning,
*,
since: Optional[str] = None,
since: str | None = None,
):
"""Deprecated. Instead, use `@deprecate_func`.
Expand Down Expand Up @@ -313,15 +314,15 @@ def wrapper(*args, **kwargs):

def _maybe_warn_and_rename_kwarg(
args: tuple[Any, ...],
kwargs: Dict[str, Any],
kwargs: dict[str, Any],
*,
func_name: str,
original_func_co_varnames: tuple[str, ...],
old_arg_name: str,
new_alias: Optional[str],
new_alias: str | None,
warning_msg: str,
category: Type[Warning],
predicate: Optional[Callable[[Any], bool]],
predicate: Callable[[Any], bool] | None,
) -> None:
# In Python 3.10+, we should set `zip(strict=False)` (the default). That is, we want to
# stop iterating once `args` is done, since some args may have not been explicitly passed as
Expand Down Expand Up @@ -353,9 +354,11 @@ def _write_deprecation_msg(
pending: bool,
additional_msg: str,
removal_timeline: str,
) -> Tuple[str, Union[Type[DeprecationWarning], Type[PendingDeprecationWarning]]]:
) -> tuple[str, Type[DeprecationWarning] | Type[PendingDeprecationWarning]]:
if pending:
category = PendingDeprecationWarning
category: Type[DeprecationWarning] | Type[
PendingDeprecationWarning
] = PendingDeprecationWarning
deprecation_status = "pending deprecation"
removal_desc = f"marked deprecated in a future release, and then removed {removal_timeline}"
else:
Expand Down Expand Up @@ -412,7 +415,7 @@ def _write_deprecation_msg(


def add_deprecation_to_docstring(
func: Callable, msg: str, *, since: Optional[str], pending: bool
func: Callable, msg: str, *, since: str | None, pending: bool
) -> None:
"""Dynamically insert the deprecation message into ``func``'s docstring.
Expand Down
39 changes: 20 additions & 19 deletions qiskit/utils/measurement_error_mitigation.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,10 @@
# that they have been altered from the originals.

"""Measurement error mitigation"""
from __future__ import annotations

import copy
from typing import List, Optional, Tuple, Dict, Callable
from typing import Type, Any

from qiskit import compiler
from qiskit.providers import Backend
Expand All @@ -35,8 +36,8 @@
additional_msg="For code migration guidelines, visit https://qisk.it/qi_migration.",
)
def get_measured_qubits(
transpiled_circuits: List[QuantumCircuit],
) -> Tuple[List[int], Dict[str, List[int]]]:
transpiled_circuits: list[QuantumCircuit],
) -> tuple[list[int], dict[str, list[int]]]:
"""
Deprecated: Retrieve the measured qubits from transpiled circuits.
Expand All @@ -51,7 +52,7 @@ def get_measured_qubits(
QiskitError: invalid qubit mapping
"""
qubit_index = None
qubit_mappings = {}
qubit_mappings: dict[str, list[int]] = {}
for idx, qc in enumerate(transpiled_circuits):
measured_qubits = []
for instruction in qc.data:
Expand Down Expand Up @@ -82,7 +83,7 @@ def get_measured_qubits(
since="0.24.0",
additional_msg="For code migration guidelines, visit https://qisk.it/qi_migration.",
)
def get_measured_qubits_from_qobj(qobj: QasmQobj) -> Tuple[List[int], Dict[str, List[int]]]:
def get_measured_qubits_from_qobj(qobj: QasmQobj) -> tuple[list[int], dict[str, list[int]]]:
"""
Deprecated: Retrieve the measured qubits from transpiled circuits.
Expand All @@ -98,7 +99,7 @@ def get_measured_qubits_from_qobj(qobj: QasmQobj) -> Tuple[List[int], Dict[str,
"""

qubit_index = None
qubit_mappings = {}
qubit_mappings: dict[str, list[int]] = {}

for idx, exp in enumerate(qobj.experiments):
measured_qubits = []
Expand Down Expand Up @@ -128,13 +129,13 @@ def get_measured_qubits_from_qobj(qobj: QasmQobj) -> Tuple[List[int], Dict[str,
additional_msg="For code migration guidelines, visit https://qisk.it/qi_migration.",
)
def build_measurement_error_mitigation_circuits(
qubit_list: List[int],
fitter_cls: Callable,
qubit_list: list[int],
fitter_cls: Type[CompleteMeasFitter] | Type[TensoredMeasFitter],
backend: Backend,
backend_config: Optional[Dict] = None,
compile_config: Optional[Dict] = None,
mit_pattern: Optional[List[List[int]]] = None,
) -> Tuple[QuantumCircuit, List[str], List[str]]:
backend_config: dict[str, Any] | None = None,
compile_config: dict[str, Any] | None = None,
mit_pattern: list[list[int]] | None = None,
) -> tuple[QuantumCircuit, list[str], list[str]]:
"""Deprecated: Build measurement error mitigation circuits
Args:
qubit_list: list of ordered qubits used in the algorithm
Expand Down Expand Up @@ -207,14 +208,14 @@ def build_measurement_error_mitigation_circuits(
additional_msg="For code migration guidelines, visit https://qisk.it/qi_migration.",
)
def build_measurement_error_mitigation_qobj(
qubit_list: List[int],
fitter_cls: Callable,
qubit_list: list[int],
fitter_cls: Type[CompleteMeasFitter] | Type[TensoredMeasFitter],
backend: Backend,
backend_config: Optional[Dict] = None,
compile_config: Optional[Dict] = None,
run_config: Optional[RunConfig] = None,
mit_pattern: Optional[List[List[int]]] = None,
) -> Tuple[QasmQobj, List[str], List[str]]:
backend_config: dict | None = None,
compile_config: dict | None = None,
run_config: RunConfig | None = None,
mit_pattern: list[list[int]] | None = None,
) -> tuple[QasmQobj, list[str], list[str]]:
"""
Args:
qubit_list: list of ordered qubits used in the algorithm
Expand Down
28 changes: 13 additions & 15 deletions qiskit/utils/mitigation/_filters.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,7 @@
Measurement correction filters.
"""

from typing import List
from __future__ import annotations
from copy import deepcopy

import numpy as np
Expand Down Expand Up @@ -65,6 +64,11 @@ def cal_matrix(self):
"""Return cal_matrix."""
return self._cal_matrix

@cal_matrix.setter
def cal_matrix(self, new_cal_matrix):
"""Set cal_matrix."""
self._cal_matrix = new_cal_matrix

@property
def state_labels(self):
"""return the state label ordering of the cal matrix"""
Expand All @@ -75,11 +79,6 @@ def state_labels(self, new_state_labels):
"""set the state label ordering of the cal matrix"""
self._state_labels = new_state_labels

@cal_matrix.setter
def cal_matrix(self, new_cal_matrix):
"""Set cal_matrix."""
self._cal_matrix = new_cal_matrix

def apply(self, raw_data, method="least_squares"):
"""Apply the calibration matrix to results.
Expand Down Expand Up @@ -229,7 +228,7 @@ class TensoredFilter:
since="0.24.0",
additional_msg="For code migration guidelines, visit https://qisk.it/qi_migration.",
)
def __init__(self, cal_matrices: np.matrix, substate_labels_list: list, mit_pattern: list):
def __init__(self, cal_matrices: np.matrix, substate_labels_list: list[str], mit_pattern: list):
"""
Initialize a tensored measurement error mitigation filter using
the cal_matrices from a tensored measurement calibration fitter.
Expand All @@ -245,9 +244,9 @@ def __init__(self, cal_matrices: np.matrix, substate_labels_list: list, mit_patt
"""

self._cal_matrices = cal_matrices
self._qubit_list_sizes = []
self._indices_list = []
self._substate_labels_list = []
self._qubit_list_sizes: list[int] = []
self._indices_list: list[dict[str, int]] = []
self._substate_labels_list: list[str] = []
self.substate_labels_list = substate_labels_list
self._mit_pattern = mit_pattern

Expand All @@ -267,7 +266,7 @@ def substate_labels_list(self):
return self._substate_labels_list

@substate_labels_list.setter
def substate_labels_list(self, new_substate_labels_list):
def substate_labels_list(self, new_substate_labels_list: list[str]):
"""Return _substate_labels_list"""
self._substate_labels_list = new_substate_labels_list

Expand All @@ -279,7 +278,6 @@ def substate_labels_list(self, new_substate_labels_list):
# get the indices in the calibration matrix
self._indices_list = []
for _, sub_labels in enumerate(self._substate_labels_list):

self._indices_list.append({lab: ind for ind, lab in enumerate(sub_labels)})

@property
Expand Down Expand Up @@ -474,7 +472,7 @@ def fun(x):

return new_count_dict

def flip_state(self, state: str, mat_index: int, flip_poses: List[int]) -> str:
def flip_state(self, state: str, mat_index: int, flip_poses: list[int]) -> str:
"""Flip the state according to the chosen qubit positions"""
flip_poses = [pos for i, pos in enumerate(flip_poses) if (mat_index >> i) & 1]
flip_poses = sorted(flip_poses)
Expand All @@ -487,7 +485,7 @@ def flip_state(self, state: str, mat_index: int, flip_poses: List[int]) -> str:
new_state += state[pos:]
return new_state

def compute_index_of_cal_mat(self, state: str, pos_qubits: List[int], indices: dict) -> int:
def compute_index_of_cal_mat(self, state: str, pos_qubits: list[int], indices: dict) -> int:
"""Return the index of (pseudo inverse) calibration matrix for the input quantum state"""
sub_state = ""
for pos in pos_qubits:
Expand Down
25 changes: 14 additions & 11 deletions qiskit/utils/mitigation/circuits.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,15 +18,18 @@
Measurement calibration circuits. To apply the measurement mitigation
use the fitters to produce a filter.
"""
from typing import List, Tuple, Union
from __future__ import annotations

from collections.abc import Sequence
from typing import Any
from qiskit.utils.deprecation import deprecate_func


@deprecate_func(
since="0.24.0",
additional_msg="For code migration guidelines, visit https://qisk.it/qi_migration.",
)
def count_keys(num_qubits: int) -> List[str]:
def count_keys(num_qubits: int) -> list[str]:
"""Deprecated: Return ordered count keys.
Args:
Expand All @@ -45,11 +48,11 @@ def count_keys(num_qubits: int) -> List[str]:
additional_msg="For code migration guidelines, visit https://qisk.it/qi_migration.",
)
def complete_meas_cal(
qubit_list: List[int] = None,
qr: Union[int, List["QuantumRegister"]] = None,
cr: Union[int, List["ClassicalRegister"]] = None,
qubit_list: Sequence[int] | None = None,
qr: int | list | None = None,
cr: int | list | None = None,
circlabel: str = "",
) -> Tuple[List["QuantumCircuit"], List[str]]:
) -> tuple[list, list[str]]:
"""
Deprecated: Return a list of measurement calibration circuits for the full
Hilbert space.
Expand Down Expand Up @@ -126,11 +129,11 @@ def complete_meas_cal(
additional_msg="For code migration guidelines, visit https://qisk.it/qi_migration.",
)
def tensored_meas_cal(
mit_pattern: List[List[int]] = None,
qr: Union[int, List["QuantumRegister"]] = None,
cr: Union[int, List["ClassicalRegister"]] = None,
mit_pattern: list[list[int]] | None = None,
qr: int | Sequence[Any] | None = None,
cr: int | Sequence[Any] | None = None,
circlabel: str = "",
) -> Tuple[List["QuantumCircuit"], List[List[int]]]:
) -> tuple[list[Any], list[list[int]]]:
"""
Deprecated: Return a list of calibration circuits
Expand Down Expand Up @@ -166,7 +169,7 @@ def tensored_meas_cal(
QiskitError: if a qubit appears more than once in `mit_pattern`.
"""
# Runtime imports to avoid circular imports causeed by QuantumInstance
# Runtime imports to avoid circular imports caused by QuantumInstance
# getting initialized by imported utils/__init__ which is imported
# by qiskit.circuit
from qiskit.circuit.quantumregister import QuantumRegister
Expand Down
Loading

0 comments on commit 2ed4b7d

Please sign in to comment.