Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactoring commutativity analysis and a new commutative inverse cancellation transpiler pass #8184

Merged
merged 30 commits into from
Aug 10, 2022
Merged
Show file tree
Hide file tree
Changes from 13 commits
Commits
Show all changes
30 commits
Select commit Hold shift + click to select a range
82f9db3
experimenting with transpiler passes
alexanderivrii Jun 14, 2022
35dbdc7
removing some prints
alexanderivrii Jun 15, 2022
f559f1a
minor cleanup
alexanderivrii Jun 15, 2022
29e8bb0
Merge branch 'main' into commutative_cancellation
alexanderivrii Jul 3, 2022
0924d62
Improving external test file
alexanderivrii Jul 3, 2022
8314da9
Simplified commutative inverse cancellation pass
alexanderivrii Jul 3, 2022
afd5d52
black
alexanderivrii Jul 3, 2022
bd19b7c
Removing duplicate code from CommutationAnalysis and DagDependency
alexanderivrii Jul 3, 2022
342cbbf
black and lint
alexanderivrii Jul 3, 2022
1c02bf4
commutation_checker cleanup
alexanderivrii Jul 5, 2022
b99ec05
Adding tests for CommutativeInverseCancellation pass
alexanderivrii Jul 5, 2022
da28f75
Removing external test python file
alexanderivrii Jul 5, 2022
00c4489
Merge branch 'main' into commutative_cancellation
alexanderivrii Jul 5, 2022
0490018
Merge branch 'main' into commutative_cancellation
alexanderivrii Jul 28, 2022
58c77dd
Update qiskit/transpiler/passes/optimization/commutative_inverse_canc…
alexanderivrii Jul 31, 2022
537fc23
Update qiskit/transpiler/passes/optimization/commutative_inverse_canc…
alexanderivrii Jul 31, 2022
a1284b9
Removing the use of dag node classes and taking args and op separately
alexanderivrii Aug 2, 2022
f40148e
Removing runtime import
alexanderivrii Aug 3, 2022
6c6417e
removing unnecessary pylint-disable
alexanderivrii Aug 3, 2022
1eeb2db
moving commutation_checker to qiskit.circuit and improving imports
alexanderivrii Aug 3, 2022
cccc9b5
Merge branch 'main' into commutative_cancellation
alexanderivrii Aug 3, 2022
c606d2a
Adding commutative_checker tests
alexanderivrii Aug 3, 2022
d654625
running black
alexanderivrii Aug 3, 2022
7342095
linting
alexanderivrii Aug 3, 2022
a996a75
Adding corner-cases to test_commutative_inverse_cancellation
alexanderivrii Aug 4, 2022
c5c4951
release notes
alexanderivrii Aug 4, 2022
3366d07
black
alexanderivrii Aug 4, 2022
3a959a6
release notes tweaks
alexanderivrii Aug 4, 2022
7699e9d
Merge branch 'main' into commutative_cancellation
alexanderivrii Aug 4, 2022
85367a0
Merge branch 'main' into commutative_cancellation
mergify[bot] Aug 10, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
81 changes: 8 additions & 73 deletions qiskit/dagcircuit/dagdependency.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,14 +18,12 @@
from collections import OrderedDict, defaultdict
import warnings

import numpy as np
import retworkx as rx

from qiskit.circuit.quantumregister import QuantumRegister, Qubit
from qiskit.circuit.classicalregister import ClassicalRegister, Clbit
from qiskit.dagcircuit.exceptions import DAGDependencyError
from qiskit.dagcircuit.dagdepnode import DAGDepNode
from qiskit.quantum_info.operators import Operator
from qiskit.exceptions import MissingOptionalLibraryError


Expand Down Expand Up @@ -94,6 +92,11 @@ def __init__(self):
self.duration = None
self.unit = "dt"

# pylint: disable=cyclic-import
from qiskit.transpiler.passes.optimization.commutation_checker import CommutationChecker

self.comm_checker = CommutationChecker()

