-
Notifications
You must be signed in to change notification settings - Fork 2.3k
/
dagnode.py
176 lines (151 loc) · 6.17 KB
/
dagnode.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
# This code is part of Qiskit.
#
# (C) Copyright IBM 2017, 2019, 2021.
#
# This code is licensed under the Apache License, Version 2.0. You may
# obtain a copy of this license in the LICENSE.txt file in the root directory
# of this source tree or at http://www.apache.org/licenses/LICENSE-2.0.
#
# Any modifications or derivative works of this code must retain this
# copyright notice, and modified files need to carry a notice indicating
# that they have been altered from the originals.
"""Objects to represent the information at a node in the DAGCircuit."""
from __future__ import annotations
import typing
import uuid
import qiskit._accelerate.circuit
from qiskit.circuit import (
Clbit,
ClassicalRegister,
IfElseOp,
WhileLoopOp,
SwitchCaseOp,
ForLoopOp,
Parameter,
QuantumCircuit,
)
from qiskit.circuit.classical import expr
if typing.TYPE_CHECKING:
from qiskit.dagcircuit import DAGCircuit
DAGNode = qiskit._accelerate.circuit.DAGNode
DAGOpNode = qiskit._accelerate.circuit.DAGOpNode
DAGInNode = qiskit._accelerate.circuit.DAGInNode
DAGOutNode = qiskit._accelerate.circuit.DAGOutNode
def _legacy_condition_eq(cond1, cond2, bit_indices1, bit_indices2) -> bool:
if cond1 is cond2 is None:
return True
elif None in (cond1, cond2):
return False
target1, val1 = cond1
target2, val2 = cond2
if val1 != val2:
return False
if isinstance(target1, Clbit) and isinstance(target2, Clbit):
return bit_indices1[target1] == bit_indices2[target2]
if isinstance(target1, ClassicalRegister) and isinstance(target2, ClassicalRegister):
return target1.size == target2.size and all(
bit_indices1[t1] == bit_indices2[t2] for t1, t2 in zip(target1, target2)
)
return False
def _circuit_to_dag(circuit: QuantumCircuit, node_qargs, node_cargs, bit_indices) -> DAGCircuit:
"""Get a :class:`.DAGCircuit` of the given :class:`.QuantumCircuit`. The bits in the output
will be ordered in a canonical order based on their indices in the outer DAG, as defined by the
``bit_indices`` mapping and the ``node_{q,c}args`` arguments."""
from qiskit.converters import circuit_to_dag # pylint: disable=cyclic-import
def sort_key(bits):
outer, _inner = bits
return bit_indices[outer]
return circuit_to_dag(
circuit,
copy_operations=False,
qubit_order=[
inner for _outer, inner in sorted(zip(node_qargs, circuit.qubits), key=sort_key)
],
clbit_order=[
inner for _outer, inner in sorted(zip(node_cargs, circuit.clbits), key=sort_key)
],
)
def _make_expr_key(bit_indices):
def key(var):
if isinstance(var, Clbit):
return bit_indices.get(var)
if isinstance(var, ClassicalRegister):
return [bit_indices.get(bit) for bit in var]
return None
return key
def _condition_op_eq(node1, node2, bit_indices1, bit_indices2):
cond1 = node1.op.condition
cond2 = node2.op.condition
if isinstance(cond1, expr.Expr) and isinstance(cond2, expr.Expr):
if not expr.structurally_equivalent(
cond1, cond2, _make_expr_key(bit_indices1), _make_expr_key(bit_indices2)
):
return False
elif isinstance(cond1, expr.Expr) or isinstance(cond2, expr.Expr):
return False
elif not _legacy_condition_eq(cond1, cond2, bit_indices1, bit_indices2):
return False
return len(node1.op.blocks) == len(node2.op.blocks) and all(
_circuit_to_dag(block1, node1.qargs, node1.cargs, bit_indices1)
== _circuit_to_dag(block2, node2.qargs, node2.cargs, bit_indices2)
for block1, block2 in zip(node1.op.blocks, node2.op.blocks)
)
def _switch_case_eq(node1, node2, bit_indices1, bit_indices2):
target1 = node1.op.target
target2 = node2.op.target
if isinstance(target1, expr.Expr) and isinstance(target2, expr.Expr):
if not expr.structurally_equivalent(
target1, target2, _make_expr_key(bit_indices1), _make_expr_key(bit_indices2)
):
return False
elif isinstance(target1, Clbit) and isinstance(target2, Clbit):
if bit_indices1[target1] != bit_indices2[target2]:
return False
elif isinstance(target1, ClassicalRegister) and isinstance(target2, ClassicalRegister):
if target1.size != target2.size or any(
bit_indices1[b1] != bit_indices2[b2] for b1, b2 in zip(target1, target2)
):
return False
else:
return False
cases1 = [case for case, _ in node1.op.cases_specifier()]
cases2 = [case for case, _ in node2.op.cases_specifier()]
return (
len(cases1) == len(cases2)
and all(set(labels1) == set(labels2) for labels1, labels2 in zip(cases1, cases2))
and len(node1.op.blocks) == len(node2.op.blocks)
and all(
_circuit_to_dag(block1, node1.qargs, node1.cargs, bit_indices1)
== _circuit_to_dag(block2, node2.qargs, node2.cargs, bit_indices2)
for block1, block2 in zip(node1.op.blocks, node2.op.blocks)
)
)
def _for_loop_eq(node1, node2, bit_indices1, bit_indices2):
indexset1, param1, body1 = node1.op.params
indexset2, param2, body2 = node2.op.params
if indexset1 != indexset2:
return False
if (param1 is None and param2 is not None) or (param1 is not None and param2 is None):
return False
if param1 is not None and param2 is not None:
sentinel = Parameter(str(uuid.uuid4()))
body1 = (
body1.assign_parameters({param1: sentinel}, inplace=False)
if param1 in body1.parameters
else body1
)
body2 = (
body2.assign_parameters({param2: sentinel}, inplace=False)
if param2 in body2.parameters
else body2
)
return _circuit_to_dag(body1, node1.qargs, node1.cargs, bit_indices1) == _circuit_to_dag(
body2, node2.qargs, node2.cargs, bit_indices2
)
_SEMANTIC_EQ_CONTROL_FLOW = {
IfElseOp: _condition_op_eq,
WhileLoopOp: _condition_op_eq,
SwitchCaseOp: _switch_case_eq,
ForLoopOp: _for_loop_eq,
}
_SEMANTIC_EQ_SYMMETRIC = frozenset({"barrier", "swap", "break_loop", "continue_loop"})