@property
def global_phase(self):
"""Return the global phase of the circuit."""
Expand Down Expand Up @@ -487,7 +490,9 @@ def _update_edges(self):
self._multi_graph.get_node_data(current_node_id).reachable = True
# Check the commutation relation with reachable node, it adds edges if it does not commute
for prev_node_id in range(max_node_id - 1, -1, -1):
if self._multi_graph.get_node_data(prev_node_id).reachable and not _does_commute(
if self._multi_graph.get_node_data(
prev_node_id
).reachable and not self.comm_checker.commute(
self._multi_graph.get_node_data(prev_node_id), max_node
):
self._multi_graph.add_edge(prev_node_id, max_node_id, {"commute": False})
Expand Down Expand Up @@ -565,73 +570,3 @@ def merge_no_duplicates(*iterables):
if val != last:
last = val
yield val


def _does_commute(node1, node2):
"""Function to verify commutation relation between two nodes in the DAG.

Args:
node1 (DAGnode): first node operation
node2 (DAGnode): second node operation

Return:
bool: True if the nodes commute and false if it is not the case.
"""

# Create set of qubits on which the operation acts
qarg1 = [node1.qargs[i] for i in range(0, len(node1.qargs))]
qarg2 = [node2.qargs[i] for i in range(0, len(node2.qargs))]

# Create set of cbits on which the operation acts
carg1 = [node1.cargs[i] for i in range(0, len(node1.cargs))]
carg2 = [node2.cargs[i] for i in range(0, len(node2.cargs))]

# Commutation for classical conditional gates
# if and only if the qubits are different.
# TODO: qubits can be the same if conditions are identical and
# the non-conditional gates commute.
if node1.type == "op" and node2.type == "op":
if node1.op.condition or node2.op.condition:
intersection = set(qarg1).intersection(set(qarg2))
return not intersection

# Commutation for non-unitary or parameterized or opaque ops
# (e.g. measure, reset, directives or pulse gates)
# if and only if the qubits and clbits are different.
non_unitaries = ["measure", "reset", "initialize", "delay"]

def _unknown_commutator(n):
return n.op._directive or n.name in non_unitaries or n.op.is_parameterized()

if _unknown_commutator(node1) or _unknown_commutator(node2):
intersection_q = set(qarg1).intersection(set(qarg2))
intersection_c = set(carg1).intersection(set(carg2))
return not (intersection_q or intersection_c)

# Gates over disjoint sets of qubits commute
if not set(qarg1).intersection(set(qarg2)):
return True

# Known non-commuting gates (TODO: add more).
non_commute_gates = [{"x", "y"}, {"x", "z"}]
if qarg1 == qarg2 and ({node1.name, node2.name} in non_commute_gates):
return False

# Create matrices to check commutation relation if no other criteria are matched
qarg = list(set(node1.qargs + node2.qargs))
qbit_num = len(qarg)

qarg1 = [qarg.index(q) for q in node1.qargs]
qarg2 = [qarg.index(q) for q in node2.qargs]

dim = 2**qbit_num
id_op = np.reshape(np.eye(dim), (2, 2) * qbit_num)

op1 = np.reshape(node1.op.to_matrix(), (2, 2) * len(qarg1))
op2 = np.reshape(node2.op.to_matrix(), (2, 2) * len(qarg2))

op = Operator._einsum_matmul(id_op, op1, qarg1)
op12 = Operator._einsum_matmul(op, op2, qarg2, right_mul=False)
op21 = Operator._einsum_matmul(op, op2, qarg2, shift=qbit_num, right_mul=True)

return np.allclose(op12, op21)
1 change: 1 addition & 0 deletions qiskit/dagcircuit/dagdepnode.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,7 @@ def qargs(self, new_qargs):
self.sort_key = str(new_qargs)

@staticmethod
# pylint: disable=arguments-differ
mtreinish marked this conversation as resolved.
Show resolved Hide resolved
def semantic_eq(node1, node2):
"""
Check if DAG nodes are considered equivalent, e.g., as a node_match for nx.is_isomorphic.
Expand Down
2 changes: 2 additions & 0 deletions qiskit/transpiler/passes/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@
InverseCancellation
CommutationAnalysis
CommutativeCancellation
CommutativeInverseCancellation
Optimize1qGatesSimpleCommutation
RemoveDiagonalGatesBeforeMeasure
RemoveResetInZeroState
Expand Down Expand Up @@ -207,6 +208,7 @@
from .optimization import ConsolidateBlocks
from .optimization import CommutationAnalysis
from .optimization import CommutativeCancellation
from .optimization import CommutativeInverseCancellation
from .optimization import CXCancellation
from .optimization import Optimize1qGatesSimpleCommutation
from .optimization import OptimizeSwapBeforeMeasure
Expand Down
1 change: 1 addition & 0 deletions qiskit/transpiler/passes/optimization/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from .consolidate_blocks import ConsolidateBlocks
from .commutation_analysis import CommutationAnalysis
from .commutative_cancellation import CommutativeCancellation
from .commutative_inverse_cancellation import CommutativeInverseCancellation
from .cx_cancellation import CXCancellation
from .optimize_1q_commutation import Optimize1qGatesSimpleCommutation
from .optimize_swap_before_measure import OptimizeSwapBeforeMeasure
Expand Down
97 changes: 4 additions & 93 deletions qiskit/transpiler/passes/optimization/commutation_analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,8 @@
"""Analysis pass to find commutation relations between DAG nodes."""

from collections import defaultdict
import numpy as np
from qiskit.transpiler.exceptions import TranspilerError
from qiskit.transpiler.basepasses import AnalysisPass
from qiskit.quantum_info.operators import Operator
from qiskit.dagcircuit import DAGOpNode

_CUTOFF_PRECISION = 1e-10


class CommutationAnalysis(AnalysisPass):
Expand All @@ -35,7 +30,9 @@ class CommutationAnalysis(AnalysisPass):

def __init__(self):
super().__init__()
self.cache = {}
from qiskit.transpiler.passes.optimization.commutation_checker import CommutationChecker

self.comm_checker = CommutationChecker()

def run(self, dag):
"""Run the CommutationAnalysis pass on `dag`.
Expand Down Expand Up @@ -74,7 +71,7 @@ def run(self, dag):
prev_gate = current_comm_set[-1][-1]
does_commute = False
try:
does_commute = _commute(current_gate, prev_gate, self.cache)
does_commute = self.comm_checker.commute(current_gate, prev_gate)
except TranspilerError:
pass
if does_commute:
Expand All @@ -85,89 +82,3 @@ def run(self, dag):

temp_len = len(current_comm_set)
self.property_set["commutation_set"][(current_gate, wire)] = temp_len - 1


_COMMUTE_ID_OP = {}


def _hashable_parameters(params):
"""Convert the parameters of a gate into a hashable format for lookup in a dictionary.

This aims to be fast in common cases, and is not intended to work outside of the lifetime of a
single commutation pass; it does not handle mutable state correctly if the state is actually
changed."""
try:
hash(params)
return params
except TypeError:
pass
if isinstance(params, (list, tuple)):
return tuple(_hashable_parameters(x) for x in params)
if isinstance(params, np.ndarray):
# We trust that the arrays will not be mutated during the commutation pass, since nothing
# would work if they were anyway. Using the id can potentially cause some additional cache
# misses if two UnitaryGate instances are being compared that have been separately
# constructed to have the same underlying matrix, but in practice the cost of string-ifying
# the matrix to get a cache key is far more expensive than just doing a small matmul.
return (np.ndarray, id(params))
# Catch anything else with a slow conversion.
return ("fallback", str(params))


def _commute(node1, node2, cache):
if not isinstance(node1, DAGOpNode) or not isinstance(node2, DAGOpNode):
return False
for nd in [node1, node2]:
if nd.op._directive or nd.name in {"measure", "reset", "delay"}:
return False
if node1.op.condition or node2.op.condition:
return False
if node1.op.is_parameterized() or node2.op.is_parameterized():
return False

# Assign indices to each of the qubits such that all `node1`'s qubits come first, followed by
# any _additional_ qubits `node2` addresses. This helps later when we need to compose one
# operator with the other, since we can easily expand `node1` with a suitable identity.
qarg = {q: i for i, q in enumerate(node1.qargs)}
num_qubits = len(qarg)
for q in node2.qargs:
if q not in qarg:
qarg[q] = num_qubits
num_qubits += 1
qarg1 = tuple(qarg[q] for q in node1.qargs)
qarg2 = tuple(qarg[q] for q in node2.qargs)

node1_key = (node1.op.name, _hashable_parameters(node1.op.params), qarg1)
node2_key = (node2.op.name, _hashable_parameters(node2.op.params), qarg2)
try:
# We only need to try one orientation of the keys, since if we've seen the compound key
# before, we've set it in both orientations.
return cache[node1_key, node2_key]
except KeyError:
pass

operator_1 = Operator(node1.op, input_dims=(2,) * len(qarg1), output_dims=(2,) * len(qarg1))
operator_2 = Operator(node2.op, input_dims=(2,) * len(qarg2), output_dims=(2,) * len(qarg2))

if qarg1 == qarg2:
# Use full composition if possible to get the fastest matmul paths.
op12 = operator_1.compose(operator_2)
op21 = operator_2.compose(operator_1)
else:
# Expand operator_1 to be large enough to contain operator_2 as well; this relies on qargs1
# being the lowest possible indices so the identity can be tensored before it.
extra_qarg2 = num_qubits - len(qarg1)
if extra_qarg2:
try:
id_op = _COMMUTE_ID_OP[extra_qarg2]
except KeyError:
id_op = _COMMUTE_ID_OP[extra_qarg2] = Operator(
np.eye(2**extra_qarg2),
input_dims=(2,) * extra_qarg2,
output_dims=(2,) * extra_qarg2,
)
operator_1 = id_op.tensor(operator_1)
op12 = operator_1.compose(operator_2, qargs=qarg2, front=False)
op21 = operator_1.compose(operator_2, qargs=qarg2, front=True)
cache[node1_key, node2_key] = cache[node2_key, node1_key] = ret = op12 == op21
return ret
Loading